In [None]:
"""
!pip install numpy==1.25.2
!pip install torch==2.7.0
!pip install torchvision==0.22.0
!pip install gym_super_mario_bros==7.4.0
!pip install gym==0.26.2
!pip install nes_py==8.2.1
!pip install gymnasium==0.29.1
!pip install opencv-python==4.11.0.86
"""

In [None]:
import warnings
warnings.filterwarnings("ignore")
import torch,gym,random
import gym_super_mario_bros
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from gym.wrappers import GrayScaleObservation,FrameStack,ResizeObservation
from gym.vector import SyncVectorEnv
import numpy as np
    #
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #
_lambda_ = 1.0
gamma = 0.99
epsilon = 0.2
lr = 1e-4
c1 = 0.5
c2 = 1e-3
c_pi = 1e4
    # 
num_env = 1 # 30 
skip_frame = 4
num_stack = 4
obs_shape = (100,100)

In [None]:
class CustomEnv(gym.Wrapper): 
        def __init__(self,env,skip):
            super().__init__(env)
            self.skip = skip
            self.score = 0
            
        def step(self, action):
            total_reward = 0  
            for _ in range(self.skip):
                obs,reward,done,truncared,info = self.env.step(action)
                total_reward+=reward
                if done:
                    self.reset()
                    return obs,total_reward,done,truncared,info
            return obs,total_reward,done,truncared,info

        def reset(self, **kwargs):
            obs, info = self.env.reset()
            return obs,info

def make_env(): 
    x = gym_super_mario_bros.make("SuperMarioBros-v1",apply_api_compatibility=True) 
    x = ResizeObservation(x,obs_shape)  
    x = CustomEnv(x,skip_frame) 
    x = JoypadSpace(x, SIMPLE_MOVEMENT)  
    x = GrayScaleObservation(x,True)
    x = FrameStack(x,num_stack)
    x = SyncVectorEnv(env_fns=[lambda : x] * num_env)
    return x

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.LazyConv2d(32,1,1,0)
        self.conv2 = nn.LazyConv2d(32,3,2,2)
        self.conv3 = nn.LazyConv2d(32,3,2,2)
        self.conv4 = nn.LazyConv2d(32,3,2,2)
        self.output = nn.LazyLinear(80)

        self.policy_head = nn.LazyLinear(7)
        self.value_head = nn.LazyLinear(1)
        self.optim = torch.optim.Adam(self.parameters(),lr=lr)
        
    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = torch.flatten(x,start_dim=1) # -> torch.Size([32, 169])
        x = F.relu(self.output(x))
        policy_output = self.policy_head(x)
        value_output = self.value_head(x)
        return F.softmax(policy_output,-1),value_output 
    
model = network().to(device)
model.forward(torch.rand((num_env,num_stack,*(obs_shape)),dtype=torch.float32,device=device))
model = nn.DataParallel(model)

In [None]:
import gc,sys
from torch.distributions import Categorical
from collections import deque

class Memory:
    def __init__(self,env : SyncVectorEnv):
        self.network = model
        self.env = env
        self.gamma = gamma
        self._lambda_ = _lambda_
        self.data = []
        self.pointer = 0
        self.episode_reward = np.zeros(self.env.num_envs, dtype=np.float32)
        self.finished_reward = deque(maxlen=5)
        self.log_total_steps = deque(maxlen=5)
        self.total_steps = torch.zeros(num_env)
    
    def rollout(self,batchsize):
        self.clear()
        self._observation,_ = self.env.reset()
        with torch.no_grad():
            for n in range(batchsize):
                self._observation = torch.from_numpy(np.array(self._observation).copy()).squeeze(-1).to(device,torch.float32) / 255. # [num_env, 3, 90, 90])
                policy_output , value = self.network.forward(self._observation)
                distribution = Categorical(policy_output)
                action = distribution.sample()
                prob = distribution.log_prob(action)
                state,reward,done,_,_ = self.env.step(action.tolist())
                
                for i in range(self.env.num_envs): # reward per episode in a deque
                    self.episode_reward[i] += reward[i] 
                    self.total_steps[i] += 1
                    if done[i]:
                        self.finished_reward.append(self.episode_reward[i])
                        self.log_total_steps.append(self.total_steps[i])
                        self.episode_reward[i] = 0
                        self.total_steps[i] = 0
        
                self.data.append(
                    [
                        self._observation,
                        torch.tensor(reward,dtype=torch.float32),
                        value,
                        prob,
                        action,
                        done
                    ]
                )
                self._observation = state
 
        _,rewards,values,_,_,_ = zip(*self.data) # advantages 
        _rewards = torch.stack(rewards).to(device) 
        _values = torch.stack(values).squeeze(-1) 

        zeros = torch.zeros(1,_values.shape[-1],device=device,dtype=torch.float32)
        _values = torch.cat([_values,zeros])
        assert _rewards.dtype == _values.dtype == torch.float32 
        advantages = torch.zeros_like(_rewards,device=device,dtype=torch.float32)
        gae = 0.0
        for n in reversed(range(len(_rewards))):
            td = _rewards[n] + self.gamma * _values[n+1] - _values[n] 
            gae = td + (self._lambda_ * self.gamma * gae) 
            advantages[n] = gae

        for data,item in zip(self.data,advantages): # append advantages to data
            data.append(item)
        random.shuffle(self.data) 

    def sample(self,number):
        output = self.data[self.pointer:self.pointer+number]
        self.pointer+=number
        states,rewards,values,logProb,actions,done,advantages = zip(*output)
        return states,actions,rewards,values,logProb,advantages,done

    def clear(self):
        self.data = []
        self.pointer = 0
        gc.collect()
        torch.cuda.empty_cache()
    
    def traj_reward(self):
        return self.finished_reward,self.log_total_steps

In [None]:
torch.autograd.set_detect_anomaly(True)
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

class agent:
    def __init__(self,):
        self.env = make_env()
        self.memory = Memory(self.env)
        self.network = model
        self.writter = SummaryWriter("./")

    def save(self,k):
        checkpoint = {
            "model_state" : self.network.state_dict(),
            "optim_state" : self.network.module.optim.state_dict(),
            "lr" : lr,
            "epsilon" : epsilon
        }
        torch.save(checkpoint,f"./mario{k}")

    def train(self,num_game,batchsize,minibatch,optim_step):
        for traj in tqdm(range(num_game),total = num_game):
            self.memory.rollout(batchsize)

            for _ in range(batchsize//minibatch):
                states,actions,rewards,values,old_log_prob,advantages,done= self.memory.sample(minibatch)    
                 
                for _ in range(optim_step): 
                    _advantages = torch.stack(advantages).view(-1)
                    _values = torch.stack(values).squeeze().view(-1)
                    _old_log_prob = torch.stack(old_log_prob) 
                    
                    vtarget = _advantages + _values
                    norm_vtarget = (vtarget - vtarget.mean()) / (vtarget.std() + 1e-10)
                    loss_critic = torch.mean(torch.pow((_values - norm_vtarget),2)) 
                    # explained variance
                    expl_variance = 1.0 - (torch.var(norm_vtarget - _values) / torch.var(vtarget + 1e-10))
                
                    stacked_states = torch.stack(states) # - > [minibatch, num_env, channel, image_shape]
                    flat_states = torch.flatten(stacked_states,0,1).to(device) # - > [(minibatch * num_env), channel, image_shape]
                    p_output,_ = self.network.forward(flat_states)
                    policy_output = p_output.view(minibatch,num_env,7) # 7 is the actionspace for a single env
                    dist = Categorical(policy_output)  
                    stacked_actions = torch.stack(actions)  
                    new_log_prob = dist.log_prob(stacked_actions)

                    ratio = torch.exp(new_log_prob - _old_log_prob).view(-1)
                    norm_advantage  = (_advantages - _advantages.mean()) / (_advantages.std() + 1e-10) 
                    prox1 = ratio * norm_advantage
                    prox2 = torch.clamp(ratio ,1-epsilon,1+epsilon) * norm_advantage
                    loss_policy = -torch.mean(torch.min(prox1,prox2)) * 2e3 
    
                    _entropy = -(policy_output * torch.log(policy_output + 1e-10)).sum(dim=2)
                    entropy = _entropy.sum(dim=1).mean() * 1e-1

                    total_loss = loss_policy + loss_critic - entropy 
    
                    self.network.module.optim.zero_grad()
                    total_loss.backward(retain_graph=True)
                    nn.utils.clip_grad_norm_(self.network.parameters(), 0.5)
                    self.network.module.optim.step()
                    
            self.writter.add_scalar("Policy/entropy",entropy,traj)
            self.writter.add_scalar("Policy/prox 1",prox1.mean(),traj)
            self.writter.add_scalar("Policy/prox 2",prox2.mean(),traj)
            self.writter.add_scalar("Policy/loss policy",loss_policy,traj)

            self.writter.add_scalar("Value/values",_values.sum(),traj)
            self.writter.add_scalar("Value/vtarget",norm_vtarget.sum(),traj)
            self.writter.add_scalar("Value/value loss",loss_critic,traj) 
            self.writter.add_scalar("Value/Explained variance",expl_variance)

            self.writter.add_scalar("main/total loss",total_loss,traj)
            self.writter.add_scalar("main/epi rewards",torch.tensor([self.memory.traj_reward()[0]]).mean(),traj) 
            self.writter.add_scalar("main/total steps",torch.tensor([self.memory.traj_reward()[1]]).mean(),traj)
     
            if traj % 20 == 0 : # save every ...k steps
                self.save(traj)
                
        self.save("_end")
    
    @staticmethod
    def run_training(run=False):
        if run:
            agnt = __class__()
            agnt.train(num_game=100,batchsize=20,minibatch=5,optim_step=5) # test 
            # agnt.train(num_game=1400,batchsize=512,minibatch=16,optim_step=10) # 

agent.run_training(True)

In [None]:
class test:
    @staticmethod
    def make_env(): 
        x = gym_super_mario_bros.make("SuperMarioBros-v1",apply_api_compatibility=True,render_mode="human") 
        x = ResizeObservation(x,obs_shape)
        x = CustomEnv(x,skip_frame) 
        x = JoypadSpace(x, SIMPLE_MOVEMENT)  
        x = GrayScaleObservation(x,True)
        x = FrameStack(x,4)
        return x

    @staticmethod
    def run(start,num_game):
        if start:
            with torch.no_grad():
                model = network()
                chk = torch.load(".\mario180",map_location="cpu")
                model.load_state_dict(chk["model_state"],strict=False)
                env = __class__.make_env()
                done = True
                re = 0
                for _ in range(num_game):
                    if done:
                        state,_ = env.reset()
                        print(re)
                        re = 0
                    state = torch.from_numpy(np.array(state).copy()).squeeze().to(device,torch.float32).unsqueeze(0)
                    dist,_ = model.forward(state)
                    action = Categorical(dist).sample().item()
                    state, reward, done, info,_ = env.step(env.action_space.sample())
                    re += reward
                    env.render()
                env.close()

test.run(start=False,num_game=10_000)

Useful links

https://spinningup.openai.com/en/latest/algorithms/ppo.html

https://github.com/vietnh1009/Super-mario-bros-PPO-pytorch/tree/master

https://github.com/yumouwei/super-mario-bros-reinforcement-learning

https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/

https://en.wikipedia.org/wiki/Proximal_policy_optimization