In [1]:
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count, chain
import tqdm
import copy

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.autograd import Variable


from nosaveddata.nsd_utils.save_hypers import Hypers, nsd_Module
from nosaveddata.nsd_utils.nsd_csv import add_to_csv
from nosaveddata.nsd_utils.networks import params_count, params_and_grad_norm, seed_np_torch
from nosaveddata.nsd_utils.einstein import Rearrange

from nosaveddata.builders.mlp import *
from nosaveddata.builders.weight_init import *
from nosaveddata.builders.resnet import IMPALA_Resnet, DQN_Conv, IMPALA_YY


from utils.experience_replay import *


import locale
locale.getpreferredencoding = lambda: "UTF-8"

import wandb


# Environment configuration
#env_name = 'Kangaroo'
#SEED = 8712

env_name = 'BattleZone'
SEED = 7782


wandb.init(
    project="Atari-100k-BBF",
    name=f"BBF-{env_name}",

    #id='rotdmtc5',
    #resume='must',

    config={
        "learning_rate": 1e-4,
        "architecture": "BBF",
        "dataset": "Assault",
        "epochs": 100,
    },

    reinit=False
)



# Optimization
batch_size = 32
lr=1e-4

eps=1e-8


# Target network EMA rate
critic_ema_decay=0.995


# Return function
initial_gamma=torch.tensor(1-0.97).log()
final_gamma=torch.tensor(1-0.997).log()

initial_n = 10
final_n = 3

num_buckets=51


# Reset Schedule and Buffer
reset_every=40000 # grad steps, not steps.
schedule_max_step=reset_every//4
total_steps=102000

prefetch_cap=1 # actually, no prefetch is being done


Transition = namedtuple('Transition',
                        ('state', 'reward', 'action', 'c_flag'))
memory = PrioritizedReplay_nSteps_Sqrt(total_steps+5, total_steps=schedule_max_step, prefetch_cap=prefetch_cap)




KeyboardInterrupt



In [None]:
# Adapted from: https://github.com/weipu-zhang/STORM/blob/main/env_wrapper.py
class MaxLast2FrameSkipWrapper(Hypers, gym.Wrapper):
    def __init__(self, env, skip=4, noops=30, seed=0):
        super().__init__(env=env)
        self.env.action_space.seed(seed)
        
    def reset(self, **kwargs):
        kwargs["seed"] = self.seed
        obs, _ = self.env.reset(**kwargs)

        return obs, _
        
    def noop_steps(self, states):
        noops = random.randint(0,self.noops)
        
        for i in range(noops):
            state = self.step(np.array([0]))[0]
            state = preprocess(state)
            states.append(state)
        return states

    def step(self, action):
        total_reward = 0
        self.obs_buffer = deque(maxlen=2)
        for _ in range(self.skip):
            obs, reward, done, truncated, info = self.env.step(action)
            self.obs_buffer.append(obs)
            total_reward += reward

            terminated = np.logical_or(done, truncated)
            #if terminated.any():
            #    for i in range(len(terminated)):
            #       obs[i] = self.reset()[0][i]
            if done or truncated:
                break
        if len(self.obs_buffer) == 1:
            obs = self.obs_buffer[0]
        else:
            obs = np.max(np.stack(self.obs_buffer), axis=0)
        return obs, total_reward, done, truncated, info
        # Life loss is calculated on the training code

env = gym.vector.make(f"{env_name}NoFrameskip-v4", num_envs=1)
env = MaxLast2FrameSkipWrapper(env,seed=SEED)


#n_actions = env.action_space.n
n_actions = env.action_space[0].n

state, info = env.reset()
n_observations = len(state)

seed_np_torch(SEED)

In [None]:

def renormalize(tensor, has_batch=False):
    shape = tensor.shape
    tensor = tensor.view(tensor.shape[0], -1)
    max_value,_ = torch.max(tensor, -1, keepdim=True)
    min_value,_ = torch.min(tensor, -1, keepdim=True)
    return ((tensor - min_value) / (max_value - min_value + 1e-5)).view(shape)


class DQN(nn.Module):
    def __init__(self, n_actions, hiddens=2048, mlp_layers=1, scale_width=4,
                 n_atoms=51, Vmin=-10, Vmax=10):
        super().__init__()
        self.support = torch.linspace(Vmin, Vmax, n_atoms).cuda()
        
        self.hiddens=hiddens
        self.scale_width=scale_width
        self.act = nn.ReLU()
        
        
        self.encoder_cnn = IMPALA_Resnet(scale_width=scale_width, norm=False, init=init_xavier, act=self.act)
        
        

        # Single layer dense that maps the flattened encoded representation into hiddens.
        self.projection = MLP(13824, med_hiddens=hiddens, out_hiddens=hiddens,
                              last_init=init_xavier, layers=1)
        self.prediction = MLP(hiddens, out_hiddens=hiddens, layers=1, last_init=init_xavier)
                                              
        self.transition = nn.Sequential(DQN_Conv(32*scale_width+n_actions, 32*scale_width, 3, 1, 1, norm=False, init=init_xavier, act=self.act),
                                        DQN_Conv(32*scale_width, 32*scale_width, 3, 1, 1, norm=False, init=init_xavier, act=self.act))

        # Single layer dense that maps hiddens into the output dim according to:
        # 1. https://arxiv.org/pdf/1707.06887.pdf -- Distributional Reinforcement Learning
        # 2. https://arxiv.org/pdf/1511.06581.pdf -- Dueling DQN
        self.a = MLP(hiddens, out_hiddens=n_actions*num_buckets, layers=1, in_act=self.act, last_init=init_xavier)
        self.v = MLP(hiddens, out_hiddens=num_buckets, layers=1, in_act=self.act, last_init=init_xavier)
    
        params_count(self, 'DQN')
    
    def forward(self, X, y_action):
        X, z = self.encode(X)
        
        
        q, action = self.q_head(X)
        z_pred = self.get_transition(z, y_action)

        return q, action, X[:,1:].clone().detach(), z_pred
    

    def env_step(self, X):
        with torch.no_grad():
            X, _ = self.encode(X)
            _, action = self.q_head(X)
            
            return action.detach()
    

    def encode(self, X):
        batch, seq = X.shape[:2]
        self.batch = batch
        self.seq = seq
        X = self.encoder_cnn(X.contiguous().view(self.batch*self.seq, *(X.shape[2:])))
        X = renormalize(X).contiguous().view(self.batch, self.seq, *X.shape[-3:])
        X = X.contiguous().view(self.batch, self.seq, *X.shape[-3:])
        z = X.clone()
        X = X.flatten(-3,-1)
        X = self.projection(X)
        return X, z

    def get_transition(self, z, action):
        z = z.contiguous().view(-1, *z.shape[-3:])
        
        action = F.one_hot(action.clone(), n_actions).view(-1, n_actions)
        action = action.view(-1, 5, n_actions, 1, 1).expand(-1, 5, n_actions, *z.shape[-2:])

        z_pred = torch.cat( (z, action[:,0]), 1)
        z_pred = self.transition(z_pred)
        z_pred = renormalize(z_pred)
        
        z_preds=[z_pred.clone()]
        

        for k in range(4):
            z_pred = torch.cat( (z_pred, action[:,k+1]), 1)
            z_pred = self.transition(z_pred)
            z_pred = renormalize(z_pred)
            
            z_preds.append(z_pred)
        
        
        z_pred = torch.stack(z_preds,1)

        z_pred = self.projection(z_pred.flatten(-3,-1)).view(self.batch,5,-1)
        z_pred = self.prediction(z_pred)
        
        return z_pred

    
    def q_head(self, X):
        q = self.dueling_dqn(X)
        action = (q*self.support).sum(-1).argmax(-1)
        
        return q, action

    def get_max_action(self, X):
        with torch.no_grad():
            X, _ = self.encode(X)
            q = self.dueling_dqn(X)
            
            action = (q*self.support).sum(-1).argmax(-1)
            return action

    def evaluate(self, X, action):
        with torch.no_grad():
            X, _ = self.encode(X)
            
            q = self.dueling_dqn(X)
            
            action = action[:,:,None,None].expand_as(q)[:,:,0][:,:,None]
            q = q.gather(-2,action)
            
            return q

    def dueling_dqn(self, X):
        X = F.relu(X)
        
        a = self.a(X).view(self.batch, -1, n_actions, num_buckets)
        v = self.v(X).view(self.batch, -1, 1, num_buckets)
        
        q = v + a - a.mean(-2,keepdim=True)
        q = F.softmax(q,-1)
        
        return q
    
    def network_ema(self, rand_network, target_network, alpha=0.5):
        for param, param_target in zip(rand_network.parameters(), target_network.parameters()):
            param_target.data = alpha * param_target.data + (1 - alpha) * param.data.clone()

    def hard_reset(self, random_model, alpha=0.5):
        with torch.no_grad():
            
            self.network_ema(random_model.encoder_cnn, self.encoder_cnn, alpha)
            self.network_ema(random_model.transition, self.transition, alpha)

            self.network_ema(random_model.projection, self.projection, 0)
            self.network_ema(random_model.prediction, self.prediction, 0)

            self.network_ema(random_model.a, self.a, 0)
            self.network_ema(random_model.v, self.v, 0)



def copy_states(source, target):
    for key, _ in zip(source.state_dict()['state'].keys(), target.state_dict()['state'].keys()):

        target.state_dict()['state'][key]['exp_avg_sq'] = copy.deepcopy(source.state_dict()['state'][key]['exp_avg_sq'])
        target.state_dict()['state'][key]['exp_avg'] = copy.deepcopy(source.state_dict()['state'][key]['exp_avg'])
        target.state_dict()['state'][key]['step'] = copy.deepcopy(source.state_dict()['state'][key]['step'])
        
def target_model_ema(model, model_target):
    with torch.no_grad():
        for param, param_target in zip(model.parameters(), model_target.parameters()):
            param_target.data = critic_ema_decay * param_target.data + (1.0 - critic_ema_decay) * param.data.clone()


    
model=DQN(n_actions).cuda()
model_target=DQN(n_actions).cuda()

model_target.load_state_dict(model.state_dict())


# Testing only
#with torch.no_grad():
#    q, action, X, z_pred = model(torch.randn(4,1,12,96,72, device='cuda', dtype=torch.float), torch.randint(0,n_actions,(4,5),device='cuda').long())
#z = model.encode(torch.randn(4,5,12,96,72, device='cuda'))[0]

# I believe the authors have actually miscalculated the params count on the paper.
# My training time is lower than theirs while having more parameters, and the same architecture is used as is their original code

In [None]:
perception_modules=[model.encoder_cnn, model.transition]
actor_modules=[model.prediction, model.projection, model.a, model.v]

params_wm=[]
for module in perception_modules:
    for param in module.parameters():
        if param.requires_grad==True: # They all require grad
            params_wm.append(param)

params_ac=[]
for module in actor_modules:
    for param in module.parameters():
        if param.requires_grad==True:
            params_ac.append(param)


optimizer = torch.optim.AdamW(chain(params_wm, params_ac),
                                lr=lr, weight_decay=0.1, eps=1.5e-4)

In [None]:
import torchvision.transforms as transforms

train_tfms = transforms.Compose([
                         transforms.Resize((96,72)),
                        ])


def preprocess(state):
    state=torch.tensor(state, dtype=torch.float, device='cuda') / 255
    state=train_tfms(state.permute(0,3,1,2))
    return state

# https://github.com/google/dopamine/blob/master/dopamine/jax/agents/dqn/dqn_agent.py
def linearly_decaying_epsilon(decay_period, step, warmup_steps, epsilon):
    steps_left = decay_period + warmup_steps - step
    bonus = (1.0 - epsilon) * steps_left / decay_period
    bonus = np.clip(bonus, 0., 1. - epsilon)
    return epsilon + bonus


def epsilon_greedy(Q_action, step, final_eps=0, num_envs=1):
    epsilon = linearly_decaying_epsilon(2001, step, 2000, final_eps)
    
    if random.random() < epsilon:
        action = torch.randint(0, n_actions, (num_envs,), dtype=torch.int64, device='cuda').squeeze(0)
    else:
        action = Q_action.view(num_envs).squeeze(0).to(torch.int64)
    return action


In [None]:
load_to_eval=True

if load_to_eval:
    model.load_state_dict(torch.load(f'checkpoints/{env_name}.pth')['model_state_dict'])
    model_target.load_state_dict(torch.load(f'checkpoints/{env_name}.pth')['model_target_state_dict'])

In [None]:
import matplotlib.pyplot as plt

num_envs=1

env = gym.vector.make(f"{env_name}NoFrameskip-v4", num_envs=num_envs, render_mode='human')
env = MaxLast2FrameSkipWrapper(env,seed=SEED)

def eval_phase(eval_runs=50, max_eval_steps=27000, num_envs=1):
    progress_bar = tqdm.tqdm(total=eval_runs)
    
    scores=[]
    
    state, info = env.reset()
    state = preprocess(state)
    print(f"init state {state.shape}")
    
    states = deque(maxlen=4)
    for i in range(4):
        states.append(state)
    
    
    eps_reward=torch.tensor([0]*num_envs, dtype=torch.float)
    
    reward=np.array([0]*num_envs)
    terminated=np.array([False]*num_envs)
    
    last_lives=np.array([0]*num_envs)
    life_loss=np.array([0]*num_envs)
    resetted=np.array([0])

    finished_envs=np.array([False]*num_envs)
    done_flag=0
    last_grad_update=0
    eval_run=0
    step=np.array([0]*num_envs)
    while eval_run<eval_runs:
        #seed_np_torch(SEED+eval_run)
        env.seed=SEED+eval_run
        model_target.train()
        
        #if resetted[0]>0:
        #    states = env.noop_steps(states)
        
        Q_action = model_target.env_step(torch.cat(list(states),-3).unsqueeze(0))
        action = epsilon_greedy(Q_action.squeeze(), 5000, 0.0005, num_envs).cpu()
        
        state, reward, terminated, truncated, info = env.step([action.numpy()] if num_envs==1 else action.numpy())
        state = preprocess(state)
        states.append(state)
        
        eps_reward+=reward

        
        done_flag = np.logical_or(terminated, truncated)
        lives = info['lives']
        life_loss = (last_lives-lives).clip(min=0)
        resetted = (lives-last_lives).clip(min=0)
        last_lives = lives        
        
        step+=1
        
        log_t = done_flag.astype(float).nonzero()[0]
        if len(log_t)>0:# or (step>max_eval_steps).any():
            progress_bar.update(1)
            for log in log_t:
                #wandb.log({'eval_eps_reward': eps_reward[log].sum()})
                if finished_envs[log]==False:
                    scores.append(eps_reward[log].clone())
                    eval_run+=1
                    #finished_envs[log]=True
                step[log]=0
                
            eps_reward[log_t]=0            
            for i, log in enumerate(step>max_eval_steps):
                if log==True and finished_envs[i]==False:
                    scores.append(eps_reward[i].clone())
                    step[i]=0
                    eval_run+=1
                    eps_reward[i]=0
                    #finished_envs[i]=True
            
    return scores



def eval(eval_runs=50, max_eval_steps=27000, num_envs=1):
    assert num_envs==1, 'The code for num eval envs > 1 is messed up.'
    
    scores = eval_phase(eval_runs, max_eval_steps, num_envs)    
    scores = torch.stack(scores)
    scores, _ = scores.sort()
    
    _25th = eval_runs//4

    iq = scores[_25th:-_25th]
    iqm = iq.mean()
    iqs = iq.std()

    print(f"Scores Mean {scores.mean()}")
    print(f"Inter Quantile Mean {iqm}")
    print(f"Inter Quantile STD {iqs}")

    
    plt.xlabel('Episode (Sorted by Reward)')
    plt.ylabel('Reward')
    plt.plot(scores)
    
    new_row = {'env_name': env_name, 'mean': scores.mean().item(), 'iqm': iqm.item(), 'std': iqs.item(), 'seed': SEED}
    add_to_csv('results.csv', new_row)

    with open(f'results/{env_name}-{SEED}.txt', 'w') as f:
        f.write(f" Scores Mean {scores.mean()}\n Inter Quantile Mean {iqm}\n Inter Quantile STD {iqs}")
    
    
    return scores

scores = eval(eval_runs=3, num_envs=1)

In [None]:
'''
import pandas as pd
new_row = {'env_name': "Amidar", 'mean': 11.0, 'iqm': 11.0, 'std': 11.0, 'seed': 000}

df = pd.read_csv('results.csv',sep=',')
df.loc[len(df.index)] = new_row    
#df.to_csv('results.csv', index=False)

df
'''
# Add to csv suddenly stopped working