In [None]:
# hypers 
lr = 0.001
_lambda_ = 0.99
gamma = 0.99
epochs = 10

epsilon = 0.2
c1 = 0.5

kernel = (1,1) # filter shape for the conv layers
OUTPUT_CHANNEL = 5
numConv = 2
numlinear = 2

In [130]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

class policy(nn.Module):
    def __init__(self):
        super().__init__()
        self.numConv = numConv 
        self.numLinear = numlinear
        self.outLinear = 1
        self.kernel = kernel
        self.channel = OUTPUT_CHANNEL
        self.convs = nn.ModuleList(
            [nn.LazyConv2d(self.channel,self.kernel) for n in range(self.numConv)]
        )
        self.linear = nn.ModuleList(
            [nn.LazyLinear(7) for n in range(self.numLinear)]
        )
        """self.linear2 = nn.ModuleList(
            [nn.LazyLinear() for n in self.numLinear]
        )"""
        self.optim = Adam(self.parameters(),lr=lr)
        
    def forward(self,x):
        for convs in self.convs:
            x = F.relu(convs(x))
        x = torch.flatten(x,0)
        for linear in self.linear:
            x = F.relu(linear(x))
        return F.softmax(x,-1)
    

class value(nn.Module):
    def __init__(self):
        super().__init__()
        self.channel = OUTPUT_CHANNEL
        self.kernel = kernel
        self.numConv = numConv 
        self.convs = nn.ModuleList(
            [nn.LazyConv2d(self.channel,self.kernel) for n in range(self.numConv)]
        )
        self.linear = nn.LazyLinear(1)
        self.optim = Adam(self.parameters(),lr=lr)

    def forward(self,x):
        for convs in self.convs:
            x = F.relu(convs(x))
        x = torch.flatten(x)
        x = F.relu(self.linear(x))
        return F.relu(x,-1)


In [132]:
r = torch.rand((5,240,256), dtype=torch.float)
a = value()
a(r)

tensor([0.], grad_fn=<ReluBackward0>)

In [None]:
import random
import sys

import warnings
warnings.filterwarnings("ignore")
import gym_super_mario_bros
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from gymnasium.wrappers import FrameStack,GrayScaleObservation
from gymnasium.spaces import Box

import numpy as np
from torch.distributions import Categorical
from PIL import Image
from torchvision.transforms import v2
from torchvision.transforms.functional import to_tensor

numFrames = 5
env = gym_super_mario_bros.make('SuperMarioBros-v0',apply_api_compatibility=True)
env = JoypadSpace(env, SIMPLE_MOVEMENT)
env = FrameStack(env,numFrames)

class Memory:
    def __init__(self):
        self.policy = policy()
        self.value = value()
        self.env = env
        self.data = []
        self.gamma = gamma
        self._lambda_ = _lambda_
    
    def rollout(self,batchsize):
        self.clear()
        for _ in range(batchsize):
            _image = np.array(self.env.reset()[0])
            _list_ = []
            for i in range(numFrames): # grayscale frame by frame
                _pil = Image.fromarray(_image[i])
                _observation = to_tensor(v2.Grayscale(1)(_pil))
                _list_.append(_observation)

            _states = torch.stack(_list_,dim=0) # --> torch.Size([5, 1, 240, 256])
            _distribution = Categorical(self.policy.forward(_states)) 
            value = self.value.forward(_states)

            action = _distribution.sample()
            prob = _distribution.log_prob(action)
            state,reward,done,_,_ = self.env.step(action.item())
            self.data.append([state,torch.tensor(reward),value,prob,action.item(),done])
        
        # advantages 
        _,_rewards,_values,_,_,_ = zip(*self.data) 
        _rewards = torch.stack(_rewards)
        _values = torch.stack(_values) 
        _values = torch.cat((_values,torch.tensor([[0]])))
        n = torch.arange(batchsize)
        _temporalDifferences  = _rewards[n] + self.gamma*_values[n+1] - _values[n]
        _temporalDifferences = torch.flip(_temporalDifferences,dims=[-1])
       
        _advantage = 0
        advantages = _temporalDifferences[n] + (self._lambda_ * self.gamma * _advantage)
        advantages = torch.flip(advantages,dims=[-1]) 
  
        for data,item in zip(self.data,advantages): # append advantages to the data 
            data.append(item)
        random.shuffle(self.data) 
        states,rewards,values,logProb,actions,done,advantages = zip(*self.data)
        return  states,actions,rewards,values,logProb,advantages,done

    def clear(self):
        self.data = []


t = Memory()
t.rollout(5)

((<gymnasium.wrappers.frame_stack.LazyFrames at 0x244b1770bd0>,
  <gymnasium.wrappers.frame_stack.LazyFrames at 0x244b0e68ea0>,
  <gymnasium.wrappers.frame_stack.LazyFrames at 0x244b17723e0>,
  <gymnasium.wrappers.frame_stack.LazyFrames at 0x244b123f740>,
  <gymnasium.wrappers.frame_stack.LazyFrames at 0x244b123da80>),
 (3, 1, 5, 6, 0),
 (tensor(0.), tensor(0.), tensor(0.), tensor(0.), tensor(0.)),
 (tensor([0.0927], grad_fn=<ReluBackward0>),
  tensor([0.0927], grad_fn=<ReluBackward0>),
  tensor([0.0927], grad_fn=<ReluBackward0>),
  tensor([0.0927], grad_fn=<ReluBackward0>),
  tensor([0.0927], grad_fn=<ReluBackward0>)),
 (tensor(-2.0227, grad_fn=<SqueezeBackward1>),
  tensor(-2.0227, grad_fn=<SqueezeBackward1>),
  tensor(-2.0227, grad_fn=<SqueezeBackward1>),
  tensor(-2.0150, grad_fn=<SqueezeBackward1>),
  tensor(-1.9380, grad_fn=<SqueezeBackward1>)),
 (tensor([-0.0927, -0.0927, -0.0927, -0.0927, -0.0927],
         grad_fn=<UnbindBackward0>),
  tensor([-0.0009, -0.0009, -0.0009, -0.000

In [None]:
torch.autograd.set_detect_anomaly(True)
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

class Agent:
    def __init__(self):
        self.env = env
        self.memory = Memory()
        self.policy = policy()
        self.value = value()
        self.batchsize = batchsize
        self.epochs = epochs

    def save(self):
        torch.save(self.policy.state_dict(),"./Mario.pth")

    def train(self):
        for _ in tqdm(range(self.epochs),total=self.epochs):
            states,rewards,values,logProb,advantages,done = self.memory.rollout(self.batchsize)

            # TODO : compute critic loss
            # TODO compute new log probs
            # TODO compute loss critic 
            #  saves data 
            # test policy

            