## Offline Deep Q-Network

Resources used while writing this notebook:
- [d3rlpy library](https://github.com/takuseno/d3rlpy/blob/master/d3rlpy/algos/qlearning/dqn.py)

Check my [Kaggle notebook](https://www.kaggle.com/code/aryamanbansal/offline-dqn) for the outputs.

In [None]:
import torch
import numpy as np
from collections import namedtuple
import random
from typing import Tuple, Optional

import torch.nn as nn
import torch.optim as optim

# Define transition tuple for experience replay
Transition = namedtuple('Transition', ('state', 'action', 'reward', 'next_state', 'done'))

class DQNNetwork(nn.Module):
    """Deep Q-Network neural network"""
    
    def __init__(self, state_dim: int, action_dim: int, hidden_dim: int = 256):
        super(DQNNetwork, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )
    
    def forward(self, x):
        return self.network(x)

class ReplayBuffer:
    """Experience replay buffer for storing transitions"""
    
    def __init__(self, capacity: int):
        self.capacity = capacity
        self.buffer = []
        self.position = 0
    
    def push(self, state, action, reward, next_state, done):
        """Save a transition"""
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = Transition(state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity
    
    def sample(self, batch_size: int):
        """Sample a batch of transitions"""
        return random.sample(self.buffer, batch_size)
    
    def __len__(self):
        return len(self.buffer)

class OfflineDQN:
    """Offline Deep Q-Network implementation"""
    
    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        lr: float = 1e-3,
        gamma: float = 0.99,
        hidden_dim: int = 256,
        device: str = 'cpu'
    ):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma
        self.device = device
        
        # Initialize Q-networks
        self.q_network = DQNNetwork(state_dim, action_dim, hidden_dim).to(device)
        self.target_network = DQNNetwork(state_dim, action_dim, hidden_dim).to(device)
        
        # Copy weights to target network
        self.target_network.load_state_dict(self.q_network.state_dict())
        
        # Optimizer
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
        
        # Loss function
        self.criterion = nn.MSELoss()
    
    def update_target_network(self):
        """Update target network with current Q-network weights"""
        self.target_network.load_state_dict(self.q_network.state_dict())
    
    def train_step(self, batch: list) -> float:
        """Perform one training step"""
        if len(batch) == 0:
            return 0.0
        
        # Convert batch to tensors
        states = torch.FloatTensor([t.state for t in batch]).to(self.device)
        actions = torch.LongTensor([t.action for t in batch]).to(self.device)
        rewards = torch.FloatTensor([t.reward for t in batch]).to(self.device)
        next_states = torch.FloatTensor([t.next_state for t in batch]).to(self.device)
        dones = torch.BoolTensor([t.done for t in batch]).to(self.device)
        
        # Current Q values
        current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1))
        
        # Next Q values from target network
        with torch.no_grad():
            next_q_values = self.target_network(next_states).max(1)[0]
            target_q_values = rewards + (self.gamma * next_q_values * ~dones)
        
        # Compute loss
        loss = self.criterion(current_q_values.squeeze(), target_q_values)
        
        # Optimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()
    
    def get_action(self, state, epsilon: float = 0.0):
        """Get action using epsilon-greedy policy"""
        if random.random() < epsilon:
            return random.randint(0, self.action_dim - 1)
        
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            q_values = self.q_network(state_tensor)
            return q_values.argmax().item()
    
    def train_offline(
        self,
        dataset: list,
        batch_size: int = 64,
        num_epochs: int = 100,
        target_update_freq: int = 10
    ):
        """Train the DQN on offline dataset"""
        losses = []
        
        for epoch in range(num_epochs):
            # Shuffle dataset
            random.shuffle(dataset)
            
            epoch_losses = []
            for i in range(0, len(dataset), batch_size):
                batch = dataset[i:i + batch_size]
                loss = self.train_step(batch)
                epoch_losses.append(loss)
            
            # Update target network
            if epoch % target_update_freq == 0:
                self.update_target_network()
            
            avg_loss = np.mean(epoch_losses) if epoch_losses else 0
            losses.append(avg_loss)
            
            if epoch % 10 == 0:
                print(f"Epoch {epoch}, Average Loss: {avg_loss:.4f}")
        
        return losses
    
    def save_model(self, path: str):
        """Save the trained model"""
        torch.save({
            'q_network_state_dict': self.q_network.state_dict(),
            'target_network_state_dict': self.target_network.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
        }, path)
    
    def load_model(self, path: str):
        """Load a trained model"""
        checkpoint = torch.load(path)
        self.q_network.load_state_dict(checkpoint['q_network_state_dict'])
        self.target_network.load_state_dict(checkpoint['target_network_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
import d3rlpy
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt

# Load CartPole dataset from d3rlpy
dataset, env = d3rlpy.datasets.get_cartpole()

print(f"Dataset size: {len(dataset)}")
print(f"State dimension: {dataset.get_observation_shape()[0]}")
print(f"Action dimension: {dataset.get_action_size()}")

# Convert d3rlpy dataset to our format
def convert_dataset_to_transitions(dataset):
    transitions = []
    for episode in dataset.episodes:
        for i in range(len(episode.observations) - 1):
            state = episode.observations[i]
            action = episode.actions[i]
            reward = episode.rewards[i]
            next_state = episode.observations[i + 1]
            done = (i == len(episode.observations) - 2)
            transitions.append(Transition(state, action, reward, next_state, done))
    return transitions

offline_data = convert_dataset_to_transitions(dataset)
print(f"Converted {len(offline_data)} transitions")

# Split dataset for evaluation
train_data, eval_data = train_test_split(offline_data, test_size=0.2, random_state=42)

# Initialize our Offline DQN
state_dim = dataset.get_observation_shape()[0]
action_dim = dataset.get_action_size()
our_dqn = OfflineDQN(
    state_dim=state_dim,
    action_dim=action_dim,
    lr=1e-3,
    gamma=0.99,
    device='cuda' if torch.cuda.is_available() else 'cpu'
)

# Train our implementation
print("Training our Offline DQN...")
our_losses = our_dqn.train_offline(
    dataset=train_data,
    batch_size=64,
    num_epochs=100,
    target_update_freq=10
)

# Initialize d3rlpy DQN
print("\nTraining d3rlpy DQN...")
d3rlpy_dqn = d3rlpy.algos.DQNConfig().create(device='cuda:0' if torch.cuda.is_available() else 'cpu:0')

# Train d3rlpy implementation
d3rlpy_dqn.fit(dataset, n_steps=10000, show_progress=True)

# Evaluate both models using FQE
print("\nEvaluating models using FQE...")

# Create FQE evaluator
fqe = d3rlpy.algos.FQEConfig().create(device='cuda:0' if torch.cuda.is_available() else 'cpu:0')

# Evaluate d3rlpy DQN
d3rlpy_value = fqe.fit(dataset, d3rlpy_dqn, n_steps=5000, show_progress=True)

# For our implementation, we need to create a compatible policy
class OurDQNPolicy:
    def __init__(self, dqn_model):
        self.dqn_model = dqn_model
    
    def predict(self, x):
        actions = []
        for state in x:
            action = self.dqn_model.get_action(state, epsilon=0.0)
            actions.append(action)
        return np.array(actions)

our_policy = OurDQNPolicy(our_dqn)

# Create a mock algorithm object for FQE evaluation
class MockAlgorithm:
    def __init__(self, policy):
        self.policy = policy
    
    def predict(self, x):
        return self.policy.predict(x)

mock_algo = MockAlgorithm(our_policy)

# Evaluate our implementation
our_value = fqe.fit(dataset, mock_algo, n_steps=5000, show_progress=True)

# Plot training losses for our implementation
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(our_losses)
plt.title('Our Offline DQN Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)

plt.subplot(1, 2, 2)
plt.bar(['Our DQN', 'd3rlpy DQN'], [our_value, d3rlpy_value])
plt.title('FQE Evaluation Results')
plt.ylabel('Estimated Value')
plt.grid(True, axis='y')

plt.tight_layout()
plt.show()

print(f"\nFinal Results:")
print(f"Our Offline DQN FQE Value: {our_value:.4f}")
print(f"d3rlpy DQN FQE Value: {d3rlpy_value:.4f}")
print(f"Difference: {abs(our_value - d3rlpy_value):.4f}")

hello world


In [None]:
# Assuming you have the indices of patients with different actions
# Replace 'different_action_indices' with your actual variable name containing the indices
# and 'healthcare_data' with your actual dataframe name

# If you don't have these variables yet, here's how you might create them:
# different_action_indices = [list of patient indices where physician != RL agent actions]
# healthcare_data = [your dataframe with 173 columns/features]

# Create a large figure with subplots for all 173 features
fig, axes = plt.subplots(nrows=14, ncols=13, figsize=(50, 40))  # 14*13 = 182 > 173
axes = axes.flatten()  # Flatten to 1D array for easier indexing

# Get the subset of data for patients with different actions
patients_with_diff_actions = healthcare_data.iloc[different_action_indices]

# Create scatterplots for each feature/column
for i, column in enumerate(healthcare_data.columns):
    if i < len(axes):  # Make sure we don't exceed available subplots
        # Create scatter plot with patient index on x-axis and feature value on y-axis
        axes[i].scatter(different_action_indices, 
                       patients_with_diff_actions[column], 
                       alpha=0.6, s=20)
        axes[i].set_title(f'{column}', fontsize=8)
        axes[i].set_xlabel('Patient Index', fontsize=6)
        axes[i].set_ylabel('Feature Value', fontsize=6)
        axes[i].tick_params(labelsize=5)
        axes[i].grid(True, alpha=0.3)

# Hide unused subplots
for i in range(len(healthcare_data.columns), len(axes)):
    axes[i].set_visible(False)

plt.suptitle('Feature Distributions for Patients with Different Physician vs RL Agent Actions', 
             fontsize=16, y=0.98)
plt.tight_layout()
plt.subplots_adjust(top=0.95)
plt.show()