# **Almost No Inner Loop** Optimization

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.tensorboard import SummaryWriter

import gymnasium as gym

import numpy as np

import copy


### **LOGGING**

In [None]:
writer = SummaryWriter(log_dir = './runs/ANIL')


### **DEVICE**

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')


Device: cuda


### **META - ENV**


In [None]:
class meta_car_continuous:
    
    def __init__(self):
        
        self.base_env = gym.make('MountainCarContinuous-v0')
        self.org_gravity = 0.0025
        self.org_goal_position = 0.45
        self.tasks = []
        self.current_task = (self.org_goal_position, self.org_gravity)
        
    def sample_task(self, num_tasks):
        
        self.tasks = []
        
        for _ in range(num_tasks):
            goal_position = np.random.uniform(0.45, 0.55)
            gravity = np.random.uniform(0.0025, 0.006)
            self.tasks.append((goal_position, gravity))
            
        return self.tasks

    def set_task(self, task):
        
        self.current_task = task
        
        self.base_env.env.gravity = task[1]
        self.base_env.env.goal_position = task[0]

    def reset(self):
        
        obs = self.base_env.reset()
        
        if isinstance(obs, tuple):
            
            obs = obs[0]
            
        return obs

    def step(self, action):
        
        step_output = self.base_env.step(action)
        
        if len(step_output) == 5: 
            
            next_obs, reward, terminated, truncated, info = step_output
            done = terminated or truncated
            
        else:
            
            next_obs, reward, done, info = step_output
            
        return next_obs, reward, done, info

    def get_number(self):
        
        state_dim = self.base_env.observation_space.shape[0]
        action_dim = self.base_env.action_space.shape[0]
        max_action = float(self.base_env.action_space.high[0])
        reward_dim = 1
        
        return state_dim, action_dim, max_action, reward_dim

    def close(self):
        
        self.base_env.close()


### **SET UP**

In [None]:
META_ENV = meta_car_continuous()

tasks = META_ENV.sample_task(num_tasks = 5)

obs = META_ENV.reset()

state_dim, action_dim, max_action, reward_dim = META_ENV.get_number()

print(f'Tasks: {tasks}')
print()
print(f'obs: {obs.shape}')
print()
print(f'state dim: {state_dim} | action dim: {action_dim} | max action: {max_action}')


Tasks: [(0.49934588688337816, 0.004526556866872781), (0.512269297349774, 0.0044080949375128325), (0.5375930700427028, 0.005983689153189032), (0.5159142616008038, 0.005719457998977788), (0.5225756374044043, 0.005078740307298664)]

obs: (2,)

state dim: 2 | action dim: 1 | max action: 1.0


### **HELPER FUNCTION**

In [None]:
def safe_tensor(x):
    
    return x if torch.is_tensor(x) else torch.tensor(x, dtype = torch.float32).to(device)


### **ASSEMBLY**


In [None]:
head_1 = 32
head_2 = 64
head_3 = 64
head_4 = 32

hidden_size = 32
hidden_size_2 = 64


### **HYPER X**

In [None]:
class hyper_x(nn.Module):
    
    def __init__(self, hidden_size = hidden_size, hidden_size_2 = hidden_size_2, state_dim = state_dim):
        super(hyper_x, self).__init__()
        
        # input dim
        
        input_dim = state_dim
        
        # hyper mlp
        
        self.hyper = nn.Sequential(
            
            nn.Linear(input_dim, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.SiLU(),
            
            
            nn.Linear(hidden_size, hidden_size_2),
            nn.LayerNorm(hidden_size_2),
            nn.SiLU(),
            
            nn.Linear(hidden_size_2, hidden_size_2),
            nn.LayerNorm(hidden_size_2),
            nn.SiLU(),
            
            nn.Linear(hidden_size_2, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.SiLU()
        )
        
    def forward(self, state):
        
        return self.hyper(state)


### **POLICY**

In [None]:
class policy_net(nn.Module):
    
    def __init__(self, action_dim = action_dim, hidden_size = hidden_size, head_1 = head_1, head_2 = head_2, head_3 = head_3, head_4 = head_4, max_action = max_action):
        super(policy_net, self).__init__()
        
        # hyper x
        
        self.hyper_x = hyper_x()
        
        
        assert hidden_size == head_1
        
        # norm
        
        self.policy_mlp = nn.Sequential(
            
            nn.Linear(head_1, head_2),
            nn.LayerNorm(head_2),
            nn.SiLU(),
            
            nn.Linear(head_2, head_3),
            nn.LayerNorm(head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.LayerNorm(head_4),
            nn.SiLU()
        )
        
        # mu and log std head
        
        self.mu = nn.Linear(head_4, action_dim)
        self.log_std = nn.Linear(head_4, action_dim)
        
        # apply normalization
        
        self.apply(self.init_weight)
        
        # max action
        
        self.max_action = max_action
        
    def init_weight(self, m):
        
        if isinstance(m, nn.Linear):
            
            nn.init.orthogonal_(m.weight)
            
            if m.bias is not None:
                
                nn.init.zeros_(m.bias)
                
    def forward(self, state):
        
        # state -> hyper
        
        hyper = self.hyper_x.forward(state)
        
        # hyper -> policy mlp
        
        policy_mlp = self.policy_mlp(hyper)
        
        # mu and log std
        
        mu = self.mu(policy_mlp)
        log_std = self.log_std(policy_mlp)
        log_std = torch.clamp(log_std, -20, 2)
        std = torch.exp(log_std)
        
        # reparameterization trick
        
        dist = torch.distributions.Normal(mu, std)
        z = dist.rsample()
        tanh_z = torch.tanh(z)
        action = tanh_z * self.max_action
        
        # log squashing
        
        log_prob = dist.log_prob(z)
        squash = torch.log(1 - tanh_z.pow(2) + 1e-6)
        log_prob = log_prob - squash
        
        log_prob = log_prob.sum(dim = -1, keepdim = True)
        
        return action, log_prob


### **SET UP**

In [None]:
POLICY_NET = policy_net().to(device)

print(POLICY_NET)


policy_net(
  (hyper_x): hyper_x(
    (hyper): Sequential(
      (0): Linear(in_features=2, out_features=32, bias=True)
      (1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (2): SiLU()
      (3): Linear(in_features=32, out_features=64, bias=True)
      (4): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (5): SiLU()
      (6): Linear(in_features=64, out_features=64, bias=True)
      (7): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (8): SiLU()
      (9): Linear(in_features=64, out_features=32, bias=True)
      (10): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (11): SiLU()
    )
  )
  (policy_mlp): Sequential(
    (0): Linear(in_features=32, out_features=64, bias=True)
    (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (2): SiLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (5): SiLU()
    (6): Linear(in_features=64, out_features=32, b

### **BUFFER**

In [11]:
class buffer:
    
    def __init__(self):
        
        self.buffer = []
        
    def add(self, state, action, log_prob, reward):
        
        self.buffer.append({
            
            'states': safe_tensor(state),
            'actions': safe_tensor(action),
            'log_probs': safe_tensor(log_prob),
            'rewards': safe_tensor(reward)
            
        })
        
    def safe_stack(self, x):
        
        return torch.stack(x).to(device)
        
    def sample(self):
        
        # unpack
        
        states = [i['states'] for i in self.buffer]
        actions = [i['actions'] for i in self.buffer]
        log_probs = [i['log_probs'] for i in self.buffer]
        rewards = [i['rewards'] for i in self.buffer]
        
        # construct batch
        
        batch = {
            
            'states': self.safe_stack(states),
            'actions': self.safe_stack(actions),
            'log_probs': self.safe_stack(log_probs),
            'rewards': self.safe_stack(rewards)
            
        }
        
        return batch
    
    def clear(self):
        
        self.buffer.clear()
        

### **SET UP**

In [None]:
META_BUFFER = buffer()


### **OPTIMIZER**

In [None]:
# hyper params

inner_lr = 3e-4
meta_lr = 1e-4

T_max = 50
warmup = 20

# optimizer

OPTIMIZER = optim.AdamW(POLICY_NET.parameters(), meta_lr, weight_decay = 0)

# scheduler

warmup_sch = optim.lr_scheduler.LinearLR(OPTIMIZER, 0.1, total_iters = warmup)
cosine_sch = optim.lr_scheduler.CosineAnnealingLR(OPTIMIZER, T_max = T_max - warmup, eta_min = 1e-5)

SCHEDULER = optim.lr_scheduler.SequentialLR(OPTIMIZER, [warmup_sch, cosine_sch], milestones = [warmup])


### **COLLECT TRAJ**

In [14]:
class collect_traj:
    
    def __init__(self, META_ENV = META_ENV, META_BUFFER = META_BUFFER, POLICY_NET = POLICY_NET):
        
        self.network = POLICY_NET
        self.env = META_ENV
        self.buffer = META_BUFFER
        
    def run(self, steps):
            
        obs = self.env.reset()
        obs = safe_tensor(obs).unsqueeze(0)
            
        for step in range(steps):
            
            action, log_prob = self.network.forward(obs)
            
            action_np = action.detach().cpu().numpy()[0]
            
            next_obs, reward, done, _ = self.env.step(action_np)
            
            next_obs = safe_tensor(next_obs).unsqueeze(0)
            
            self.buffer.add(obs.squeeze(0), action.squeeze(0), log_prob.squeeze(0), [reward])
            
            obs = next_obs
            
            if done:
                
                break
            

### **SET UP**

In [None]:
COLLECT_TRAJ = collect_traj()


### **ADAPTED POLICY**

In [None]:
def adapted_policy(inner_lr, network = POLICY_NET, buffer = META_BUFFER):
    
    # adapt policy
    
    adapt_policy = copy.deepcopy(network).to(device)
    
    # now do  ANIL
    
    for name, param in adapt_policy.named_parameters():
        
        if 'mu' not in name and 'log_std' not in name:
            
            param.requires_grad = False
            
    # sample buffer
            
    batch = buffer.sample()
    
    # unpack batch
    
    log_probs = batch['log_probs']
    rewards = batch['rewards']
    
    rewards = rewards.view(-1, 1)    
    # loss
    
    loss = (- log_probs * rewards).mean()
    
    # compute grad
    
    grad = torch.autograd.grad(loss, filter(lambda p: p.requires_grad, adapt_policy.parameters()), create_graph = True, allow_unused = True)
    
    # inner update
    
    for p, g in zip(filter(lambda p: p.requires_grad , adapt_policy.parameters()), grad):
        
        if g is not None:
            
            p.data -= inner_lr * g
        
    return adapt_policy


### **COLLECT - VAL**

In [None]:
def collect_val(adapted_policy, task, steps, META_ENV = META_ENV):
    
    META_ENV.set_task(task)
    
    obs = META_ENV.reset()
    
    obs = safe_tensor(obs).unsqueeze(0)
    
    total_loss = 0.0
    
    for _ in range(steps):
        
        action, log_prob = adapted_policy(obs)
        
        action_np = action.detach().cpu().numpy()[0]
        
        next_obs, reward, done, _ = META_ENV.step(action_np)
        
        reward = safe_tensor([reward])
        
        loss = - (log_prob.squeeze() * reward).mean()
        
        total_loss += loss
        
        next_obs = safe_tensor(next_obs).unsqueeze(0)
        
        obs = next_obs
        
        if done:
            
            break
        
    val_loss = total_loss / steps
    
    return val_loss


### **TRAINING LOOP**

In [18]:
def train_loop(meta_iteration, steps, num_tasks, inner_lr = inner_lr, OPTIMIZER = OPTIMIZER, SCHEDULER = SCHEDULER, META_BUFFER = META_BUFFER, META_ENV = META_ENV):
    
    for iteration in range(meta_iteration):
        
        total_val_loss = 0.0
        
        META_BUFFER.clear()
        
        tasks = META_ENV.sample_task(num_tasks)
        
        OPTIMIZER.zero_grad()
        
        for task in tasks:
            
            META_ENV.set_task(task)
            
            COLLECT_TRAJ.run(steps)
            
            adapt_policy = adapted_policy(inner_lr)
            
            val_loss = collect_val(adapt_policy, task, steps)
            
            val_loss.backward()
            torch.nn.utils.clip_grad_norm_(POLICY_NET.parameters(), max_norm = 0.5)
            
            total_val_loss += val_loss.item()
            
            
        OPTIMIZER.step()
        SCHEDULER.step()
        
        avg_val_loss = total_val_loss / num_tasks
        
        writer.add_scalar('Validation loss', avg_val_loss, iteration)
        
        writer.flush()
        
        print(f'epoch: {iteration} | avg_val_loss: {avg_val_loss:.4f}')
            

### **SET UP**

In [None]:
meta_iteration = 30
steps = 64
num_tasks = 10

train_loop(meta_iteration, steps, num_tasks)


epoch: 0 | avg_val_loss: -0.0072
epoch: 1 | avg_val_loss: -0.0048
epoch: 2 | avg_val_loss: -0.0084
epoch: 3 | avg_val_loss: -0.0061
epoch: 4 | avg_val_loss: -0.0068
epoch: 5 | avg_val_loss: -0.0040
epoch: 6 | avg_val_loss: -0.0069
epoch: 7 | avg_val_loss: -0.0027
epoch: 8 | avg_val_loss: -0.0089
epoch: 9 | avg_val_loss: -0.0083
epoch: 10 | avg_val_loss: -0.0059
epoch: 11 | avg_val_loss: -0.0037
epoch: 12 | avg_val_loss: -0.0052
epoch: 13 | avg_val_loss: -0.0066
epoch: 14 | avg_val_loss: -0.0064
epoch: 15 | avg_val_loss: -0.0051
epoch: 16 | avg_val_loss: -0.0055
epoch: 17 | avg_val_loss: -0.0048
epoch: 18 | avg_val_loss: -0.0060




epoch: 19 | avg_val_loss: -0.0072
epoch: 20 | avg_val_loss: -0.0062
epoch: 21 | avg_val_loss: -0.0059
epoch: 22 | avg_val_loss: -0.0050
epoch: 23 | avg_val_loss: -0.0080
epoch: 24 | avg_val_loss: -0.0062
epoch: 25 | avg_val_loss: -0.0073
epoch: 26 | avg_val_loss: -0.0061
epoch: 27 | avg_val_loss: -0.0087
epoch: 28 | avg_val_loss: -0.0071
epoch: 29 | avg_val_loss: -0.0039
