In [None]:
import warnings
warnings.filterwarnings("ignore")
import torch,gym,random
import gym_super_mario_bros
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from gym.wrappers import GrayScaleObservation,FrameStack
from gym.vector import SyncVectorEnv
import numpy as np
    #
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #
lr = 0.0001
_lambda_ = 1.0
gamma = 0.99
epsilon = 0.2
c1 = 0.5
c2 = 0.01
    # 
num_env = 8
skip_frame = 4

In [10]:
class CustomEnv(gym.Wrapper): 
        def __init__(self,env,skip):
            super().__init__(env)
            self.skip = skip
            self.score = 0
            self.current_x_pos = 40
            self.penalty = 0
            
        def step(self, action):
            total_reward = 0  
            for _ in range(self.skip):
                obs,_,done,truncared,info = self.env.step(action)

            if info.get("life") == 2 : 
                if info.get("score") > self.score :
                    total_reward = ((info.get("score") - self.score) / 10. ) + 10
                    self.score = info.get("score") 
                elif info.get("x_pos") > self.current_x_pos :
                    self.current_x_pos = info.get("x_pos")
                    total_reward = 15
                    self.penalty = 0
                else : 
                    if info.get("x_pos") != 40 :
                        self.penalty += -1
                        total_reward += max(self.penalty,-10)
            else:
                total_reward = -20
                done = True
            return obs,total_reward,done,truncared,info

        def reset(self, **kwargs):
            self.score = 0
            self.current_x_pos = 40
            obs, info = self.env.reset()
            return obs,info

def make_env(): 
    x = gym_super_mario_bros.make('SuperMarioBros-v3',apply_api_compatibility=True) 
    x = CustomEnv(x,skip_frame) 
    x = JoypadSpace(x, SIMPLE_MOVEMENT)  
    x = GrayScaleObservation(x,True)
    x = SyncVectorEnv(env_fns=[lambda : x] * num_env)
    return x

In [3]:
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from IPython.display import clear_output

class network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.LazyConv2d(16,1,1,0)
        self.conv2 = nn.LazyConv2d(16,2,2,1)
        self.conv3 = nn.LazyConv2d(16,2,2,1)
        self.conv4 = nn.LazyConv2d(16,2,2,1)
        self.output1 = nn.LazyLinear(500)
        self.output2 = nn.LazyLinear(80)

        self.policy_head = nn.LazyLinear(7)
        self.value_head = nn.LazyLinear(1)
        self.optim = Adam(self.parameters(),lr=lr)
        
    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = torch.flatten(x,start_dim=1)  
        x = F.relu(self.output1(x))
        x = F.relu(self.output2(x))
        policy_output = self.policy_head(x)
        value_output = self.value_head(x)
        return F.softmax(policy_output,-1),value_output
      
model = network().to(device)
model.forward(torch.rand((8,3,240,256),dtype=torch.float32,device=device))
clear_output()

In [4]:
import gc,sys
from torch.distributions import Categorical

class Memory:
    def __init__(self,env : SyncVectorEnv):
        self.network = model
        self.env = env
        self.gamma = gamma
        self._lambda_ = _lambda_
        self.data = []
        self.pointer = 0
        self._observation,info = self.env.reset()
    
    def rollout(self,batchsize):
        self.clear()
        with torch.no_grad(): 
            for n in range(batchsize):
                self._observation = torch.from_numpy(self._observation.copy()).permute(0,-1,1,2).to(device,torch.float32) # [num_env, 3, 90, 90])
                policy_output , value = self.network.forward(self._observation)
                distribution = Categorical(policy_output)
                action = distribution.sample()
                prob = distribution.log_prob(action)
                state,reward,done,_,_ = self.env.step(action.tolist())  
                
                if done.any(): 
                    done_idx = np.where(done == True)[0] 
                    for i in done_idx:
                        state[i],_ = self.env.envs[i].reset()
         
                self.data.append([self._observation,
                                  torch.tensor(reward,dtype=torch.float32),value,prob,action,done])
                self._observation = state
 
        _,rewards,values,_,_,_ = zip(*self.data) # compute advantages 
        _rewards = torch.stack(rewards).to(device) 
        _values = torch.stack(values).squeeze() 

        zeros = torch.zeros(1,_values.shape[-1],device=device,dtype=torch.float32)
        _values = torch.cat([_values,zeros])
        assert _rewards.dtype == _values.dtype == torch.float32

        advantages = torch.zeros_like(_rewards,device=device,dtype=torch.float32)
        gae = 0.0
        for n in reversed(range(len(_rewards))):
            td = _rewards[n] + self.gamma * _values[n+1] - _values[n] 
            gae = td + (self._lambda_ * self.gamma * gae) 
            advantages[n] = gae

        for data,item in zip(self.data,advantages): # append advantages to data
            data.append(item)
        random.shuffle(self.data) 

        sys.exit(len(self.data))

    def sample(self,number):
        output = self.data[self.pointer:self.pointer+number]
        self.pointer+=number
        states,rewards,values,logProb,actions,done,advantages = zip(*output)
        return states,actions,rewards,values,logProb,advantages,done

    def clear(self):
        self.data = []
        self.pointer = 0
        gc.collect()

In [5]:
torch.autograd.set_detect_anomaly(True)
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

class agent:
    def __init__(self,):
        self.env = make_env()
        self.memory = Memory(self.env)
        self.network = model
        self.writter = SummaryWriter("./")

    def save(self,k):
        checkpoint = {
            "model_state" : self.network.state_dict(),
            "optim_state" : self.network.optim.state_dict(),
            "lr" : lr,
            "c2" : c2,
            "c1" : c1,
            "epsilon" : epsilon
        }
        torch.save(checkpoint,f"./mario{k}")

    def train(self,num_game,batchsize,minibatch,optim_step):
        for traj in tqdm(range(num_game),total = num_game):
            self.memory.rollout(batchsize)

            for _ in range(batchsize//minibatch):
                states,actions,rewards,values,old_log_prob,advantages,done = self.memory.sample(minibatch)    
                _rewards = torch.mean(torch.stack(rewards)).round(decimals=4)
                 
                for _ in range(optim_step): 
                    _advantages = torch.stack(advantages).view(-1)
                    _values = torch.stack(values).squeeze().view(-1)
                    _old_log_prob = torch.stack(old_log_prob) 
                    
                    vtarget = _advantages + _values
                    norm_vtarget = (vtarget - vtarget.mean()) / (vtarget.std() + 1e-10)
                    loss_critic = torch.mean(torch.pow((_values - norm_vtarget),2)) 
                
                    stacked_states = torch.stack(states) # [minibatch, num_env, channel, image_shape]
                    flat_states = torch.flatten(stacked_states,0,1).to(device) # [(minibatch * num_env), channel, image_shape]
                    p_output,_ = self.network.forward(flat_states)
                    policy_output = p_output.view(minibatch,num_env,7)
                    dist = Categorical(policy_output)   
                    stacked_actions = torch.stack(actions)  
                    new_log_prob = dist.log_prob(stacked_actions)

                    ratio = torch.exp(new_log_prob - _old_log_prob).view(-1)
                    norm_advantage  = (_advantages - _advantages.mean()) / (_advantages.std() + 1e-10) 
                    proxy1 = ratio * norm_advantage
                    proxy2 = torch.clamp(ratio ,1-epsilon,1+epsilon) * norm_advantage
                    loss_policy = -torch.mean(torch.min(proxy1,proxy2))
    
                    _entropy = -(policy_output * torch.log(policy_output + 1e-10)).sum(dim=2)
                    entropy = _entropy.sum(dim=1).mean()
    
                    total_loss = loss_policy + (c1*loss_critic) - (c2*entropy)
    
                    self.network.optim.zero_grad()
                    total_loss.backward(retain_graph=True)
                    torch.nn.utils.clip_grad_norm_(self.network.parameters(), 0.5)
                    self.network.optim.step()
                    
            self.writter.add_scalar("main/reward",_rewards,traj)
            self.writter.add_scalar("main/entropy",entropy,traj)
            self.writter.add_scalar("main/total loss",total_loss,traj)
            self.writter.add_scalar("main/loss policy",loss_policy,traj)
            self.writter.add_scalar("main/loss critic",loss_critic,traj) 
        
            if traj % 20 == 0 : # save every ...k steps
                self.save(traj)
                
        self.save("_end")

test = agent()
#test.train(num_game=5,batchsize=10,minibatch=5,optim_step=5)  # rollout test

In [9]:
class test:
    @staticmethod
    def make_test_env(): 
        x = gym_super_mario_bros.make('SuperMarioBros-v3',apply_api_compatibility=True,render_mode="human") 
        x = CustomEnv(x,skip_frame) 
        x = JoypadSpace(x, SIMPLE_MOVEMENT)  
        return x

    @staticmethod
    def run(start,num_game):
        if start:
            with torch.no_grad():
                model = network()
                chk = torch.load("./mario1440",map_location="cpu")
                model.load_state_dict(chk["model_state"],strict=False)
                env = __class__.make_test_env()
                done = True
                step=0
                for _ in range(num_game):
                    if done:
                        state,_ = env.reset()
                        print(step)
                        step = 0
                    state = torch.from_numpy(state.copy()).permute(-1,0,1).to(torch.float32).unsqueeze(0) 
                    dist,_ = model.forward(state)
                    action = Categorical(dist).sample().item()
                    state, reward, done, info,_ = env.step(action)
                    step+=1
                    env.render()

                env.close()

test.run(start=True,num_game=5000)

0


  if not isinstance(terminated, (bool, np.bool8)):


204
368
58
63
399
238
133
178
69
54
217
562
460
529
73
123
69
885
