# **REPTILE**

In [1]:
import torch
torch.set_float32_matmul_precision('high')
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import gymnasium as gym

import numpy as np

from torch.utils.tensorboard import SummaryWriter


### **LOGGING**

In [None]:
writer = SummaryWriter(log_dir = './runs/Reptile')


### **DEVICE**

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')


Device: cuda


### **META - ENV**

In [None]:
class meta_env:
    
    def __init__(self, env_name):
        
        self.base_env = gym.make(env_name)
        
    def sample_task(self, num_tasks):
        
        self.tasks = []
        
        for _ in range(num_tasks):
            
            goal_position = np.random.uniform(0.45, 0.55)
            gravity = np.random.uniform(0.0025, 0.006)
            
            self.tasks.append((goal_position, gravity))
            
        return self.tasks
    
    def set_task(self, task):
        
        self.base_env.env.goal_position = task[0]
        self.base_env.env.gravity = task[1]
        
    def reset(self):
        
        obs = self.base_env.reset()
        
        if isinstance(obs, tuple):
            
            obs = obs[0]
            
        return obs
    
    def step(self, action_np):
        
        next_state, reward, termination, timeout, info = self.base_env.step(action_np)
        done = termination or timeout
        
        return next_state, reward, done, info
    
    def get_number(self):
        
        state_dim = self.base_env.observation_space.shape[0]
        action_dim = self.base_env.action_space.shape[0]
        max_action = self.base_env.action_space.high[0]
        reward_dim = 1
        
        return state_dim, action_dim, max_action, reward_dim
    
    def close(self):
        
        self.base_env.close()


### **SET UP**

In [None]:
META_ENV = meta_env('MountainCarContinuous-v0')

state_dim, action_dim, max_action, reward_dim = META_ENV.get_number()


### **HELPER**

In [None]:
def safe_tensor(x):
    
    return x if torch.is_tensor(x) else torch.tensor(x, dtype = torch.float32).to(device)


### **ASSEMBLY**

In [None]:
head_1 = 32
head_2 = 64
head_3 = 64
head_4 = 32

hidden_size = 32
hidden_size_2 = 64
hidden_size_3 = 128


### **HYPER X**

In [None]:
class hyper_x(nn.Module):
    
    def __init__(self, state_dim = state_dim, hidden_size = hidden_size, hidden_size_2 = hidden_size_2, hidden_size_3 = hidden_size_3):
        super(hyper_x, self).__init__()
        
        # input dim
        
        input_dim = state_dim
        
        # hyper mlp
        
        self.hyper = nn.Sequential(
            
            nn.Linear(input_dim, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.SiLU(),
            
            nn.Linear(hidden_size, hidden_size_2),
            nn.LayerNorm(hidden_size_2),
            nn.SiLU(),
            
            nn.Linear(hidden_size_2, hidden_size_3),
            nn.LayerNorm(hidden_size_3),
            nn.SiLU(),
            
            nn.Linear(hidden_size_3, hidden_size_3),
            nn.LayerNorm(hidden_size_3),
            nn.SiLU(),
            
            nn.Linear(hidden_size_3, hidden_size_2),
            nn.LayerNorm(hidden_size_2),
            nn.SiLU(),
            
            nn.Linear(hidden_size_2, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.SiLU()
        )
                
    def forward(self, state):
        
        hyper = self.hyper(state)
        
        return hyper


### **POLICY**

In [None]:
class policy_net(nn.Module):
    
    def __init__(self, action_dim = action_dim, head_1 = head_1, head_2 = head_2, head_3 = head_3, head_4 = head_4, max_action = max_action):
        super(policy_net, self).__init__()
        
        # max action
        
        self.max_action = max_action
        
        # hyper
        
        self.hyper = hyper_x()
        
        # norm
        
        self.hyper_norm = nn.LayerNorm(head_1)
        
        # post mlp
        
        self.post_process = nn.Sequential(
            
            nn.Linear(head_1, head_2),
            nn.LayerNorm(head_2),
            nn.SiLU(),
            
            nn.Linear(head_2, head_3),
            nn.LayerNorm(head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.LayerNorm(head_4),
            nn.SiLU()
        )
        
        # mu and log std head
        
        self.mu = nn.Linear(head_4, action_dim)
        self.log_std = nn.Linear(head_4, action_dim)
        
        # add normalization
        
        self.apply(self.init_weights)
        
    def init_weights(self, m):
        
        if isinstance(m, nn.Linear):
            
            nn.init.xavier_normal_(m.weight)
            
            if m.bias is not None:
                
                nn.init.zeros_(m.bias)
                
    def forward(self, state, deterministic = False):
        
        # state -> hyper
        
        hyper = self.hyper(state)
        
        # hyper -> norm
        
        norm = self.hyper_norm(hyper)
        
        # norm -> post process
        
        post_process = self.post_process(norm)
        
        # mu and log std
        
        mu = self.mu(post_process)
        
        if deterministic:
            
            return mu
        
        log_std = self.log_std(post_process)
        log_std = torch.clamp(log_std, -20, 2)
        std = torch.exp(log_std)
        
        # reparameterization
        
        dist = torch.distributions.Normal(mu, std)
        z = dist.rsample()
        tanh_z = torch.tanh(z)
        action = tanh_z * self.max_action
        
        log_prob = dist.log_prob(z)
        squash = torch.log(1 - tanh_z.pow(2) + 1e-6)
        log_prob = log_prob - squash
        
        log_prob = log_prob.sum(dim = -1, keepdim = True)
        
        entropy = dist.entropy().sum(dim = -1).mean()
        
        return action, log_prob, entropy


### **CRITIC**

In [None]:
class critic_net(nn.Module):
    
    def __init__(self, action_dim = action_dim, state_dim = state_dim, head_1 = head_1, head_2 = head_2, head_3 = head_3, head_4 = head_4):
        super(critic_net, self).__init__()
        
        # input dim
        
        input_dim = state_dim + action_dim
        
        # pre process 
        
        self.pre_process = nn.Linear(input_dim, head_1)
        self.norm = nn.LayerNorm(head_1)
        
        # post process
        
        self.post_process = nn.Sequential(
            
            nn.Linear(head_1, head_2),
            nn.LayerNorm(head_2),
            nn.SiLU(),
            
            nn.Linear(head_2, head_3),
            nn.LayerNorm(head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.LayerNorm(head_4),
            nn.SiLU()
            
        )
        
        # critic head
        
        self.critic = nn.Linear(head_4, 1)
        
    def forward(self, state, action):
        
        # cat
        
        cat = torch.cat([state, action], dim = -1)
        
        # pre process
        
        pre = self.pre_process(cat)
        
        # post 
        
        post = self.post_process(pre)
        
        # critic
        
        q = self.critic(post)
        
        return q


### **SET UP**

In [None]:
# policy network

POLICY_NET = policy_net().to(device)

print(POLICY_NET)

print('-' * 200)

# critic network

CRITIC_NET = critic_net().to(device)

print(CRITIC_NET)


policy_net(
  (hyper): hyper_x(
    (hyper): Sequential(
      (0): Linear(in_features=2, out_features=32, bias=True)
      (1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (2): SiLU()
      (3): Linear(in_features=32, out_features=64, bias=True)
      (4): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (5): SiLU()
      (6): Linear(in_features=64, out_features=128, bias=True)
      (7): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (8): SiLU()
      (9): Linear(in_features=128, out_features=128, bias=True)
      (10): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (11): SiLU()
      (12): Linear(in_features=128, out_features=64, bias=True)
      (13): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (14): SiLU()
      (15): Linear(in_features=64, out_features=32, bias=True)
      (16): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (17): SiLU()
    )
  )
  (hyper_norm): LayerNorm((32,), eps=1e-05, elementwise_

### **OPTIMIZER**

In [12]:
# lr

policy_lr = 1e-4

critic_lr = 3e-4

meta_lr = 0.005

T_max = 10

warmup_epochs = 5

# optimizer

OPTIMIZER = optim.AdamW([
    
    {'params': CRITIC_NET.parameters(), 'lr': critic_lr, 'weight_decay': 1e-6},
    {'params': POLICY_NET.parameters(), 'lr': policy_lr, 'weight_decay': 0}
])


# scheduler

warmup_scheduler = optim.lr_scheduler.LinearLR(OPTIMIZER, start_factor = 0.1, total_iters = warmup_epochs)
cosine_scheduler = optim.lr_scheduler.CosineAnnealingLR(OPTIMIZER, T_max = T_max - warmup_epochs, eta_min = 1e-5)


SCHEDULER = optim.lr_scheduler.SequentialLR(OPTIMIZER, [warmup_scheduler, cosine_scheduler], milestones = [warmup_epochs])


### **BUFFER**

In [None]:
class meta_buffer:
    
    def __init__(self):
        
        self.buffer = []
        
    def add(self, state, action, log_prob, reward, done, next_state, entropy):
        
        # add
        
        self.buffer.append({
            
            'states': safe_tensor(state),
            'actions': safe_tensor(action),
            'log_probs': safe_tensor(log_prob),
            'rewards': safe_tensor(reward),
            'dones': safe_tensor(done),
            'next_states': safe_tensor(next_state),
            'entropy': safe_tensor(entropy)
        })
        
    def safe_stack(self, x):
        
        return torch.stack(x).to(device)
        
    def sample(self):
        
        # stack
        
        states = [i['states'] for i in self.buffer]
        actions = [i['actions'] for i in self.buffer]
        log_probs = [i['log_probs'] for i in self.buffer]
        rewards = [i['rewards'] for i in self.buffer]
        dones = [i['dones'] for i in self.buffer]
        next_states = [i['next_states'] for i in self.buffer]
        entropy = [i['entropy'] for i in self.buffer]
        
        # batch
        
        batch = {
            
            'states': self.safe_stack(states),
            'actions': self.safe_stack(actions),
            'log_probs': self.safe_stack(log_probs),
            'rewards': self.safe_stack(rewards),
            'dones': self.safe_stack(dones),
            'next_states': self.safe_stack(next_states),
            'entropy': self.safe_stack(entropy)
            
        }
        
        return batch
    
    def clear(self):
        
        self.buffer.clear()


### **SET UP**

In [None]:
META_BUFFER = meta_buffer()


### **COMPUTE LOSS**

In [None]:
class loss_func:
    
    def __init__(self, gamma, gae_lam, entropy_coef, POLICY_NET = POLICY_NET, CRITIC_NET = CRITIC_NET, META_BUFFER = META_BUFFER):
        
        # hyper param
        
        self.gamma = gamma
        self.gae_lam = gae_lam
        self.entropy_coef = entropy_coef
        
        # network
        
        self.policy = POLICY_NET
        self.critic = CRITIC_NET
        
        # buffer
        
        self.buffer = META_BUFFER

    def compute_gae(self, rewards, dones, value, last_value):
        
        values = torch.cat([value, last_value], dim = 0)
        
        advantages = []
        gae = 0
        
        for step in reversed(range(len(rewards))):
            
            delta = rewards[step] + self.gamma * (1 - dones[step]) * values[step + 1] - values[step]
            gae = delta + self.gamma * self.gae_lam * (1 - dones[step]) * gae
            
            advantages.insert(0, gae)
            
        advantages = safe_tensor(advantages)
        
        returns = [adv + val for adv, val in zip(advantages, value)]
        
        returns = safe_tensor(returns)
        
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-7)
        
        return advantages, returns
    
    def critic_loss(self, value, returns):
        
        value = value.squeeze(1)
        
        returns = returns.detach()
        
        critic_loss = F.mse_loss(value, returns)
        
        return critic_loss
    
    def policy_loss(self, log_probs, advantages, entropy):
        
        advantages = advantages.detach()
        
        policy_loss = (- log_probs * advantages).mean() - self.entropy_coef * entropy.mean()
        
        return policy_loss

    def compute_loss(self):
        
        # sample
        
        batch = self.buffer.sample()
        
        # unpack
        
        states = batch['states']
        actions = batch['actions']
        log_probs = batch['log_probs']
        rewards = batch['rewards']
        dones = batch['dones']
        next_states = batch['next_states']
        entropy = batch['entropy']
        
        # shape check
        
        rewards = rewards.view(-1, 1)
        dones = dones.view(-1, 1)
        
        # get value from critic through policy action
        
        value = self.critic.forward(states, actions)    
        
        # get last value from critic through next state
        
        with torch.no_grad():
            
            last_state = next_states[-1:]
            
            last_action, _, _ = self.policy.forward(last_state)
            
            last_value = self.critic.forward(last_state, last_action)
            
        # compute gae
        
        advantages, returns = self.compute_gae(rewards, dones, value, last_value)
        
        # compute critic loss
        
        critic_loss = self.critic_loss(value, returns)
        
        # compute policy loss
        
        policy_loss = self.policy_loss(log_probs, advantages, entropy)
        
        total_loss = policy_loss + critic_loss
        
        return total_loss, policy_loss, critic_loss


### **SET UP**

In [None]:
# hyper param

gamma = 0.99
gae_lam = 0.95
entropy_coef = 0.01

# setup

LOSS_FUNCTION = loss_func(gamma, gae_lam, entropy_coef)


### **HELPER 2**

In [None]:
def get_current_params(agent):
    
    return [p.clone().detach() for p in agent.parameters()]


In [18]:
def meta_update(meta_params, adapted_params, meta_lr):
    
    for param, adapted in zip(meta_params, adapted_params):
        
        param.data = param.data + meta_lr * (adapted.data - param.data)


### **TRAINING**

In [19]:
def TRAINING_LOOP(meta_iteration, inner_steps, batch_size, meta_lr, META_ENV = META_ENV, POLICY_NET = POLICY_NET, OPTIMIZER = OPTIMIZER, SCHEDULER = SCHEDULER, LOSS_FUNCTION = LOSS_FUNCTION):
    
    for iteration in range(meta_iteration):
        
        meta_params = get_current_params(POLICY_NET)
        
        total_agent_loss, total_policy_loss, total_critic_loss = 0.0, 0.0, 0.0
        
        for task in META_ENV.sample_task(batch_size):
        
            META_ENV.set_task(task)
            
            obs = META_ENV.reset()
            
            obs = safe_tensor(obs).unsqueeze(0)
            
            META_BUFFER.clear()
            
            # inner loop
            
            for step in range(inner_steps):
                
                action, log_prob, entropy = POLICY_NET.forward(obs)
                
                action_np = action.detach().cpu().numpy()[0]
                
                next_state, reward, done, _ = META_ENV.step(action_np)
                
                next_state = safe_tensor(next_state).unsqueeze(0)
                
                META_BUFFER.add(obs.squeeze(0), action.squeeze(0), log_prob.squeeze(0), reward, done, next_state.squeeze(0), entropy)
                
                obs = next_state
                
                if done:
                    
                    break
                
                
            # compute loss
            
            agent_loss, policy_loss, critic_loss = LOSS_FUNCTION.compute_loss()
            
            OPTIMIZER.zero_grad()
            agent_loss.backward()
            torch.nn.utils.clip_grad_norm_(POLICY_NET.parameters(), max_norm = 0.5)
            torch.nn.utils.clip_grad_norm_(CRITIC_NET.parameters(), max_norm = 0.5)
            OPTIMIZER.step()
            SCHEDULER.step()
            
            total_agent_loss += agent_loss.item()
            total_policy_loss += policy_loss.item()
            total_critic_loss += critic_loss.item()
            
        avg_agent_loss = total_agent_loss / batch_size
        avg_policy_loss = total_policy_loss / batch_size
        avg_critic_loss = total_critic_loss / batch_size
        
        writer.add_scalar('Agent loss', avg_agent_loss, iteration)
        writer.add_scalar('Policy loss', policy_loss, iteration)
        writer.add_scalar('Critic loss', critic_loss, iteration)
        
        adapted_params = get_current_params(POLICY_NET)
        meta_update(meta_params, adapted_params, meta_lr)
        
        print(f'epoch: {iteration} | agent loss: {avg_agent_loss:.3f} | policy loss: {avg_policy_loss:.3f} | critic loss: {avg_critic_loss:.3f}')
        

### **SET UP**

In [None]:
# param

meta_iteration = 10
inner_steps = 64
batch_size = 256

TRAINING_LOOP(meta_iteration, inner_steps, batch_size, meta_lr)




epoch: 0 | agent loss: 1.142 | policy loss: -0.016 | critic loss: 1.158
epoch: 1 | agent loss: 0.933 | policy loss: -0.019 | critic loss: 0.952
epoch: 2 | agent loss: 0.607 | policy loss: -0.024 | critic loss: 0.632
epoch: 3 | agent loss: 0.311 | policy loss: -0.034 | critic loss: 0.345
epoch: 4 | agent loss: 0.139 | policy loss: -0.034 | critic loss: 0.173
epoch: 5 | agent loss: 0.026 | policy loss: -0.034 | critic loss: 0.060
epoch: 6 | agent loss: -0.026 | policy loss: -0.034 | critic loss: 0.008
epoch: 7 | agent loss: -0.030 | policy loss: -0.034 | critic loss: 0.004
epoch: 8 | agent loss: -0.030 | policy loss: -0.034 | critic loss: 0.004
epoch: 9 | agent loss: -0.030 | policy loss: -0.034 | critic loss: 0.004
