In [1]:
import random
import warnings
warnings.filterwarnings("ignore")
import gym_super_mario_bros
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from gymnasium.wrappers import FrameStack
import numpy as np
import torch
from PIL import Image
from torchvision.transforms import v2,Resize
from torchvision.transforms.functional import to_tensor

# hypers 
lr = 0.0003
_lambda_ = 0.99
gamma = 0.99
epsilon = 0.2
c1 = 0.5
numFrames = 5 

env = gym_super_mario_bros.make('SuperMarioBros-v0',apply_api_compatibility=True)
env = JoypadSpace(env, SIMPLE_MOVEMENT)
env = FrameStack(env,numFrames)

def _stransform_env_output(observation): # output -> torch.Size([1, 5, 150, 150])
    _list_ = []
    for element in observation:
        _tonumpy = np.array(element)
        _topil = Image.fromarray(_tonumpy)
        _gray = to_tensor(v2.Grayscale(1)(_topil))
        _resized = Resize((150,150))(_gray)
        _list_.append(_resized)
    return torch.stack(_list_,dim=0).permute(1,0,2,3)

In [2]:
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
 
class network(nn.Module):
    def __init__(self):
        super().__init__()
        self.input = nn.LazyConv2d(1,(1,1))
        self.conv1 = nn.LazyConv2d(1,(3,3),stride=2)
        self.conv2 = nn.LazyConv2d(1,(3,3),stride=1)
        self.conv3 = nn.LazyConv2d(1,(3,3),stride=2)
        
        self.linear1 = nn.LazyLinear(3000)
        self.linear2 = nn.LazyLinear(1500)
        self.linear3 = nn.LazyLinear(750)
        self.linear4 = nn.LazyLinear(375)

        self.policyHead = nn.LazyLinear(7)
        self.valueHead = nn.LazyLinear(1)

        self.optim = Adam(self.parameters(),lr=lr)
        
    def forward(self,x):
        x = F.relu(self.input(x))
        x = self.conv1(x)
        x = F.relu(self.conv2(x))
        x = self.conv3(x)
        x = F.relu(torch.flatten(x,start_dim=1))
        x = self.linear1(x)
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        x = F.relu(self.linear4(x))
        policyOut = self.policyHead(x)
        valueOut = self.valueHead(x)
        return F.softmax(policyOut,-1),valueOut
    
network()(torch.rand((1,5,150,150),dtype=torch.float)) # init
 

(tensor([[0.1471, 0.1359, 0.1433, 0.1376, 0.1473, 0.1419, 0.1469]],
        grad_fn=<SoftmaxBackward0>),
 tensor([[0.0047]], grad_fn=<AddmmBackward0>))

In [None]:
from torch.distributions import Categorical
import gc

class Memory:
    def __init__(self):
        self.network = network()
        self.env = env
        self.data = []
        self.gamma = gamma
        self._lambda_ = _lambda_
        self.pointer = 0
    
    def rollout(self,batchsize):
        self.clear()
        _stacked_frames,_ = self.env.reset()
        for _ in range(batchsize):
            _tranformed_observation = _stransform_env_output(_stacked_frames)
            _policy_output , value = self.network.forward(_tranformed_observation)
            _distribution = Categorical(_policy_output)
            action = _distribution.sample()
            prob = _distribution.log_prob(action)
            state,reward,done,_,_ = self.env.step(action.item())
            self.data.append([state,torch.tensor(reward),value,prob,action,done])

        _,_rewards,_values,_,_,_ = zip(*self.data) # compute advantages --
        _rewards = torch.stack(_rewards)
        _values = torch.stack(_values).reshape(batchsize)
        _values = torch.cat((_values,torch.tensor([0])))
        
        n = torch.arange(batchsize)
        _temporalDifferences  = _rewards[n] + self.gamma*_values[n+1] - _values[n]
        _temporalDifferences = torch.flip(_temporalDifferences,dims=[-1])
        _advantage = 0
        advantages = _temporalDifferences[n] + (self._lambda_ * self.gamma * _advantage)
        advantages = torch.flip(advantages,dims=[-1]) 
      
        for data,item in zip(self.data,advantages): # append advantages to data 
            data.append(item)
        random.shuffle(self.data) 

    def sample(self,number):
        output = self.data[self.pointer:self.pointer+number]
        self.pointer+=number
        states,rewards,values,logProb,actions,done,advantages = zip(*output)
        return states,actions,rewards,values,logProb,advantages,done

    def clear(self):
        self.data = []
        self.pointer = 0
        gc.collect()
 

In [None]:
torch.autograd.set_detect_anomaly(True)
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import sys

class Agent:
    def __init__(self):
        self.env = env
        self.memory = Memory()
        self.network = network()
        self.batchsize = 10 # 128
        self.minibatch = 5
        self.epochs = 10  # 40_000 
        self.writter = SummaryWriter("./")

    def save(self):
        torch.save(self.network.state_dict(),"./mario.pth")
  
    def train(self):
        for traj in tqdm(range(self.epochs),total=self.epochs):
            self.memory.rollout(self.batchsize)
            for n in range(self.batchsize//self.minibatch):
                states,actions,rewards,values,_old_log_prob,advantages,done = self.memory.sample(self.minibatch)        
                advantages = torch.stack(advantages)
                values = torch.stack(values,dim=-1)
                _vtarget = advantages + values
                loss_critic = F.mse_loss(values,_vtarget)
               
                _states_list = []
                for element in states: # process each Lazyframe then stack 
                    _transformed_states = _stransform_env_output(element)
                    _states_list.append(_transformed_states)
                _stacked_states = torch.stack(_states_list,dim=0).squeeze(1) # ->  torch.Size([5, 5, 150, 150])
                _stacked_actions = torch.stack(actions,dim=1)

                _policy_output,_ = self.network.forward(_stacked_states)
                dist = Categorical(_policy_output)
                _new_log_prob = dist.log_prob(_stacked_actions)
                _old_log_prob = torch.stack(_old_log_prob,dim=1)
                ratio = torch.exp(_new_log_prob)/torch.exp(_old_log_prob)

                _loss_policy_list = [] 
                for i in range(len(advantages)):
                    _ratio_advantages = ratio*advantages[i]
                    _clipped_ratio_advantages = torch.clamp(_ratio_advantages,(1-epsilon),(1+epsilon))*advantages[i]
                    loss = torch.min(_ratio_advantages,_clipped_ratio_advantages)
                    _loss_policy_list.append(loss)

                loss_policy = -torch.mean(torch.stack(_loss_policy_list))
                totalLoss = loss_policy + c1*loss_critic
                self.network.optim.zero_grad()
                totalLoss.backward(retain_graph=True)
                self.network.optim.step()

                _rewards = torch.mean(torch.stack(rewards)) 
            self.writter.add_scalar("main/Reward",_rewards)
            self.writter.add_scalar("main/Loss",totalLoss)
            
        self.save()

Agent().train()


  0%|          | 0/10 [00:00<?, ?it/s]

0.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
1.0
0.0


 10%|█         | 1/10 [00:01<00:17,  1.96s/it]

0.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
1.0
0.0


 20%|██        | 2/10 [00:03<00:13,  1.73s/it]

0.0
0.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
0.0


 30%|███       | 3/10 [00:04<00:11,  1.61s/it]

0.0
0.0
-1.0
0.0
1.0
0.0
0.0
0.0
0.0
1.0


 40%|████      | 4/10 [00:06<00:09,  1.62s/it]

0.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
0.0
1.0


 50%|█████     | 5/10 [00:08<00:07,  1.55s/it]

0.0
0.0
-1.0
0.0
1.0
0.0
0.0
0.0
0.0
0.0


 60%|██████    | 6/10 [00:09<00:06,  1.57s/it]

0.0
-1.0
0.0
1.0
0.0
0.0
0.0
0.0
1.0
0.0


 70%|███████   | 7/10 [00:11<00:04,  1.59s/it]

0.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
1.0
0.0


 80%|████████  | 8/10 [00:13<00:03,  1.71s/it]

0.0
0.0
0.0
0.0
0.0
1.0
0.0
0.0
1.0
0.0


 90%|█████████ | 9/10 [00:15<00:01,  1.74s/it]

0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
1.0
0.0


100%|██████████| 10/10 [00:16<00:00,  1.67s/it]
