# F I R S T - O R D E R - M A M L

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

import numpy as np
import copy

import gymnasium as gym


### D E V I C E 

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')


Device: cuda


### H E L P E R 

In [None]:
def safe_tensor(x):
    
    return x if torch.is_tensor(x) else torch.tensor(x, dtype = torch.float32).to(device)


### L O G G I N G

In [None]:
writer = SummaryWriter(log_dir = './runs/FOMAML')


### M E T A - E N V

In [None]:
class meta_car_continuous:
    
    def __init__(self):
        
        self.base_env = gym.make('MountainCarContinuous-v0')
        self.org_gravity = 0.0025
        self.org_goal_position = 0.45
        self.tasks = []
        self.current_task = (self.org_goal_position, self.org_gravity)
        
    def sample_task(self, num_tasks):
        
        self.tasks = []
        
        for _ in range(num_tasks):
            goal_position = np.random.uniform(0.45, 0.55)
            gravity = np.random.uniform(0.0025, 0.006)
            self.tasks.append((goal_position, gravity))
            
        return self.tasks

    def set_task(self, task):
        
        self.current_task = task
        
        self.base_env.env.gravity = task[1]
        self.base_env.env.goal_position = task[0]

    def reset(self):
        
        obs = self.base_env.reset()
        
        if isinstance(obs, tuple):
            
            obs = obs[0]
            
        return obs

    def step(self, action):
        
        step_output = self.base_env.step(action)
        
        if len(step_output) == 5: 
            
            next_obs, reward, terminated, truncated, info = step_output
            done = terminated or truncated
            
        else:
            
            next_obs, reward, done, info = step_output
            
        return next_obs, reward, done, info

    def get_number(self):
        
        state_dim = self.base_env.observation_space.shape[0]
        action_dim = self.base_env.action_space.shape[0]
        max_action = float(self.base_env.action_space.high[0])
        reward_dim = 1
        
        return state_dim, action_dim, max_action, reward_dim

    def close(self):
        
        self.base_env.close()


### S E T U P 

In [6]:
# Initiate env

env = meta_car_continuous()

tasks = env.sample_task(2)

for task in tasks:
    
    env.set_task(task)

    obs = env.reset()

    print(f'obs: {obs.shape}')

# get number

state_dim, action_dim, max_action, reward_dim = env.get_number()

print(f'state dim: {state_dim} | action dim: {action_dim} | max action: {max_action} | reward dim: {reward_dim}')


obs: (2,)
obs: (2,)
state dim: 2 | action dim: 1 | max action: 1.0 | reward dim: 1


### A S S E M B L Y 

In [None]:
head_1 = 32
head_2 = 64
head_3 = 64
head_4 = 32

hidden_size = 32
hidden_size_2 = 64


### H Y P E R - X 

In [8]:
class hyper_x(nn.Module):
    
    def __init__(self, state_dim = state_dim, hidden_size = hidden_size, hidden_size_2 = hidden_size_2):
        super(hyper_x, self).__init__()
        
        # input dim
        
        input_dim = state_dim
        
        # mlp
        
        self.hyper = nn.Sequential(
            
            nn.Linear(input_dim, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.SiLU(),
            
            nn.Linear(hidden_size, hidden_size_2),
            nn.LayerNorm(hidden_size_2),
            nn.SiLU(),
            
            
            nn.Linear(hidden_size_2, hidden_size_2),
            nn.LayerNorm(hidden_size_2),
            nn.SiLU(),
            
            
            nn.Linear(hidden_size_2, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.SiLU(),
            
        )
        
    def forward(self, state):
        
        hyper = self.hyper(state)
        
        return hyper
    

### P O L I C Y

In [9]:
class policy_net(nn.Module):
    
    def __init__(self, head_1 = head_1, head_2 = head_2, head_3 = head_3, head_4 = head_4, action_dim = action_dim, max_action = max_action):
        super(policy_net, self).__init__()
        
        # max action
        
        self.max_action = max_action
        
        # hyper
        
        self.hyper = hyper_x()
        
        # mlp : [ input -> state -> hyper -> output -> hidden  [ 64 ] -> input -> policy -> head 1 [ 64 ] -> output -> action ]
        
        self.policy = nn.Sequential(
            
            nn.Linear(head_1, head_2),
            nn.LayerNorm(head_2),
            nn.SiLU(),
            
            nn.Linear(head_2, head_3),
            nn.LayerNorm(head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.LayerNorm(head_4),
            nn.SiLU()
            
        )
        
        # mu and log std head
        
        self.mu = nn.Linear(head_4, action_dim)
        self.log_std = nn.Linear(head_4, action_dim)
        
    def forward(self, state):
        
        # state to hyper
        
        hyper = self.hyper.forward(state)
        
        # input from hyper to policy
        
        x = self.policy(hyper)
        
        # mu  and log std
        
        mu = self.mu(x)
        log_std = self.log_std(x)
        log_std = log_std.clamp(-10, 2)
        std = torch.exp(log_std)
        
        dist = torch.distributions.Normal(mu, std)
        z = dist.rsample()
        tanh_z = torch.tanh(z)
        action = tanh_z * self.max_action
        
        log_prob = dist.log_prob(z)
        log_prob = log_prob - (1 - tanh_z.pow(2) + 1e-6).log()
        log_prob = log_prob.sum(dim = -1, keepdim = True)
        
        return action, log_prob
        

### S E T U P 

In [None]:
POLICY_NETWORK = policy_net().to(device)

print(POLICY_NETWORK)


policy_net(
  (hyper): hyper_x(
    (hyper): Sequential(
      (0): Linear(in_features=2, out_features=32, bias=True)
      (1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (2): SiLU()
      (3): Linear(in_features=32, out_features=64, bias=True)
      (4): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (5): SiLU()
      (6): Linear(in_features=64, out_features=64, bias=True)
      (7): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (8): SiLU()
      (9): Linear(in_features=64, out_features=32, bias=True)
      (10): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (11): SiLU()
    )
  )
  (policy): Sequential(
    (0): Linear(in_features=32, out_features=64, bias=True)
    (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (2): SiLU()
    (3): Linear(in_features=64, out_features=64, bias=True)
    (4): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (5): SiLU()
    (6): Linear(in_features=64, out_features=32, bias=Tr

### O P T I M I Z E R 

In [None]:
# lr

meta_lr = 1e-4
meta_iteration = 50

# optimizer

OPTIMIZER = optim.AdamW(POLICY_NETWORK.parameters(), meta_lr, weight_decay = 0)
SCHEDULER = optim.lr_scheduler.CosineAnnealingLR(OPTIMIZER, T_max = meta_iteration, eta_min = 1e-5)


### B U F F E R

In [None]:
class roller_buffer:
    
    def __init__(self):
        
        self.buffer = []
        
    def add(self, state, action, log_prob, reward, next_state):
    
        # add to buffer
        
        self.buffer.append({
            
            'states': safe_tensor(state),
            'actions': safe_tensor(action),
            'log_probs': safe_tensor(log_prob),
            'rewards': safe_tensor(reward),
            'next_states': safe_tensor(next_state)
        })
        
    def safe_stack(self, x):
        
        return torch.stack(x).to(device)
        
    def sample(self):
        
        # take trajectories
        
        states = [i['states'] for i in self.buffer]
        actions = [i['actions'] for i in self.buffer]
        log_probs = [i['log_probs'] for i in self.buffer]
        rewards = [i['rewards'] for i in self.buffer]
        next_states = [i['next_states'] for i in self.buffer]
        
        # form batch
        
        batch = {
            
            'states': self.safe_stack(states),
            'actions': self.safe_stack(actions),
            'log_probs': self.safe_stack(log_probs),
            'rewards': self.safe_stack(rewards),
            'next_states': self.safe_stack(next_states)
            
        }
        
        return batch
    
    def clear(self):
        
        self.buffer.clear()


### S E T U P 

In [None]:
buffer = roller_buffer()


### C O L L E C T - T R A J

In [14]:
class collect_traj:
    
    def __init__(self, buffer = buffer, env = env, POLICY_NETWORK = POLICY_NETWORK):
        
        self.buffer = buffer
        self.env = env
        self.policy = POLICY_NETWORK
        
    def run(self, steps):
            
        obs = self.env.reset()
            
        obs = safe_tensor(obs).unsqueeze(0)
            
        total_reward = 0.0
            
        for _ in range(steps):
                
            action, log_prob = self.policy.forward(obs)
                
            action_np = action.detach().cpu().numpy()[0]
                
            next_obs, reward, done, _ = self.env.step(action_np)
                
            total_reward += reward
                
            next_obs = safe_tensor(next_obs).unsqueeze(0)
                
            self.buffer.add(obs.squeeze(0), action.squeeze(0), log_prob.squeeze(0), reward, next_obs.squeeze(0))
            
            obs = next_obs
                
            if done:
                    
                break
                

### S E T U P 

In [None]:
COLLECT_TRAJECTORY = collect_traj()


### A D A P T - P O L I C Y

In [None]:
def adapt_policy(learning_rate, POLICY_NETWORK = POLICY_NETWORK, buffer = buffer):
    
    # clone the policy
    
    adapted_policy = copy.deepcopy(POLICY_NETWORK)
    
    # Get batch from buffer
    
    batch = buffer.sample()
    
    states = batch['states']
    actions = batch['actions']
    rewards = batch['rewards']
    
    rewards = rewards.unsqueeze(1)
    
    # compute log probs from cloned policy
    
    _, log_probs = adapted_policy.forward(states)
    
    # compute reinforce loss
    
    loss = - (log_probs.squeeze() * rewards.squeeze()).mean()
    
    # compute gradients
    
    grad = torch.autograd.grad(loss, adapted_policy.parameters(), create_graph = True)
    
    # gradient step
    
    for p, g in zip(adapted_policy.parameters(), grad):
        
        p.data = p.data - learning_rate * g
        
    return adapted_policy


### C O L L E C T - V A L 

In [None]:
def collect_val(adapted_policy, task, steps, env = env):
      
    env.set_task(task)
        
    obs = env.reset()
    obs = safe_tensor(obs).unsqueeze(0)
        
    total_loss = 0.0
    
    for _ in range(steps):
            
        action, log_prob = adapted_policy.forward(obs)
            
        action_np = action.detach().cpu().numpy()[0]
            
        next_obs, reward, done, _ = env.step(action_np)
            
        next_obs = safe_tensor(next_obs).unsqueeze(0)
            
        reward = safe_tensor(reward)
            
        log_prob = log_prob.squeeze()
            
        total_loss += ( - log_prob * reward)
            
        obs = next_obs
            
        if done:
                
            break
            
    val_loss = total_loss / steps
    return val_loss 


### T R A I N I N G - L O O P

In [18]:
def meta_runner_loop(num_tasks, steps, learning_rate, OPTIMIZER = OPTIMIZER, SCHEDULER = SCHEDULER, meta_iteration = meta_iteration):
    
    for iteration in tqdm(range(1, meta_iteration + 1), desc = 'FOMAML'):
        
        OPTIMIZER.zero_grad()
        total_val_loss = 0.0
        
        tasks = env.sample_task(num_tasks)
        
        for task_num in tasks:
            
            env.set_task(task_num)
            
            buffer.clear()
            
            COLLECT_TRAJECTORY.run(steps = steps)
            
            adapted_policy = adapt_policy(learning_rate)
            
            val_loss = collect_val(adapted_policy, task_num, steps)
            
            val_loss.backward()
            torch.nn.utils.clip_grad_norm_(POLICY_NETWORK.parameters(), max_norm = 0.5)
            
            total_val_loss += val_loss.item()
            
        OPTIMIZER.step()
        SCHEDULER.step()
        
        avg_val_loss = total_val_loss / num_tasks
        
        writer.add_scalar('VALIDATION Loss', avg_val_loss, iteration)
        
        print(f'epoch: {iteration} | avg_val_loss: {avg_val_loss:.4f}')
            

### S E T U P 

In [None]:
num_tasks = 5
steps = 64
learning_rate = 1e-5

meta_runner_loop(num_tasks, steps, learning_rate)


FOMAML:   2%|▏         | 1/50 [00:03<03:00,  3.69s/it]

epoch: 1 | avg_val_loss: -0.0005


FOMAML:   4%|▍         | 2/50 [00:06<02:32,  3.18s/it]

epoch: 2 | avg_val_loss: -0.0052


FOMAML:   6%|▌         | 3/50 [00:10<02:39,  3.40s/it]

epoch: 3 | avg_val_loss: -0.0035


FOMAML:   8%|▊         | 4/50 [00:12<02:21,  3.08s/it]

epoch: 4 | avg_val_loss: -0.0027


FOMAML:  10%|█         | 5/50 [00:15<02:06,  2.80s/it]

epoch: 5 | avg_val_loss: -0.0005


FOMAML:  12%|█▏        | 6/50 [00:19<02:26,  3.33s/it]

epoch: 6 | avg_val_loss: -0.0019


FOMAML:  14%|█▍        | 7/50 [00:22<02:19,  3.24s/it]

epoch: 7 | avg_val_loss: -0.0058


FOMAML:  16%|█▌        | 8/50 [00:24<02:03,  2.93s/it]

epoch: 8 | avg_val_loss: -0.0042


FOMAML:  18%|█▊        | 9/50 [00:27<01:56,  2.83s/it]

epoch: 9 | avg_val_loss: -0.0047


FOMAML:  20%|██        | 10/50 [00:30<01:52,  2.81s/it]

epoch: 10 | avg_val_loss: -0.0023


FOMAML:  22%|██▏       | 11/50 [00:33<01:57,  3.01s/it]

epoch: 11 | avg_val_loss: -0.0000


FOMAML:  24%|██▍       | 12/50 [00:36<01:48,  2.86s/it]

epoch: 12 | avg_val_loss: -0.0041


FOMAML:  26%|██▌       | 13/50 [00:38<01:38,  2.67s/it]

epoch: 13 | avg_val_loss: -0.0079


FOMAML:  28%|██▊       | 14/50 [00:40<01:32,  2.56s/it]

epoch: 14 | avg_val_loss: -0.0043


FOMAML:  30%|███       | 15/50 [00:42<01:27,  2.49s/it]

epoch: 15 | avg_val_loss: -0.0076


FOMAML:  32%|███▏      | 16/50 [00:45<01:23,  2.46s/it]

epoch: 16 | avg_val_loss: -0.0016


FOMAML:  34%|███▍      | 17/50 [00:47<01:20,  2.45s/it]

epoch: 17 | avg_val_loss: -0.0039


FOMAML:  36%|███▌      | 18/50 [00:49<01:15,  2.35s/it]

epoch: 18 | avg_val_loss: -0.0036


FOMAML:  38%|███▊      | 19/50 [00:52<01:11,  2.31s/it]

epoch: 19 | avg_val_loss: -0.0046


FOMAML:  40%|████      | 20/50 [00:54<01:10,  2.34s/it]

epoch: 20 | avg_val_loss: 0.0003


FOMAML:  42%|████▏     | 21/50 [00:57<01:09,  2.40s/it]

epoch: 21 | avg_val_loss: -0.0024


FOMAML:  44%|████▍     | 22/50 [00:59<01:05,  2.34s/it]

epoch: 22 | avg_val_loss: -0.0019


FOMAML:  46%|████▌     | 23/50 [01:01<01:04,  2.40s/it]

epoch: 23 | avg_val_loss: -0.0073


FOMAML:  48%|████▊     | 24/50 [01:05<01:12,  2.80s/it]

epoch: 24 | avg_val_loss: -0.0079


FOMAML:  50%|█████     | 25/50 [01:07<01:05,  2.64s/it]

epoch: 25 | avg_val_loss: -0.0057


FOMAML:  52%|█████▏    | 26/50 [01:10<01:01,  2.57s/it]

epoch: 26 | avg_val_loss: -0.0032


FOMAML:  54%|█████▍    | 27/50 [01:12<00:56,  2.44s/it]

epoch: 27 | avg_val_loss: -0.0018


FOMAML:  56%|█████▌    | 28/50 [01:14<00:51,  2.34s/it]

epoch: 28 | avg_val_loss: -0.0045


FOMAML:  58%|█████▊    | 29/50 [01:16<00:48,  2.32s/it]

epoch: 29 | avg_val_loss: -0.0014


FOMAML:  60%|██████    | 30/50 [01:19<00:46,  2.33s/it]

epoch: 30 | avg_val_loss: -0.0023


FOMAML:  62%|██████▏   | 31/50 [01:21<00:44,  2.32s/it]

epoch: 31 | avg_val_loss: -0.0054


FOMAML:  64%|██████▍   | 32/50 [01:23<00:41,  2.32s/it]

epoch: 32 | avg_val_loss: -0.0018


FOMAML:  66%|██████▌   | 33/50 [01:25<00:38,  2.28s/it]

epoch: 33 | avg_val_loss: -0.0052


FOMAML:  68%|██████▊   | 34/50 [01:27<00:35,  2.21s/it]

epoch: 34 | avg_val_loss: -0.0053


FOMAML:  70%|███████   | 35/50 [01:30<00:33,  2.22s/it]

epoch: 35 | avg_val_loss: -0.0054


FOMAML:  72%|███████▏  | 36/50 [01:32<00:31,  2.25s/it]

epoch: 36 | avg_val_loss: -0.0034


FOMAML:  74%|███████▍  | 37/50 [01:34<00:29,  2.25s/it]

epoch: 37 | avg_val_loss: -0.0037


FOMAML:  76%|███████▌  | 38/50 [01:37<00:27,  2.25s/it]

epoch: 38 | avg_val_loss: -0.0026


FOMAML:  78%|███████▊  | 39/50 [01:39<00:24,  2.27s/it]

epoch: 39 | avg_val_loss: -0.0018


FOMAML:  80%|████████  | 40/50 [01:41<00:22,  2.25s/it]

epoch: 40 | avg_val_loss: -0.0058


FOMAML:  82%|████████▏ | 41/50 [01:43<00:20,  2.25s/it]

epoch: 41 | avg_val_loss: -0.0058


FOMAML:  84%|████████▍ | 42/50 [01:46<00:18,  2.26s/it]

epoch: 42 | avg_val_loss: -0.0046


FOMAML:  86%|████████▌ | 43/50 [01:48<00:15,  2.22s/it]

epoch: 43 | avg_val_loss: -0.0025


FOMAML:  88%|████████▊ | 44/50 [01:50<00:13,  2.20s/it]

epoch: 44 | avg_val_loss: -0.0028


FOMAML:  90%|█████████ | 45/50 [01:52<00:10,  2.18s/it]

epoch: 45 | avg_val_loss: -0.0045


FOMAML:  92%|█████████▏| 46/50 [01:54<00:08,  2.18s/it]

epoch: 46 | avg_val_loss: -0.0058


FOMAML:  94%|█████████▍| 47/50 [01:56<00:06,  2.16s/it]

epoch: 47 | avg_val_loss: -0.0070


FOMAML:  96%|█████████▌| 48/50 [01:58<00:04,  2.16s/it]

epoch: 48 | avg_val_loss: -0.0023


FOMAML:  98%|█████████▊| 49/50 [02:01<00:02,  2.16s/it]

epoch: 49 | avg_val_loss: -0.0025


FOMAML: 100%|██████████| 50/50 [02:03<00:00,  2.47s/it]

epoch: 50 | avg_val_loss: -0.0038



