In [None]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import torch
import torch.nn as nn
from collections import deque
from copy import deepcopy
import wandb
import random

import pandas as pd
from tqdm.auto import tqdm

In [3]:
def tumor_ratio(mask_path):
    mask = np.load(mask_path)
    return np.mean(np.isin(mask, [1,4]))

In [4]:
class Glioblastoma(gym.Env):
    metadata = {"render_modes": ["human"], "render_fps": 4}

    def __init__(self, image_path, mask_path, grid_size=4, render_mode=None):
        super().__init__()
        
        # Load and normalize image exactly as described
        self.image = np.load(image_path).astype(np.float32)
        self.mask = np.load(mask_path).astype(np.uint8)
        
        # Normalize to [0,1] as in paper
        img_min, img_max = self.image.min(), self.image.max()
        if img_max > 1.0:
            self.image = (self.image - img_min) / (img_max - img_min + 1e-8)

        self.grid_size = grid_size
        self.block_size = self.image.shape[0] // grid_size  # 240/4 = 60
        self.render_mode = render_mode

        # Exact action space from paper: 0=stay, 1=down, 2=right
        self.action_space = spaces.Discrete(3)

        # Observation: single 60x60 patch
        self.observation_space = spaces.Box(
            low=0, high=1,
            shape=(self.block_size, self.block_size),
            dtype=np.float32
        )

        # Always start at top-left as in paper Fig 1a
        self.agent_pos = [0, 0]
        self.current_step = 0
        self.max_steps = 20  # Exactly as in paper

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        
        # Reset to top-left corner exactly as in paper
        self.agent_pos = [0, 0]
        self.current_step = 0
        obs = self._get_obs()
        info = {}
        return obs, info

    def step(self, action):
        self.current_step += 1
        prev_pos = self.agent_pos.copy()
        
        # Exact movement logic from paper - only down and right
        if action == 1 and self.agent_pos[0] < self.grid_size - 1:  # move down
            self.agent_pos[0] += 1
        elif action == 2 and self.agent_pos[1] < self.grid_size - 1:  # move right
            self.agent_pos[1] += 1
        # action == 0: stay still (no position change)

        # Get reward using paper's exact scheme
        reward = self._get_reward_paper_exact(action, prev_pos)
        obs = self._get_obs()

        # Episode ends after exactly 20 steps as in paper
        terminated = self.current_step >= self.max_steps
        truncated = False
        info = {}

        return obs, reward, terminated, truncated, info

    def _get_obs(self):
        r0 = self.agent_pos[0] * self.block_size
        c0 = self.agent_pos[1] * self.block_size
        patch = self.image[r0:r0+self.block_size, c0:c0+self.block_size].astype(np.float32)
        return patch

    def _get_reward_paper_exact(self, action, prev_pos):
        """EXACT reward scheme from paper Figure 1b-d"""
        current_r0 = self.agent_pos[0] * self.block_size
        current_c0 = self.agent_pos[1] * self.block_size
        current_patch_mask = self.mask[current_r0:current_r0+self.block_size, 
                                     current_c0:current_c0+self.block_size]
        
        # Check if current position overlaps tumor (any tumor voxel)
        current_on_tumor = np.any(np.isin(current_patch_mask, [1, 4]))
        
        # Check if previous position overlapped tumor
        prev_r0 = prev_pos[0] * self.block_size
        prev_c0 = prev_pos[1] * self.block_size
        prev_patch_mask = self.mask[prev_r0:prev_r0+self.block_size,
                                  prev_c0:prev_c0+self.block_size]
        prev_on_tumor = np.any(np.isin(prev_patch_mask, [1, 4]))
        
        # Paper's exact reward logic from Figure 1:
        if current_on_tumor and action == 0:  # On tumor and staying still
            return +1.0
        elif not current_on_tumor and action == 0:  # Off tumor and staying still  
            return -2.0
        elif current_on_tumor and action != 0:  # Moved to tumor
            return +1.0
        else:  # Moved but no tumor (action != 0 and not on tumor)
            return -0.5

    def render(self):
        if self.render_mode != "human":
            return

        vis_img = np.stack([self.image] * 3, axis=-1).astype(np.float32)
        tumor_overlay = np.zeros_like(vis_img)
        tumor_overlay[..., 0] = (self.mask > 0).astype(float)
        
        alpha = 0.4
        vis_img = (1 - alpha) * vis_img + alpha * tumor_overlay

        fig, ax = plt.subplots(figsize=(3, 3))
        ax.imshow(vis_img, cmap='gray', origin='upper')

        # Draw grid
        for i in range(1, self.grid_size):
            ax.axhline(i * self.block_size, color='white', lw=1, alpha=0.5)
            ax.axvline(i * self.block_size, color='white', lw=1, alpha=0.5)

        # Draw agent position
        r0 = self.agent_pos[0] * self.block_size
        c0 = self.agent_pos[1] * self.block_size
        rect = patches.Rectangle(
            (c0, r0),
            self.block_size,
            self.block_size,
            linewidth=2,
            edgecolor='yellow',
            facecolor='none'
        )
        ax.add_patch(rect)

        ax.set_title(f"Agent at {self.agent_pos} | Step {self.current_step}")
        ax.axis('off')
        plt.show()



In [5]:
class DQN(torch.nn.Module):
    """EXACT architecture from paper - 4 conv layers + 3 FC layers"""
    
    def __init__(self, env, learning_rate=1e-4, device='cpu'):
        super(DQN, self).__init__()
        self.device = device
        self.n_outputs = env.action_space.n
        self.actions = np.arange(env.action_space.n)
        self.learning_rate = learning_rate
        
        # Exact CNN architecture from paper: 4 conv layers with 32 channels each
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),  # 60x60 -> 30x30
            nn.ELU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),  # 30x30 -> 15x15
            nn.ELU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),  # 15x15 -> 8x8
            nn.ELU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),  # 8x8 -> 4x4
            nn.ELU(),
        )
        
        # Calculate flattened size
        with torch.no_grad():
            dummy_input = torch.zeros(1, 1, 60, 60)
            n_flatten = self.features(dummy_input).view(1, -1).size(1)
            
        # Exact FC architecture from paper: 512 -> 256 -> 128 -> 3
        self.fc = nn.Sequential(
            nn.Linear(n_flatten, 512),
            nn.ELU(),
            nn.Linear(512, 256),
            nn.ELU(),
            nn.Linear(256, 128),
            nn.ELU(),
            nn.Linear(128, self.n_outputs)  # 3 actions
        )
        
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        
        if self.device == 'cuda':
            self.cuda()

    def forward(self, x):
        features = self.features(x)
        features_flat = features.view(x.size(0), -1)
        q_values = self.fc(features_flat)
        return q_values
    
    def get_action(self, state, epsilon=0.05):
        if np.random.random() < epsilon:
            return np.random.choice(self.actions)
        else:
            qvals = self.get_qvals(state)
            return torch.argmax(qvals, dim=-1).item()
    
    def get_qvals(self, state):
        if isinstance(state, np.ndarray):
            if state.ndim == 2:
                state = np.expand_dims(np.expand_dims(state, 0), 0)
            elif state.ndim == 3:
                if state.shape[0] != 1:
                    state = np.expand_dims(state, 1)
                    
        state_t = torch.FloatTensor(state).to(self.device)
        qvals = self.forward(state_t)
        return qvals


class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = deque(maxlen=capacity)

    def append(self, state, action, reward, done, next_state):
        self.buffer.append((state.copy(), action, reward, done, next_state.copy()))

    def sample_batch(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        states, actions, rewards, dones, next_states = zip(*batch)
        return (
            np.array(states, dtype=np.float32),
            np.array(actions, dtype=np.int64),
            np.array(rewards, dtype=np.float32),
            np.array(dones, dtype=np.bool_),
            np.array(next_states, dtype=np.float32),
        )


In [6]:
class DQNAgent:
    """EXACT training procedure from paper with SINGLE replay buffer"""
    
    def __init__(self, env, dnnetwork, train_pairs, epsilon=0.7, eps_decay=1e-4, 
                 epsilon_min=1e-4, batch_size=128, gamma=0.99):
        self.env = env
        self.dnnetwork = dnnetwork
        self.target_network = deepcopy(dnnetwork)
        self.target_network.optimizer = None
        
        # EXACT PAPER: Single replay buffer with 15,000 capacity
        self.replay_buffer = ReplayBuffer(capacity=15000)
        
        self.epsilon = epsilon
        self.eps_decay = eps_decay
        self.epsilon_min = epsilon_min
        self.batch_size = batch_size
        self.gamma = gamma
        self.train_pairs = train_pairs
        
        self.initialize()

    def initialize(self):
        self.update_loss = []
        self.training_rewards = []
        self.mean_training_rewards = []
        self.total_reward = 0
        self.step_count = 0
        self.state0 = None

    def take_step(self, eps, mode='train'):
        if mode == 'explore':
            action = self.env.action_space.sample()
        else:
            action = self.dnnetwork.get_action(self.state0, eps)
            self.step_count += 1
            
        new_state, reward, terminated, truncated, _ = self.env.step(action)
        done = terminated or truncated
        self.total_reward += reward
        
        # Store in SINGLE shared replay buffer (as in paper)
        self.replay_buffer.append(self.state0, action, reward, done, new_state)
        self.state0 = new_state.copy()
        
        if done:
            self.state0 = self.env.reset()[0]
        return done

    def train(self, gamma=0.99, max_episodes=90, 
              dnn_update_frequency=4, dnn_sync_frequency=2000):
        """EXACT training procedure from paper - 90 episodes, 30 images"""
        
        self.gamma = gamma

        # Fill SINGLE replay buffer with random experiences from all images
        print("Filling replay buffer...")
        for img_path, mask_path in self.train_pairs:
            self.env = Glioblastoma(img_path, mask_path, grid_size=4)
            self.state0, _ = self.env.reset()

            # Add experiences from this image to the shared buffer
            for _ in range(500):  # Distribute experiences across all images
                self.take_step(self.epsilon, mode='explore')

        # Training exactly as described
        episode = 0
        training = True
        
        print("Training for 90 episodes as in paper...")
        pbar = tqdm(total=max_episodes, desc="Training")
        
        while training and episode < max_episodes:
            # Sample random image from training set as described
            img_path, mask_path = random.choice(self.train_pairs)
            self.env = Glioblastoma(img_path, mask_path, grid_size=4)
            self.state0, _ = self.env.reset()
            self.total_reward = 0

            # Run episode for exactly 20 steps
            for step in range(self.env.max_steps):
                done = self.take_step(self.epsilon, mode='train')

                # Update network every 4 steps as typical in DQN
                if self.step_count % dnn_update_frequency == 0 and len(self.replay_buffer.buffer) >= self.batch_size:
                    self.update()

                # Sync target network less frequently
                if self.step_count % dnn_sync_frequency == 0:
                    self.target_network.load_state_dict(self.dnnetwork.state_dict())

                if done:
                    break

            # Episode complete
            episode += 1
            pbar.update(1)
            
            # Store metrics
            self.training_rewards.append(self.total_reward)
            
            # Paper's exact epsilon decay: linear from 0.7 to 1e-4 over 90 episodes
            self.epsilon = max(self.epsilon_min, self.epsilon - self.eps_decay)
            #self.epsilon = max(self.epsilon_min, self.epsilon - self.eps_decay * self.env.max_steps)

            
            # Log progress
            if episode % 10 == 0:
                mean_reward = np.mean(self.training_rewards[-10:])
                current_loss = np.mean(self.update_loss) if self.update_loss else 0
                print(f"Episode {episode:3d} | Mean Reward: {mean_reward:7.2f} | Epsilon: {self.epsilon:.4f} | Loss: {current_loss:.4f}")
                
                wandb.log({
                    'episode': episode,
                    'mean_reward': mean_reward,
                    'epsilon': self.epsilon,
                    'loss': current_loss
                })
            
            self.update_loss = []  # Reset loss tracking

        pbar.close()
        print("Training completed")

    def calculate_loss(self, batch):
        states, actions, rewards, dones, next_states = batch
        
        # Add channel dimension for CNN
        states = torch.FloatTensor(states).unsqueeze(1)
        next_states = torch.FloatTensor(next_states).unsqueeze(1)
        
        rewards = torch.FloatTensor(rewards)
        actions = torch.LongTensor(actions).unsqueeze(1)
        dones = torch.BoolTensor(dones)
        
        # Current Q values
        current_q = self.dnnetwork.get_qvals(states).gather(1, actions).squeeze()
        
        # Target Q values
        with torch.no_grad():
            next_q = self.target_network.get_qvals(next_states).max(1)[0]
            next_q[dones] = 0.0
            target_q = rewards + self.gamma * next_q
        
        # MSE loss
        loss = nn.MSELoss()(current_q, target_q)
        return loss

    def update(self):
        if len(self.replay_buffer.buffer) < self.batch_size:
            return
            
        batch = self.replay_buffer.sample_batch(self.batch_size)
        loss = self.calculate_loss(batch)
        
        self.dnnetwork.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.dnnetwork.parameters(), 1.0)
        self.dnnetwork.optimizer.step()
        
        self.update_loss.append(loss.item())

In [7]:
# EXACT HYPERPARAMETERS FROM PAPER
LR = 1e-4
MEMORY_SIZE = 15000
MAX_EPISODES = 90
EPSILON_START = 0.7
EPSILON_DECAY = 1e-4  # Linear decay per episode
EPSILON_MIN = 1e-4
GAMMA = 0.99
BATCH_SIZE = 128

In [8]:
# Load your data
base_dir = "/home/martina/codi2/4year/tfg/training_set_npy"
csv_path = "/home/martina/codi2/4year/tfg/training_dataset_slices.csv"

df = pd.read_csv(csv_path)
df["image_path"] = df.apply(
    lambda row: os.path.join(base_dir, f"{row['Patient']:03d}_{row['SliceIndex']}.npy"), axis=1
)
df["mask_path"] = df.apply(
    lambda row: os.path.join(base_dir, f"{row['Patient']:03d}_{row['SliceIndex']}_mask.npy"), axis=1
)

train_pairs = [
    (img, mask) for img, mask in zip(df["image_path"], df["mask_path"])
    if os.path.exists(img) and os.path.exists(mask) # and tumor_ratio(mask) >= 0.01     # keep slices with ≥1% tumor
]


print(f"Found {len(train_pairs)} training pairs")

Found 30 training pairs


In [9]:
# Initialize with exact paper specifications
env = Glioblastoma(*train_pairs[0], grid_size=4)
net = DQN(env, learning_rate=LR, device='cpu')
agent = DQNAgent(env, net, train_pairs,
                epsilon=EPSILON_START, 
                eps_decay=EPSILON_DECAY,
                epsilon_min=EPSILON_MIN,
                batch_size=BATCH_SIZE, 
                gamma=GAMMA)

# Train
wandb.init(project="TFG_ExactPaperReplication", config={
    "lr": LR, "episodes": MAX_EPISODES, "epsilon_start": EPSILON_START,
    "epsilon_decay": EPSILON_DECAY, "epsilon_min": EPSILON_MIN,
    "batch_size": BATCH_SIZE, "gamma": GAMMA
})

[34m[1mwandb[0m: Currently logged in as: [33mmartinacarrettab[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
agent.train(max_episodes=MAX_EPISODES)

Filling replay buffer...
Training for 90 episodes as in paper...


Training:   0%|          | 0/90 [00:00<?, ?it/s]

Episode  10 | Mean Reward:  -17.35 | Epsilon: 0.6990 | Loss: 0.2671
Episode  20 | Mean Reward:  -13.90 | Epsilon: 0.6980 | Loss: 0.2139
Episode  30 | Mean Reward:  -13.60 | Epsilon: 0.6970 | Loss: 0.2352
Episode  40 | Mean Reward:  -13.60 | Epsilon: 0.6960 | Loss: 0.2369
Episode  50 | Mean Reward:  -17.35 | Epsilon: 0.6950 | Loss: 0.1822
Episode  60 | Mean Reward:  -16.15 | Epsilon: 0.6940 | Loss: 0.1987
Episode  70 | Mean Reward:  -15.40 | Epsilon: 0.6930 | Loss: 0.2298
Episode  80 | Mean Reward:  -15.40 | Epsilon: 0.6920 | Loss: 0.2048
Episode  90 | Mean Reward:  -16.30 | Epsilon: 0.6910 | Loss: 0.2368
Training completed


In [11]:
wandb.finish()

0,1
episode,▁▂▃▄▅▅▆▇█
epsilon,█▇▆▅▅▄▃▂▁
loss,█▄▅▆▁▂▅▃▆
mean_reward,▁▇██▁▃▅▅▃

0,1
episode,90.0
epsilon,0.691
loss,0.23681
mean_reward,-16.3


In [12]:
def evaluate_agent(agent, test_pairs, num_episodes=30):
    """EXACT evaluation from paper - 30 test images, 20 steps each"""
    correct = 0
    
    for img_path, mask_path in test_pairs:
        env = Glioblastoma(img_path, mask_path, grid_size=4)
        state, _ = env.reset()
        
        found_tumor = False
        # Run for exactly 20 steps as in paper
        for step in range(env.max_steps):
            action = agent.dnnetwork.get_action(state, epsilon=0.0)  # No exploration
            state, reward, terminated, truncated, _ = env.step(action)
            
            # Check if agent is on tumor at any point
            r0 = env.agent_pos[0] * env.block_size
            c0 = env.agent_pos[1] * env.block_size
            patch_mask = env.mask[r0:r0+env.block_size, c0:c0+env.block_size]
            if np.any(np.isin(patch_mask, [1, 4])):
                found_tumor = True
                break
                
        # Check final position if tumor not found during steps
        if not found_tumor:
            r0 = env.agent_pos[0] * env.block_size
            c0 = env.agent_pos[1] * env.block_size
            patch_mask = env.mask[r0:r0+env.block_size, c0:c0+env.block_size]
            if np.any(np.isin(patch_mask, [1, 4])):
                found_tumor = True
        
        if found_tumor:
            correct += 1
    
    accuracy = correct / len(test_pairs)
    print(f"Test Accuracy: {accuracy:.1%} ({correct}/{len(test_pairs)})")
    return accuracy

In [13]:
base_dir = "/home/martina/codi2/4year/tfg/training_set_npy"
csv_path = "/home/martina/codi2/4year/tfg/training_dataset_slices.csv"

df = pd.read_csv(csv_path)
df["image_path"] = df.apply(
    lambda row: os.path.join(base_dir, f"{row['Patient']:03d}_{row['SliceIndex']}.npy"), axis=1
)
df["mask_path"] = df.apply(
    lambda row: os.path.join(base_dir, f"{row['Patient']:03d}_{row['SliceIndex']}_mask.npy"), axis=1
)

train_pairs = [
    (img, mask) for img, mask in zip(df["image_path"], df["mask_path"])
    if os.path.exists(img) and os.path.exists(mask)
]

print(f"Found {len(train_pairs)} training pairs")

# Evaluate
test_accuracy = evaluate_agent(agent, train_pairs[:30], 30)  # Using first 30 as test

Found 30 training pairs
Test Accuracy: 23.3% (7/30)
