In [None]:
import matplotlib.pyplot as plt
plt.plot([1,2,3,4])

In [None]:
import random
import gym
import sys
import numpy as np
import matplotlib.pyplot as plt
from collections import deque,namedtuple
import os
import copy
import torch
import torch.nn as nn
from torch.optim import RAdam,Adam
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage

In [None]:
env = make_atari_env('BreakoutNoFrameskip-v4')

In [None]:
x = env.reset()

In [None]:
x.shape

In [None]:
plt.imshow(x.squeeze())

In [None]:
env = VecFrameStack(env,n_stack=4)

In [None]:
x = env.reset()

In [None]:
x.shape

In [None]:
plt.imshow(x.squeeze(0))

In [None]:
x_new,r,done,_ = env.step([3])
plt.imshow(x_new.squeeze(0))

In [None]:
x_new,r,done,_ = env.step([3])
plt.imshow(x_new.squeeze(0))

In [None]:
env = VecTransposeImage(env)

In [None]:
state_sz = env.observation_space.shape
action_sz = env.action_space.n
print('State space: ',state_sz)
print('Action space: ',action_sz)

In [None]:
class DQN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4)),
                                nn.ReLU(),
                                nn.Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2)),
                                nn.ReLU(),
                                nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)),
                                nn.ReLU(),
                                nn.Flatten(start_dim=1, end_dim=-1),
                                nn.Linear(in_features=3136, out_features=512, bias=True),
                                nn.ReLU(),
                                nn.Linear(in_features=512, out_features=4, bias=True))
    def forward(self,x):
        return self.net(x)
    
    def predict(self,x):
        return self(x).argmax(dim=-1).cpu().numpy()
            

In [None]:
dqn = DQN().cuda()

In [None]:
x = env.reset()

In [None]:
z = torch.tensor(x).to(torch.float).cuda()

In [None]:
dqn(z)

In [None]:
n_transitions=10000000
batch_size=32
gamma = 0.99
lr = 0.0001
eps = 1.0
decay = 1e-6

In [None]:
buffer_sz = int(1e6)
replay_buffer = deque(maxlen=buffer_sz)

In [None]:
transition = namedtuple('transition',['x_new','reward','x','action','done'])

In [None]:
def store(transition):
    replay_buffer.append(transition)

In [None]:
dqn = dqn.cuda()
target = copy.deepcopy(dqn)

In [None]:
optimizer = Adam(dqn.parameters(),lr=lr)
loss_fn = nn.HuberLoss()

In [None]:
def update():
    
    if len(replay_buffer)<batch_size:
        return
    
    batch = random.sample(replay_buffer,batch_size)
    
    x = torch.FloatTensor(np.array([t.x for t in batch])).squeeze(1).cuda()
    r = torch.FloatTensor(np.array([t.reward for t in batch])).cuda()
    x_new = torch.FloatTensor(np.array([t.x_new for t in batch])).squeeze(1).cuda()
    a = torch.LongTensor(np.array([t.action for t in batch])).unsqueeze(1).cuda()
    done = torch.FloatTensor(np.array([t.done for t in batch])).squeeze(1).cuda()
    
    a_ = dqn(x_new).argmax(dim=-1).unsqueeze(1)
    
    target_q = (r + gamma*target(x_new).gather(1,a_).squeeze(1)*(1-done).squeeze())

    
    prediction_q = dqn(x).gather(1,a)

    
    optimizer.zero_grad()
    
    loss = loss_fn(target_q,prediction_q.squeeze())
    
    loss.backward()
    
    optimizer.step() 
    
    

In [None]:
class Agent():
    def __init__(self,target_update_frequency=10000,eps=1):
        
        self.eps = eps
        self.target_update_frequency = target_update_frequency
        self.target_update_counter = 0
        self.total_rewards = 0.0
        self.total_transitions = 0
        self.episodes = 0
        self.episode_reward = 0.0
        
        
    def select_action(self,x,eps):
        
        t = np.random.random()
        if t < eps:
            a = np.random.choice(range(action_sz))
        else:
            q = dqn(torch.FloatTensor(x).cuda())
            a = q.argmax().item()   
        return a
            
        
        
    def run_episode(self,render):

        x = env.reset()
        self.episode_reward = 0
        done = False
        total_reward = 0.0
        transition_count = 0
        
        while not done:
            
            self.target_update_counter += 1
            
            if self.eps > 0.1:
                self.eps -= decay
            else:
                self.eps = 0.1

            action = self.select_action(x,self.eps)
            
            x_new,reward,done,_ = env.step([action])

            transition_count+=1
                    
            x = x_new
            
            t = transition(x_new,float(reward.item()),x,action,done)
            store(t)
            
            update()
            
            self.episode_reward += reward.item()

            done = done 

        if self.target_update_counter >= self.target_update_frequency:

                self.target_update_counter = 0
                target.load_state_dict(dqn.state_dict())
                print('Target network updated')
                
        self.total_rewards += self.episode_reward
        self.total_transitions += transition_count
            

#             print('Running Average',np.mean(self.rewards[-20:]))
        
        
    def train(self):
        
        while self.total_transitions <= n_transitions:
            
            for i in range(4):
                self.run_episode(False)
    
            print('Total Transitions',self.total_transitions)
            print('Avg. Reward per Episode',self.total_rewards/4)
    
            print('\n --------------------------------------------------------------')
        
            self.total_rewards = 0.0

            

       

In [None]:
agent = Agent()

In [None]:
agent.train()

In [None]:
obs = env.reset()
torch.tensor(obs).cuda().shape

In [None]:
dqn(torch.FloatTensor(obs).cuda())

In [None]:
obs = env.reset()
for _ in range(10000):
    action = dqn.predict(torch.FloatTensor(obs).cuda())
    obs, rewards, dones, info = env.step(action)
    env.render()