In [None]:
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

## Packages

In [None]:
!pip install ALE
!pip install gym[atari,accept-rom-license]==0.21.0

In [None]:
from copy import deepcopy
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
from matplotlib import pyplot as plt
import random
from collections import deque
import pickle

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Policy

In [None]:
class Policy(nn.Module):
    def __init__(self, buffer, env, device=torch.device('cpu')):
        super(Policy, self).__init__()
        self.device = device
        self.env = env
        learning_rate = 0.001
        self.epsilon = 0.5
        self.batch_size = 64
        self.network = Q_network(self.env, learning_rate)
        self.target_network = deepcopy(self.network)
        self.buffer = buffer
        self.window = 50
        self.reward_threshold = 800
        self.training_rewards = []
        self.training_loss = []
        self.update_loss = []
        self.mean_training_rewards = []
        self.sync_eps = []
        self.rewards = 0
        self.step_count = 0


    def forward(self, x):
        return x
    
    def act(self, state):
        action = self.network.greedy_action(torch.FloatTensor(state))
        return action

    def update(self):
        self.network.optimizer.zero_grad()
        batch,weights,tree_idxs = self.buffer.sample(batch_size=self.batch_size)
        loss,td_error = self.calculate_loss(batch,weights=weights)
        self.buffer.update_priorities(tree_idxs,td_error)
        loss.backward()
        self.network.optimizer.step()

        self.update_loss.append(loss.item())

    def take_step(self, mode='exploit'):
        if mode == 'explore':
            action = self.env.action_space.sample()
        else:
            action = self.network.greedy_action(torch.FloatTensor(self.s_0))

        s_1, r, done, _ = self.env.step(action)

        self.buffer.add(self.s_0, action, r, done, s_1)
        self.rewards += r
        self.s_0 = s_1.copy()

        self.step_count += 1
        if done:
            self.s_0 = self.env.reset()
        
        return done

    def expert_train(self):
        self.gamma = 0.99
        network_sync_frequency = 2
        batch_size_sum = 0
        self.loss_function = nn.MSELoss()

        ep = 0
        
        print("############ Pre-training with expert trace started!")
        while batch_size_sum <= self.buffer.get_memory_size():    
            self.update()
            
            if ep % network_sync_frequency == 0:
              self.target_network.load_state_dict(self.network.state_dict())
              self.sync_eps.append(ep)

           
            if len(self.update_loss) == 0:
              self.training_loss.append(0)
            else:
              self.training_loss.append(np.mean(self.update_loss))

            self.update_loss = []
            mean_loss = np.mean(self.training_loss[-self.window:])
            print("\rEpisode {:d} || mean loss = {:.2f}\t\t".format(ep, mean_loss))
            
            if(ep%10 == 0):
                print("Checkpoint!")
                self.save_check("pre"+str(ep))
            
            batch_size_sum += self.batch_size
            ep+=1

        self.pretrain_save()

    def train(self):
        self.gamma = 0.99
        max_episodes = 5
        network_update_frequency = 10
        network_sync_frequency = 200
        self.loss_function = nn.MSELoss()
        self.s_0 = self.env.reset()

        for _ in range(self.batch_size):
            self.take_step(mode='explore')
        
        ep = 0
        training = True
        self.populate = False
        print("############ Training started")
        while training:
            self.s_0 = self.env.reset()

            self.rewards = 0
            done = False
            while not done:

                p = np.random.random()
                if p < self.epsilon:
                    done = self.take_step(mode='explore')
                else:
                    done = self.take_step(mode='exploit')
                
                if self.step_count % network_update_frequency == 0:
                    self.update()
                
                if self.step_count % network_sync_frequency == 0:
                    self.target_network.load_state_dict(
                        self.network.state_dict())
                    self.sync_eps.append(ep)

                if done:
                    ep += 1

                    if self.epsilon >= 0.05:
                        self.epsilon = self.epsilon * 0.7
                
                    self.training_rewards.append(self.rewards)

                    if len(self.update_loss) == 0:
                        self.training_loss.append(0)
                    else:
                        self.training_loss.append(np.mean(self.update_loss))

                    self.update_loss = []
                    mean_rewards = np.mean(self.training_rewards[-self.window:])
                    mean_loss = np.mean(self.training_loss[-self.window:])
                    self.mean_training_rewards.append(mean_rewards)
                    print(
                        "\rEpisode {:d} Mean Rewards {:.2f}  Episode reward = {:.2f}   mean loss = {:.2f}\t\t".format(
                            ep, mean_rewards, self.rewards, mean_loss))

                    if ep >= max_episodes:
                        training = False
                        print('\nEpisode limit reached.')
                        break
                    if mean_rewards >= self.reward_threshold:
                        training = False
                        print('\nEnvironment solved in {} episodes!'.format(ep))
                    
                    if(ep%10 == 0):
                        print("Checkpoint!")
                        self.save_check(ep)

        self.save()
        self.plot_training_rewards()

    def plot_training_rewards(self):
        plt.plot(self.mean_training_rewards)
        plt.title('Mean training rewards')
        plt.ylabel('Reward')
        plt.xlabel('Episods')
        plt.show()
        plt.savefig('mean_training_rewards.png')
        plt.clf()

    def calculate_loss(self, batch,weights=None):
        states, actions, rewards, next_states, dones = list(batch)
        rewards = rewards.reshape(-1, 1)
        actions = torch.from_numpy(np.array(actions[:,-1],dtype="int64")).reshape(-1, 1)
        dones = dones.reshape(-1, 1)
        states = from_tuple_to_tensor(states)
        next_states = from_tuple_to_tensor(next_states)
        loss=0

        for i,s in enumerate(states):
            qvals = self.network.get_qvals(s)
            qvals = torch.gather(qvals, 0, actions[i])
            next_qvals= self.target_network.get_qvals(next_states[i])
            next_qvals_max = torch.max(next_qvals, dim=-1)[0].reshape(-1, 1)
            target_qvals = rewards[i] + (1 - dones[i])*self.gamma*next_qvals_max
            loss+=torch.mean((qvals - target_qvals) ** 2 * weights)

        loss=loss/len(states)

        if weights is None:
            weights = torch.ones_like(qvals)

        td_error = torch.abs(qvals - target_qvals).detach()
        rewards.detach()
        actions.detach()
        dones.detach()
        states.detach()
        next_states.detach()
        return loss,td_error

    def save_check(self,ep):
        name = "/content/drive/MyDrive/Colab Notebooks/RL/Demonstration/model_" + str(ep) + ".pt"
        torch.save(self.state_dict(), name)

    def save(self):
        torch.save(self.state_dict(), '/content/drive/MyDrive/Colab Notebooks/RL/Demonstration/model.pt')

    def pretrain_save(self):
        torch.save(self.state_dict(), '/content/drive/MyDrive/Colab Notebooks/RL/Demonstration/pre_model.pt')

    def load(self):
        self.load_state_dict(torch.load('/content/drive/MyDrive/Colab Notebooks/RL/Demonstration/model.pt'))

    def to(self, device):
        ret = super().to(device)
        ret.device = device
        return ret

## DQN

In [None]:
def from_tuple_to_tensor(tuple_of_np):
    tensor = torch.zeros((len(tuple_of_np), tuple_of_np[0].shape[0],tuple_of_np[0].shape[1],tuple_of_np[0].shape[2]))
    for i, x in enumerate(tuple_of_np):
        tensor[i] = torch.FloatTensor(x)
    return tensor

In [None]:
class Net(nn.Module):
    def __init__(self, n_frames,n_actions, hidden_size=32, bias=True):
        super().__init__()
        self.hidden_size = hidden_size
        self.n_frames = n_frames
        self.conv1 = nn.Conv2d(n_frames, hidden_size, 7)
        self.conv2 = nn.Conv2d(hidden_size, hidden_size, 5)
        self.fc1 = nn.Linear(hidden_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, n_actions)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = x.reshape(self.hidden_size, -1).max(axis=1).values
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = F.softmax(x, dim=0)
        return x

class Q_network(nn.Module):
    def __init__(self, env,  learning_rate=1e-4):
        super(Q_network, self).__init__()

        self.network = Net(1, env.action_space.n)
        self.optimizer = torch.optim.Adam(self.network.parameters(),lr=learning_rate)
        self.gs = transforms.Grayscale()
        self.rs = transforms.Resize((64,64))

    def greedy_action(self, state):
        qvals = self.get_qvals(state)
        greedy_a = torch.max(qvals, dim=-1)[1].item()
        return greedy_a

    def get_qvals(self, state):
        state=self.preproc_state(state)
        out = self.network(state)
        return out

    def preproc_state(self, state):
        state=state.numpy()
        state = state[:83, :].transpose(2, 0, 1)
        state = torch.from_numpy(state)
        state = self.gs(state)
        state = self.rs(state)
        return state / 255

In [None]:
class SumTree:
    def __init__(self, size):
        self.nodes = [0] * (2 * size - 1)
        self.data = [0] * size

        self.size = size
        self.count = 0
        self.real_size = 0

    @property
    def total(self):
        return self.nodes[0]

    def update(self, data_idx, value):
        idx = data_idx + self.size - 1
        change = value - self.nodes[idx]

        self.nodes[idx] = value

        parent = (idx - 1) // 2
        while parent >= 0:
            self.nodes[parent] += change
            parent = (parent - 1) // 2

    def add(self, value, data):
        self.data[self.count] = data

        self.update(self.count, value)

        self.count = (self.count + 1) % self.size
        self.real_size = min(self.size, self.real_size + 1)

    def get(self, cumsum):
        assert cumsum <= self.total

        idx = 0
        while 2 * idx + 1 < len(self.nodes):
            left, right = 2*idx + 1, 2*idx + 2

            if cumsum <= self.nodes[left]:
                idx = left
            else:
                idx = right
                cumsum = cumsum - self.nodes[left]

        data_idx = idx - self.size + 1

        return data_idx, self.nodes[idx], self.data[data_idx]

    def get_size(self):
        return self.real_size

    def __repr__(self):
        return f"SumTree(nodes={self.nodes.__repr__()}, data={self.data.__repr__()})"

In [None]:
class Prioritized_experience_replay_buffer:
    def __init__(self,env, device,memory_size=50000, burn_in=10000,eps=1e-2, alpha=0, beta=0.1):
        self.env=env
        self.tree = SumTree(size=memory_size)
        self.device = device
        self.memory_size = memory_size
        self.burn_in = burn_in
        self.count = 0
        self.real_size = 0
        self.replay_memory = deque(maxlen=memory_size)
        self.eps = eps
        self.alpha = alpha
        self.beta = beta
        self.max_priority = eps
        self.state = torch.empty(memory_size, env.observation_space._shape[0],env.observation_space._shape[1],env.observation_space._shape[2], dtype=torch.float)
        self.action = torch.empty(memory_size, env.action_space.n, dtype=torch.float)
        self.reward = torch.empty(memory_size, dtype=torch.float)
        self.next_state = torch.empty(memory_size, env.observation_space._shape[0],env.observation_space._shape[1],env.observation_space._shape[2], dtype=torch.float)
        self.done = torch.empty(memory_size, dtype=torch.int)

    def sample(self, batch_size=32):
        assert self.real_size >= batch_size, "buffer contains less samples than batch size"
        sample_idxs, tree_idxs = [], []
        priorities = torch.empty(batch_size, 1, dtype=torch.float)
        segment = self.tree.total / batch_size

        for i in range(batch_size):
            a, b = segment * i, segment * (i + 1)

            cumsum = random.uniform(a, b)
            tree_idx, priority, sample_idx = self.tree.get(cumsum)
            priorities[i] = torch.tensor(priority)
            tree_idxs.append(tree_idx)
            sample_idxs.append(sample_idx)

        probs = priorities / self.tree.total
        weights = ((1/self.real_size) * (1/probs)) ** self.beta
        weights = weights / weights.max()

        batch = (
            self.state[sample_idxs],
            self.action[sample_idxs],
            self.reward[sample_idxs],
            self.next_state[sample_idxs],
            self.done[sample_idxs]
        )
        return batch, weights, tree_idxs

    def update_priorities(self, data_idxs, priorities):
        if isinstance(priorities, torch.Tensor):
            priorities = priorities.detach().cpu().numpy()

        for data_idx, priority in zip(data_idxs, priorities):
            priority = (priority + self.eps) ** self.alpha
            self.tree.update(data_idx, priority)
            self.max_priority = max(self.max_priority, priority)

    def burn_in_capacity(self):
        return len(self.replay_memory) / self.burn_in

    def capacity(self):
        return len(self.replay_memory) / self.memory_size

    def get_memory_size(self):
        return self.real_size

    def add(self, state, action, reward,done, next_state):
        self.tree.add(self.max_priority, self.count)

        self.state[self.count] = torch.as_tensor(state)
        self.action[self.count] = torch.as_tensor(action)
        self.reward[self.count] = torch.as_tensor(reward)
        self.next_state[self.count] = torch.as_tensor(next_state)
        self.done[self.count] = torch.as_tensor(done)

        self.count = (self.count + 1) % self.memory_size
        self.real_size = min(self.memory_size, self.real_size + 1)

In [None]:
objects = []
with (open("/content/drive/MyDrive/Colab Notebooks/RL/Demonstration/expert_tracefinal.pkl", "rb")) as openfile:
    while True:
        try:
            objects.append(pickle.load(openfile))
        except EOFError:
            break

obs = objects[0][0]
act = objects[0][1]
rew = objects[0][2]
next = objects[0][3]
done = objects[0][4]

In [None]:
print("Num states: "+str(len(objects[0][0])))
print("Num actions: "+str(len(objects[0][1])))
print("Num rewards: "+str(len(objects[0][2])))
print("Num next states: "+str(len(objects[0][3])))
print("Num dones: "+str(len(objects[0][4])))

In [None]:
env = gym.make('MontezumaRevenge-v4', render_mode='rgb_array')
buffer = Prioritized_experience_replay_buffer(env,device)

In [None]:
for i in range(len(obs)):
  buffer.add(obs[i], act[i], rew[i], done[i],next[i])

In [None]:
agent = Policy(buffer,env,device)

In [None]:
agent.expert_train()

In [None]:
agent.train()