In [None]:
from tqdm import tqdm
import gym
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Any
from random import sample, random
import torch.optim as optim
import wandb
from collections import deque
import numpy as np
# import torchsnooper

@dataclass
class Sarsd:
    state: Any 
    action: int
    reward: float
    next_state: Any
    done: bool  # 
        
class DQN_Agent:
    def __init__(self, model):
        self.model = model
    
    def get_actions(self, observations):
        # observations shape is (N, 4) (x, x', the, omega)
        q_values = self.model(observations)
        
        # q_values shape(N, 2) (left, right)?
         
        return q_values.max(-1)[1]  # 
    
class Model(nn.Module):
    def __init__(self, obs_shape, num_actions):
        super(Model, self).__init__()
        assert len(obs_shape) == 1, "This network only works for flat observations"
        self.obs_shape = obs_shape
        self.num_actions = num_actions
        self.net = torch.nn.Sequential(   #
            torch.nn.Linear(obs_shape[0], 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, num_actions),
        )
        self.opt = optim.Adam(self.net.parameters(), lr = 0.0001)
        
    def forward(self, x):
        return self.net(x)        #在train_step里是正经的forward pass--系统了解forward pass!

    

class ReplayBuffer:
    def __init__(self, buffer_size = 100000):  
        self.buffer_size = buffer_size
        self.buffer = deque(maxlen = buffer_size)
        
    def insert(self, sarsd):
        self.buffer.append(sarsd)
#         self.buffer = self.buffer[-self.buffer_size:] #   
    
    def sample(self, num_samples):
        assert num_samples <= len(self.buffer)
        return sample(self.buffer, num_samples)
    
    
def update_tgt_model(m, tgt):
    tgt.load_state_dict(m.state_dict())
    
    
#@torchsnooper.snoop()
def train_step(model, state_transitions, tgt, num_actions):        # get the state vector here ********
    cur_states = torch.stack([torch.Tensor(s.state) for s in state_transitions])
    rewards = torch.stack([torch.Tensor([s.reward]) for s in state_transitions])
    mask = torch.stack(([torch.tensor([0], dtype = torch.float32) if s.done else torch.Tensor([1] ) for s in state_transitions]))
    next_states = torch.stack([torch.Tensor(s.next_state) for s in state_transitions])
    actions = [s.action for s in state_transitions]

    with torch.no_grad():
        qvals_next = tgt(next_states).max(-1)[0]
        #print(qvals_next.shape)
    model.opt.zero_grad()
    qvals = model(cur_states) # (N, num_actions)
#     print(qvals.shape()) 2500
#     print(actions)
    one_hot_actions = F.one_hot(torch.LongTensor(actions), num_actions)
    print(one_hot_actions.size)
#     print(one_hot_actions)
#     import ipdb; ipdb.set_trace() 
#     print(type(rewards), type(mask), type(qvals_next), type(qvals), type(one_hot_actions))
    x = qvals * one_hot_actions.to(torch.float32)
    print(x)
    loss = ((rewards + mask[:, 0] * qvals_next - torch.sum(x, -1))**2).mean()
    loss.backward()
    model.opt.step()
    return loss

    
def main(test = False, chkpt = None):
    if not test:
        wandb.init(project = "dqn-tutorial", name = "dqn-cartpole")
    min_rb_size = 10000
    sample_size = 2500
    
#     eps_max = 1.0
    exp_min = 0.01
    
    eps_decay = 0.99998
    
    env_steps_before_train = 100  # every 100 steps, train 2500 samples;
    tgt_model_update = 150 # epochs
    
    env = gym.make("CartPole-v1")
    last_observation = env.reset()
    
    m = Model(env.observation_space.shape, env.action_space.n)
    if chkpt is not None:
        m.load_state_dict(torch.load(chkpt))
    tgt = Model(env.observation_space.shape, env.action_space.n)
    update_tgt_model(m, tgt)
    
    rb = ReplayBuffer()
    steps_since_former_train = 0
    epochs_since_tgt = 0
    
    step_num = -1 * min_rb_size
    #qvals = m(torch.Tensor(observation))
    
    episode_rewards = []
    rolling_reward = 0

    tq = tqdm()
    try:
        while True:
            if test:
                env.render()
                time.sleep(0.05)
            tq.update(1)
            
            eps = eps_decay**(step_num)
            if test:
                eps = 0
                
            if random() < eps:
                
                action = env.action_space.sample() # Your agent here takes random actions
#                 print(action)
            else: 
                action = m(torch.Tensor(last_observation)).max(-1)[1].item()
                
            observation, reward, done, info = env.step(action)
            rolling_reward += reward
    #    env.render()
    #    time.sleep(0.1)  
    #    action = env.action_space.sample() # your agent here (this takes random actions)
    #    observation, reward, done, info = env.step(action)
    #
    
            reward = reward / 100.0 # Normalization
        
            rb.insert(Sarsd(last_observation, action, reward, observation, done))
#             print(rb.buffer[10].done)
#             import ipdb; ipdb_set_trace()
            last_observation = observation
            
            if done:
                episode_rewards.append(rolling_reward)
                if test:
                    print(rolling_reward)
                rolling_reward = 0
                observation = env.reset()
                
            steps_since_former_train += 1
            step_num += 1
            
            if not test and len(rb.buffer) > min_rb_size and steps_since_former_train > env_steps_before_train:         #不立马update, 等一会才
                loss = train_step(m, rb.sample(sample_size), tgt, env.action_space.n)
                wandb.log({'loss': loss.detach().item(), 'eps': eps, 'avg_reward': np.mean(episode_rewards), \
                          }, step = step_num)
#                 print(step_num, loss.detach().item())
                episode_rewards = []
                epochs_since_tgt += 1
                if epochs_since_tgt > tgt_model_update:
                    print("updating target model")
                    update_tgt_model(m, tgt)
                    epochs_since_tgt = 0
                    torch.save(tgt.state_dict(), f"D:/college/machine_learning/Jack of Some's cartpole tut/models_cartpole/ \
                    {step_num}.pth")
                    
                steps_since_former_train = 0
                
#                 import ipdb; ipdb.set_trace()
#                 print(loss)
#                 raise Exception()
    except KeyboardInterrupt:
        pass
    
    env.close()


if __name__ == '__main__':
    main(True, )#"D:/college/machine_learning/Jack of Some's cartpole tut/models_cartpole/                     228665.pth")