In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical
import gymnasium as gym
from collections import deque
import random
import matplotlib.pyplot as plt
import os
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import pandas as pd
from gymnasium import spaces


In [None]:
class Glioblastoma(gym.Env):
    metadata = {"render_modes": ["human"], "render_fps": 4} 
    # The metadata of the environment, e.g. {“render_modes”: [“rgb_array”, “human”], “render_fps”: 30}. 
    # For Jax or Torch, this can be indicated to users with “jax”=True or “torch”=True.

    def __init__(self, image_path, mask_path, grid_size=4, tumor_threshold=0.0001, rewards = [1.0, -2.0, -0.5], action_space=spaces.Discrete(3), render_mode="human"): # cosntructor with the brain image, the mask and a size
        super().__init__() # parent class
        
        self.image = np.load(image_path).astype(np.float32)
        self.mask = np.load(mask_path).astype(np.uint8)
        
        img_min, img_max = self.image.min(), self.image.max()
        if img_max > 1.0:  # only normalize if not already in [0, 1]
            self.image = (self.image - img_min) / (img_max - img_min + 1e-8) #avoid division by 0

        self.grid_size = grid_size
        self.block_size = self.image.shape[0] // grid_size  # 240/4 = 60
        
        self.action_space = action_space
        self.tumor_threshold = tumor_threshold # 15% of the patch must be tumor to consider that the agent is inside the tumor region
        self.rewards = rewards  # [reward_on_tumor, reward_stay_no_tumor, reward_move_no_tumor]
        
        self.render_mode = render_mode

        # Observations: grayscale patch (normalized 0-1)
        # apparently Neural networks train better when inputs are scaled to small, 
        # consistent ranges rather than raw 0–255 values.
        # self.observation_space = spaces.Box( # Supports continuous (and discrete) vectors or matrices
        #     low=0, high=1, # Data has been normalized
        #     shape=(self.block_size, self.block_size), # shape of the observation
        #     dtype=np.float32
        # )
        self.observation_space = spaces.Dict({
            'patch': spaces.Box(low=0, high=1, shape=(self.block_size, self.block_size), dtype=np.float32),
            'position': spaces.Box(low=0, high=self.grid_size-1, shape=(2,), dtype=np.int32)
        })


        self.agent_pos = [0, 0] # INITIAL POSITION AT TOP LEFT
        self.current_step = 0 # initialize counter
        self.max_steps = 20  # like in the paper
    
    def reset(self, seed=None, options=None):
        obs, info = super().reset(seed=seed)
        return self._get_obs(), info

    def step(self, action):
        self.current_step += 1

        prev_pos = self.agent_pos.copy() # for reward computation taking into consideration the transition changes
        
        # Apply action (respect grid boundaries)
        if self.action_space == spaces.Discrete(3):
            if action == 1 and self.agent_pos[0] < self.grid_size - 1:
                self.agent_pos[0] += 1  # move down
            elif action == 2 and self.agent_pos[1] < self.grid_size - 1:
                self.agent_pos[1] += 1  # move right
            # else, the agent doesn't move so the observation 
            # and reward will be calculated from the same position
            # no need to compute self.agent_pos
        elif self.action_space == spaces.Discrete(5):
            if action == 1 and self.agent_pos[0] < self.grid_size - 1:
                self.agent_pos[0] += 1  # move down
            elif action == 2 and self.agent_pos[1] < self.grid_size - 1:
                self.agent_pos[1] += 1  # move right
            elif action == 3 and self.agent_pos[0] > 0:
                self.agent_pos[0] -= 1  # move up
            elif action == 4 and self.agent_pos[1] > 0:
                self.agent_pos[1] -= 1  # move left
        elif self.action_space == spaces.Discrete(9):
            if action == 1 and self.agent_pos[0] < self.grid_size - 1:
                self.agent_pos[0] += 1  # move down
            elif action == 2 and self.agent_pos[1] < self.grid_size - 1:
                self.agent_pos[1] += 1  # move right
            elif action == 3 and self.agent_pos[0] > 0:
                self.agent_pos[0] -= 1  # move up
            elif action == 4 and self.agent_pos[1] > 0:
                self.agent_pos[1] -= 1  # move left
            elif action == 5 and self.agent_pos[0] < self.grid_size - 1 and self.agent_pos[1] < self.grid_size - 1:
                self.agent_pos[0] += 1  # move down-right
                self.agent_pos[1] += 1
            elif action == 6 and self.agent_pos[0] > 0 and self.agent_pos[1] < self.grid_size - 1:
                self.agent_pos[0] -= 1  # move up-right
                self.agent_pos[1] += 1
            elif action == 7 and self.agent_pos[0] < self.grid_size - 1 and self.agent_pos[1] > 0:
                self.agent_pos[0] += 1  # move down-left
                self.agent_pos[1] -= 1
            elif action == 8 and self.agent_pos[0] > 0 and self.agent_pos[1] > 0:
                self.agent_pos[0] -= 1  # move up-left
                self.agent_pos[1] -= 1
        
        reward = self._get_reward(action, prev_pos)
                
        obs = self._get_obs()

        # Episode ends
        terminated = self.current_step >= self.max_steps
        truncated = False  # we don’t need truncation here
        info = {}

        return obs, reward, terminated, truncated, info

    def _get_obs(self):
        r0 = self.agent_pos[0] * self.block_size # row start
        c0 = self.agent_pos[1] * self.block_size # col start
        
        patch = self.image[r0:r0+self.block_size, c0:c0+self.block_size].astype(np.float32)
    
        return {
                'patch': patch,
                'position': np.array(self.agent_pos, dtype=np.int32)
            }
    

    def _get_reward(self, action, prev_pos):        
        # look position of the agent in the mask
        r0 = self.agent_pos[0] * self.block_size
        c0 = self.agent_pos[1] * self.block_size
        patch_mask = self.mask[r0:r0+self.block_size, c0:c0+self.block_size]
        
        # Now that i have the patch where i was and the patch where i am, i can check if there is tumor in any of them
        # tumor is labeled as 1 or 4 in the mask        
        # label 2 is edema
        
        # first get a count of the tumor pixels in the patch. 
        tumor_count_curr = np.sum(np.isin(patch_mask, [1, 4]))
        total = self.block_size * self.block_size # to compute the percentage
        # Determine if patch has more than self.tumor_threshold of tumor
        inside = (tumor_count_curr / total) >= self.tumor_threshold
        
        if inside:
            return self.rewards[0]  # reward for being on tumor or staying on tumor
        else:
            if action == 0 or prev_pos == self.agent_pos:  # stayed in place but no tumor. we are also taking into consideration that if the action was to move but we are at the edge of the grid, we also stay in place
                return self.rewards[1]
            else:
                return self.rewards[2]  # moved but no tumor

    def render(self):
        if self.render_mode != "human": # would be rgb_array or ansi
            return  # Only render in human mode

        # Create RGB visualization image
        # not necessary since it's grayscale, but i want to draw the mask and position
        vis_img = np.stack([self.image] * 3, axis=-1).astype(np.float32)

        # Overlay tumor mask in red [..., 0] 
        tumor_overlay = np.zeros_like(vis_img) # do all blank but here we have 3 channels, mask is 2D
        tumor_overlay[..., 0] = (self.mask > 0).astype(float) # red channel. set to float to avoid issues when blending in vis_img

        # transparency overlay (crec que es el mateix valor que tinc a l'altra notebook)
        alpha = 0.4
        vis_img = (1 - alpha) * vis_img + alpha * tumor_overlay

        # Plotting
        fig, ax = plt.subplots(figsize=(3, 3))
        ax.imshow(vis_img, cmap='gray', origin='upper')

        # Draw grid lines
        # alpha for transparency again
        for i in range(1, self.grid_size):
            ax.axhline(i * self.block_size, color='white', lw=1, alpha=0.5)
            ax.axvline(i * self.block_size, color='white', lw=1, alpha=0.5)

        # Draw agent position
        r0 = self.agent_pos[0] * self.block_size
        c0 = self.agent_pos[1] * self.block_size
        rect = patches.Rectangle(
            (c0, r0), # (x,y) bottom left corner
            self.block_size, # width
            self.block_size, # height
            linewidth=2,
            edgecolor='yellow',
            facecolor='none'
        )
        ax.add_patch(rect)

        ax.set_title(f"Agent at {self.agent_pos} | Step {self.current_step}")
        ax.axis('off')
        plt.show()
        
    def current_patch_overlap_with_lesion(self): # FALTAAA chat
        """ Returns the number of overlapping lesion pixels between the agent's current patch and the ground-truth mask. If > 0, the agent is correctly over the lesion (TP). """
        # get current agent patch boundaries
        row, col = self.agent_pos
        patch_h = self.block_size # not grid_size because grid_size is number of patches per side
        patch_w = self.block_size
        
        y0 = row * patch_h
        y1 = y0 + patch_h
        x0 = col * patch_w
        x1 = x0 + patch_w
        # extract mask region under current patch
        patch_mask = self.mask[y0:y1, x0:x1]
        # count how many pixels of lesion (nonzero)
        overlap = np.sum(patch_mask > 0)
        return overlap


In [None]:
def prepare(mode = "train"):
    if mode == "train":
        base_dir = "/home/martina/codi2/4year/tfg/training_set_npy"
        csv_path = "/home/martina/codi2/4year/tfg/set_training.csv"
    else:
        base_dir = "/home/martina/codi2/4year/tfg/testing_set_npy"
        csv_path = "/home/martina/codi2/4year/tfg/set_testing.csv"

    # Load the CSV
    df = pd.read_csv(csv_path)

    # Construct image and mask filenames
    df["image_path"] = df.apply(
        lambda row: os.path.join(base_dir, f"{row['Patient']:03d}_{row['SliceIndex']}.npy"), axis=1
    )
    df["mask_path"] = df.apply(
        lambda row: os.path.join(base_dir, f"{row['Patient']:03d}_{row['SliceIndex']}_mask.npy"), axis=1
    )

    # Sanity check (optional)
    pairs = [
        (img, mask)
        for img, mask in zip(df["image_path"], df["mask_path"])
        if os.path.exists(img) and os.path.exists(mask)
    ]

    print(f"✅ Found {len(pairs)} pairs out of {len(df)} listed in CSV.")
    return pairs

In [None]:
class GlobalAwarePPOActorCritic(nn.Module):
    def __init__(self, env, learning_rate=3e-4, device='cpu'):
        super().__init__()
        self.device = device
        self.n_outputs = env.action_space.n
        self.learning_rate = learning_rate
        
        # CNN for patch processing
        input_channels = 1
        patch_shape = env.observation_space['patch'].shape  # (60, 60)
        
        self.patch_features = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=3, stride=2, padding=1),
            nn.ELU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ELU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.ELU(),
        )
        
        # Calculate flattened patch features size
        with torch.no_grad():
            dummy_patch = torch.zeros(1, input_channels, *patch_shape)
            patch_features_out = self.patch_features(dummy_patch)
            patch_flatten = patch_features_out.view(1, -1).size(1)
        
        # Position embedding
        position_size = env.observation_space['position'].shape[0]  # 2
        self.position_embedding = nn.Sequential(
            nn.Linear(position_size, 16),
            nn.ELU(),
            nn.Linear(16, 32),
            nn.ELU()
        )
        
        # Combined features
        combined_features_size = patch_flatten + 32
        
        # Actor and Critic
        self.actor = nn.Sequential(
            nn.Linear(combined_features_size, 256),
            nn.ELU(),
            nn.Linear(256, 128),
            nn.ELU(),
            nn.Linear(128, self.n_outputs),
            nn.Softmax(dim=-1)
        )
        
        self.critic = nn.Sequential(
            nn.Linear(combined_features_size, 256),
            nn.ELU(),
            nn.Linear(256, 128),
            nn.ELU(),
            nn.Linear(128, 1)
        )
        
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        
        if self.device == 'cuda':
            self.to(self.device)
    
    def forward(self, x):
        if isinstance(x, dict):
            # Single observation - convert to batch of size 1
            patch = x['patch']
            position = x['position']
            
            if isinstance(patch, np.ndarray):
                if patch.ndim == 2:
                    patch = patch[np.newaxis, np.newaxis, :, :]  # (1, 1, 60, 60)
                patch = torch.FloatTensor(patch).to(self.device)
            
            if isinstance(position, np.ndarray):
                position = torch.FloatTensor(position).to(self.device).unsqueeze(0)  # Add batch dim
            
        elif isinstance(x, list):
            # Batch of observations
            patch_batch = []
            position_batch = []
            
            for obs in x:
                patch_batch.append(obs['patch'])
                position_batch.append(obs['position'])
            
            patch_array = np.array(patch_batch)
            if patch_array.ndim == 3:
                patch_array = patch_array[:, np.newaxis, :, :]
            
            patch = torch.FloatTensor(patch_array).to(self.device)
            position = torch.FloatTensor(np.array(position_batch)).to(self.device)
        
        # Process through networks (both paths now have batch dimension)
        patch_features = self.patch_features(patch)
        patch_flat = patch_features.view(patch.size(0), -1)
        position_embedded = self.position_embedding(position)
        combined = torch.cat([patch_flat, position_embedded], dim=-1)
        
        action_probs = self.actor(combined)
        state_values = self.critic(combined)
        
        return action_probs, state_values



# Fixed environment with global awareness
class GlobalAwareGlioblastoma(Glioblastoma):
    def __init__(self, image_path, mask_path, grid_size=4, tumor_threshold=0.0001, rewards=[1.0, -2.0, -0.5], action_space=spaces.Discrete(3), render_mode="human"):
        super().__init__(image_path, mask_path, grid_size, tumor_threshold, rewards, action_space, render_mode)
        
        self.image = np.load(image_path).astype(np.float32)
        self.mask = np.load(mask_path).astype(np.uint8)
        
        img_min, img_max = self.image.min(), self.image.max()
        if img_max > 1.0:
            self.image = (self.image - img_min) / (img_max - img_min + 1e-8)

        self.grid_size = grid_size
        self.block_size = self.image.shape[0] // grid_size
        
        self.action_space = action_space
        self.tumor_threshold = tumor_threshold
        self.rewards = rewards
        self.render_mode = render_mode

        # Dict observation space with position info
        self.observation_space = spaces.Dict({
            'patch': spaces.Box(low=0, high=1, shape=(self.block_size, self.block_size), dtype=np.float32),
            'position': spaces.Box(low=0, high=grid_size-1, shape=(2,), dtype=np.int32)
        })

        self.agent_pos = [0, 0]
        self.current_step = 0
        self.max_steps = 20

    def reset(self, seed=None, options=None):
        self.agent_pos = [0, 0]
        self.current_step = 0
        obs = self._get_obs()
        info = {}
        return obs, info

    def step(self, action):
        self.current_step += 1
        prev_pos = self.agent_pos.copy()
        
        # Movement logic (same as before)
        if self.action_space.n == 3:
            if action == 1 and self.agent_pos[0] < self.grid_size - 1:
                self.agent_pos[0] += 1
            elif action == 2 and self.agent_pos[1] < self.grid_size - 1:
                self.agent_pos[1] += 1
        elif self.action_space.n == 5:
            if action == 1 and self.agent_pos[0] < self.grid_size - 1:
                self.agent_pos[0] += 1
            elif action == 2 and self.agent_pos[1] < self.grid_size - 1:
                self.agent_pos[1] += 1
            elif action == 3 and self.agent_pos[0] > 0:
                self.agent_pos[0] -= 1
            elif action == 4 and self.agent_pos[1] > 0:
                self.agent_pos[1] -= 1
        
        reward = self._get_reward(action, prev_pos)
        obs = self._get_obs()
        terminated = self.current_step >= self.max_steps
        truncated = False
        info = {}

        return obs, reward, terminated, truncated, info

    def _get_obs(self):
        r0 = self.agent_pos[0] * self.block_size
        c0 = self.agent_pos[1] * self.block_size
        patch = self.image[r0:r0+self.block_size, c0:c0+self.block_size].astype(np.float32)
        
        return {
            'patch': patch,
            'position': np.array(self.agent_pos, dtype=np.int32)
        }

    def _get_reward(self, action, prev_pos):
        r0 = self.agent_pos[0] * self.block_size
        c0 = self.agent_pos[1] * self.block_size
        patch_mask = self.mask[r0:r0+self.block_size, c0:c0+self.block_size]
        
        tumor_count_curr = np.sum(np.isin(patch_mask, [1, 4]))
        total = self.block_size * self.block_size
        inside = (tumor_count_curr / total) >= self.tumor_threshold
        
        if inside:
            return self.rewards[0]
        else:
            if action == 0 or prev_pos == self.agent_pos:
                return self.rewards[1]
            else:
                return self.rewards[2]

    def render(self):
        if self.render_mode != "human":
            return

        vis_img = np.stack([self.image] * 3, axis=-1).astype(np.float32)
        tumor_overlay = np.zeros_like(vis_img)
        tumor_overlay[..., 0] = (self.mask > 0).astype(float)
        
        alpha = 0.4
        vis_img = (1 - alpha) * vis_img + alpha * tumor_overlay

        import matplotlib.pyplot as plt
        import matplotlib.patches as patches
        
        fig, ax = plt.subplots(figsize=(3, 3))
        ax.imshow(vis_img, cmap='gray', origin='upper')

        for i in range(1, self.grid_size):
            ax.axhline(i * self.block_size, color='white', lw=1, alpha=0.5)
            ax.axvline(i * self.block_size, color='white', lw=1, alpha=0.5)

        r0 = self.agent_pos[0] * self.block_size
        c0 = self.agent_pos[1] * self.block_size
        rect = patches.Rectangle(
            (c0, r0), self.block_size, self.block_size,
            linewidth=2, edgecolor='yellow', facecolor='none'
        )
        ax.add_patch(rect)

        ax.set_title(f"Agent at {self.agent_pos} | Step {self.current_step}")
        ax.axis('off')
        plt.show()
        
    def current_patch_overlap_with_lesion(self):
        row, col = self.agent_pos
        patch_h = self.block_size
        patch_w = self.block_size
        
        y0 = row * patch_h
        y1 = y0 + patch_h
        x0 = col * patch_w
        x1 = x0 + patch_w
        patch_mask = self.mask[y0:y1, x0:x1]
        overlap = np.sum(patch_mask > 0)
        return overlap


class PPOAgent:
    def __init__(self, env_config, model, train_pairs, env_class,
                 gamma=0.99, gae_lambda=0.95,
                 clip_epsilon=0.2, ppo_epochs=4, batch_size=64,
                 save_name="PPO_Agent"):
        
        self.env_config = env_config
        self.env_class = env_class
        self.model = model
        self.device = model.device
        
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.clip_epsilon = clip_epsilon
        self.ppo_epochs = ppo_epochs
        self.batch_size = batch_size
        self.save_name = save_name
        
        self.training_rewards = []
        self.mean_training_rewards = []
        self.actor_losses = []
        self.critic_losses = []
        self.entropies = []
        
        self.train_pairs = train_pairs


# Also need to fix the PPOAgent to handle dict observations
class GlobalAwarePPOAgent(PPOAgent):
    def __init__(self, env_config, model, train_pairs, env_class,
                 gamma=0.99, gae_lambda=0.95,
                 clip_epsilon=0.2, ppo_epochs=4, batch_size=64,
                 save_name="GlobalAware_PPO"):
        
        self.env_config = env_config
        self.env_class = env_class
        self.model = model
        self.device = model.device
        
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.clip_epsilon = clip_epsilon
        self.ppo_epochs = ppo_epochs
        self.batch_size = batch_size
        self.save_name = save_name
        
        self.training_rewards = []
        self.mean_training_rewards = []
        self.actor_losses = []
        self.critic_losses = []
        self.entropies = []
        
        self.train_pairs = train_pairs
        
    def compute_gae(self, rewards, values, dones, next_value):
        gae = 0
        returns = []
        advantages = []
        
        values = values + [next_value]
        
        for step in reversed(range(len(rewards))):
            delta = rewards[step] + self.gamma * values[step + 1] * (1 - dones[step]) - values[step]
            gae = delta + self.gamma * self.gae_lambda * (1 - dones[step]) * gae
            advantages.insert(0, gae)
            returns.insert(0, gae + values[step])
            
        return returns, advantages
    
    def collect_trajectories(self, num_steps=2048):
        all_states = []
        all_actions = []
        all_rewards = []
        all_dones = []
        all_values = []
        all_log_probs = []
        
        img_path, mask_path = random.choice(self.train_pairs)
        env = self.env_class(img_path, mask_path, **self.env_config)
        state, _ = env.reset()
        
        episode_reward = 0
        episode_rewards = []
        
        for step in range(num_steps):
            with torch.no_grad():
                action_probs, value = self.model(state)
                dist = Categorical(action_probs)
                action = dist.sample()
                log_prob = dist.log_prob(action)
                value = value.squeeze()
            
            next_state, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated
            
            # Store the state dict properly
            all_states.append({
                'patch': state['patch'].copy(),
                'position': state['position'].copy()
            })
            all_actions.append(action.item())
            all_rewards.append(reward)
            all_dones.append(done)
            all_values.append(value.item())
            all_log_probs.append(log_prob.item())
            
            episode_reward += reward
            state = next_state
            
            if done:
                episode_rewards.append(episode_reward)
                img_path, mask_path = random.choice(self.train_pairs)
                env = self.env_class(img_path, mask_path, **self.env_config)
                state, _ = env.reset()
                episode_reward = 0
        
        with torch.no_grad():
            _, next_value = self.model(state)
            next_value = next_value.squeeze().item()
        
        return (all_states, all_actions, all_rewards, all_dones, 
                all_values, all_log_probs, next_value, episode_rewards)
    
    def update(self, states, actions, returns, advantages, old_log_probs):
        # Prepare batch data
        patch_batch = []
        position_batch = []
        
        for state in states:
            patch_batch.append(state['patch'])
            position_batch.append(state['position'])
        
        # Convert to tensors
        patch_array = np.array(patch_batch)
        if patch_array.ndim == 3:
            patch_array = patch_array[:, np.newaxis, :, :]
        
        batch_data = {
            'patch': torch.FloatTensor(patch_array).to(self.device),
            'position': torch.FloatTensor(np.array(position_batch)).to(self.device)
        }
        
        actions_tensor = torch.LongTensor(actions).to(self.device)
        returns_tensor = torch.FloatTensor(returns).to(self.device)
        advantages_tensor = torch.FloatTensor(advantages).to(self.device)
        old_log_probs_tensor = torch.FloatTensor(old_log_probs).to(self.device)
        
        # Normalize advantages
        advantages_tensor = (advantages_tensor - advantages_tensor.mean()) / (advantages_tensor.std() + 1e-8)
        
        batch_size = len(states)
        indices = np.arange(batch_size)
        
        for _ in range(self.ppo_epochs):
            np.random.shuffle(indices)
            
            for start in range(0, batch_size, self.batch_size):
                end = start + self.batch_size
                batch_indices = indices[start:end]
                
                batch_states = {
                    'patch': batch_data['patch'][batch_indices],
                    'position': batch_data['position'][batch_indices]
                }
                batch_actions = actions_tensor[batch_indices]
                batch_returns = returns_tensor[batch_indices]
                batch_advantages = advantages_tensor[batch_indices]
                batch_old_log_probs = old_log_probs_tensor[batch_indices]
                
                # Get current policy and value
                action_probs, values = self.model(batch_states)
                dist = Categorical(action_probs)
                new_log_probs = dist.log_prob(batch_actions)
                entropy = dist.entropy().mean()
                
                values = values.squeeze()
                
                # Calculate ratios
                ratios = torch.exp(new_log_probs - batch_old_log_probs)
                
                # Policy loss
                surr1 = ratios * batch_advantages
                surr2 = torch.clamp(ratios, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * batch_advantages
                policy_loss = -torch.min(surr1, surr2).mean()
                
                # Value loss
                value_loss = 0.5 * (values - batch_returns).pow(2).mean()
                
                # Total loss
                loss = policy_loss + 0.5 * value_loss - 0.01 * entropy
                
                # Backpropagate
                self.model.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=0.5)
                self.model.optimizer.step()
                
                self.actor_losses.append(policy_loss.item())
                self.critic_losses.append(value_loss.item())
                self.entropies.append(entropy.item())
    
    def train(self, max_episodes=1000, num_steps=512):
        print("Starting Global-Aware PPO training...")
        
        episode = 0
        best_mean_reward = -float('inf')
        
        while episode < max_episodes:
            (states, actions, rewards, dones, values, 
             old_log_probs, next_value, episode_rewards) = self.collect_trajectories(num_steps)
            
            if not states:
                continue
            
            returns, advantages = self.compute_gae(rewards, values, dones, next_value)
            self.update(states, actions, returns, advantages, old_log_probs)
            
            self.training_rewards.extend(episode_rewards)
            
            if len(self.training_rewards) >= 100:
                mean_reward = np.mean(self.training_rewards[-100:])
            else:
                mean_reward = np.mean(self.training_rewards)
                
            self.mean_training_rewards.append(mean_reward)
            
            if episode_rewards:
                avg_episode_reward = np.mean(episode_rewards)
                print(f"Episode {episode} | "
                      f"Avg Reward: {avg_episode_reward:.2f} | "
                      f"Mean Reward (100): {mean_reward:.2f} | "
                      f"Actor Loss: {np.mean(self.actor_losses[-10:] or [0]):.4f}")
            
            episode += len(episode_rewards)
            
            if mean_reward > best_mean_reward:
                best_mean_reward = mean_reward
                torch.save(self.model.state_dict(), f"{self.save_name}_best.pth")
                print("New best model saved!")
            
            if episode % 100 == 0:
                torch.save(self.model.state_dict(), f"{self.save_name}_checkpoint.pth")
        
        torch.save(self.model.state_dict(), f"{self.save_name}_final.pth")
        print("Training completed!")


# Training

In [None]:
# Training setup
CURRENT_CONFIG = {
    'grid_size': 4,
    'rewards': [5.0, -0.01, 0.0], 
    'action_space': gym.spaces.Discrete(5)
}

LR = 3e-4
MAX_EPISODES = 1000
NUM_STEPS = 512  # Start with smaller rollout for testing

train_pairs = prepare()

env = GlobalAwareGlioblastoma(*train_pairs[0], **CURRENT_CONFIG)
model = GlobalAwarePPOActorCritic(env, learning_rate=LR, device='cpu')
agent = GlobalAwarePPOAgent(
    env_config=CURRENT_CONFIG,
    model=model,
    train_pairs=train_pairs,
    env_class=GlobalAwareGlioblastoma,  # Use the new environment class
    gamma=0.99,
    clip_epsilon=0.2,
    ppo_epochs=4,
    batch_size=128,
    save_name="GlobalAware_PPO_batch128"
)

# Start training
agent.train(max_episodes=MAX_EPISODES, num_steps=NUM_STEPS)


# Testing

In [None]:
def test_ppo_agent(agent, test_pairs):
    """Test the trained PPO agent on test data"""
    results = []
    
    for img_path, mask_path in test_pairs:
        env = agent.env_class(img_path, mask_path, **agent.env_config)
        state, _ = env.reset()
        
        tumor_hits = 0
        final_on_tumor = False
        action_distribution = np.zeros(env.action_space.n)
        
        for step in range(20):  # Fixed 20 steps
            with torch.no_grad():
                action_probs, _ = agent.model(state)
                dist = Categorical(action_probs)
                action = dist.sample()
                
            action_distribution[action.item()] += 1
            
            next_state, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated
            
            # Check if current patch overlaps with lesion
            if env.current_patch_overlap_with_lesion() > 0:
                tumor_hits += 1
            
            state = next_state
            
            if done:
                break
        
        # Check if final position is on tumor
        if env.current_patch_overlap_with_lesion() > 0:
            final_on_tumor = True
        
        results.append({
            'image_path': img_path,
            'tumor_hits': tumor_hits,
            'final_on_tumor': final_on_tumor,
            'action_distribution': action_distribution / np.sum(action_distribution)  # Normalize
        })
    
    return results

test_pairs = prepare(mode="test")
test_results = test_ppo_agent(agent, test_pairs)

count = 0
for i, result in enumerate(test_results):
    print(f"#{i} Image: {result['image_path']}, Tumor Hits: {result['tumor_hits']}, "
          f"Final on Tumor: {result['final_on_tumor']}, "
          f"Action Distribution: {result['action_distribution']}")
    if result['final_on_tumor']:
        count += 1
        
print(f"Total images where agent ended on tumor: {count} out of {len(test_results)}")

[48 - batch64](GlobalAware_PPO_batch64_best.pth)

[49 - batch128](GlobalAware_PPO_batch128_best.pth)

In [None]:
import os
import numpy as np
import torch
from torch.distributions import Categorical
import imageio
from PIL import Image
import matplotlib.pyplot as plt

def test_agent_unified(agent, test_pairs, agent_type, num_episodes=None, env_config=None, save_gifs=True, gif_folder="TEST_GIFS"):
    """
    Unified testing function for both DQN and PPO agents
    
    Args:
        agent: The trained agent (DQN or PPO)
        test_pairs: List of (image_path, mask_path) tuples
        agent_type: Either "dqn" or "ppo"
        num_episodes: Number of episodes to test (default: all test pairs)
        env_config: Environment configuration dictionary
        save_gifs: Whether to save GIFs of episodes
        gif_folder: Folder to save GIFs
    
    Returns:
        Dictionary with test results including success rate and action distributions
    """
    if num_episodes is None:
        num_episodes = len(test_pairs)
    
    # Create GIF folder if needed
    if save_gifs and not os.path.exists(gif_folder):
        os.makedirs(gif_folder)
    
    # Set model to evaluation mode
    if agent_type.lower() == "dqn":
        agent.dnnetwork.eval()
    elif agent_type.lower() == "ppo":
        agent.model.eval()
    
    results = {
        'success_rate': [],
        'final_position_accuracy': [],
        'average_reward': [],
        'steps_to_find_tumor': [],
        'tumor_coverage': [],
        'total_tumor_reward': [],
        'episode_details': []
    }
    
    grid_size = env_config.get('grid_size', 4)
    rewards = env_config.get('rewards', [5.0, -1.0, -0.2])
    action_space = env_config.get('action_space', None)
    
    for i in range(min(num_episodes, len(test_pairs))):
        img_path, mask_path = test_pairs[i]
        
        # Create environment
        if hasattr(agent, 'env_class'):
            env = agent.env_class(img_path, mask_path, grid_size=grid_size, rewards=rewards, action_space=action_space)
        else:
            env = Glioblastoma(img_path, mask_path, grid_size=grid_size, rewards=rewards, action_space=action_space)
        
        state, _ = env.reset()
        total_reward = 0
        found_tumor = False
        tumor_positions_visited = set()
        steps_to_find = env.max_steps
        tumor_rewards = 0
        
        # For action distribution tracking
        action_counts = np.zeros(env.action_space.n)
        
        # For GIF creation
        frames = []
        
        for step in range(env.max_steps):
            with torch.no_grad():
                if agent_type.lower() == "dqn":
                    action = agent.dnnetwork.get_action(state, epsilon=0.00)
                    action_idx = action
                elif agent_type.lower() == "ppo":
                    action_probs, _ = agent.model(state)
                    dist = Categorical(action_probs)
                    action = dist.sample()
                    action_idx = action.item()
            
            action_counts[action_idx] += 1
            
            next_state, reward, terminated, truncated, _ = env.step(action_idx)
            state = next_state
            total_reward += reward
            
            # Track tumor-related metrics
            current_overlap = env.current_patch_overlap_with_lesion()
            if current_overlap > 0:
                tumor_positions_visited.add(tuple(env.agent_pos))
                if not found_tumor:
                    found_tumor = True
                    steps_to_find = step + 1
                
                # Count positive rewards (when on tumor)
                if reward > 0:
                    tumor_rewards += 1
            
            # Capture frame for GIF
            if save_gifs:
                frame = env.render(mode='rgb_array')
                if frame is not None:
                    frames.append(frame)
            
            if terminated or truncated:
                break
        
        # Save GIF
        gif_path = None
        if save_gifs and frames:
            gif_path = os.path.join(gif_folder, f"episode_{i}_{os.path.basename(img_path).split('.')[0]}.gif")
            # Convert frames to PIL Images and save as GIF
            pil_frames = [Image.fromarray(frame) for frame in frames]
            pil_frames[0].save(
                gif_path,
                save_all=True,
                append_images=pil_frames[1:],
                duration=500,  # milliseconds per frame
                loop=0
            )
        
        # Calculate metrics for this episode
        final_overlap = env.current_patch_overlap_with_lesion()
        
        # Success: ended on tumor region
        success = final_overlap > 0
        results['success_rate'].append(success)
        
        # Final position accuracy
        results['final_position_accuracy'].append(final_overlap > 0)
        
        # Average reward
        results['average_reward'].append(total_reward)
        
        # Steps to find tumor
        results['steps_to_find_tumor'].append(steps_to_find)
        
        # Tumor coverage (percentage of tumor patches visited)
        total_tumor_patches = count_tumor_patches(env)
        coverage = len(tumor_positions_visited) / total_tumor_patches if total_tumor_patches > 0 else 0
        results['tumor_coverage'].append(coverage)
        
        # Total positive rewards from tumor
        results['total_tumor_reward'].append(tumor_rewards)
        
        # Store detailed episode information
        episode_detail = {
            'image_path': img_path,
            'success': success,
            'final_on_tumor': final_overlap > 0,
            'total_reward': total_reward,
            'steps_to_find_tumor': steps_to_find,
            'tumor_coverage': coverage,
            'tumor_rewards': tumor_rewards,
            'action_distribution': action_counts / np.sum(action_counts),  # Normalized
            'gif_path': gif_path
        }
        results['episode_details'].append(episode_detail)
    
    # Calculate overall metrics
    overall_results = {
        'success_rate': np.mean(results['success_rate']),
        'final_position_accuracy': np.mean(results['final_position_accuracy']),
        'average_reward': np.mean(results['average_reward']),
        'avg_steps_to_find_tumor': np.mean(results['steps_to_find_tumor']),
        'avg_tumor_coverage': np.mean(results['tumor_coverage']),
        'avg_tumor_rewards': np.mean(results['total_tumor_reward']),
        'episode_details': results['episode_details']
    }
    
    # Print summary
    print("\n" + "="*60)
    print(f"TEST RESULTS ({agent_type.upper()} Agent)")
    print("="*60)
    print(f"Success Rate: {overall_results['success_rate']*100:.2f}%")
    print(f"Final Position Accuracy: {overall_results['final_position_accuracy']*100:.2f}%")
    print(f"Average Episode Reward: {overall_results['average_reward']:.2f}")
    print(f"Average Steps to Find Tumor: {overall_results['avg_steps_to_find_tumor']:.2f}")
    print(f"Average Tumor Coverage: {overall_results['avg_tumor_coverage']*100:.2f}%")
    print(f"Average Tumor Rewards per Episode: {overall_results['avg_tumor_rewards']:.2f}")
    
    # Print individual episode results
    print(f"\nDetailed Results for {len(results['episode_details'])} episodes:")
    print("-" * 80)
    for i, detail in enumerate(results['episode_details']):
        print(f"Episode {i}: {os.path.basename(detail['image_path'])}")
        print(f"  Success: {detail['success']}, Final on Tumor: {detail['final_on_tumor']}")
        print(f"  Total Reward: {detail['total_reward']:.2f}, Steps to Find: {detail['steps_to_find_tumor']}")
        print(f"  Tumor Coverage: {detail['tumor_coverage']*100:.2f}%, Tumor Rewards: {detail['tumor_rewards']}")
        print(f"  Action Distribution: {detail['action_distribution']}")
        if detail['gif_path']:
            print(f"  GIF saved: {detail['gif_path']}")
        print()
    
    return overall_results

def count_tumor_patches(env):
    """Count total number of patches that contain tumor"""
    tumor_patches = 0
    original_pos = env.agent_pos.copy()  # Save original position
    
    for i in range(env.grid_size):
        for j in range(env.grid_size):
            env.agent_pos = [i, j]
            if env.current_patch_overlap_with_lesion() > 0:
                tumor_patches += 1
    
    env.agent_pos = original_pos  # Restore original position
    return tumor_patches

# Render

In [None]:
# render on human mode the last test image
env = agent.env_class(test_pairs[-1][0], test_pairs[-1][1], **agent.env_config)
state, _ = env.reset()
for step in range(20):
    with torch.no_grad():
        action_probs, _ = agent.model(state)
        dist = Categorical(action_probs)
        action = dist.sample()
        
    next_state, reward, terminated, truncated, _ = env.step(action.item())
    env.render()
    
    state = next_state
    
    if terminated or truncated:
        break
    
