In [None]:
import warnings
warnings.filterwarnings("ignore")
import torch,gym,random
import gym_super_mario_bros
from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from gym.wrappers import GrayScaleObservation,FrameStack,ResizeObservation
from gym.vector import SyncVectorEnv
import numpy as np


In [None]:
seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #
lr = 0.0001
_lambda_ = 0.95
gamma = 0.99
epsilon = 0.2
c1 = 0.5
c2 = 0.01
    # 
resized_observation = (90,90)
num_frames = 4
num_env = 6
skip_frame = 4

In [None]:
def make_env(): 
    
    class CustomEnv(gym.Wrapper): 
        def __init__(self,env,skip):
            super().__init__(env)
            self.skip = skip
            self.score = 0

        def step(self, action):
            total_reward = 0
            for _ in range(self.skip):
                obs,reward,done,truncared,info = self.env.step(action)
                total_reward += reward

            if info.get("life") == 2 : 
                total_reward += (info.get("score") - self.score) / 50
                self.score = info.get("score")
            else:
                total_reward -= 200
                done = True
            return obs,total_reward,done,truncared,info

        def reset(self, **kwargs):
            return self.env.reset()

    x = gym_super_mario_bros.make('SuperMarioBros-v0',apply_api_compatibility=True) 
    x = ResizeObservation(x,resized_observation)  
    x = CustomEnv(x,skip_frame) 
    x = JoypadSpace(x, SIMPLE_MOVEMENT)  
    x = GrayScaleObservation(env=x,keep_dim=True)  
    #x = FrameStack(x,num_frames)
    x = SyncVectorEnv(env_fns=[lambda : x] * num_env)
    return x


In [51]:
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from IPython.display import clear_output

class network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.LazyConv2d(32,1,1,0)
        self.conv2 = nn.LazyConv2d(32,3,2,2)
        self.conv3 = nn.LazyConv2d(32,3,2,2)
        self.conv4 = nn.LazyConv2d(32,3,2,2)
        self.output = nn.LazyLinear(80)

        self.policy_head = nn.LazyLinear(7)
        self.value_head = nn.LazyLinear(1)
        self.optim = Adam(self.parameters(),lr=lr)
        
    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = torch.flatten(x,start_dim=1) # -> torch.Size([32, 169])
        x = F.relu(self.output(x))
        policy_output = self.policy_head(x)
        value_output = self.value_head(x)
        return F.softmax(policy_output,-1),value_output

model = network().to(device)
model.forward(torch.rand((1,1,90,90),dtype=torch.float32))
clear_output()


In [None]:
import gc,sys
from torch.distributions import Categorical

class Memory:
    def __init__(self,env : SyncVectorEnv):
        self.network = model
        self.env = env
        self.gamma = gamma
        self._lambda_ = _lambda_
        self.data = []
        self.pointer = 0
    
    def rollout(self,batchsize):
        self.clear()
        _observation,info = self.env.reset()
        observation = torch.from_numpy(_observation).permute(0,-1,1,2).to(device,torch.float32) # torch.Size([6, 1, 90, 90])
        with torch.no_grad(): 
            for _ in range(batchsize):  
                policy_output , value = self.network.forward(observation)
                distribution = Categorical(policy_output)
                action = distribution.sample()
                prob = distribution.log_prob(action)
                state,reward,done,_,_ = self.env.step(action.tolist())
                
                if done.any(): # reset each done env or get an error haha !
                    done_idx = np.where(done == True)[0] 
                    for i in done_idx:
                        _observation[i],info[i] = self.env.envs[i].reset()
                        
                self.data.append([torch.from_numpy(state).permute(0,-1,1,2).to(torch.float32),
                                  torch.tensor(reward,dtype=torch.float32),value,prob,action,done])

        _,rewards,values,_,_,_ = zip(*self.data) # compute advantages 
        _rewards = torch.stack(rewards).to(device) 
        _values = torch.stack(values).squeeze() 
        zeros = torch.zeros(1,_values.shape[-1],device=device,dtype=torch.float32)
        _values = torch.cat([_values,zeros])
        assert _rewards.dtype == _values.dtype == torch.float32

        advantages = torch.zeros_like(_rewards,device=device,dtype=torch.float32)
        gae = 0.0
        for n in reversed(range(len(_rewards))):
            td = _rewards[n] + self.gamma * _values[n+1] - _values[n] # temporal difference
            gae = td + (self._lambda_ * self.gamma * gae) 
            advantages[n] = gae

        for data,item in zip(self.data,advantages): # append advantages to data
            data.append(item)
        random.shuffle(self.data) 

    def sample(self,number):
        output = self.data[self.pointer:self.pointer+number]
        self.pointer+=number
        states,rewards,values,logProb,actions,done,advantages = zip(*output)
        return states,actions,rewards,values,logProb,advantages,done

    def clear(self):
        self.data = []
        self.pointer = 0
        gc.collect()

In [None]:
torch.autograd.set_detect_anomaly(True)
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

class agent:
    def __init__(self,):
        self.env = make_env()
        self.memory = Memory(self.env)
        self.network = model
        self.writter = SummaryWriter("./")

    def save(self,k):
        checkpoint = {
            "model_state" : self.network.state_dict(),
            "optim_state" : self.network.optim.state_dict(),
            "lr" : lr,
            "c2" : c2,
            "c1" : c1,
            "epsilon" : epsilon
        }
        torch.save(checkpoint,f"./mario{k}")

    def train(self,num_game,batchsize,minibatch,optim_step):
        for traj in tqdm(range(num_game),total = num_game):
            self.memory.rollout(batchsize)

            for _ in range(batchsize//minibatch):
                states,actions,rewards,values,old_log_prob,advantages,done = self.memory.sample(minibatch)    
                _rewards = torch.mean(torch.stack(rewards))
                
                for _ in range(optim_step): 
                    _advantages = torch.stack(advantages)
                    _values = torch.stack(values).squeeze()
                    _old_log_prob = torch.stack(old_log_prob) 
                    
                    vtarget = _advantages + _values
                    loss_critic = torch.mean(torch.pow((_values-vtarget),2)) 
                
                    stacked_states = torch.stack(states) # [minibatch, num_env, channel, image_shape]
                    flat_states = torch.flatten(stacked_states,0,1).to(device) # [(minibatch * num_env), channel, image_shape]
                    p_output,_ = self.network.forward(flat_states)
                    policy_output = p_output.view(minibatch,num_env,7)
                    dist = Categorical(policy_output)   
                    stacked_actions = torch.stack(actions)  
                    new_log_prob = dist.log_prob(stacked_actions)

                    _entropy = -(policy_output * torch.log(policy_output + 1e-10)).sum(dim=2)
                    entropy = _entropy.sum(dim=1).mean()
            
                    ratio = torch.exp(new_log_prob - _old_log_prob)
                    proxy1 = ratio * _advantages
                    proxy2 = torch.clamp(ratio ,1-epsilon,1+epsilon) * _advantages
                    loss_policy = -torch.mean(torch.min(proxy1,proxy2))
               
                    total_loss = loss_policy + (c1*loss_critic) - (c2*entropy)
    
                    self.network.optim.zero_grad()
                    total_loss.backward(retain_graph=True)
                    torch.nn.utils.clip_grad_norm_(self.network.parameters(), 0.5)
                    self.network.optim.step()
                    
            self.writter.add_scalar("main/reward",_rewards,traj)
            self.writter.add_scalar("main/entropy",entropy,traj)
            self.writter.add_scalar("main/total loss",total_loss,traj)
            self.writter.add_scalar("main/loss policy",loss_policy,traj)
            self.writter.add_scalar("main/loss critic",loss_critic,traj) 
            
            if traj % 20 == 0 : # save every ...k steps
                self.save(traj)
                
        self.save("_end")

test = agent()
test.train(num_game=600,batchsize=600,minibatch=100,optim_step=5) # rollout test

In [34]:
import torch
torch.manual_seed(42)

t = torch.softmax(torch.rand((1,9),dtype=torch.float),-1)
t2 = torch.softmax(torch.rand((1,9),dtype=torch.float),-1)

In [43]:
t + 1e-10

tensor([[0.1316, 0.1360, 0.0799, 0.1422, 0.0805, 0.0994, 0.0704, 0.1205, 0.1396]])

In [37]:
t2

tensor([[0.0653, 0.1456, 0.1036, 0.1364, 0.1009, 0.1200, 0.0879, 0.1387, 0.1015]])

In [17]:
x = t/t2
x

tensor([[2.0143, 0.9339, 0.7714, 1.0420, 0.7977, 0.8278, 0.8012, 0.8688, 1.3745]])

In [21]:
torch.log(x) 

tensor([[ 0.7003, -0.0684, -0.2595,  0.0411, -0.2261, -0.1890, -0.2216, -0.1406,
          0.3181]])

In [50]:
torch.log(torch.tensor([-1e23242]))

tensor([nan])