In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import deque
from scipy.stats import entropy
from environments import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
class DQN(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# Replay Memory
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
        
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
        
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)
        return (np.array(state), np.array(action), np.array(reward), 
                np.array(next_state), np.array(done))
    
    def __len__(self):
        return len(self.buffer)

# Epsilon-greedy action selection
def select_action(state, policy_net, epsilon, action_dim):
    if random.random() < epsilon:
        return random.randint(0, action_dim - 1)
    else:
        with torch.no_grad():
            state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
            return policy_net(state).argmax(dim=1).item()

# Training the DQN

def train_dqn(env, episodes=1000, gamma=0.99, batch_size=32, 
              lr=0.001, epsilon_start=1.0, epsilon_end=0.01, 
              epsilon_decay=0.995, memory_size=10000, target_update=10):
    
    state_dim = 1  # Since observation is entropy (scalar value)
    action_dim = 2  # Actions are binary (0 or 1)
    
    policy_net = DQN(state_dim, action_dim).to(device)
    target_net = DQN(state_dim, action_dim).to(device)
    target_net.load_state_dict(policy_net.state_dict())
    target_net.eval()
    
    optimizer = optim.Adam(policy_net.parameters(), lr=lr)
    memory = ReplayBuffer(memory_size)
    epsilon = epsilon_start
    
    for episode in range(episodes):
        state = env.reset()
        done = False
        episode_reward = 0
        
        while not done:
            action = select_action([state], policy_net, epsilon, action_dim)
            next_state, reward, done, info = env.step(action)
            memory.push([state], action, reward, [next_state], done)
            state = next_state
            episode_reward += reward
            
            if len(memory) > batch_size:
                states, actions, rewards, next_states, dones = memory.sample(batch_size)
                states = torch.tensor(states, dtype=torch.float32).to(device)
                actions = torch.tensor(actions, dtype=torch.int64).unsqueeze(1).to(device)
                rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
                next_states = torch.tensor(next_states, dtype=torch.float32).to(device)
                dones = torch.tensor(dones, dtype=torch.float32).to(device)
                
                q_values = policy_net(states).gather(1, actions).squeeze()
                next_q_values = target_net(next_states).max(1)[0].detach()
                target_q_values = rewards + (gamma * next_q_values * (1 - dones))
                
                loss = nn.functional.mse_loss(q_values, target_q_values)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        
        if episode % target_update == 0:
            target_net.load_state_dict(policy_net.state_dict())
        
        epsilon = max(epsilon_end, epsilon * epsilon_decay)
        
        if episode % 100 == 0:
            print(f"Episode {episode}, Reward: {episode_reward}, Epsilon: {epsilon:.3f}")
    
    return policy_net


In [8]:
# env = SBEOS_Environment()
# trained_model = train_dqn(env, episodes=500)

In [None]:
env = SBEDS_Environment()
trained_model = train_dqn(env, episodes=500)

Episode 0, Reward: 670, Epsilon: 0.995
