In [None]:
# hypers 
lr = 0.001
_lambda_ = 0.99
gamma = 0.99
epsilon = 0.2
c1 = 0.5

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from IPython.display import clear_output
 
class network(nn.Module):
    def __init__(self):
        super().__init__()
        self.input = nn.LazyConv2d(1,(1,1))
        self.conv1 = nn.LazyConv2d(1,(3,3),stride=2)
        self.conv2 = nn.LazyConv2d(1,(3,3),stride=1)
        self.conv3 = nn.LazyConv2d(1,(3,3),stride=2)
        
        self.linear1 = nn.LazyLinear(3000)
        self.linear2 = nn.LazyLinear(1500)
        self.linear3 = nn.LazyLinear(750)
        self.linear4 = nn.LazyLinear(375)

        self.policyHead = nn.LazyLinear(7)
        self.valueHead = nn.LazyLinear(1)

        self.optim = Adam(self.parameters(),lr=lr)
        
    def forward(self,x):
        x = F.relu(self.input(x))
        x = self.conv1(x)
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        x = F.relu(torch.flatten(x))
        x = self.linear1(x)
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        x = F.relu(self.linear4(x))
        policyOut = self.policyHead(x)
        valueOut = self.valueHead(x)
        return F.softmax(policyOut,-1),valueOut
    
network()(torch.rand((5,1,150,150),dtype=torch.float))
clear_output()

In [5]:
import random
import warnings
warnings.filterwarnings("ignore")
import gym_super_mario_bros
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from gymnasium.wrappers import FrameStack

from PIL import Image
from torchvision.transforms import v2,Resize
from torchvision.transforms.functional import to_tensor
import numpy as np
from torch.distributions import Categorical

numFrames = 5 
env = gym_super_mario_bros.make('SuperMarioBros-v0',apply_api_compatibility=True)
env = JoypadSpace(env, SIMPLE_MOVEMENT)
env = FrameStack(env,numFrames)

class Memory:
    def __init__(self):
        self.network = network()
        self.env = env
        self.data = []
        self.gamma = gamma
        self._lambda_ = _lambda_
    
    def rollout(self,batchsize):
        self.clear()
        for _ in range(batchsize):
            _image = np.array(self.env.reset()[0])
            _list_ = []
            for i in range(numFrames): # grayscale,downsampling,downscaling frame by frame
                _pil = Image.fromarray(_image[i])
                _observation = to_tensor(v2.Grayscale(1)(_pil))
                _resized = Resize((150,150))(_observation)
                _list_.append(_resized)

            _states = torch.stack(_list_,dim=0) # --> torch.Size([5, 1, 150, 150])
            _policyOutput , value = self.network.forward(_states)
            _distribution = Categorical(_policyOutput)
            action = _distribution.sample()
            prob = _distribution.log_prob(action)
            state,reward,done,_,_ = self.env.step(action.item())
            self.data.append([state,torch.tensor(reward),value,prob,action,done])
             
        _,_rewards,_values,_,_,_ = zip(*self.data) # compute advantages --
        _rewards = torch.stack(_rewards)
        _values = torch.stack(_values) 
        _values = torch.cat((_values,torch.tensor([[0]])))
        
        n = torch.arange(batchsize)
        _temporalDifferences  = _rewards[n] + self.gamma*_values[n+1] - _values[n]
        _temporalDifferences = torch.flip(_temporalDifferences,dims=[-1])
       
        _advantage = 0
        advantages = _temporalDifferences[n] + (self._lambda_ * self.gamma * _advantage)
        advantages = torch.flip(advantages,dims=[-1]) 
       
        for data,item in zip(self.data,advantages): # append advantages to data 
            data.append(item)
        random.shuffle(self.data) 
        states,rewards,values,logProb,actions,done,advantages = zip(*self.data)
        return  states,actions,rewards,values,logProb,advantages,done

    def clear(self):
        self.data = []


In [None]:
torch.autograd.set_detect_anomaly(True)
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

class Agent:
    def __init__(self):
        self.env = env
        self.memory = Memory()
        self.network = network()
        self.batchsize = 2 # 128
        self.epochs = 10  # 40_000 
        self.writter = SummaryWriter("./")

    def save(self):
        torch.save(self.network.state_dict(),"./mario.pth")
  
    def train(self):
        for traj in tqdm(range(self.epochs),total=self.epochs):
            states,actions,rewards,values,oldProb,advantages,done = self.memory.rollout(self.batchsize)
            _rewards = torch.mean(torch.stack(rewards)) # to be collected later in the code 

            advantages = torch.stack(advantages)
            values = torch.stack(values)
            vtarget = advantages + values
            lossCritic = F.mse_loss(values,vtarget)
            
            _frames = np.array(states)
            _subFramesList = []
            r = []
            for element in _frames:
                for subfram in element:
                    _pil = Image.fromarray(subfram)
                    _gray = to_tensor(v2.Grayscale(1)(_pil))
                    _resized = Resize((150,150))(_gray)
                    _subFramesList.append(_resized)
                r.append(_subFramesList)
  
            new  = []
            for s,a in zip(r,actions): # s -> torch.Size([5, 1, 150, 150])
                _policyOut,_ = self.network.forward(s[0])
                dist = Categorical(_policyOut)
                log = dist.log_prob(a)
                new.append(log)
            
            ratio = torch.exp(torch.stack(new)) / torch.exp(torch.stack(oldProb))
            l = [] 
            for i in range(len(advantages)):
                ra = ratio*advantages[i]
                clip = torch.clamp(ra,(1-epsilon),(1+epsilon))*advantages[i]
                loss = torch.min(ra,clip)
                l.append(loss)

            lossPolicy = -torch.mean(torch.stack(l))
            
            totalLoss = lossPolicy + c1*lossCritic
            self.network.optim.zero_grad()
            totalLoss.backward(retain_graph=True)
            self.network.optim.step()
             
            self.writter.add_scalar("main/Reward",_rewards)
            self.writter.add_scalar("main/Loss",totalLoss)
            
            

In [None]:
Agent().train()