In [None]:
import gymnasium as gym 
import ale_py
import matplotlib.pyplot as plt
import numpy as np
import torch 
import torch.nn as nn
import torch.optim as optim
from collections import namedtuple, deque
import random
from itertools import count
import torch.nn.functional as F
import math
import os
from torch.cuda.amp import autocast, GradScaler


In [None]:


env = gym.make('ALE/AirRaid-v5',render_mode='rgb_array')

transition = namedtuple('transition', ('state', 'action', 'next_state', 'reward'))

path=r'E:\projects\atariGAN\train_video'
os.makedirs(path,exist_ok=True)
env = gym.wrappers.RecordVideo(
    env,
    episode_trigger=lambda num: num % 100 == 0,
    video_folder=path,
    name_prefix="video-"
    
) 

class ReplayMemory(object):
    def __init__(self,capacity):
        self.memory=deque([],maxlen=capacity)
        
    def push(self,*args):
        self.memory.append(transition(*args))
    
    def sample (self,batch_size):
        return random.sample(self.memory,batch_size)
    
    def __len__(self):
        return len(self.memory)
    
class DQN (nn.Module):

    def __init__(self, n_observations,n_actions):
        super(DQN,self).__init__()
        self.layer1=nn.Linear(n_observations,128)
        self.layer2=nn.Linear(128,128)
        self.layer3=nn.Linear(128,n_actions)
    
    def forward(self,x):
        x=F.relu(self.layer1(x))
        x=F.relu(self.layer2(x))
        return self.layer3(x)

In [None]:
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)



BATCH_SIZE=128
GAMMA=0.99
EPS_START=0.9
EPS_END=0.01
EPS_DECAY=2500
TAU=0.005
LR=3E-5
MAX_STEPS=1000

n_actions=env.action_space.n
state, info = env.reset()
state_tensor = torch.tensor(state, dtype=torch.float32)
n_observations = state_tensor.numel()

policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer=optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory=ReplayMemory(10000)

steps_done=0

def select_action(state):
    global steps_done
    sample=random.random()
    eps_treshold=EPS_END-(EPS_START-EPS_END)*\
        math.exp(-1*steps_done/EPS_DECAY)
    steps_done+=1
    if sample > eps_treshold:
        with torch.no_grad():
            return policy_net(state).max(1).indices.view(1,1)
    else:
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
    
episode_durations=[]

def plot_durations(episode,show_result=False,reward=None,done=False):
    durations_t=torch.tensor(episode_durations,dtype=torch.float)
    if episode%10==0:
        plt.figure(1)
        plt.clf()
        if show_result:
            plt.title('Result')
        else:
            plt.title('training')
        plt.xlabel('episode')
        plt.ylabel('Duration')
        plt.plot(durations_t.numpy())
        if len(durations_t)>=100:
            means=durations_t.unfold(0,100,1).mean(1).view(-1)
            means=torch.cat((torch.zeros(99),means))
            plt.plot(means.numpy())
        plt.draw()
        plt.pause(0.001)
    if len(episode_durations)>0 :
        print(f"Episode: {len(episode_durations)} | Last duration: {episode_durations[-1]} \
              | episode reward : {reward if done else 'result graph'}\
              | Average last 100: {durations_t[-100:].mean().item():.2f} ")
    


In [None]:
scaler=torch.amp.GradScaler(device=device)
def model_optimize():
    if len(memory)<BATCH_SIZE:
        return
    transitions=memory.sample(BATCH_SIZE)
    batch=transition(*zip(*transitions))

    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                          batch.next_state)), device=device, dtype=torch.bool)
    state_batch = torch.cat(batch.state).view(BATCH_SIZE, -1)
    non_final_next_states = torch.cat([s.view(-1) for s in batch.next_state if s is not None]).view(-1, n_observations)
    action_batch = torch.cat(batch.action).long().view(-1, 1)
    reward_batch=torch.cat(batch.reward)

    state_action_values=policy_net(state_batch).gather(1,action_batch)

    next_state_values=torch.zeros(BATCH_SIZE,device=device)
    with torch.no_grad():
        next_state_values[non_final_mask]=target_net(non_final_next_states).max(1).values
    criterion=nn.SmoothL1Loss()

    with torch.amp.autocast(device_type='cuda'):
        expected_state_action_values=(next_state_values* GAMMA) + reward_batch
        loss=criterion(state_action_values,expected_state_action_values.unsqueeze(1))


    optimizer.zero_grad()
    scaler.scale(loss).backward()
    torch.nn.utils.clip_grad_value_(policy_net.parameters(),100)
    scaler.step(optimizer)
    scaler.update()
    


In [None]:
plt.ion()
frame_skip=4

for i_episode in range(2500):
    max_reward = 0
    allocated = torch.cuda.memory_allocated(device)
    reserved = torch.cuda.memory_reserved(device)
    total = torch.cuda.get_device_properties(device).total_memory
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32,device=device).unsqueeze(0)
    state = state.view(state.size(0), -1)
    for t in count():
        total_reward=0.0
        done=False
        action=select_action(state)
        for _ in range (frame_skip):
            observation,reward_frame,terminated,truncated, _ =env.step(action.item())
            total_reward+=reward_frame
            done = terminated or truncated
            if done:
                break
        reward=torch.tensor([total_reward],device=device)
            
        if terminated:
            next_state=None
        else:
            next_state=torch.tensor(observation,dtype=torch.float32,device=device).unsqueeze(0)
            next_state = next_state.view(next_state.size(0), -1)
        
        max_reward += reward.item()
        memory.push(state,action,next_state,reward)
        if allocated >= total * 0.9 or reserved >= total * 0.9:
            torch.cuda.empty_cache()
        state=next_state
        model_optimize()

        target_net_state_dict=target_net.state_dict()
        policy_net_state_dict=policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key]=policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
            
        target_net.load_state_dict(target_net_state_dict)
            
        if done or t>MAX_STEPS:
            episode_durations.append(t+1)
            plot_durations(i_episode,reward=max_reward)
            break
       
 
plot_durations(i_episode,show_result=True,reward=max_reward,done=True)
plt.ioff()
plt.show()                 
