# Contextual Cross Attention Meta-RL

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import gymnasium as gym
import copy
import random

from torch.utils.tensorboard import SummaryWriter

import warnings


### L O G G I N G

In [2]:
writer = SummaryWriter(log_dir = './runs/CTX')


### **PREFERANCES**

In [3]:

warnings.filterwarnings('ignore', category = UserWarning)


### D E V I C E 

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device : {device}')


Device : cuda


### M E T A - E N V


In [5]:
class Meta_env:
    
    def __init__(self):
        
        self.base_env = gym.make('Walker2d-v5')
        self.target_velocity = 1.0
        
    def sample_tasks(self, num_tasks):
        
        self.tasks = []
        
        for _ in range(num_tasks):
            
            gravity = np.random.uniform(5.0, 15.0)
            torso_mass = np.random.uniform(1.0, 5.0)
            target_velocity = np.random.uniform(0.5, 3.0)
            
            self.tasks.append((gravity, torso_mass, target_velocity))
        
        return self.tasks
    
    def set_task(self, task):
        
        raw_env = self.base_env.unwrapped
        raw_env.model.opt.gravity[-1] = -task[0]
        raw_env.model.body_mass[1] = task[1]
        self.target_velocity = task[2]
        
    def reset(self):
        
        obs, _ = self.base_env.reset()
        
        return obs
    
    def step(self, action):
        
        next_obs, reward, termination, timeouts, info = self.base_env.step(action)

        done = termination | timeouts
        raw_env = self.base_env.unwrapped
        vel = raw_env.data.qvel[0]
        reward -= 0.5 * abs(vel - self.target_velocity)
        
        return next_obs, reward, done
    
    def get_number(self):
        
        state_dim = self.base_env.observation_space.shape[0]
        action_dim = self.base_env.action_space.shape[0]
        max_action = self.base_env.action_space.high[0]
        reward_dim = 1
        
        return state_dim, action_dim, max_action, reward_dim
    
    def close(self):
        
        self.base_env.close()


### S E T U P 

In [6]:
env = Meta_env()

obs = env.reset()

state_dim, action_dim, max_action, reward_dim = env.get_number()

print(f'state shape: {obs.shape}')
print(f'state dim: {state_dim} | action dim: {action_dim} | max action: {max_action}')



state shape: (17,)
state dim: 17 | action dim: 6 | max action: 1.0


### A S S E M B L Y


In [7]:
head_1 = 64
head_2 = 128
head_3 = 128
head_4 = 64

embed_dim = 64

hidden_size = 128
hidden_size_2 = 256


### H E L P E R 

In [8]:
def safe_tensor(x):
    
    return x if torch.is_tensor(x) else torch.tensor(x, dtype = torch.float32).to(device)


### C O N T E X T - E N C O D E R 

In [9]:
class context_encoder(nn.Module):
    
    def __init__(self, head_1 = head_1, head_2 = head_2, head_3 = head_3, head_4 = head_4, state_dim = state_dim, action_dim = action_dim, reward_dim = reward_dim, embed_dim = embed_dim):
        super(context_encoder, self).__init__()
        
        input_dim = 2 * state_dim + action_dim + reward_dim
        
        self.encode = nn.Sequential(
            
            nn.Linear(input_dim, head_1),
            nn.LayerNorm(head_1),
            nn.SiLU(),
            
            nn.Linear(head_1, head_2),
            nn.LayerNorm(head_2),
            nn.SiLU(),
            
            nn.Linear(head_2, head_3),
            nn.LayerNorm(head_3),
            nn.SiLU(),
            
            nn.Linear(head_3, head_4),
            nn.LayerNorm(head_4),
            nn.SiLU()
        )
        
        self.out = nn.Linear(head_4, embed_dim)
        self.norm = nn.LayerNorm(embed_dim)
        
    def forward(self, state, action, reward, next_state):
        
        cat = torch.cat([state, action, reward, next_state], dim = -1)
        
        encode = self.encode(cat)
    
        out = F.silu(self.norm(self.out(encode)))
        
        return out        
    

### S E T U P 

In [10]:
CONTEXT_ENCODER = context_encoder().to(device)

print(CONTEXT_ENCODER)


context_encoder(
  (encode): Sequential(
    (0): Linear(in_features=41, out_features=64, bias=True)
    (1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (2): SiLU()
    (3): Linear(in_features=64, out_features=128, bias=True)
    (4): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (5): SiLU()
    (6): Linear(in_features=128, out_features=128, bias=True)
    (7): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (8): SiLU()
    (9): Linear(in_features=128, out_features=64, bias=True)
    (10): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (11): SiLU()
  )
  (out): Linear(in_features=64, out_features=64, bias=True)
  (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
)


### H Y P E R - X

In [11]:
class hyper_x(nn.Module):
    
    def __init__(self, embed_dim = embed_dim, hidden_size = hidden_size, hidden_size_2 = hidden_size_2):
        super(hyper_x, self).__init__()
        
        self.hype = nn.Sequential(
            
            nn.Linear(embed_dim, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.SiLU(),
            
            nn.Linear(hidden_size, hidden_size_2),
            nn.LayerNorm(hidden_size_2),
            nn.SiLU(),
            
            nn.Linear(hidden_size_2, hidden_size_2),
            nn.LayerNorm(hidden_size_2),
            nn.SiLU(),
            
            nn.Linear(hidden_size_2, hidden_size),
            nn.LayerNorm(hidden_size),
            nn.SiLU()
        )
        
        self.hype_out = nn.Linear(hidden_size, embed_dim)
        self.norm = nn.LayerNorm(embed_dim)
        
    def forward(self, encoded_z):
        
        hype = self.hype(encoded_z)
        
        out = F.silu(self.norm(self.hype_out(hype)))
        
        return out


### S E T U P 

In [12]:
HYPER_X = hyper_x().to(device)

print(HYPER_X)


hyper_x(
  (hype): Sequential(
    (0): Linear(in_features=64, out_features=128, bias=True)
    (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (2): SiLU()
    (3): Linear(in_features=128, out_features=256, bias=True)
    (4): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (5): SiLU()
    (6): Linear(in_features=256, out_features=256, bias=True)
    (7): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    (8): SiLU()
    (9): Linear(in_features=256, out_features=128, bias=True)
    (10): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (11): SiLU()
  )
  (hype_out): Linear(in_features=128, out_features=64, bias=True)
  (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
)


### C R O S S - A T T E N T I O N

In [13]:
class cross_attention(nn.Module):
    
    def __init__(self, state_dim = state_dim, embed_dim = embed_dim, hidden_size = hidden_size):
        super(cross_attention, self).__init__()
        
        # Q
        
        self.query = nn.Linear(state_dim, hidden_size)
        self.norm = nn.LayerNorm(hidden_size)
        
        # K 
        
        self.key = nn.Linear(embed_dim, hidden_size)
        self.normk = nn.LayerNorm(hidden_size)
        
        # V 
        
        self.value = nn.Linear(embed_dim, hidden_size)
        self.normv = nn.LayerNorm(hidden_size)
        
        # attn out
        
        self.out = nn.Linear(hidden_size, hidden_size)
        self.norm_out = nn.LayerNorm(hidden_size)
        
    def forward(self, query, encoded_z):
        
        query = query.unsqueeze(0)    # [batch, 1, dim] as encoded is [ batch, context, dim ]
        
        Q = F.silu(self.norm(self.query(query)))         
        K = F.silu(self.normk(self.key(encoded_z)))            
        V = F.silu(self.normv(self.value(encoded_z)))
        
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / (K.size(-1) ** 0.5)
        attn_weights = F.softmax(attn_scores, dim = -1)
        
        attn_out = torch.matmul(attn_weights, V)
        
        attn_out = attn_out.squeeze(1)              # [ batch, dim ]
        
        out = F.silu(self.norm_out(self.out(attn_out)))
        
        return out        
        

### S E T U P

In [14]:
CROSS_ATTENTION = cross_attention().to(device)

print(CROSS_ATTENTION)


cross_attention(
  (query): Linear(in_features=17, out_features=128, bias=True)
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (key): Linear(in_features=64, out_features=128, bias=True)
  (normk): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (value): Linear(in_features=64, out_features=128, bias=True)
  (normv): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (out): Linear(in_features=128, out_features=128, bias=True)
  (norm_out): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)


### A C T O R 

In [15]:
class actor_net(nn.Module):
    
    def __init__(self, state_dim = state_dim, hidden_size = hidden_size, head_1 = head_1, head_2 = head_2, head_3 = head_3, head_4 = head_4, max_action = max_action):
        super(actor_net, self).__init__()
        
        # first layer
        
        input_dim = state_dim + hidden_size
        
        self.fc1 = nn.Linear(input_dim, head_1)
        self.norm = nn.LayerNorm(head_1)
        
        # parchio
        
        self.mod = nn.Linear(hidden_size, head_2)
        self.mod_norm = nn.LayerNorm(head_2)
        
        # second layer
        
        self.fc2 = nn.Linear(head_1, head_2)
        self.norm2 = nn.LayerNorm(head_2)
        
        # third layer
        
        self.fc3 = nn.Linear(head_2, head_3)
        self.norm3 = nn.LayerNorm(head_3)
        
        # forth layer
        
        self.fc4 = nn.Linear(head_3, head_4)
        self.norm4 = nn.LayerNorm(head_4)
        
        # mu and log std head
        
        self.mu = nn.Linear(head_4, action_dim)
        self.log_std = nn.Linear(head_4, action_dim)
        
        # max action
        
        self.max_action = max_action
        
    def forward(self, state, attn_out):
        
        cat = torch.cat([state, attn_out], dim = -1)
        
        x = F.silu(self.norm(self.fc1(cat)))
        
        x = F.silu(self.norm2(self.fc2(x)))
        
        mod = F.silu(self.mod_norm(self.mod(attn_out)))
                
        x = x + mod
        
        x = F.silu(self.norm3(self.fc3(x)))
        
        x = F.silu(self.norm4(self.fc4(x)))
        
        # mu and log std
        
        mu = self.mu(x)
        
        action = torch.tanh(mu) * self.max_action
        
        return action   


### S E T U P 

In [16]:
ACTOR_NETWORK = actor_net().to(device)

TARGET_ACTOR = copy.deepcopy(ACTOR_NETWORK).to(device)

print(ACTOR_NETWORK)


actor_net(
  (fc1): Linear(in_features=145, out_features=64, bias=True)
  (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  (mod): Linear(in_features=128, out_features=128, bias=True)
  (mod_norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (fc2): Linear(in_features=64, out_features=128, bias=True)
  (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (fc3): Linear(in_features=128, out_features=128, bias=True)
  (norm3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (fc4): Linear(in_features=128, out_features=64, bias=True)
  (norm4): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  (mu): Linear(in_features=64, out_features=6, bias=True)
  (log_std): Linear(in_features=64, out_features=6, bias=True)
)


### C R I T I C 

In [17]:
class critic_net(nn.Module):
    
    def __init__(self, state_dim = state_dim, action_dim = action_dim, hidden_size = hidden_size, head_1 = head_1, head_2 = head_2, head_3 = head_3, head_4 = head_4):
        super(critic_net, self).__init__()
        
        # first layer
        
        input_dim = state_dim + action_dim
        
        self.fc1 = nn.Linear(input_dim, head_1)
        self.norm = nn.LayerNorm(head_1)
        
        # second layer
        
        self.fc2 = nn.Linear(head_1, head_2)
        self.norm2 = nn.LayerNorm(head_2)
        
        # modulation
        
        self.mod = nn.Linear(hidden_size, head_2)
        self.norm_mod = nn.LayerNorm(head_2)
        
        # third layer
        
        self.fc3 = nn.Linear(head_2, head_3)
        self.norm3 = nn.LayerNorm(head_3)
        
        # forth layer
        
        self.fc4 = nn.Linear(head_3, head_4)
        self.norm4 = nn.LayerNorm(head_4)
        
        # critic head
        
        self.critic_head = nn.Linear(head_4, 1)
        self.critic_head_2 = nn.Linear(head_4, 1)
        
    def forward(self, state, action, attn_out):
        
        # cat
        
        cat = torch.cat([state, action], dim = -1)
        
        x = F.silu(self.norm(self.fc1(cat)))
        
        x = F.silu(self.norm2(self.fc2(x)))
        
        # modulation
        
        mod = F.silu(self.norm_mod(self.mod(attn_out)))
        x = x + mod
        
        # continue non linearity
        
        x = F.silu(self.norm3(self.fc3(x)))
        
        x = F.silu(self.norm4(self.fc4(x)))
        
        # critic head
        
        q1 = self.critic_head(x)
        q2 = self.critic_head_2(x)
        
        return q1, q2        


### S E T U P

In [18]:
CRITIC_NETWORK = critic_net().to(device)

TARGET_CRITIC = copy.deepcopy(CRITIC_NETWORK).to(device)

print(CRITIC_NETWORK)


critic_net(
  (fc1): Linear(in_features=23, out_features=64, bias=True)
  (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  (fc2): Linear(in_features=64, out_features=128, bias=True)
  (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (mod): Linear(in_features=128, out_features=128, bias=True)
  (norm_mod): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (fc3): Linear(in_features=128, out_features=128, bias=True)
  (norm3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (fc4): Linear(in_features=128, out_features=64, bias=True)
  (norm4): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
  (critic_head): Linear(in_features=64, out_features=1, bias=True)
  (critic_head_2): Linear(in_features=64, out_features=1, bias=True)
)


### O P T I M I Z E R

In [19]:
# lr

context_lr = 3e-5
critic_lr = 1e-4
actor_lr = 3e-4
hyper_x_lr = 3e-5
cross_attention_lr = 1e-5

T_max = 2800

# param

actor_params = ACTOR_NETWORK.parameters()
critic_params = CRITIC_NETWORK.parameters()
context_params = CONTEXT_ENCODER.parameters()
hyper_x_params = HYPER_X.parameters()
cross_attention_params = CROSS_ATTENTION.parameters()

# optimizer

CRITIC_OPTIMIZER = optim.AdamW([
    {'params': critic_params, 'lr': critic_lr, 'weight_decay': 1e-5}
])

OPTIMIZER = optim.AdamW([
    
    {'params': context_params, 'lr': context_lr, 'weight_decay': 1e-5},
    
    {'params': hyper_x_params, 'lr': hyper_x_lr, 'weight_decay': 1e-5},
    
    {'params': cross_attention_params, 'lr': cross_attention_lr, 'weight_decay': 1e-5},
    
    {'params': actor_params, 'lr': actor_lr, 'weight_decay': 0}  
    
])

# scheduler

CRITIC_SCHEDULER = optim.lr_scheduler.CosineAnnealingLR(
    CRITIC_OPTIMIZER,
    T_max=200,
    eta_min=1e-5
)


SCHEDULER = optim.lr_scheduler.CosineAnnealingLR(OPTIMIZER, T_max, eta_min = 1e-5)


### B U F F E R

In [20]:
class meta_buffer:
    
    def __init__(self, max_episodes):
        
        self.max_episodes = max_episodes
        self.current_episode = []
        self.episodes = []
        
    def add(self, state, action, reward, done, next_state):
        
        # safe tensor all
        
        state = safe_tensor(state)
        action = safe_tensor(action)
        reward = safe_tensor(reward)
        done = safe_tensor(done)
        next_state = safe_tensor(next_state)
        
        # add 
        
        self.current_episode.append({
            
            'states': state,
            'actions': action,
            'rewards': reward,
            'dones': done,
            'next_states': next_state
        })
        
        if done.item()  == 1:
            
            self.episodes.append(self.current_episode)
            self.current_episode = []
            
        if len(self.episodes) > self.max_episodes:
            
            for _ in range(5):
            
                self.episodes.pop(0)
            
    def sample(self, batch_size, fixed_length):
        
        sampled_ep = random.sample(self.episodes, k=min(batch_size, len(self.episodes)))
        
        segments = []
        masks = []
        
        for ep in sampled_ep:
            
            if len(ep) >= fixed_length:
                
                seg = ep[:fixed_length]
                mask = torch.ones(fixed_length, dtype = torch.float32)
                
            else:
                
                seg, mask = self.pad_episode(ep, fixed_length)
                
            segments.append(seg)
            masks.append(mask)
            
            
        def stack_field(x):
        
            return torch.stack([torch.stack([step[x] for step in seg]) for seg in segments])
        
        batch = {
            
            'states': stack_field('states'),
            'actions': stack_field('actions'),
            'rewards': stack_field('rewards'),
            'dones': stack_field('dones'),
            'next_states': stack_field('next_states'),
            'masks': torch.stack(masks).to(device)
            
        }
        
        for k in batch:
            
            batch[k] = batch[k].to(device)
        
        return batch
            
    def pad_episode(self, ep, fixed_length):
        
        pad_length = fixed_length - len(ep)
        
        last_step = ep[-1]
        
        pad_step = {}
        
        for k, v in last_step.items():
            
            if torch.is_tensor(v):
                
                pad_step[k] = v.clone()
                
            else:
                
                pad_step[k] = torch.zeros_like(v)
                
        
        mask = torch.cat([
            
          torch.ones(len(ep), dtype = torch.float32),
          torch.zeros(pad_length, dtype = torch.float32)  
            
        ])
        
        return ep + [pad_step] * pad_length, mask
    
    def clear(self):
        
        self.episodes.clear()
                

### S E T U P

In [21]:
max_episodes = 300

buffer = meta_buffer(max_episodes)


### M E T A - R U N N E R 

In [22]:
class meta_runner:
    
    def __init__(self, max_episode_length, env = env, buffer = buffer, CONTEXT_ENCODER = CONTEXT_ENCODER, HYPER_X = HYPER_X, CROSS_ATTENTION = CROSS_ATTENTION, ACTOR_NETWORK = ACTOR_NETWORK):
    
        # network
        
        self.encoder = CONTEXT_ENCODER
        self.hyper = HYPER_X
        self.attention = CROSS_ATTENTION
        self.actor =  ACTOR_NETWORK
        
        # buffer
        
        self.buffer = buffer
        
        # episodes
        
        self.max_ep_length = max_episode_length
        
        # env
                
        self.env = env
        
    def run(self, num_tasks):
        
        tasks = self.env.sample_tasks(num_tasks)
        
        for task in tasks:
        
            self.env.set_task(task)
            obs = self.env.reset()
            obs = safe_tensor(obs).unsqueeze(0)
            
            context = []
            
            for _ in range(self.max_ep_length):
                
                latent_z = self.get_latent(context)    
                
                hyper = self.hyper.forward(latent_z)
                    
                attn_out = self.attention.forward(obs, hyper)
                    
                action = self.actor.forward(obs, attn_out)
                    
                action_np = action.detach().cpu().numpy()[0]
                    
                next_state, reward, done = self.env.step(action_np)
                    
                next_state = safe_tensor(next_state).unsqueeze(0)
                reward = safe_tensor([reward]).unsqueeze(0)
                    
                self.buffer.add(obs.squeeze(0), action.squeeze(0), reward.squeeze(0), float(done), next_state.squeeze(0))
                    
                context.append({
                    
                    'states': obs,
                    'actions': action,
                    'rewards': reward,
                    'next_states': next_state

                })    
                    
                obs = next_state
                    
                if done:
                        
                    break
                    
    def build_context(self, context):
        
        return {
            
            'states': torch.cat([e['states'] for e in context], dim = 0),
            'actions': torch.cat([e['actions'] for e in context], dim = 0),
            'rewards': torch.cat([e['rewards'] for e in context], dim = 0),
            'next_states': torch.cat([e['next_states'] for e in context], dim = 0)
        }
        
    def get_latent(self, context):
        
        if len(context) != 0 :
            
            ctx = self.build_context(context)
            latent_z = self.encoder.forward(ctx['states'], ctx['actions'], ctx['rewards'], ctx['next_states'])
            
        else:
            
            latent_z = torch.zeros((1, embed_dim)).to(device)
        
        return latent_z
        


### S E T U P 

In [23]:
META_RUNNER = meta_runner(max_episode_length = 512)


### L O S S - F U N C 

In [24]:
class loss_func:
    
    def __init__(self, tau, gamma, noise_clip, policy_noise, policy_delay, max_action = max_action, ACTOR_NETWORK = ACTOR_NETWORK, CRITIC_NETWORK = CRITIC_NETWORK, HYPER_X = HYPER_X, CROSS_ATTENTION = CROSS_ATTENTION, CONTEXT_ENCODER = CONTEXT_ENCODER, OPTIMIZER = OPTIMIZER, SCHEDULER = SCHEDULER, buffer = buffer, CRITIC_OPTIMIZER = CRITIC_OPTIMIZER, CRITIC_SCHEDULER = CRITIC_SCHEDULER, TARGET_ACTOR = TARGET_ACTOR, TARGET_CRITIC = TARGET_CRITIC):
        
        # networks
        
        self.actor = ACTOR_NETWORK
        self.critic = CRITIC_NETWORK
        self.hyper = HYPER_X
        self.attention = CROSS_ATTENTION
        self.encoder = CONTEXT_ENCODER
        
        self.target_actor = TARGET_ACTOR
        self.target_critic = TARGET_CRITIC
        
        # buffer
        
        self.buffer = buffer
        
        # optimizer
        
        self.optimizer = OPTIMIZER
        self.schduler = SCHEDULER
        self.critic_opt = CRITIC_OPTIMIZER
        self.critic_sch = CRITIC_SCHEDULER
        
        # hyper params
        
        self.step = 0
        self.actor_update = 0
        self.tau = tau
        self.gamma = gamma
        self.noise_clip = noise_clip
        self.policy_noise = policy_noise
        self.max_action = max_action
        self.policy_delay = policy_delay
        
    def soft_update(self, source, target):
        
        with torch.no_grad():        
            
            for param, target_param in zip(source.parameters(), target.parameters()):
                
                target_param.data.copy_(param * self.tau + target_param * (1 - self.tau))
                
    def critic_loss(self, states, actions, rewards, dones, next_states, hyper, attn_out):
        
        with torch.no_grad():
            
            next_attn = self.attention.forward(next_states, hyper)
            next_attn = next_attn.squeeze(0)
            next_action = self.target_actor.forward(next_states, next_attn) 
            noise = (torch.randn_like(next_action) * self.policy_noise).clamp(-self.noise_clip, self.noise_clip)
            next_action = (next_action + noise).clamp(-self.max_action, self.max_action)
            target_1, target_2 = self.target_critic.forward(next_states, next_action, next_attn)
            target_val = torch.min(target_1, target_2)
            target_q = rewards + self.gamma * (1 - dones) * target_val
        
        c1, c2 = self.critic.forward(states, actions.detach(), attn_out.detach())
        
        loss = F.mse_loss(c1, target_q) + F.mse_loss(c2, target_q)
        
        return loss
    
    def actor_loss(self, states, attn_out):
        
        new_action = self.actor.forward(states, attn_out.detach())
        q1, _ = self.critic.forward(states, new_action.detach(), attn_out.detach())
        
        actor_loss = -q1.mean()
        
        return actor_loss
    
    def update(self, batch_size, fixed_length):
        
        self.step += 1

        batch = self.buffer.sample(batch_size, fixed_length)

        states = batch['states']
        actions = batch['actions']
        rewards = batch['rewards']
        dones = batch['dones'].unsqueeze(2)
        next_states = batch['next_states']
        
        # Encoder
        
        latent_z_critic = self.encoder.forward(states, actions, rewards, next_states)
        
        # hyper
        
        hyper_critic = self.hyper.forward(latent_z_critic)
        
        # attention
        
        attn_critic = self.attention.forward(states, latent_z_critic.detach())
        
        attn_critic = attn_critic.squeeze(0)
        
        # critic loss
        
        critic_loss = self.critic_loss(states, actions, rewards, dones, next_states, hyper_critic, attn_critic)
        
        # optimize critic
        
        self.critic_opt.zero_grad()
        critic_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm = 0.5)
        self.critic_opt.step()
        self.schduler.step()
        
        # actor loss
        
        actor_loss = torch.tensor(0.0)
        
        if self.step % self.policy_delay == 0:
            
            latent_Z_actor = self.encoder.forward(states, actions, rewards, next_states)
            
            hyper_Actor = self.hyper.forward(latent_Z_actor)
            
            attn_out = self.attention.forward(states, hyper_Actor.detach())
            
            attn_out = attn_out.squeeze(0)
            
            actor_loss = self.actor_loss(states, attn_out.detach())
            
            # optimize
            
            self.optimizer.zero_grad()
            actor_loss.backward()
            torch.nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm = 0.5)
            torch.nn.utils.clip_grad_norm_(self.hyper.parameters(), max_norm = 0.5)
            torch.nn.utils.clip_grad_norm_(self.attention.parameters(), max_norm = 0.5)
            torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), max_norm = 0.5)
            self.optimizer.step()
            self.schduler.step()
            
            # soft update
            
            self.soft_update(self.actor, self.target_actor)
            
        self.soft_update(self.critic, self.target_critic)
        
        return actor_loss, critic_loss, self.step
            

### S E T U P

In [25]:
# hyper params

tau = 0.005
gamma = 0.99
noise_clip = 0.3
policy_noise = 0.2
policy_delay = 2


# setup

LOSS_FUNCTION = loss_func(tau, gamma, noise_clip, policy_noise, policy_delay = policy_delay)


### T R A I N I N G

In [26]:
def train(epochs = 20, mini_batch = 64, num_tasks = 10, batch_size = 512, fixed_length = 512, buffer = buffer):


    total_actor_loss, total_critic_loss = 0.0, 0.0

    for epoch in range(epochs):
        
        total_actor_loss, total_critic_loss = 0.0, 0.0
        
        LOSS_FUNCTION.step = 0
        
        for _ in range(mini_batch):
            
            buffer.clear()
            
            META_RUNNER.run(num_tasks)
            
            actor_loss, critic_loss, actor_update = LOSS_FUNCTION.update(batch_size, fixed_length)
            
            total_actor_loss += actor_loss.item()
            total_critic_loss += critic_loss.item()
            
        avg_actor_loss = total_actor_loss / actor_update
        avg_critic_loss = total_critic_loss / mini_batch
        
        writer.add_scalar('Actor loss', avg_actor_loss, epoch)
        writer.add_scalar('Critic loss', avg_critic_loss, epoch)
        
        writer.flush()
        
        print(f'epoch: {epoch} | avg actor loss: {avg_actor_loss:.3f} | avg critic loss: {avg_critic_loss:.3f}')


In [27]:
train()


epoch: 0 | avg actor loss: 0.457 | avg critic loss: 6.416
epoch: 1 | avg actor loss: 0.744 | avg critic loss: 2.297
epoch: 2 | avg actor loss: 0.860 | avg critic loss: 1.715
epoch: 3 | avg actor loss: 0.967 | avg critic loss: 1.157
epoch: 4 | avg actor loss: 0.909 | avg critic loss: 0.759
epoch: 5 | avg actor loss: 0.578 | avg critic loss: 0.527
epoch: 6 | avg actor loss: 0.426 | avg critic loss: 0.414
epoch: 7 | avg actor loss: 0.351 | avg critic loss: 0.311
epoch: 8 | avg actor loss: 0.367 | avg critic loss: 0.288
epoch: 9 | avg actor loss: 0.333 | avg critic loss: 0.234
epoch: 10 | avg actor loss: 0.373 | avg critic loss: 0.203
epoch: 11 | avg actor loss: 0.302 | avg critic loss: 0.152
epoch: 12 | avg actor loss: 0.242 | avg critic loss: 0.148
epoch: 13 | avg actor loss: 0.189 | avg critic loss: 0.113
epoch: 14 | avg actor loss: 0.116 | avg critic loss: 0.141
epoch: 15 | avg actor loss: 0.049 | avg critic loss: 0.089
epoch: 16 | avg actor loss: 0.005 | avg critic loss: 0.080
epoch: 