In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import deque
from scipy.stats import entropy
from environments import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class LSTMDQN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(LSTMDQN, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        lstm_out, _ = self.lstm(x)
        return self.fc(lstm_out[:, -1, :])  # Take the last timestep output

class DQNAgent:
    def __init__(self, state_size, action_size, hidden_dim=64, lr=0.001, gamma=0.99, batch_size=32, memory_size=10000):
        self.state_size = state_size
        self.action_size = action_size
        self.gamma = gamma
        self.lr = lr
        self.batch_size = batch_size
        self.memory = deque(maxlen=memory_size)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.model = LSTMDQN(state_size, hidden_dim, action_size).to(self.device)
        self.target_model = LSTMDQN(state_size, hidden_dim, action_size).to(self.device)
        self.target_model.load_state_dict(self.model.state_dict())
        
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.loss_fn = nn.MSELoss()
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
    
    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))
    
    def act(self, state):
        if np.random.rand() <= self.epsilon:
            return random.choice([0, 1])
        state = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0).to(self.device)
        with torch.no_grad():
            action_values = self.model(state)
        return torch.argmax(action_values).item()
    
    def replay(self):
        if len(self.memory) < self.batch_size:
            return
        batch = random.sample(self.memory, self.batch_size)
        
        states, actions, rewards, next_states, dones = zip(*batch)
        states = torch.FloatTensor(states).unsqueeze(2).to(self.device)
        next_states = torch.FloatTensor(next_states).unsqueeze(2).to(self.device)
        actions = torch.LongTensor(actions).unsqueeze(1).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)
        
        q_values = self.model(states).gather(1, actions).squeeze()
        with torch.no_grad():
            next_q_values = self.target_model(next_states).max(1)[0]
            target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
        
        loss = self.loss_fn(q_values, target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
    
    def update_target_model(self):
        self.target_model.load_state_dict(self.model.state_dict())

# Training the DQN with LSTM on SBEOS_Environment

def train_dqn(env,episodes=500):
    #env = SBEOS_Environment()
    state_size = 1  # Since observation is entropy value
    action_size = 2  # Binary action space {0,1}
    agent = DQNAgent(state_size, action_size)
    
    for ep in range(episodes):
        state = env.reset()
        total_reward = 0
        done = False
        
        while not done:
            action = agent.act([state])
            next_state, reward, done, _ = env.step(action)
            agent.remember([state], action, reward, [next_state], done)
            state = next_state
            total_reward += reward
            agent.replay()
        
        agent.update_target_model()
        print(f"Episode {ep+1}/{episodes}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.4f}")


In [None]:
env = SBEOS_Environment()
#train_dqn(env)

In [None]:
env = SBEDS_Environment()
train_dqn(env)

Episode 1/500, Total Reward: 940, Epsilon: 0.4738
Episode 2/500, Total Reward: 1210, Epsilon: 0.1922
Episode 3/500, Total Reward: 1235, Epsilon: 0.0780
Episode 4/500, Total Reward: 1160, Epsilon: 0.0316
Episode 5/500, Total Reward: 1280, Epsilon: 0.0128
Episode 6/500, Total Reward: 1290, Epsilon: 0.0100
Episode 7/500, Total Reward: 1190, Epsilon: 0.0100
Episode 8/500, Total Reward: 1270, Epsilon: 0.0100
Episode 9/500, Total Reward: 1290, Epsilon: 0.0100
Episode 10/500, Total Reward: 1295, Epsilon: 0.0100
Episode 11/500, Total Reward: 1195, Epsilon: 0.0100
Episode 12/500, Total Reward: 1250, Epsilon: 0.0100
Episode 13/500, Total Reward: 1255, Epsilon: 0.0100
Episode 14/500, Total Reward: 1315, Epsilon: 0.0100
Episode 15/500, Total Reward: 1250, Epsilon: 0.0100
Episode 16/500, Total Reward: 1265, Epsilon: 0.0100
Episode 17/500, Total Reward: 1255, Epsilon: 0.0100
Episode 18/500, Total Reward: 1315, Epsilon: 0.0100
Episode 19/500, Total Reward: 1320, Epsilon: 0.0100
Episode 20/500, Total 