In [12]:
import torch
import gymnasium as gym
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter

In [13]:
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

In [14]:
env = gym.make('CartPole-v1', render_mode="human")

In [15]:
class Model(torch.nn.Module):
    
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.input_layer = torch.nn.Linear(4,128)
        self.hidden_layer_1 = torch.nn.Linear(128,128)
        self.output_layer = torch.nn.Linear(128,2)
    
    def forward(self,x):
        output = self.input_layer(x)
        output = F.leaky_relu_(output)
        output = self.hidden_layer_1(output)
        output = F.leaky_relu_(output)
        output = self.output_layer(output)
        return output      
    

In [16]:
class Memory(Dataset):
    
    def __init__(self,batch_size:int=32,memory_len:int=100000, *args, **kwargs):
        self.__memory = []
        self.__memory_len = memory_len
        self.__batch_size = batch_size
    
    def __iter__(self):
        return self

    def __next__(self):
        memory_as_tensor = torch.tensor(self.__memory,device=device,dtype=torch.float32)
        indexes = torch.randint(low=0,high=len(self.__memory),size=(min(self.__batch_size,len(self.__memory)),))
        return memory_as_tensor[indexes]
        
    def push(self,env_output):
        self.__memory.append([env_output])
        if len(self.__memory) > self.__memory_len:
            del self.__memory[0]
    
    def __bool__(self):
        return len(self.__memory) > self.__batch_size    
    

In [17]:
def choose_action(q_values:torch.Tensor):
    if torch.rand(size=(1,)) <0.2:
        return torch.randint(2,size=(1,)).cpu().item()
    return torch.argmax(q_values.clone().detach().view(-1).cpu(),dim=0).cpu().item()

In [18]:
def target_action(q_values:torch.Tensor):
    return torch.max(q_values.clone().detach(),dim=0).values.cpu().item()

In [19]:
criterion = torch.nn.MSELoss()
target_network = Model().to(device)
policy_network = Model().to(device)
target_network.load_state_dict(policy_network.state_dict())
optimizer = torch.optim.AdamW(policy_network.parameters(),lr=1e-4)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer=optimizer,end_factor=1e-8,total_iters=500)
memory = Memory()
target_network = target_network.eval()

In [20]:
writer = SummaryWriter()

In [21]:
steps = 0
for epoch in tqdm(range(1000)):
    done = False
    past_observation, info = env.reset()
    ep_return = 0
    while not done:
        with torch.no_grad():
            action = choose_action(policy_network(torch.tensor(past_observation,device=device,dtype=torch.float32)))
            observation, reward, terminated, truncated, info = env.step(action)
            done = truncated or terminated
            memory.push(env_output=[*observation,*past_observation,int(action), reward,terminated])
            past_observation = observation
        if memory:
            for i, batch in enumerate(memory):
                batch = batch.squeeze()
                action_indexes = (F.one_hot(batch[:,8].view(-1,1).to(torch.int64),num_classes=2)==1).view(-1,2)
                current_state_outputs = policy_network(batch[:,4:8])[action_indexes]
                with torch.no_grad():
                    next_state_outputs = target_network(batch[:,:4])
                    next_state_outputs = torch.max(next_state_outputs,dim=1).values
                    next_state_outputs = batch[:,9]+ ( (~batch[:,10].to(torch.bool))*0.99*next_state_outputs)
                    next_state_outputs = (next_state_outputs +1) / 2
                optimizer.zero_grad()
                loss = criterion(current_state_outputs.to(device).squeeze(),next_state_outputs.to(device).squeeze())
                writer.add_scalar('loss',loss.cpu().item(),steps)
                writer.add_scalar('reward',ep_return,steps)
                steps += 1
                loss.backward()
                optimizer.step()
                break
        ep_return += reward        
    if (epoch+1)%10 == 0:
        target_network.load_state_dict(policy_network.state_dict())

  0%|          | 0/2000 [00:00<?, ?it/s]

In [None]:
past_observation, info = env.reset()
for _ in range(1000):
    with torch.no_grad():
        q_values = policy_network(torch.tensor(past_observation,device=device,dtype=torch.float32))
        action = torch.argmax(q_values,dim=0).cpu().item()  # agent policy that uses the observation and info
        env_output = env.step(action)
        observation, reward, terminated, truncated, info = env_output
        if not(terminated or truncated):
            past_observation = observation
env.close()

  logger.warn(
