In [116]:
# hypers 
lr = 0.001
_lambda_ = 0.99
gamma = 0.99

epsilon = 0.2
c1 = 0.5

kernel = (1,1) # filter shape for the conv layers
OUTPUT_CHANNEL = 5
numConv = 2
numlinear = 2

In [117]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

class policy(nn.Module):
    def __init__(self):
        super().__init__()
        self.numConv = numConv 
        self.numLinear = numlinear
        self.outLinear = 1
        self.kernel = kernel
        self.channel = OUTPUT_CHANNEL
        self.convs = nn.ModuleList(
            [nn.LazyConv2d(self.channel,self.kernel) for n in range(self.numConv)]
        )
        self.linear = nn.ModuleList(
            [nn.LazyLinear(7) for n in range(self.numLinear)]
        )
        """self.linear2 = nn.ModuleList(
            [nn.LazyLinear() for n in self.numLinear]
        )"""
        self.optim = Adam(self.parameters(),lr=lr)
        
    def forward(self,x):
        for convs in self.convs:
            x = F.relu(convs(x))
        x = torch.flatten(x,0)
        for linear in self.linear:
            x = F.relu(linear(x))
        return F.softmax(x,-1)
    
class value(nn.Module):
    def __init__(self):
        super().__init__()
        self.channel = OUTPUT_CHANNEL
        self.kernel = kernel
        self.numConv = numConv 
        self.convs = nn.ModuleList(
            [nn.LazyConv2d(self.channel,self.kernel) for n in range(self.numConv)]
        )
        self.linear = nn.LazyLinear(1)
        self.optim = Adam(self.parameters(),lr=lr)

    def forward(self,x):
        for convs in self.convs:
            x = F.relu(convs(x))
        x = torch.flatten(x)
        x = F.relu(self.linear(x))
        return x

r = torch.rand((5,1,240,256), dtype=torch.float)
a = policy()
a(r)

tensor([0.1294, 0.1398, 0.1407, 0.1303, 0.1853, 0.1294, 0.1452],
       grad_fn=<SoftmaxBackward0>)

In [118]:
import random
import sys

import warnings
warnings.filterwarnings("ignore")
import gym_super_mario_bros
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from gymnasium.wrappers import FrameStack,GrayScaleObservation
from gymnasium.spaces import Box

import numpy as np
from torch.distributions import Categorical
from PIL import Image
from torchvision.transforms import v2
from torchvision.transforms.functional import to_tensor

numFrames = 5
env = gym_super_mario_bros.make('SuperMarioBros-v0',apply_api_compatibility=True)
env = JoypadSpace(env, SIMPLE_MOVEMENT)
env = FrameStack(env,numFrames)

class Memory:
    def __init__(self):
        self.policy = policy()
        self.value = value()
        self.env = env
        self.data = []
        self.gamma = gamma
        self._lambda_ = _lambda_
    
    def rollout(self,batchsize):
        self.clear()
        for _ in range(batchsize):
            _image = np.array(self.env.reset()[0])
            _list_ = []
            for i in range(numFrames): # grayscale frame by frame
                #sys.exit(_image[i].shape)
                _pil = Image.fromarray(_image[i])
                _observation = to_tensor(v2.Grayscale(1)(_pil))
                _list_.append(_observation)

            _states = torch.stack(_list_,dim=0) # --> torch.Size([5, 1, 240, 256])
            _distribution = Categorical(self.policy.forward(_states)) 
            value = self.value.forward(_states)
            action = _distribution.sample()
            prob = _distribution.log_prob(action)
            state,reward,done,_,_ = self.env.step(action.item())
            self.data.append([state,torch.tensor(reward),value,prob,action,done])
             
        # advantages 
        _,_rewards,_values,_,_,_ = zip(*self.data) 
        _rewards = torch.stack(_rewards)
        _values = torch.stack(_values) 
        _values = torch.cat((_values,torch.tensor([[0]])))
        
        n = torch.arange(batchsize)
        _temporalDifferences  = _rewards[n] + self.gamma*_values[n+1] - _values[n]
        _temporalDifferences = torch.flip(_temporalDifferences,dims=[-1])
       
        _advantage = 0
        advantages = _temporalDifferences[n] + (self._lambda_ * self.gamma * _advantage)
        advantages = torch.flip(advantages,dims=[-1]) 
       
        for data,item in zip(self.data,advantages): # append advantages to the data 
            data.append(item)
        random.shuffle(self.data) 
        states,rewards,values,logProb,actions,done,advantages = zip(*self.data)
        return  states,actions,rewards,values,logProb,advantages,done

    def clear(self):
        self.data = []


In [134]:
torch.autograd.set_detect_anomaly(True)
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import sys

class Agent:
    def __init__(self):
        self.env = env
        self.memory = Memory()
        self.policy = policy()
        self.value = value()
        self.batchsize = 2
        self.epochs = 2
        self.writter = SummaryWriter("./")

    def save(self):
        torch.save(self.policy.state_dict(),"./mario.pth")

    def train(self):
        for _ in tqdm(range(self.epochs),total=self.epochs):
            states,actions,rewards,values,oldProb,advantages,done = self.memory.rollout(self.batchsize)
            _rewards = torch.mean(torch.stack(rewards)) # to be collected later in the code 

            advantages = torch.stack(advantages)
            values = torch.stack(values)
            vtarget = advantages + values
            lossCritic = F.mse_loss(values,vtarget)
            
            _frames = np.array(states)
            _subFramesList = []
            r = []
            for element in _frames:
                for subfram in element:
                    _pil = Image.fromarray(subfram)
                    _gray = to_tensor(v2.Grayscale(1)(_pil))
                    _subFramesList.append(_gray)
                #_observation = torch.stack(_subFramesList)
                r.append(_subFramesList)
            # TODO : more cleaning aroung here...
            new  = []
            
            for s,a in zip(r,actions): # s -> torch.Size([5, 1, 240, 256])
                dist = Categorical(self.policy.forward(s[0]))
                log = dist.log_prob(a)
                new.append(log)
            
            ratio = torch.exp(torch.stack(new)) / torch.exp(torch.stack(oldProb))
            l = [] #TODO : cleaning around here too
            for i in range(len(advantages)):
                ra = ratio*advantages[i]
                clip = torch.clamp(ra,(1-epsilon),(1+epsilon))*advantages[i]
                loss = torch.min(ra,clip)
                l.append(loss)

            lossPolicy = -torch.mean(torch.stack(l))
            
            totalLoss = lossPolicy + c1*lossCritic
            self.policy.optim.zero_grad()
            self.value.optim.zero_grad()
            totalLoss.backward(retain_graph=True)
            self.policy.optim.step()
            self.value.optim.step()
             
        #self.writter.add_scalars("main/Reward",_rewards)
        #self.writter.add_scalars("main/Loss",totalLoss)

        #TODO checkpoint implementation
        self.save()

            
             

             
          
          
     
t = Agent()
t.train()     

100%|██████████| 2/2 [00:01<00:00,  1.46it/s]
