In [1]:
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count, chain
import tqdm
import copy

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.autograd import Variable


from nosaveddata.nsd_utils.save_hypers import Hypers, nsd_Module
from nosaveddata.nsd_utils.nsd_csv import add_to_csv
from nosaveddata.nsd_utils.networks import params_count, params_and_grad_norm, seed_np_torch
from nosaveddata.nsd_utils.einstein import Rearrange
from nosaveddata.nsd_utils.dreamer import symlog, symexp, two_hot, two_hot_view, two_hot_view_no_symlog, ReturnsNormalizer

from nosaveddata.builders.mlp import *
from nosaveddata.builders.weight_init import *
from nosaveddata.builders.resnet import IMPALA_Resnet, DQN_Conv, IMPALA_YY, Residual_Block

from utils.experience_replay import *


import locale
locale.getpreferredencoding = lambda: "UTF-8"

import wandb


# Environment configuration
#env_name = 'Kangaroo'
#SEED = 8712

env_name = 'Assault'
SEED = 7783


wandb.init(
    project="Atari-100k-Efficient Zero",
    name=f"EffZ-{env_name}",

    #id='rotdmtc5',
    #resume='must',

    config={
        "learning_rate": 1e-4,
        "architecture": "Efficient Zero",
        "dataset": "Assault",
        "epochs": 100,
    },

    reinit=False
)


# Optimization
batch_size = 256
lr=3e-4

eps=1e-8


# Target network EMA rate
#critic_ema_decay=0.995
critic_ema_decay=0.98
#critic_ema_decay=0.95


# Return function
initial_gamma=torch.tensor(1-0.97).log()
final_gamma=torch.tensor(1-0.997).log()

#initial_n = 10
#final_n = 3

initial_n = 5
final_n = 5

num_buckets=51

n_sim = 16
topk_actions = 8


# Reset Schedule and Buffer
#reset_every=40000 # grad steps, not steps.
reset_every=100000 # grad steps, not steps.
schedule_max_step=reset_every//4
total_steps=102000

prefetch_cap=4 # actually, no prefetch is being done



Transition = namedtuple('Transition',
                        ('state', 'reward', 'action', 'c_flag'))
memory = PrioritizedReplay_nSteps_Sqrt(total_steps+5, total_steps=schedule_max_step, prefetch_cap=prefetch_cap, alpha=1, beta=1)





returns_normalizer = ReturnsNormalizer(0.99)


def save_checkpoint(net, model_target, optimizer, step, path):
    torch.save({
            'model_state_dict': net.state_dict(),
            'model_target_state_dict': model_target.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'step': step,
            }, path)


  torch.utils._pytree._register_pytree_node(
[34m[1mwandb[0m: Currently logged in as: [33msnykralafk[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

In [2]:
# Adapted from: https://github.com/weipu-zhang/STORM/blob/main/env_wrapper.py
class MaxLast2FrameSkipWrapper(Hypers, gym.Wrapper):
    def __init__(self, env, skip=4, noops=30, seed=0):
        super().__init__(env=env)
        self.env.action_space.seed(seed)
        
    def reset(self, **kwargs):
        kwargs["seed"] = self.seed
        obs, _ = self.env.reset(**kwargs)

        return obs, _
        
    def noop_steps(self, states):
        noops = random.randint(0,self.noops)
        
        for i in range(noops):
            state = self.step(np.array([0]))[0]
            state = preprocess(state)
            states.append(state)
        return states

    def step(self, action):
        total_reward = 0
        self.obs_buffer = deque(maxlen=2)
        for _ in range(self.skip):
            obs, reward, done, truncated, info = self.env.step(action)
            self.obs_buffer.append(obs)
            total_reward += reward

            terminated = np.logical_or(done, truncated)
            #if terminated.any():
            #    for i in range(len(terminated)):
            #       obs[i] = self.reset()[0][i]
            if done or truncated:
                break
        if len(self.obs_buffer) == 1:
            obs = self.obs_buffer[0]
        else:
            obs = np.max(np.stack(self.obs_buffer), axis=0)
        return obs, total_reward, done, truncated, info
        # Life loss is calculated on the training code

env = gym.vector.make(f"{env_name}NoFrameskip-v4", num_envs=1)
env = MaxLast2FrameSkipWrapper(env,seed=SEED)


#n_actions = env.action_space.n
n_actions = env.action_space[0].n

state, info = env.reset()
n_observations = len(state)



seed_np_torch(SEED)

print(f"N° of actions: {n_actions}.")

  gym.logger.warn(


N° of actions: 7.


In [3]:


class EffZ_Perception(nsd_Module):
    def __init__(self, n_actions, scale_width=1, act=nn.ReLU()):
        super().__init__()
        
        
        self.conv1 = nn.Sequential(nn.Conv2d(12, 32, 3, stride=2, padding=1, bias=False),
                                   nn.BatchNorm2d(32),
                                   act)
        self.conv2 = Residual_Block(32, 32, act=act, out_act=act)
        self.conv3 = Residual_Block(32, 64, act=act, out_act=act, stride=2)
        self.conv4 = Residual_Block(64, 64, act=act, out_act=act)
        self.pool1 = nn.AvgPool2d(3, stride=2, padding=1)
        self.conv5 = Residual_Block(64, 64, act=act, out_act=act)
        self.pool2 = nn.AvgPool2d(3, stride=2, padding=1)
        self.conv6 = Residual_Block(64, 64, act=act, out_act=act)
        
        self.conv1.apply(init_xavier)
        
        self.conv = nn.Sequential(self.conv1, self.conv2, self.conv3, self.conv4, self.pool1,
                                   self.conv5, self.pool2, self.conv6)
        
    def forward(self, X):
        X = self.conv(X)
        return X

class _1conv_residual(nn.Module):
    def __init__(self, hiddens, act=nn.ReLU()):
        super().__init__()
        
        self.net = nn.Sequential(nn.Conv2d(hiddens+1, hiddens, 3, padding=1, bias=False),
                                        nn.BatchNorm2d(hiddens))
        
    def forward(self, x):
        proj = x[:,:-1]
        x = self.net(x)
        
        return x+proj
        
class RewardPred(nsd_Module):
    def __init__(self, in_channels, out_channels, in_hiddens, hiddens, bottleneck=32, out_dim=51, act=nn.ReLU(), k=5):
        super().__init__()
        
        self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1),
                                 nn.BatchNorm2d(out_channels),
                                 act)
        
        self.lstm = nn.LSTMCell(in_hiddens, hiddens)
        self.norm_relu = nn.Sequential(nn.BatchNorm1d(hiddens))
        
        self.mlp = MLP_LayerNorm(hiddens, bottleneck, out_dim, layers=2, in_act=act, init=init_xavier, last_init=init_zeros, out_act=nn.Softmax(-1))
        
    def forward(self, x):
        bs, seq = x.shape[:2]
        
        x = self.conv(x.view(bs*seq, *x.shape[-3:])).view(bs,seq,-1)
        
        ht = torch.zeros(x.shape[0], self.hiddens, device='cuda')
        ct = torch.zeros_like(ht)
        
        hs = []
        for i in range(self.k):
            
            ht, ct = self.lstm(x[:,i], (ht, ct))
            hs.append(ht)
        hs = torch.stack(hs,1)
        
        x = self.mlp(hs)
        
        return x
    
    def transition_one_step(self, x, ht):
        
        x = self.conv(x).view(x.shape[0],-1)
        #print('reward one step', x.shape)
        
        ht, ct = self.lstm(x, ht)
        
        x = self.mlp(ht)
        
        return x, (ht,ct)
        

class ActorCritic(nsd_Module):
    def __init__(self, in_channels, out_channels, in_hiddens, bottleneck=32, out_value=51, out_policy=1, act=nn.ReLU()):
        super().__init__()
        
        self.residual = Residual_Block(in_channels, in_channels, act=self.act, out_act=self.act)
        
        conv_policy = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1),
                                 nn.BatchNorm2d(out_channels),
                                 act)
        conv_value  = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1),
                                 nn.BatchNorm2d(out_channels),
                                 act)
        
        self.policy = nn.Sequential(conv_policy,
                                    nn.Flatten(-3,-1),
                                   MLP_LayerNorm(in_hiddens, bottleneck, out_policy, layers=2, in_act=act, init=init_xavier, last_init=init_zeros))
        self.value = nn.Sequential(conv_policy,
                                    nn.Flatten(-3,-1),
                                   MLP_LayerNorm(in_hiddens, bottleneck, out_value,  layers=2, in_act=act,
                                                 out_act=nn.Softmax(dim=-1), init=init_xavier, last_init=init_zeros))
        
        
    def forward(self, x):
        bs, seq = x.shape[:2]
        
        x = self.residual(x.view(-1, *x.shape[-3:]))
        
        logits = self.policy(x).view(bs, seq, -1)
        probs = F.softmax(logits, -1)
        
        value_probs = self.value(x).view(bs, seq, -1)
        
        return logits, probs, value_probs
        
    def one_step(self, x):
        bs = x.shape[0]
        
        x = self.residual(x.view(-1, *x.shape[-3:]))
        
        logits = self.policy(x).view(bs, -1)
        probs = F.softmax(logits, -1)
        
        value_probs = self.value(x).view(bs, -1)
        
        return logits, probs, value_probs
        
    
class EfficientZero(nsd_Module):
    def __init__(self, n_actions, hiddens=512, mlp_layers=1, scale_width=1,
                 n_atoms=51, Vmin=-20, Vmax=20):
        super().__init__()
        self.support = torch.linspace(Vmin, Vmax, n_atoms).cuda()
        self.reward_support = torch.linspace(-2, 2, n_atoms).cuda()
        
        self.hiddens=hiddens
        self.scale_width=scale_width
        self.act = nn.ReLU()
        
        
        #self.encoder_cnn = IMPALA_Resnet(scale_width=scale_width, norm=False, init=init_xavier, act=self.act)
        self.encoder_cnn = EffZ_Perception(n_actions, scale_width)
        

        self.projection = MLP_LayerNorm(2304*scale_width, hiddens, hiddens*2,
                                        init=init_xavier, last_init=init_xavier, layers=3, in_act=self.act,
                                        add_last_norm=False)
        self.prediction = MLP_LayerNorm(hiddens*2, hiddens, hiddens*2, layers=2,
                                        init=init_xavier, last_init=init_xavier,
                                        in_act=self.act, add_last_norm=False)
        
                                       
            
        self.transition = nn.Sequential(_1conv_residual(64, self.act),
                                        Residual_Block(64, 64, act=self.act, out_act=self.act))
        
        self.reward_pred = RewardPred(64, 16, 16*((96//16)**2), hiddens)
        #self.reward_pred = RewardPred(64, 16, 576, hiddens)

        
        self.ac = ActorCritic(64, 16, 16*((96//16)**2), out_policy=n_actions)
    
        params_count(self, 'Efficient Zero Network')
        
    
    def forward(self, X, y_action):
        z_proj, z = self.encode(X)
        
        
        #q, action = self.q_head(X)
        logits, probs, value_probs = self.ac(z)
        
        z_proj_pred, reward_pred = self.get_transition(z[:,0][:,None], y_action)

        #return q, action, X[:,1:].clone().detach(), z_pred
        return z_proj, z_proj_pred, reward_pred, logits, probs, value_probs
    
    def get_root(self, X):
        z = self.encode_z(X)
        logits, probs, value_probs = self.ac(z)
        
        return z.squeeze(1), logits.squeeze(1), probs.squeeze(1), value_probs.squeeze(1)
        #return z, logits, probs, value_probs
        
    def encode(self, X):
        batch, seq = X.shape[:2]
        self.batch = batch
        self.seq = seq
        X = self.encoder_cnn(X.contiguous().view(self.batch*self.seq, *(X.shape[2:])))
        X = X.contiguous().view(self.batch, self.seq, *X.shape[-3:])
        z = X.clone()
        
        X = X.flatten(-3,-1)
        
        X = self.projection(X)
        return X, z
    
    def encode_z(self, X):
        batch, seq = X.shape[:2]
        self.batch = batch
        self.seq = seq
        X = self.encoder_cnn(X.contiguous().view(self.batch*self.seq, *(X.shape[2:])))
        X = X.contiguous().view(self.batch, self.seq, *X.shape[-3:])

        return X

    def env_step(self, X):
        with torch.no_grad():
            z = self.encode_z(X)
            _, probs, _ = self.ac(z)
    
    
            #return probs.argmax(-1)
            return torch.multinomial(probs.squeeze(), 1) 
            
    
    def get_zero_ht(self, batch_size):
        ht = torch.zeros(batch_size, self.hiddens, device='cuda')
        ct = torch.zeros_like(ht)
        return (ht, ct)
    
    def transition_one_step(self, z, action, ht):
        
        z = z.contiguous().view(-1, *z.shape[-3:])
        
        action_one_hot = (
            torch.ones(
                (
                    z.shape[0],
                    z.shape[2],
                    z.shape[3],
                )
            )
            .to(action.device)
            .float()
        )
        
        action = (action[:, None, None] * action_one_hot / self.n_actions)[:,None]
        #print('one step', z.shape, action.shape)
        z_pred = torch.cat( (z, action), 1)
        z_pred = self.transition(z_pred)

        
        
        
        reward_pred, ht = self.reward_pred.transition_one_step(z_pred, ht)

        logits, probs, value_probs = self.ac.one_step(z_pred)
        

        
        return z_pred, logits, probs, value_probs, reward_pred, ht
    
    def get_transition(self, z, action):
        z = z.contiguous().view(-1, *z.shape[-3:])
        
        action_one_hot = (
            torch.ones(
                (
                    z.shape[0],
                    5,
                    z.shape[2],
                    z.shape[3],
                )
            )
            .to(action.device)
            .float()
        )
        
        action = (action[:, :, None, None] * action_one_hot / self.n_actions)[:,:,None]

        #print('transition full', z.shape, action.shape)
        z_pred = torch.cat( (z, action[:,0]), 1)
        z_pred = self.transition(z_pred)
        
        
        z_preds=[z_pred.clone()]
        

        for k in range(4):
            z_pred = torch.cat( (z_pred, action[:,k+1]), 1)
            z_pred = self.transition(z_pred)
            
            
            z_preds.append(z_pred)
        
        
        z_pred = torch.stack(z_preds,1)
        
        reward_pred = self.reward_pred(z_pred)
        
        #print('transition full z_pred reward_pred', z_pred.shape, reward_pred.shape)

        z_proj_pred = self.projection(z_pred.flatten(-3,-1)).view(self.batch,5,-1)
        z_proj_pred = self.prediction(z_proj_pred)
        
        return z_proj_pred, reward_pred

    
    def evaluate(self, X):
        z = self.encode_z(X)
        values = self.ac(z)[-1]
        values = (values*symexp(self.support)).sum(-1)
        
        return values
        
    
    def network_ema(self, rand_network, target_network, alpha=0.5):
        for param, param_target in zip(rand_network.parameters(), target_network.parameters()):
            param_target.data = alpha * param_target.data + (1 - alpha) * param.data.clone()

    def hard_reset(self, random_model, alpha=0.5):
        with torch.no_grad():
            
            self.network_ema(random_model.encoder_cnn, self.encoder_cnn, alpha)
            self.network_ema(random_model.transition, self.transition, alpha)

            self.network_ema(random_model.projection, self.projection, 0)
            self.network_ema(random_model.prediction, self.prediction, 0)
            self.network_ema(random_model.reward_mlp, self.reward_mlp, 0)

            self.network_ema(random_model.a, self.a, 0)
            self.network_ema(random_model.v, self.v, 0)


def copy_states(source, target):
    for key, _ in zip(source.state_dict()['state'].keys(), target.state_dict()['state'].keys()):

        target.state_dict()['state'][key]['exp_avg_sq'] = copy.deepcopy(source.state_dict()['state'][key]['exp_avg_sq'])
        target.state_dict()['state'][key]['exp_avg'] = copy.deepcopy(source.state_dict()['state'][key]['exp_avg'])
        target.state_dict()['state'][key]['step'] = copy.deepcopy(source.state_dict()['state'][key]['step'])
        
def target_model_ema(model, model_target, decay=critic_ema_decay):
    with torch.no_grad():
        for param, param_target in zip(model.parameters(), model_target.parameters()):
            param_target.data = decay * param_target.data + (1.0 - decay) * param.data.clone()


model=EfficientZero(n_actions).cuda()
model_target=EfficientZero(n_actions).cuda()
#model_reanalyze=DQN(n_actions).cuda()

model_target.load_state_dict(model.state_dict())
#model_reanalyze.load_state_dict(model.state_dict())


Efficient Zero Network Parameters: 5.80M
Efficient Zero Network Parameters: 5.80M


<All keys matched successfully>

In [4]:
class MCTS_Node(nsd_Module):
    def __init__(self, z, value, logits, probs, ht, reward, prev_state, n_actions):
        super().__init__()
        
        self.transitions = [None]*self.n_actions
        
        self.n = torch.zeros(n_actions, device='cuda')
        #print(f"logits {logits.shape}")

        self.Q = torch.zeros_like(logits)

        self.choosen_action = torch.tensor(-1, device='cuda', dtype=torch.long)
        
    def reset_n(self):
        self.n = torch.zeros(self.n_actions, device='cuda')

    def get_stats(self):
        
        return self.Q, self.logits, self.probs, self.value, self.ht, self.z, self.n, self.reward, self.choosen_action
        
    def forward(self, x):

        return x


class MCTS(nsd_Module):
    def __init__(self, n_actions, k=5, batch_size=3, n_sim=16, topk_actions=8, c_visit=50, c_scale=0.1):
        # Good c_scale values are 0.1 and 1
        super().__init__()
        self.topk_actions = torch.tensor(topk_actions)
    

    def get_root(self, model, x):
        z, logits, probs, value_probs = model.get_root(x)
        value = (value_probs*symexp(model.support)).sum(-1)
        ###print(f"get root z value {z.shape, value.shape}")
        ht = torch.zeros(z.shape[0], 512, device='cuda')
        '''
        print(f"ROOT ROOT")
        if Q.shape[0]>1:
            print(f"{z.sum(), Q[0]==Q[1]}")
            print(f"{z.shape, Q.shape}")
        '''
        nodes = []
        for i in range(z.shape[0]):
            root = MCTS_Node(z[i], value[i], logits[i], probs[i], (ht[i], ht[i]), torch.tensor([0]*self.batch_size).cuda()[i], prev_state=None, n_actions=self.n_actions)
            nodes.append(root)
        self.root = nodes
        #print(f"ROOT ROOT")
        return self.root

    def collate_nodes(self):
        Q, logits, probs, values, hts, cts, Z, N, R, A = [], [], [], [], [], [], [], [], [], []
        for node in self.cur_state:
            q, logit, prob, value, ht, z, n, r, a = node.get_stats()
            Q.append(q)
            logits.append(logit)
            probs.append(prob)
            values.append(value)
            hts.append(ht[0])
            cts.append(ht[1])
            Z.append(z)
            N.append(n)
            R.append(r)
            A.append(a)
            
        return torch.stack(Q,0), torch.stack(logits,0), torch.stack(probs,0), torch.stack(values,0), (torch.stack(hts,0), torch.stack(cts,0)), torch.stack(Z,0), \
                torch.stack(N,0), torch.stack(R,0), torch.stack(A,0)
    
    def transition(self, model, x, action, ht):
        '''
        print(f"TRANSITION TRANSITION")
        if x.shape[0]>1:
            print(f"{x[0]==x[1]}")
        '''
        
        z, logits, probs, value_probs, reward_pred, ht = model.transition_one_step(x, action, ht)
        value = (value_probs*symexp(model.support)).sum(-1)
        reward_pred = (reward_pred*model.reward_support).sum(-1)
        
        ###print(f"mcts transition {ht[0].shape, ht[1].shape}")
        
        '''
        print(f"TRANSITION {z.shape, Q.shape, reward_pred.shape}")
        if Q.shape[0]>1:
            print(f"{Q[0]==Q[1]}")
            print(f"{z[0].sum(), z[1].sum()}")
        '''
        nodes = []
        
        for i in range(z.shape[0]):
            if self.cur_state[i].transitions[action[i]] == None:
                node = MCTS_Node(z[i], value[i], logits[i], probs[i], (ht[0][i], ht[1][i]), reward_pred[i], prev_state=self.cur_state[i], n_actions=self.n_actions)
                nodes.append(node)
                self.cur_state[i].transitions[action[i]] = node
            else:
                nodes.append(self.cur_state[i].transitions[action[i]])
                

        return nodes


    
    '''  BACKUP  '''
    def backup(self, model):
        
        q, logits, probs, values, ht, z, n, r_t, choosen_action = self.collate_nodes()
        
        _, _, value_probs = model.ac.one_step(z)
        next_values = (value_probs*symexp(model.support)).sum(-1)
        
        

        gammas = torch.ones(self.batch_size, self.k, device='cuda')*0.997
        

        for i in range(len(self.cur_state)):
            self.cur_state[i] = self.cur_state[i].prev_state
            

        for l in range(self.k):
            Q, logits, probs, values, ht, z, n, r_t, choosen_action = self.collate_nodes()
            

            ###print(f"td {(r*gammas[:,:l+1].cumprod(-1)).sum(-1).shape, next_values.shape, (gammas[:,:l+1].prod(-1)).shape}")
            #print(f"{r_t.shape, (next_values*(gammas[:,:l+1].prod(-1))).shape}")
            
            returns = r_t + next_values*(gammas[:,:l+1].prod(-1))
            
            #print(f"backup returns {returns.shape}")
            #print(f"n choosen action {n}\n{choosen_action}\n")
            
            n_action = n[torch.arange(self.batch_size), choosen_action]
            
            ###print(f"Q {Q.shape, n.shape, choosen_action.shape}")
            ###print(f"{self.batch_size, choosen_action}")

            
            #Q[torch.arange(self.batch_size), choosen_action] = (n_action*Q[torch.arange(self.batch_size),choosen_action] + returns) / (n_action+1)
            Q[torch.arange(self.batch_size), choosen_action] += returns
            
            n[torch.arange(self.batch_size), choosen_action] += 1

            
            
            
            for i in range(len(self.cur_state)):
                self.cur_state[i].Q = Q[i]
                self.cur_state[i].n = n[i]
                
                self.cur_state[i] = self.cur_state[i].prev_state
                
        
    def forward(self, model, x, w_gumbel=1):
        with torch.no_grad():
    
            self.cur_state = self.get_root(model, x)
    
            logits = self.collate_nodes()[1]
            gumbel = F.gumbel_softmax(torch.zeros_like(logits)) * w_gumbel
            gumbel_sa = gumbel+logits
            
    
            k = min(self.topk_actions, self.n_actions)
            
            Q_mask = F.one_hot(gumbel_sa.topk(k)[1], self.n_actions).sum(-2)
            #print(f"TOP K {gumbel_sa.topk(k)[1]}")
            #print(f"STARTING Q MASK {Q_mask}")
            
            
    
            halve_sims = self.n_sim//2
            halves = 2
    
            next_halve = math.floor(self.n_sim/(torch.log2(self.topk_actions)*(self.topk_actions)))

            
            for sim in range(self.n_sim):
                actions_to_step = []
                improved_policies = []
                for l in range(self.k):
                    q, logits, probs, values, ht, z, n, _, _ = self.collate_nodes()
                    
                    if l==0 and sim==next_halve:
                        gumbel = F.gumbel_softmax(torch.zeros_like(logits)) * w_gumbel
                        
                        gumbel_sa = gumbel + logits
                        k = min(self.topk_actions//halves*2, self.n_actions)
                        
                        
                        Q_mask = F.one_hot(gumbel_sa.topk(k)[1], self.n_actions).sum(-2)
                        
    
                        next_halve += math.floor(self.n_sim/(torch.log2(self.topk_actions)*(self.topk_actions/(halves*2))))
                        
                        halves*=2

                    v_mix = (1/(1+n.sum(-1))) * ( values + (n.sum(-1)/(((n!=0)*probs+0.01).sum(-1)) ) * ((n!=0)*q*probs).sum(-1) )
                    v_mix = v_mix[:,None].repeat_interleave(q.shape[-1], -1)
                        
                    completed_Q = (n==0)*v_mix + (n>0)*q
                    
                    
                    completed_Q = (self.c_visit + n.max(-1)[0][:,None])*self.c_scale*completed_Q

                    improved_policy = F.softmax((logits+completed_Q),-1)
                    improved_policies.append(improved_policy)

                    '''
                    if q.shape[0]>1:
                        print(f"Mask {Q_mask}")
                        print(f"Q {q}")
                        print(f"completed Q {completed_Q}")
                        print(f"imp. policy {improved_policy}")
                        print(f"")
                    '''
                        
                    if l==0:
                        Q = gumbel_sa + (self.c_visit + n.max(-1)[0][:,None])*self.c_scale*q/(sim+1)

                        action = (Q*Q_mask).argmax(-1) 
                        
                    else:
                        Q = improved_policy - n/(1+n.sum(-1)[:,None])
    
                        action = Q.argmax(-1)
                        
    
                    
                    
                    
                    actions_to_step.append(Q.argmax(-1))
                    
                    
                    for i in range(len(self.cur_state)):
                        self.cur_state[i].choosen_action = action[i]
                        
                    self.cur_state = self.transition(model, z, action, ht)
                    
                self.backup(model)
                self.cur_state = self.root
                
            
            Q, logits, probs, values, ht, z, n, _, _ = self.collate_nodes()


            
            '''
            v_mix = (1/(1+n.sum(-1))) * ( values + (n.sum(-1)/(((n!=0)*probs+0.01).sum(-1)) ) * ((n!=0)*Q*probs).sum(-1) )
            v_mix = v_mix[:,None].repeat_interleave(Q.shape[-1], -1)
            
            completed_Q = (n==0)*v_mix + (n>0)*Q
            
            
            completed_Q = (self.c_visit + n.max(-1)[0][:,None])*self.c_scale*completed_Q
                        
                        
            improved_policy = F.softmax((logits+completed_Q),-1)
            '''

            
            return Q.max(-1)[0]/self.n_sim, torch.stack(improved_policies,1), [torch.multinomial(improved_policy,1)]


mcts           = MCTS(n_actions, n_sim = n_sim, topk_actions = topk_actions, batch_size = batch_size)
mcts_inference = MCTS(n_actions, n_sim = n_sim, topk_actions = topk_actions, batch_size = 1)

In [5]:
perception_modules=[model.encoder_cnn, model.transition]
actor_modules=[model.prediction, model.projection, model.ac, model.reward_pred]

params_wm=[]
for module in perception_modules:
    for param in module.parameters():
        if param.requires_grad==True: # They all require grad
            params_wm.append(param)

params_ac=[]
for module in actor_modules:
    for param in module.parameters():
        if param.requires_grad==True:
            params_ac.append(param)


#optimizer = torch.optim.AdamW(chain(params_wm, params_ac),
#                                lr=lr, weight_decay=1e-4)
optimizer = torch.optim.SGD(chain(params_wm, params_ac), momentum=0.9,
                                lr=0.2, weight_decay=1e-4)

In [6]:
import torchvision.transforms as transforms

train_tfms = transforms.Compose([
                         transforms.Resize((96,96)),
                        ])


def preprocess(state):
    state=torch.tensor(state, dtype=torch.float, device='cuda') / 255
    state=train_tfms(state.permute(0,3,1,2))
    return state

# https://github.com/google/dopamine/blob/master/dopamine/jax/agents/dqn/dqn_agent.py
def linearly_decaying_epsilon(decay_period, step, warmup_steps, epsilon):
    steps_left = decay_period + warmup_steps - step
    bonus = (1.0 - epsilon) * steps_left / decay_period
    bonus = np.clip(bonus, 0., 1. - epsilon)
    return epsilon + bonus


#def epsilon_greedy(Q_action, step, final_eps=0, num_envs=1):
def epsilon_greedy(actions_to_step, step, final_eps=0, num_envs=1):
    epsilon = linearly_decaying_epsilon(2001, step, 2000, final_eps)
    
    if random.random() < epsilon:
        #action = torch.randint(0, n_actions, (num_envs,), dtype=torch.int64, device='cuda').squeeze(0)
        action = torch.randint(0, n_actions, (num_envs,), dtype=torch.int64).squeeze(0)
    else:
        #action = Q_action.view(num_envs).squeeze(0).to(torch.int64)
        action = actions_to_step.pop(0).squeeze()
    return action


In [7]:
mse = torch.nn.MSELoss(reduction='none')


scaler = torch.cuda.amp.GradScaler()
def optimize(step, grad_step, n):
        
    model.train()
    model_target.train()

    with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=False):
        with torch.no_grad():
            states, next_states, rewards, action, c_flag, idxs, is_w = memory.sample(n, batch_size, grad_step)
            z = model_target.encode(states[:,1:6])[0]
            
        terminal=1-c_flag
        #print(f"STUFF HERE {states.shape, rewards.shape, c_flag.shape, action.shape, n}")
    
        z_proj, z_proj_pred, reward_pred, logits, probs, value_probs = model(states[:,:5], action[:,:5].long())
        
        value = (value_probs*symexp(model.support)).sum(-1).detach().squeeze()

        

        
        next_values = model_target.evaluate(next_states[:,n-1][:,None].contiguous())
        

        #action = action[:,0,None].expand(batch_size,num_buckets)
        #action = action[:,None]
        action = F.one_hot(action[:,0], n_actions)
        with torch.no_grad():
            gammas_one=torch.ones(batch_size,n,1,dtype=torch.float, device='cuda')
            gamma_step = 1-torch.tensor(( (schedule_max_step - min(grad_step, schedule_max_step)) / schedule_max_step) * (initial_gamma-final_gamma) + final_gamma).exp()
            gammas=gammas_one*gamma_step

            
            returns = []
            for t in range(n):
                ret = 0
                for u in reversed(range(t, n)):
                    ret += torch.prod(c_flag[:,t+1:u+1],-2)*torch.prod(gammas[:,t:u],-2)*rewards[:,u+1]
                returns.append(ret)
            returns = torch.stack(returns,1)
        
        plot_vs = returns.clone().sum(-1)
        
        same_traj = (torch.prod(c_flag[:,:n],-2)).squeeze()
        loss = 0



        '''Value non-SVE'''
        value_prefix = returns[:,:n]
        returns = returns[:,0]
        #returns = returns + torch.prod(gammas[0,:initial_n],-2).squeeze()*same_traj[:,None]*model.support[None,:]
        #returns = returns.squeeze()
            
        #next_values = next_values[:,0]


        returns = returns + torch.prod(c_flag[:,:n],-2) * torch.prod(gammas[:,:n],-2) * next_values

        th_returns = two_hot_view(returns, 51, model.support)
        
        loss += -(th_returns*torch.log(value_probs[:,0].squeeze()+eps)).sum(-1) * 0.25

        #loss += -(action*torch.log(probs.squeeze()+eps)).sum(-1) * (returns.squeeze()-value)
        

        '''MCTS'''
        value_mcts, improved_policy, _ = mcts(model_target, states[:,0][:,None])
        
        if step > reset_every*0.4:
            
            
            

            steps = idxs - reset_every*(idxs//reset_every)
            sve_mask = (steps<reset_every*0.4).float() + (steps>reset_every*0.8).float()
            sve_mask = sve_mask.cuda()
            
            #print(f"{sve_mask, idxs}")

            loss = sve_mask*loss

            #print(f"{value_mcts.shape, value_probs.shape}")
            value_mcts_th = two_hot(value_mcts, 51, model.support)
            loss += -(value_mcts_th*torch.log(value_probs[:,0].squeeze()+eps)).sum(-1) * (1-sve_mask) * 0.25
            
            #print(f"optim improved policy, probs {improved_policy.shape, probs.shape}")
        
        loss += -(improved_policy*torch.log(probs.squeeze()+eps)).sum(-1).mean(-1)# * (1-sve_mask)
        ### print(f"{improved_policy[0], probs[0].squeeze()}")
        
        
        dqn_loss = loss.clone().mean()
        
        '''Reward Pred'''

        value_prefix_th = two_hot_view_no_symlog(value_prefix.squeeze().clip(-2,2), 51, model.reward_support)
        value_prefix_th = value_prefix_th.view(batch_size,-1,51)
        
        
        loss += -(value_prefix_th*torch.log(reward_pred+eps)).sum(-1).mean(-1)
        
        batched_loss = loss.clone()


        '''Entropy Term'''
        loss += 5e-3 * (probs.squeeze() * torch.log(probs.squeeze()+eps)).sum(-1).mean(-1)

        
        '''Recon'''
        z = F.normalize(z, 2, dim=-1, eps=1e-5)
        z_proj_pred = F.normalize(z_proj_pred, 2, dim=-1, eps=1e-5)

        
        recon_loss = (mse(z_proj_pred.contiguous().view(-1,1024), z.contiguous().view(-1,1024))).sum(-1)
        recon_loss = 2*(recon_loss.view(batch_size, -1).mean(-1))*same_traj
        
        
        loss += recon_loss

        
        loss = (loss*is_w).mean() # mean across batch axis

    loss.backward()

    param_norm, grad_norm = params_and_grad_norm(model)
    #scaler.scale(loss).backward()
    #scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
    #scaler.step(optimizer)
    #scaler.update()
    
    optimizer.step()
    optimizer.zero_grad()
    
    #memory.set_priority(idxs, batched_loss)
    memory.set_priority(idxs, batched_loss, same_traj)
    
    
    lr = optimizer.param_groups[0]['lr']
    
    wandb.log({'loss': loss, 'dqn_loss': dqn_loss, 'recon_loss': recon_loss.mean(), 'lr': lr, 'returns': plot_vs.mean(),
               'buffer rewards': rewards.mean(0).sum(), 'is_w': is_w.mean(),
               'gamma': gamma_step, 'param_norm': param_norm.sum(), 'grad_norm': grad_norm.sum()})
    



scores=[]
memory.free()
step=0
#model.share_memory()
grad_step=0


In [8]:
step=0

progress_bar = tqdm.tqdm(total=total_steps)

while step<(80000):
    state, info = env.reset()
    state = preprocess(state)

    states = deque(maxlen=4)
    for i in range(4):
        states.append(state)
    
    
    eps_reward=torch.tensor([0], dtype=torch.float)
    
    reward=np.array([0])
    done_flag=np.array([False])
    terminated=np.array([False])

    last_lives=np.array([0])
    life_loss=np.array([0])
    resetted=np.array([0])

    actions_to_step = []
    
    last_grad_update=0
    while step<(total_steps):
        progress_bar.update(1)
        model_target.train()
        
        len_memory = len(memory)
        #if resetted[0]>0:
        #    states = env.noop_steps(states)

        
        if len(actions_to_step)==0:
            actions_to_step = mcts_inference(model, torch.cat(list(states),-3).unsqueeze(0), w_gumbel=0)[-1]
        '''
        if len(actions_to_step)==0:
            if step > reset_every*0.4: 
                _, _, actions_to_step = mcts_inference(model, torch.cat(list(states),-3).unsqueeze(0))
            else:
                #Q_action = model_target.env_step(torch.cat(list(states),-3).unsqueeze(0))
                Q_action = model.env_step(torch.cat(list(states),-3).unsqueeze(0))
                actions_to_step = [Q_action]
                #print(f"ENV STEP CALLED {actions_to_step}")
        '''
        action = epsilon_greedy(actions_to_step, len_memory).cpu()
        
        
        memory.push(torch.cat(list(states),-3).detach().cpu(), torch.tensor(reward,dtype=torch.float), action,
                    torch.tensor(np.logical_or(done_flag, life_loss),dtype=torch.bool))
        #print('action', action, action.shape)
        
        state, reward, terminated, truncated, info = env.step([action.numpy()])
        state = preprocess(state)
        states.append(state)
        
        eps_reward+=reward
        reward = reward.clip(-1, 1)


        if grad_step%400==0:
            model_target.load_state_dict(model.state_dict())
        
        done_flag = np.logical_or(terminated, truncated)
        lives = info['lives']
        life_loss = (last_lives-lives).clip(min=0)
        resetted = (lives-last_lives).clip(min=0)
        last_lives = lives

        
        n = int(initial_n * (final_n/initial_n)**(min(grad_step,schedule_max_step) / schedule_max_step))
        n = np.array(n).item()
        
        

        if len_memory>2000:
            for i in range(1):
                optimize(step, grad_step, n)
                
                #target_model_ema(model, model_target)
                
                #target_model_ema(model, model_reanalyze, decay=0.98)
                
                grad_step+=1

        
        if ((step+1)%10000)==0:
            save_checkpoint(model, model_target, optimizer, step,
                            'checkpoints/atari_last.pth')
        
            
        
        #if grad_step>reset_every:
        if grad_step>200000:
            #eval()
            print('Reseting on step', step, grad_step)
            
            random_model = DQN(n_actions).cuda()
            model.hard_reset(random_model)
            
            random_model = DQN(n_actions).cuda()
            model_target.hard_reset(random_model)

            #random_model = DQN(n_actions).cuda()
            #model_reanalyze.hard_reset(random_model)
            random_model=None
            
            seed_np_torch(SEED)
            
            grad_step=0

            actor_modules=[model.prediction, model.projection, model.ac, model.reward_pred]
            params_ac=[]
            for module in actor_modules:
                for param in module.parameters():
                    params_ac.append(param)
                    

            perception_modules=[model.encoder_cnn, model.transition]
            params_wm=[]
            for module in perception_modules:
                for param in module.parameters():
                    params_wm.append(param)
            
            #optimizer_aux = torch.optim.AdamW(params_wm, lr=lr, weight_decay=1e-4)
            optimizer_aux = torch.optim.SGD(params_wm, lr=0.2, momentum=0.9, weight_decay=1e-4)
            
            copy_states(optimizer, optimizer_aux)
            
            #optimizer = torch.optim.AdamW(chain(params_wm, params_ac),
            #                    lr=lr, weight_decay=1e-4)
            optimizer = torch.optim.SGD(chain(params_wm, params_ac), momentum=0.9,
                                lr=0.2, weight_decay=1e-4)
            copy_states(optimizer_aux, optimizer)
        
        
        
        step+=1
        
        log_t = done_flag.astype(float).nonzero()[0]
        
        if len(log_t)>0:
            for log in log_t:
                wandb.log({'eps_reward': eps_reward[log].sum()})
                scores.append(eps_reward[log].clone())
            eps_reward[log_t]=0

save_checkpoint(model, model_target, optimizer, step, f'checkpoints/effz_{env_name}_{SEED}.pth')

  gamma_step = 1-torch.tensor(( (schedule_max_step - min(grad_step, schedule_max_step)) / schedule_max_step) * (initial_gamma-final_gamma) + final_gamma).exp()
  4%|▎         | 3790/102000 [2:41:07<137:08:01,  5.03s/it]
KeyboardInterrupt



In [None]:
load_to_eval=False

if load_to_eval:
    model.load_state_dict(torch.load(f'checkpoints/effz_{env_name}_{SEED}.pth')['model_state_dict'])
    model_target.load_state_dict(torch.load(f'checkpoints/effz_{env_name}_{SEED}.pth')['model_target_state_dict'])

In [None]:
import matplotlib.pyplot as plt

num_envs=1

#env = gym.vector.make(f"{env_name}NoFrameskip-v4", num_envs=num_envs, render_mode='human')
env = gym.vector.make(f"{env_name}NoFrameskip-v4", num_envs=num_envs)
env = MaxLast2FrameSkipWrapper(env, seed=SEED)

def eval_phase(eval_runs=50, max_eval_steps=27000, num_envs=1):
    progress_bar = tqdm.tqdm(total=eval_runs)
    
    scores=[]
    
    state, info = env.reset()
    state = preprocess(state)
    print(f"init state {state.shape}")
    
    states = deque(maxlen=4)
    for i in range(4):
        states.append(state)
    
    
    eps_reward=torch.tensor([0]*num_envs, dtype=torch.float)
    
    reward=np.array([0]*num_envs)
    terminated=np.array([False]*num_envs)
    
    last_lives=np.array([0]*num_envs)
    life_loss=np.array([0]*num_envs)
    resetted=np.array([0])

    finished_envs=np.array([False]*num_envs)
    done_flag=0
    last_grad_update=0
    eval_run=0
    step=np.array([0]*num_envs)
    
    actions_to_step=[]
    
    while eval_run<eval_runs:
        #seed_np_torch(SEED+eval_run)
        env.seed=SEED+eval_run
        model_target.train()
        
        #if resetted[0]>0:
        #    states = env.noop_steps(states)
        
        #Q_action = model_target.env_step(torch.cat(list(states),-3).unsqueeze(0))
        #action = epsilon_greedy(Q_action.squeeze(), 5000, 0.0005, num_envs).cpu()
        if len(actions_to_step)==0:
            actions_to_step = mcts_inference(model, torch.cat(list(states),-3).unsqueeze(0), w_gumbel=0)[-1]
        action = epsilon_greedy(actions_to_step, 5000, 0.0005, num_envs).cpu()
        
        state, reward, terminated, truncated, info = env.step([action.numpy()] if num_envs==1 else action.numpy())
        state = preprocess(state)
        states.append(state)
        
        eps_reward+=reward

        
        done_flag = np.logical_or(terminated, truncated)
        lives = info['lives']
        life_loss = (last_lives-lives).clip(min=0)
        resetted = (lives-last_lives).clip(min=0)
        last_lives = lives        
        
        step+=1
        
        log_t = done_flag.astype(float).nonzero()[0]
        if len(log_t)>0:# or (step>max_eval_steps).any():
            progress_bar.update(1)
            for log in log_t:
                if finished_envs[log]==False:
                    scores.append(eps_reward[log].clone())
                    eval_run+=1
                    #finished_envs[log]=True
                step[log]=0
                
            eps_reward[log_t]=0            
            for i, log in enumerate(step>max_eval_steps):
                if log==True and finished_envs[i]==False:
                    scores.append(eps_reward[i].clone())
                    step[i]=0
                    eval_run+=1
                    eps_reward[i]=0
                    #finished_envs[i]=True
            
    return scores



def eval(eval_runs=50, max_eval_steps=27000, num_envs=1):
    assert num_envs==1, 'The code for num eval envs > 1 is messed up.'
    
    scores = eval_phase(eval_runs, max_eval_steps, num_envs)    
    scores = torch.stack(scores)
    scores, _ = scores.sort()
    
    _25th = eval_runs//4

    iq = scores[_25th:-_25th]
    iqm = iq.mean()
    iqs = iq.std()

    print(f"Scores Mean {scores.mean()}")
    print(f"Inter Quantile Mean {iqm}")
    print(f"Inter Quantile STD {iqs}")

    
    plt.xlabel('Episode (Sorted by Reward)')
    plt.ylabel('Reward')
    plt.plot(scores)
    
    new_row = {'env_name': env_name, 'mean': scores.mean().item(), 'iqm': iqm.item(), 'std': iqs.item(), 'seed': SEED}
    add_to_csv('results.csv', new_row)

    with open(f'results/{env_name}-{SEED}.txt', 'w') as f:
        f.write(f" Scores Mean {scores.mean()}\n Inter Quantile Mean {iqm}\n Inter Quantile STD {iqs}")
    
    
    return scores

scores = eval(eval_runs=100, num_envs=1)

In [None]:
'''
import pandas as pd
new_row = {'env_name': "Amidar", 'mean': 11.0, 'iqm': 11.0, 'std': 11.0, 'seed': 000}

df = pd.read_csv('results.csv',sep=',')
df.loc[len(df.index)] = new_row    
#df.to_csv('results.csv', index=False)

df
'''
# Add to csv suddenly stopped working