In [None]:
import os
import torch
import warnings
import tqdm
import pandas as pd
import numpy as np
import gym
import torch.nn as nn
from gym import spaces
from torch.utils.data import Dataset, DataLoader
from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy
from transformers import BertModel
from stable_baselines3 import PPO

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
torch.cuda.set_per_process_memory_fraction(0.75, device = 0)
warnings.filterwarnings('ignore')
torch.manual_seed(52)
np.random.seed(52)

In [None]:
class TransformerFeatureExtractor(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(TransformerFeatureExtractor, self).__init__()
        self.transformer = BertModel.from_pretrained("bert-base-uncased")  
        self.fc = nn.Linear(hidden_dim, hidden_dim)  

    def forward(self, x):
        x = self.transformer(x).last_hidden_state
        x = self.fc(x[:, -1, :])  
        return x

class CustomTransformerPolicy(ActorCriticPolicy):
    def __init__(self, *args, **kwargs):
        super(CustomTransformerPolicy, self).__init__(*args, **kwargs)
        self.feature_extractor = TransformerFeatureExtractor(self.features_dim, 128)
        self.fc = nn.Linear(128, self.features_dim)

    def forward(self, features):
        transformer_out = self.feature_extractor(features)
        return self.fc(transformer_out)

In [None]:
class TimeSeriesEnv(gym.Env):
    def __init__(self, data, labels=None):
        super(TimeSeriesEnv, self).__init__()

        self.data = data
        self.labels = labels
        self.current_step = 0

        observation_space = spaces.Box(low = 0.0, high = 1.0, shape = (100, 1), dtype = np.float32)
        self.action_space = spaces.Discrete(3)  

    def reset(self):
        self.current_step = 0
        return self.data[self.current_step]

    def step(self, action):
        reward = self.get_reward(action)
        self.current_step += 1
        done = self.current_step >= len(self.data) - 1
        return self.data[self.current_step], reward, done, {}

    def get_reward(self, action):
        if self.labels is not None:
            true_label = self.labels[self.current_step]
            if action == true_label:
                return 1
            else:
                return -1     
        return 0