In [83]:
import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

In [84]:
is_ipython = 'inline' in matplotlib.get_backend()
print(is_ipython)
if is_ipython: from IPython import display

True


In [85]:
class DQN(nn.Module):
    def __init__(self,img_height, img_width):
        super().__init__()
        # Add these layers for a faster and better convergence cuz duh...
        # Also add batchn Normalisation after convulation layers, it helps in converging faster for better resolution images.
        self.c2d1 = nn.Conv2d(in_channels = 3, out_channels = 6, kernel_size = 5)
        self.c2d2 = nn.Conv2d(in_channels = 6, out_channels = 12, kernel_size = 5)
        self.fc1 = nn.Linear(in_features = ((img_height-8)*(img_width-8))*12, out_features = 64)
#         self.fc1 = nn.Linear(in_features = img_height*img_width*3, out_features = 64)
        self.fc2 = nn.Linear(in_features = 64, out_features = 64)
        self.out = nn.Linear(in_features = 64, out_features = 2)
    def forward(self, t):
        t = F.relu(self.c2d1(t))
        t = F.relu(self.c2d2(t))
        t = t.flatten(start_dim = 1)
        t = F.relu(self.fc1(t))
        t = F.relu(self.fc2(t))
        t = self.out(t)
        return t

In [104]:
a = DQN(100,100)
print(a)

DQN(
  (c2d1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (c2d2): Conv2d(6, 12, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=101568, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (out): Linear(in_features=64, out_features=2, bias=True)
)


In [86]:
Experience = namedtuple('Experience',('state','action','next_state','reward'))

In [87]:
e = Experience(2,3,1,4)

In [88]:
e

Experience(state=2, action=3, next_state=1, reward=4)

In [89]:
e.state

2

In [90]:
class ReplayMemory():
    def __init__(self,capacity):
        self.capacity = capacity
        self.memory = []
        self.push_count = 0
    def push(self,experience):
        if len(self.memory) < self.capacity:
            self.memory.append(experience)
        else:
            self.memory[self.push_count%self.capacity]=experience
        self.push_count +=1
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    def can_provide_sample(self, batch_size):
        return len(self.memory)>=batch_size

In [91]:
class EpsilonGreedyStrategy():
    def __init__(self,start,end,decay):
        self.start = start
        self.end = end
        self.decay = decay
    def get_exploration_rate(self, current_step):
        return self.end + (self.start-self.end)*math.exp(-1*current_step*self.decay)

In [92]:
class Agent():
    def __init__(self, strategy, num_actions, device):
        self.current_step = 0
        self.device = device
        self.strategy = strategy
        self.num_actions = num_actions
    def select_action(self, state, policy_net):
        rate = strategy.get_exploration_rate(self.current_step)
        self.current_step +=1
        
        if rate > random.random():
            # explore
            action = random.randrange(self.num_actions)
            return torch.tensor([action]).to(device)
        else:
            # exploit
            with torch.no_grad():
                return policy_net(state).argmax(dim = 1).to(device)

In [93]:
class CartPoleEnvManager():
    def __init__(self, device):
        self.device = device
        self.env = gym.make('CartPole-v0').unwrapped
        self.env.reset()
        self.current_screen = None
        self.done = False
        
    def reset(self):
        self.env.reset()
        self.current_screen = None
    
    def close(self):
        self.env.close()
    
    def render(self, mode = 'human'):
        return self.env.render(mode)
    
    def num_actions_available(self):
        return self.env.action_space.n
    
    def take_action(self,action):
        useless, reward, self.done, uselessPart2 = self.env.step(action.item())
        return torch.tensor([reward], device = self.device)
    
    def just_starting(self):
        return self.current_screen is None
    
    def get_state(self):
        if self.just_starting() or self.done:
            self.current_screen = self.get_processed_screen()
            black_screen = torch.zeros_like(self.current_screen)
            return black_screen
        else:
            s1 = self.current_screen
            s2 = self.get_processed_screen()
            self.current_screen = s2
            return s2-s1
        
    def get_screen_height(self):
        screen = self.get_processed_screen()
        return screen.shape[2]
    
    def get_screen_width(self):
        screen = self.get_processed_screen()
        return screen.shape[3]
    
    def get_processed_screen(self):
        screen = self.render('rgb_array').transpose((2, 0, 1)) # PyTorch expects CHW
        screen = self.crop_screen(screen)
        return self.transform_screen_data(screen)
    
    def crop_screen(self, screen):
        screen_height = screen.shape[1]
        top = int(screen_height * 0.4)
        bottom = int(screen_height * 0.8)
        screen = screen[:,top:bottom,:]
        return screen
    
    def transform_screen_data(self, screen):       
        # Convert to float, rescale, convert to tensor
        screen = np.ascontiguousarray(screen, dtype=np.float32) / 255
        screen = torch.from_numpy(screen)
    
        # Use torchvision package to compose image transforms
        resize = T.Compose([
            T.ToPILImage()
            ,T.Resize((40,90))
            ,T.ToTensor()
        ])
    
        return resize(screen).unsqueeze(0).to(self.device)

In [94]:
print(torch.cuda.is_available())

False


In [95]:
'''device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
em = CartPoleEnvManager(device)
em.reset()
screen = em.render('rgb_array')

plt.figure()
plt.imshow(screen)
plt.title('Non-processed screen example')
plt.show()
em.close()'''

'device = torch.device("cuda" if torch.cuda.is_available() else "cpu")\nem = CartPoleEnvManager(device)\nem.reset()\nscreen = em.render(\'rgb_array\')\n\nplt.figure()\nplt.imshow(screen)\nplt.title(\'Non-processed screen example\')\nplt.show()\nem.close()'

In [96]:
'''screen = em.get_processed_screen()

plt.figure()
plt.imshow(screen.squeeze(0).permute(1, 2, 0).cpu(), interpolation='none')
plt.title('Processed screen example')
plt.show()
em.close()'''

"screen = em.get_processed_screen()\n\nplt.figure()\nplt.imshow(screen.squeeze(0).permute(1, 2, 0).cpu(), interpolation='none')\nplt.title('Processed screen example')\nplt.show()\nem.close()"

In [97]:
'''screen = em.get_state()
    
plt.figure()
plt.imshow(screen.squeeze(0).permute(1, 2, 0).cpu(), interpolation='none')
plt.title('Starting state example')
plt.show()
em.close()'''

"screen = em.get_state()\n    \nplt.figure()\nplt.imshow(screen.squeeze(0).permute(1, 2, 0).cpu(), interpolation='none')\nplt.title('Starting state example')\nplt.show()\nem.close()"

In [98]:
'''for i in range(5):
    em.take_action(torch.tensor([1]))
screen = em.get_state()

plt.figure()
plt.imshow(screen.squeeze(0).permute(1, 2, 0).cpu(), interpolation='none')
plt.title('Non starting state example')
plt.show()
em.close()'''

"for i in range(5):\n    em.take_action(torch.tensor([1]))\nscreen = em.get_state()\n\nplt.figure()\nplt.imshow(screen.squeeze(0).permute(1, 2, 0).cpu(), interpolation='none')\nplt.title('Non starting state example')\nplt.show()\nem.close()"

In [99]:
class QValues():
    device = torch.device("cpu")
    
    @staticmethod
    def get_current(policy_net, states, actions):
        return policy_net(states).gather(dim = 1, index = actions.unsqueeze(-1))
    
    @staticmethod
    def get_next(target_net, next_states):
        # we flatten the values of the pixels in each of the next_states and then
        # compare the max of them to 0 as the ending state will have a blank black 
        # screen and then store those locations(indices of states) as True
        final_state_locations = next_states.flatten(start_dim = 1).max(dim = 1)[0].eq(0).type(torch.bool)
        # just flipping the above tensor's bool values
        non_final_state_locations = (final_state_locations == False)
        # tensors allow slicing with a tensor of bool values, not just indices
        non_final_states = next_states[non_final_state_locations]
        batch_size = next_states.shape[0]
        # initialising a tensor of length of states made of zeros
        values = torch.zeros(batch_size).to(QValues.device)
        # sets values of q* function only in the states without a terminal 
        # ending and then returns the max out of them
        values[non_final_state_locations] = target_net(non_final_states).max(dim = 1)[0].detach()
        return values

In [100]:
def plot(values, moving_avg_period):
    plt.figure(2)
    plt.clf()
    plt.title('Training.....')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(values)
    moving_avg = get_moving_average(moving_avg_period, values)
    plt.plot(moving_avg)
    plt.pause(0.001)
    print("Episode", len(values),"\n",moving_avg_period, "episode moving avg: ",moving_avg[-1])
    if is_ipython: display.clear_output(wait = True)

def get_moving_average(period, values):
    values = torch.tensor(values, dtype = torch.float)
    if (len(values)>= period):
        moving_avg = values.unfold(dimension = 0, size = period, step = 1).mean(dim = 1).flatten(start_dim = 0)
        moving_avg = torch.cat((torch.zeros(period-1), moving_avg))
        return moving_avg.numpy()
    else:
        moving_avg = torch.zeros(len(values))
        return moving_avg.numpy()

In [101]:
def extract_tensors(experiences):
    batch = Experience(*zip(*experiences))
    
    t1 = torch.cat(batch.state)
    t2 = torch.cat(batch.action)
    t3 = torch.cat(batch.reward)
    t4 = torch.cat(batch.next_state)
    
    return (t1,t2,t3,t4)

In [102]:
batch_size = 256
gamma = 0.999
eps_start = 1
eps_end = 0.01
eps_decay = 0.001
target_update = 10
memory_size = 100000
lr = 0.001
num_episodes = 500

In [105]:
device = torch.device("cpu")
em = CartPoleEnvManager(device)
strategy = EpsilonGreedyStrategy(eps_start, eps_end, eps_decay)
agent = Agent(strategy, em.num_actions_available(), device)
memory = ReplayMemory(memory_size)
policy_net = DQN(em.get_screen_height(), em.get_screen_width()).to(device)
target_net = DQN(em.get_screen_height(), em.get_screen_width()).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
optimizer = optim.Adam(params = policy_net.parameters(), lr = lr)
episode_durations = []
for episode in range(num_episodes):
    em.reset()
    state = em.get_state()
    
    for timestep in count():
        action = agent.select_action(state,policy_net)
        reward = em.take_action(action)
        next_state = em.get_state()
        memory.push(Experience(state, action, next_state, reward))
        state = next_state
        
        if memory.can_provide_sample(batch_size):
            experiences = memory.sample(batch_size)
            states, actions, rewards, next_states = extract_tensors(experiences)
            
            current_q_values = QValues.get_current(policy_net,states, actions)
            next_q_values = QValues.get_next(target_net, next_states)
            target_q_values = (next_q_values * gamma) + rewards
            
            loss = F.mse_loss(current_q_values, target_q_values.unsqueeze(1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        if em.done:
            episode_durations.append(timestep)
            plot(episode_durations,100)
            break
        
    if episode % target_update == 0:
        target_net.load_state_dict(policy_net.state_dict())

device = torch.device("cpu")        
policy_net = policy_net.to(device)
target_net = target_net.to(device)
agent.device = device
'''torch.cuda.empty_cache()
device = torch.device("cuda")        
policy_net = policy_net.to(device)
target_net = target_net.to(device)
agent.device = device
'''
for episode in range(num_episodes):
    em.reset()
    state = em.get_state()
    
    for timestep in count():
        action = agent.select_action(state,policy_net)
        reward = em.take_action(action)
        next_state = em.get_state()
        memory.push(Experience(state, action, next_state, reward))
        state = next_state
        
        if memory.can_provide_sample(batch_size):
            experiences = memory.sample(batch_size)
            states, actions, rewards, next_states = extract_tensors(experiences)
            
            current_q_values = QValues.get_current(policy_net,states, actions)
            next_q_values = QValues.get_next(target_net, next_states)
            target_q_values = (next_q_values * gamma) + rewards
            
            loss = F.mse_loss(current_q_values, target_q_values.unsqueeze(1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        if em.done:
            episode_durations.append(timestep)
            plot(episode_durations,100)
            break
        
    if episode % target_update == 0:
        target_net.load_state_dict(policy_net.state_dict())
        
em.close()

AttributeError: 'Tensor' object has no attribute 'c2d1'