In [1]:
from gym_env import FintechAppDRLEnv
from gym_env.utils.scheduler import Scheduler,linear_decay
import torch
import gymnasium as gym
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter

pygame 2.6.0 (SDL 2.28.4, Python 3.11.9)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

In [3]:
class Model(torch.nn.Module):
    
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
    
        self.input_layer = torch.nn.Linear(4,512)
        self.hidden_1 = torch.nn.Linear(512,512)
        self.hidden_2 = torch.nn.Linear(512,512)
        self.output_layer = torch.nn.Linear(512,137)
        
    def forward(self,input):
        output = self.input_layer(input)
        output = F.leaky_relu_(output)
        output = self.hidden_1(output)
        output = F.leaky_relu_(output)
        output = self.hidden_2(output)
        output = F.leaky_relu_(output)
        output = self.output_layer(output)
        return output

In [4]:
class Memory(Dataset):
    
    def __init__(self,batch_size:int=32,memory_len:int=250_000, *args, **kwargs):
        self.__memory = []
        self.__memory_len = memory_len
        self.__batch_size = batch_size
    
    def __iter__(self):
        return self

    def __next__(self):
        memory_as_tensor = torch.tensor(self.__memory,device=device,dtype=torch.float32)
        indexes = torch.randint(low=0,high=len(self.__memory),size=(min(self.__batch_size,len(self.__memory)),))
        return memory_as_tensor[indexes]
        
    def push(self,env_output):
        self.__memory.append([env_output])
        if len(self.__memory) > self.__memory_len:
            del self.__memory[0]
    
    def __bool__(self):
        return len(self.__memory) > self.__batch_size

In [5]:
env = FintechAppDRLEnv(4)
epsilon = Scheduler(0.1,linear_decay(500))
target_network = Model().to(device)
policy_network = Model().to(device)
target_network.load_state_dict(policy_network.state_dict())
memory = Memory()
optimizer = torch.optim.AdamW(policy_network.parameters(),1e-4)
criterion = torch.nn.MSELoss()
target_network = target_network.eval()
writer = SummaryWriter()

In [6]:
def choose_action(q_values:torch.Tensor,epsilon:Scheduler):
    if torch.rand(size=(1,)) <float(epsilon):
        return torch.randint(high=137,size=(1,)).cpu().item()
    return torch.argmax(q_values.clone().detach().view(-1).cpu(),dim=0).cpu().item()

In [7]:
true_node = None
for node in env._env_map.values():
    if node.color == "green":
        true_node = node
        break

In [8]:
for node in true_node.return_children():
    if node.color == "green":
        true_node = node

In [9]:
print(true_node)

joinscreen


In [10]:
steps = 0
for epoch in tqdm(range(500)):
    done = False
    past_observation = env.reset(seed=7)
    ep_return = 0
    i = 0
    while not done:
        i += 1
        with torch.no_grad():
            action = choose_action(policy_network(torch.tensor(past_observation,device=device,dtype=torch.float32)),epsilon)
            observation, reward, terminated = env.step(action)
            done = terminated
            memory.push(env_output=[*observation,*past_observation,int(action), reward,terminated])
            past_observation = observation
        if memory:
            for i, batch in enumerate(memory):
                batch = batch.squeeze()
                action_indexes = (F.one_hot(batch[:,8].view(-1,1).to(torch.int64),num_classes=137)==1).view(-1,137)
                current_state_outputs = policy_network(batch[:,4:8])[action_indexes]
                with torch.no_grad():
                    next_state_outputs = target_network(batch[:,:4])
                    next_state_outputs = torch.max(next_state_outputs,dim=1).values
                    next_state_outputs = batch[:,9]+ ( ~(batch[:,10].to(torch.bool))*0.99*next_state_outputs)
                    next_state_outputs = (next_state_outputs +1) / 2
                optimizer.zero_grad()
                loss = criterion(current_state_outputs.to(device).squeeze(),next_state_outputs.to(device).squeeze())
                writer.add_scalar('loss',loss.cpu().item(),steps)
                writer.add_scalar('reward',ep_return,steps)
                steps += 1
                loss.backward()
                optimizer.step()
                break
        if i > 75:
            break
        ep_return += reward
    epsilon.step()        
    if (epoch+1)%10 == 0:
        target_network.load_state_dict(policy_network.state_dict())

  0%|          | 0/500 [00:00<?, ?it/s]

finished with FALSE
true finished count: 0
--------------------
