In [None]:
# Python 3.10.17
# pip 23.0.1
# setuptools 80.3.1
# wheel 0.45.1

!pip install numpy==1.23.0
!pip install gym-super-mario-bros==7.4.0

# gym 0.26.2
# torch 2.6.0
# nes-py 8.2.1

In [None]:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

In [None]:
import warnings
warnings.filterwarnings("ignore")
import torch,gym,random
from dataclasses import dataclass
import gym_super_mario_bros
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from gym.wrappers import GrayScaleObservation,FrameStack,ResizeObservation
from gym.vector import SyncVectorEnv
import numpy as np

@dataclass(frozen=False)
class hypers:
    seed : int = 42
    lambda_ : float = 1.0
    gamma : float = 0.99
    epsilon : float = 0.2
    lr : float = 1e-4
    critic_coeff : float = 1e-1
    policy_ceoff : int = 1e2
    entropy_coeff : float = 1e-1
    skip_frame : int = 4
    num_stack : int  = 4
    obs_shape : tuple[int,int] = (100,100) # observation shape
    num_env : int = 90 
    num_game : int = 500  
    batchsize : int =  512 
    minibatch : int = 16
    optim_steps : int = 10

configs = hypers()

In [None]:
class CustomEnv(gym.Wrapper): 
        def __init__(self,env,skip):
            super().__init__(env)
            self.skip = skip
            self.score = 0
            
        def step(self, action):
            total_reward = 0  
            for _ in range(self.skip):
                obs,reward,done,truncared,info = self.env.step(action)
                total_reward += reward
                if done:
                    self.reset()
                    return obs,(total_reward/10.),done,truncared,info
            return obs,(total_reward/10.),done,truncared,info

        def reset(self, **kwargs):
            self.score = 0
            obs,info = self.env.reset()
            return obs,info

def make_env():
    def env():
        x = gym_super_mario_bros.make("SuperMarioBros-v0", apply_api_compatibility=True)
        x = JoypadSpace(x, SIMPLE_MOVEMENT)
        x = ResizeObservation(x, configs.obs_shape)
        x = CustomEnv(x, configs.skip_frame)
        x = GrayScaleObservation(x, keep_dim=True)
        x = FrameStack(x, configs.num_stack) 
        return x
    return SyncVectorEnv([env for _ in range(configs.num_env)])

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(4,32,1,1,0)
        self.conv2 = nn.Conv2d(32,32,3,2,2)
        self.conv3 = nn.Conv2d(32,32,3,2,2)
        self.conv4 = nn.Conv2d(32,32,3,2,2)
        self.output1 = nn.Linear(7200,1500)
        self.output2 = nn.Linear(1500,500)

        self.policy_head = nn.Linear(500,7)
        self.value_head = nn.Linear(500,1)
        self.optim = torch.optim.Adam(self.parameters(),lr=configs.lr)
        self.apply(self.init_weights)
        
    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = nn.Flatten(1)(x)
        x = F.relu(self.output1(x))
        x = F.relu(self.output2(x))
        policy_output = self.policy_head(x)
        value_output = self.value_head(x)
        return F.softmax(policy_output,-1),value_output 

    def init_weights(self,layer):
        if isinstance(layer,(nn.Conv2d,nn.Linear)):
            nn.init.orthogonal_(layer.weight)
            nn.init.constant_(layer.bias,0.0)

In [None]:
import gc,sys
from torch.distributions import Categorical
from collections import deque

class Memory:
    def __init__(self,env : SyncVectorEnv):
        self.network = network()
        self.env = env
        self.gamma = configs.gamma
        self._lambda_ = configs.lambda_
        self.data = []
        self.pointer = 0
        self.episode_reward = np.zeros(self.env.num_envs, dtype=np.float32)
        self.finished_reward = deque(maxlen=5)
        self.log_total_steps = deque(maxlen=1)
        self.total_steps = torch.zeros(configs.num_env)
    
    def rollout(self,batchsize,device):
        self.clear()
        self.network = self.network.to(device) # to TPUs 
        self._observation,_ = self.env.reset()
        with torch.no_grad():
            for n in range(batchsize):
                self._observation = torch.from_numpy(np.array(self._observation).copy()).squeeze(-1).to(device,torch.float32) / 255.
                policy_output , value = self.network.forward(self._observation)
                distribution = Categorical(policy_output)
                action = distribution.sample()
                prob = distribution.log_prob(action)
                state,reward,done,_,_ = self.env.step(action.tolist())
                
                for i in range(self.env.num_envs): # reward per episode in a deque
                    self.episode_reward[i] += reward[i] 
                    self.total_steps[i] += 1
                    if done[i]:
                        self.finished_reward.append(self.episode_reward[i])
                        self.log_total_steps.append(self.total_steps[i])
                        self.episode_reward[i] = 0
                        self.total_steps[i] = 0
        
                self.data.append(
                    [
                        self._observation,
                        torch.tensor(reward,dtype=torch.float32),
                        value,
                        prob,
                        action,
                        done
                    ]
                )
                self._observation = state
 
        _,rewards,values,_,_,_ = zip(*self.data) # advantages 
        _rewards = torch.stack(rewards).to(device) 
        _values = torch.stack(values).squeeze(-1) 

        zeros = torch.zeros(1,_values.shape[-1],device=device,dtype=torch.float32)
        _values = torch.cat([_values,zeros])
        assert _rewards.dtype == _values.dtype == torch.float32 
        advantages = torch.zeros_like(_rewards,device=device,dtype=torch.float32)
        gae = 0.0
        for n in reversed(range(len(_rewards))):
            td = _rewards[n] + self.gamma * _values[n+1].detach() - _values[n].detach()
            gae = td + (self._lambda_ * self.gamma * gae) 
            advantages[n] = gae

        for data,item in zip(self.data,advantages): # append advantages to data
            data.append(item)
            
        random.shuffle(self.data) 

    def sample(self,number):
        output = self.data[self.pointer:self.pointer+number]
        self.pointer+=number
        states,rewards,values,logProb,actions,done,advantages = zip(*output)
        return states,actions,rewards,values,logProb,advantages,done

    def clear(self):
        self.data = []
        self.pointer = 0
        gc.collect()
    
    def traj_reward(self):
        return self.finished_reward,self.log_total_steps

In [None]:
torch.autograd.set_detect_anomaly(True)
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

class agent:
    def __init__(self,device):
        self.env = make_env()
        self.memory = Memory(self.env)
        self.device = device
        self.network = network().to(self.device)
        self.writter = SummaryWriter("./")
        

    def save(self,k):
        checkpoint = {
            "model_state" : self.network.state_dict(),
            "optim_state" : self.network.optim.state_dict(),
        }
        torch.save(checkpoint,f"./mario{k}")

    def train(self,num_game,batchsize,minibatch,optim_step):
        for traj in tqdm(range(num_game),total = num_game):
            self.memory.rollout(batchsize,self.device)

            for _ in range(batchsize//minibatch):
                states,actions,rewards,values,old_log_prob,advantages,done= self.memory.sample(minibatch)    
                 
                for _ in range(optim_step): 
                    _advantages = torch.stack(advantages).view(-1)
                    _values = torch.stack(values).squeeze().view(-1)
                    
                    vtarget = (_advantages + _values).detach()
                    loss_critic = F.smooth_l1_loss(_values,vtarget) * configs.critic_coeff
                    explained_variance = 1.0 - (torch.var(vtarget - _values) / torch.var(vtarget + 1e-10))
                
                    stacked_states = torch.stack(states) # - > [minibatch, num_env, channel, image_shape]
                    flat_states = torch.flatten(stacked_states,0,1).to(self.device) # - > [(minibatch * num_env), channel, image_shape]
                    p_output,_ = self.network.forward(flat_states)
                    policy_output = p_output.view(minibatch,configs.num_env,7) # 7 is the actionspace for a single env
                    dist = Categorical(policy_output)  
                    stacked_actions = torch.stack(actions)  
                    new_log_prob = dist.log_prob(stacked_actions)
                    _old_log_prob = torch.stack(old_log_prob).detach()
                    
                    ratio = torch.exp(new_log_prob - _old_log_prob).view(-1)
                    #norm_advantage  = (_advantages - _advantages.mean()) / (_advantages.std() + 1e-10) 
                    prox1 = ratio * _advantages
                    prox2 = torch.clamp(ratio ,1-configs.epsilon,1+configs.epsilon) * _advantages
                    loss_policy = -torch.mean(torch.min(prox1,prox2))  
    
                    _entropy = -(policy_output * torch.log(policy_output + 1e-10)).sum(dim=2)
                    entropy = _entropy.sum(dim=1).mean() * configs.entropy_coeff

                    total_loss = loss_policy + loss_critic - entropy 
    
                    self.network.optim.zero_grad()
                    total_loss.backward(retain_graph=True)
                    nn.utils.clip_grad_norm_(self.network.parameters(), 0.5)
                    xm.optimizer_step(self.network.optim)
                    xm.mark_step()
                    
            self.writter.add_scalar("Policy/entropy",entropy,traj)
            self.writter.add_scalar("Policy/loss policy",loss_policy,traj)
            self.writter.add_scalar("Value/values",_values.mean(),traj)
            self.writter.add_scalar("Value/vtarget",vtarget.mean(),traj)
            self.writter.add_scalar("Value/value loss",loss_critic,traj) 
            self.writter.add_scalar("Value/Explained variance",explained_variance,traj)
            self.writter.add_scalar("main/total loss",total_loss,traj)
            self.writter.add_scalar("main/epi rewards",torch.tensor([self.memory.traj_reward()[0]]).mean(),traj) 
            self.writter.add_scalar("main/total steps",torch.tensor([self.memory.traj_reward()[1]]).mean(),traj)

            epireward = round(torch.tensor([self.memory.traj_reward()[0]]).mean().tolist(),2)
            print(
                f"{traj}/1k | REWA {epireward} | ENTR {entropy:.2f} | POLI {loss_policy:.2f} | CRIT {loss_critic:.2f} | TLoss {total_loss:.2f} | Val {_values.mean():.2f} | Vtarg {vtarget.mean():.2f}"
                )
            if traj % 20 == 0 : # save every ...k steps
                self.save(traj)
                
        self.save("_end")
    
def train_fn(rank):
    torch.manual_seed(configs.seed)
    random.seed(configs.seed)
    np.random.seed(configs.seed)
    xm.set_rng_state(configs.seed)
    
    device = xm.xla_device()
    agnt = agent(device)

    agnt.train(
        num_game=configs.num_game,
        batchsize=configs.batchsize,
        minibatch=configs.minibatch,
        optim_step=configs.optim_steps
    ) 
   
if __name__ == "__main__":
    xmp.spawn(train_fn, nprocs=1, start_method=None)