In [1]:
from gym_env import FintechAppDRLEnv
from gym_env.utils.scheduler import Scheduler,linear_decay,exponential_decay
import torch
import gymnasium as gym
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter

pygame 2.6.0 (SDL 2.28.4, Python 3.11.9)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)

In [3]:
class Model(torch.nn.Module):
    
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.embed_nodes_possible_locations = torch.nn.Embedding(128,256)
        self.embed_nodes_past_observations = torch.nn.Embedding(128,128)
        self.gru = torch.nn.GRU(5*128,256,num_layers=9,bidirectional=True,dropout=0.3)
        self.hidden_1 = torch.nn.Linear(768,1024)
        self.output_layer = torch.nn.Linear(1024,128)
        self.flatten = torch.nn.Flatten()
        self.layer_norm_start_1 = torch.nn.LayerNorm(5*128)
        self.layer_norm_start_2 = torch.nn.LayerNorm(256)
        self.layer_norm_final = torch.nn.LayerNorm(768)
        self.dropout = torch.nn.Dropout(p=0.3)
        self.cnn_1_7 = torch.nn.Conv1d(10,64,kernel_size=7,stride=3)
        self.cnn_1_3 = torch.nn.Conv1d(64,64,kernel_size=3)
        self.cnn_2_3 = torch.nn.Conv1d(64,128,kernel_size=3)
        self.max_pool_3 = torch.nn.MaxPool1d(3)
        self.max_pool_7 = torch.nn.MaxPool1d(7)

    def forward(self,past_observations:torch.Tensor,possible_locations:torch.Tensor):
        
        output_possible_locations = self.embed_nodes_possible_locations(possible_locations.to(torch.int32))
        output_possible_locations = self.layer_norm_start_2(output_possible_locations)
        output_possible_locations = self.cnn_1_7(output_possible_locations)
        output_possible_locations = self.max_pool_7(output_possible_locations)
        output_possible_locations = self.cnn_1_3(output_possible_locations)
        output_possible_locations = self.cnn_2_3(output_possible_locations)
        output_possible_locations = self.max_pool_3(output_possible_locations)
        output_possible_locations = self.flatten(output_possible_locations)
        #---------------------------------------------------------------
        output_past_observations = self.embed_nodes_past_observations(past_observations.to(torch.int32))
        output_past_observations = self.flatten(output_past_observations)
        output_past_observations = self.layer_norm_start_1(output_past_observations)
        output_past_observations,_ = self.gru(output_past_observations)
        #----------------------------------------------------------------
        output = torch.cat((output_past_observations,output_possible_locations),dim=1)
        output = self.layer_norm_final(output)
        output = self.dropout(output)
        #----------------------------------------------------------------
        output = self.hidden_1(output)
        output = F.leaky_relu(output)
        output = self.output_layer(output)
        return output

In [4]:
class Memory(Dataset):
    
    def __init__(self,batch_size:int=32,memory_len:int=15_000, *args, **kwargs):
        self.__false_memory = []
        self.__true_memory = []
        self.__memory_len = memory_len
        self.__batch_size = batch_size
    
    def __iter__(self):
        return self

    def __next__(self):
        false_indexes = torch.randint(low=0,high=len(self.__false_memory),size=(min(5*self.__batch_size//10,len(self.__false_memory),len(self.__true_memory)),))
        true_indexes = torch.randint(low=0,high=len(self.__true_memory),size=(min(5*self.__batch_size//10,len(self.__false_memory),len(self.__true_memory)),))
        collected_memory = self.__collect_memory(false_indexes,False) + self.__collect_memory(true_indexes,True)
        memory_as_tensor = torch.tensor(collected_memory,device=device,dtype=torch.float32)
        return memory_as_tensor
    
    def __collect_memory(self,indexes:list,collecting_type:bool)->list:
        collected_memory = []
        if collecting_type:
            memory = self.__true_memory
        else:
            memory = self.__false_memory
        for idx in indexes:
            collected_memory.append(memory[idx])
        return collected_memory

    def push(self,env_output,node_type:bool):
        if node_type:
            self.__true_memory.append([env_output])
            if len(self.__true_memory) > self.__memory_len:
                del self.__true_memory[0]
        else:
            self.__false_memory.append([env_output])
            if len(self.__false_memory) > self.__memory_len:
                del self.__false_memory[0]
    
    def __bool__(self):
        return len(self.__false_memory) > 1 and len(self.__true_memory) > 1

In [5]:
target_network = Model().to(device)
policy_network = Model().to(device)
target_network.load_state_dict(policy_network.state_dict())

optimizer = torch.optim.AdamW(policy_network.parameters(),1e-4)
criterion = torch.nn.MSELoss()
target_network = target_network.eval()
writer = SummaryWriter()

In [6]:
def choose_action(q_values:torch.Tensor,epsilon:Scheduler,possible_locations:list[int]):
    if np.random.random() < float(epsilon):
        return np.random.choice(possible_locations)
    calculated = torch.argmax(q_values.clone().detach().view(-1).cpu(),dim=0).cpu().item()
    return calculated

In [7]:
steps = 0
for seed in range(512):
    env = FintechAppDRLEnv(5)
    epsilon = Scheduler(0.1,linear_decay(100))
    memory = Memory(batch_size=32)
    for epoch in tqdm(range(300)):
        done = False
        past_observation,past_input_possible_locations = env.reset(seed=np.random.randint(low=0,high=8))
        ep_return = 0
        truncate = 0
        false_count = 0
        while not done:
            truncate += 1
            false_count += 1
            with torch.no_grad():
                q_values = policy_network(torch.tensor(past_observation,device=device,dtype=torch.float32).view(1,-1),torch.tensor(past_input_possible_locations,device=device,dtype=torch.float32).view(1,-1))
                possible_locations = env.get_possible_locations()
                action = choose_action(q_values,epsilon,possible_locations)
                observation, reward, terminated, node_type = env.step(action)
                input_possible_locations = env.possible_locations_input()
                done = terminated
                if reward == 0:
                    memory.push(env_output=[*observation,*past_observation,*past_input_possible_locations,*input_possible_locations,int(action), reward,terminated],node_type=False)
                else:
                    memory.push(env_output=[*observation,*past_observation,*past_input_possible_locations,*input_possible_locations,int(action), reward,terminated],node_type=node_type)
                past_observation = observation
                past_input_possible_locations = input_possible_locations
            if memory:
                for i, batch in enumerate(memory):
                    batch = batch.squeeze()
                    action_indexes = (F.one_hot(batch[:,30].view(-1,1).to(torch.int64),num_classes=128)==1).view(-1,128)
                    current_state_outputs = policy_network(batch[:,5:10],batch[:,10:20])[action_indexes]
                    with torch.no_grad():
                        next_state_outputs = target_network(batch[:,:5],batch[:,20:30])
                        next_state_outputs = torch.max(next_state_outputs,dim=1).values
                        next_state_outputs = batch[:,31]+ ( ~(batch[:,32].to(torch.bool))*0.99*next_state_outputs)
                        next_state_outputs = (next_state_outputs +1) / 2
                    optimizer.zero_grad()
                    loss = criterion(current_state_outputs.to(device).squeeze(),next_state_outputs.to(device).squeeze())
                    writer.add_scalar('loss',loss.cpu().item(),steps)
                    writer.add_scalar('reward',ep_return,steps)
                    steps += 1
                    loss.backward()
                    optimizer.step()
                    break
            ep_return += reward
        if epoch < 100:
            epsilon.step()        
        if (epoch+1)%10 == 0:
            target_network.load_state_dict(policy_network.state_dict())

  0%|          | 0/300 [00:00<?, ?it/s]

finished with FALSE
finish count: 0
--------------------
finished with FALSE
finish count: 1
--------------------
finished with TRUE
finish count: 2
--------------------
finished with FALSE
finish count: 3
--------------------
finished with FALSE
finish count: 4
--------------------
finished with TRUE
finish count: 5
--------------------
finished with FALSE
finish count: 6
--------------------
finished with FALSE
finish count: 7
--------------------
finished with TRUE
finish count: 8
--------------------
finished with FALSE
finish count: 9
--------------------
finished with TRUE
finish count: 10
--------------------
finished with FALSE
finish count: 11
--------------------
finished with TRUE
finish count: 12
--------------------
finished with FALSE
finish count: 13
--------------------
finished with FALSE
finish count: 14
--------------------
finished with TRUE
finish count: 15
--------------------
finished with FALSE
finish count: 16
--------------------
finished with TRUE
finish coun

  0%|          | 0/300 [00:00<?, ?it/s]

finished with TRUE
finish count: 0
--------------------
finished with TRUE
finish count: 1
--------------------
finished with TRUE
finish count: 2
--------------------
finished with TRUE
finish count: 3
--------------------
finished with TRUE
finish count: 4
--------------------
finished with TRUE
finish count: 5
--------------------
finished with TRUE
finish count: 6
--------------------
finished with TRUE
finish count: 7
--------------------
finished with TRUE
finish count: 8
--------------------
finished with TRUE
finish count: 9
--------------------
finished with TRUE
finish count: 10
--------------------
finished with TRUE
finish count: 11
--------------------
finished with TRUE
finish count: 12
--------------------
finished with TRUE
finish count: 13
--------------------
finished with TRUE
finish count: 14
--------------------
finished with FALSE
finish count: 15
--------------------
finished with TRUE
finish count: 16
--------------------
finished with TRUE
finish count: 17
----

  0%|          | 0/300 [00:00<?, ?it/s]

finished with FALSE
finish count: 0
--------------------
finished with TRUE
finish count: 1
--------------------
finished with TRUE
finish count: 2
--------------------
finished with TRUE
finish count: 3
--------------------
finished with TRUE
finish count: 4
--------------------
finished with TRUE
finish count: 5
--------------------
finished with TRUE
finish count: 6
--------------------
finished with TRUE
finish count: 7
--------------------
finished with TRUE
finish count: 8
--------------------
finished with TRUE
finish count: 9
--------------------
finished with TRUE
finish count: 10
--------------------
finished with TRUE
finish count: 11
--------------------
finished with TRUE
finish count: 12
--------------------
finished with TRUE
finish count: 13
--------------------
finished with TRUE
finish count: 14
--------------------
finished with TRUE
finish count: 15
--------------------
finished with TRUE
finish count: 16
--------------------
finished with TRUE
finish count: 17
----

KeyboardInterrupt: 