In [None]:
import gym
import time
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
import wandb
from mygame import Game2048Env
from matplotlib import pyplot as plt

class ReplayBuffer:
    def __init__(self, max_size=100000):
        self.buffer = [None]*max_size
        self.max_size = max_size
        self.index = 0
        self.size = 0
    
    def insert(self, obj):
        self.buffer[self.index] = obj
        self.size = min(self.size+1, self.max_size)
        self.index = (self.index+1)%self.max_size
    
    def sample(self, batch_size):
        indices = random.sample(range(self.size), batch_size)
        return [self.buffer[index] for index in indices]
    
    
class Model(nn.Module):
    def __init__(self, obs_shape, num_actions):
        super(Model, self).__init__()
        assert len(obs_shape)==1 , "This network only works for flat observations"
        self.obs_shape = obs_shape
        self.num_actions = num_actions
        #self.net = nn.Sequential(torch.nn.Linear(obs_shape[0], 256), torch.nn.ReLU(), torch.nn.Linear(256, num_actions),)
         #torch.nn.ReLU(), nn.Linear(32,32), torch.nn.ReLU(), nn.Linear(32, num_actions),)
        self.net = nn.Sequential(torch.nn.Linear(obs_shape[0],512), torch.nn.ReLU(), torch.nn.Linear(512,256), 
                   torch.nn.ReLU(), torch.nn.Linear(256, 128), torch.nn.ReLU(), torch.nn.Linear(128, num_actions),)
        self.optimizer = optim.Adam(self.net.parameters(), lr = 0.0001)
        
    def forward(self, x):
        return self.net(x)

def update_target_model(model, target):
    target.load_state_dict(model.state_dict())
    print('updating target model')
    
class Agent:
    lr = 1e-4
    epsilon_min = 0.05
    epsilon_decay_factor = 0.999999
    epsilon = 1.0
    min_rb_size = 20000
    sample_size = 750
    env_steps_before_train = 100
    tgt_model_update = 1000
    device = 'cpu'
    env = None
    
    episode = 0
    num_episodes = 1000
    step_num = 0
    rewards = []
    losses = []
    average_rewards = []
    epsilons = []
    num_episodes = []
    
    def __init__(self):
        self.rb = ReplayBuffer()
        
    def start_agent(self, obs_shape, num_actions):
        self.model = Model(self.env.observation_space.shape, self.env.action_space.n).to(self.device)
        self.target = Model(self.env.observation_space.shape, self.env.action_space.n).to(self.device)
        self.update_target_model()
        self.obs_shape = obs_shape
        self.num_actions = num_actions
        
    def select_action(self, observations, deterministic=False):
        self.epsilon_decay()
        if np.random.random()>self.epsilon or deterministic:
            return self.model(torch.tensor(observations).float()).max(-1)[1].item()
        else:
            return np.random.randint(0, self.num_actions)
    
    def store_transition(self, state, action, reward, next_state, done):
        self.rb.insert((state,action,reward,next_state,done))
        
    def update_target_model(self):
        self.target.load_state_dict(self.model.state_dict())
        
    def train_step(self,state_transitions):
        cur_states = torch.stack(([torch.tensor(s[0]) for s in state_transitions])).to(self.device)
        actions = torch.stack(([torch.tensor(s[1]) for s in state_transitions]))
        rewards = torch.stack(([torch.tensor([s[2]]) for s in state_transitions])).to(self.device)
        next_states = torch.stack(([torch.tensor(s[3]) for s in state_transitions])).to(self.device)
        mask = torch.stack(([torch.tensor([0]) if s[4] else torch.tensor([1]) for s in state_transitions])).to(self.device)

        with torch.no_grad():
            qvals_next = self.target(next_states.float()).max(-1)[0]

        self.model.optimizer.zero_grad
        q_vals = self.model(cur_states.float())
        one_hot_actions = F.one_hot(torch.LongTensor(actions), self.num_actions).to(self.device)
        loss = ((rewards + mask[:,0]*qvals_next - torch.sum(q_vals*one_hot_actions, -1))**2).mean()

        loss.backward()
        self.model.optimizer.step()
        return loss
    
    def epsilon_decay(self):
        self.epsilon = 1 - (1-self.epsilon_min)*(self.episode/self.num_episodes)
        # self.epsilon = max(self.epsilon*self.epsilon_decay_factor, self.epsilon_min)
        
    def preprocess(self, state, divide=16, logarithmic = False):
        if logarithmic:
            state = np.log2(1 + state) / divide
            return state.reshape(-1)
        else:
            return state
    
    def run_episode(self, test = False):
        self.episode += 1
        self.epsilon_decay()
        last_observation = self.preprocess(self.env.reset())
        done = False
        rolling_reward = 0
        
        epsilon_initial = self.epsilon
        if test:
            self.epsilon = 0
        
        while not done:
            self.step_num += 1
            action = self.select_action()
            observation, reward, done, info = self.env.step(action)
            observation = self.preprocess(observation)
            rolling_reward += reward
            
            self.rb.insert((last_observation, action, reward, observation, done))

        self.epsilon = epsilon_intitial
        if test:
            return rolling_reward
        else:
            self.rewards.append(rolling_reward)
            rolling_reward = 0
            
