# Notebook to experiment with testing:

## Code:

In [None]:
import numpy as np
import random
import torch
from gymnasium import spaces
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from gymnasium import spaces

SEED = 42
# Python RNG
random.seed(SEED)

# NumPy RNG
np.random.seed(SEED)

# PyTorch RNG (CPU + GPU)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


In [None]:
import sys
sys.path.append("/home/martina/codi2/4year/tfg")  # add parent folder of general.py

from general import prepare, Glioblastoma, Glioblastoma2, testing


In [9]:
test_pairs = prepare(mode='test')

âœ… Found 100 pairs out of 100 listed in CSV.


In [17]:
class CNNPolicy(nn.Module):
    def __init__(self, action_dim, channels=1):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(channels, 16, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2),
            nn.ReLU(),
        )

        self.fc = nn.Sequential(
            nn.Linear(64 * 5 * 5, 256),   # flatten size for 60x60 input
            nn.ReLU(),
            nn.Linear(256, action_dim),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        x = self.cnn(x)
        x = x.flatten(1)
        return self.fc(x), None

    def act(self, state):
        # state shape: (60,60) or (1,60,60)
        state = torch.tensor(state, dtype=torch.float32)

        if state.ndim == 2:
            state = state.unsqueeze(0)  # (1,60,60)
        if state.ndim == 3:
            state = state.unsqueeze(0)  # (1,1,60,60)

        probs, _ = self.forward(state)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()

        return action.item(), dist.log_prob(action)

class REINFORCEAgent:
    def __init__(self, env_class, train_pairs, env_config,
                 gamma=0.99, lr=1e-4):

        self.env_class = env_class
        self.train_pairs = train_pairs
        self.env_config = env_config
        self.gamma = gamma

        # Create sample env to read obs shape
        sample_img, sample_mask = train_pairs[0]

        sample_env = env_class(
            sample_img,
            sample_mask,
            grid_size=env_config["grid_size"],
            rewards=env_config["rewards"],
            action_space=env_config["action_space"]
        )

        obs, _ = sample_env.reset()

        channels = 1 if obs.ndim == 2 else obs.shape[0]

        self.action_dim = env_config["action_space"].n
        self.policy = CNNPolicy(self.action_dim, channels)
        self.model = self.policy
        self.optim = optim.Adam(self.policy.parameters(), lr=lr)
        self.best_reward = -1e9
        self.save_path = "reinforce_policy.pt"


    def run_episode(self, img_path, mask_path):
        env = self.env_class(
            img_path,
            mask_path,
            grid_size=self.env_config["grid_size"],
            rewards=self.env_config["rewards"],
            action_space=self.env_config["action_space"]
        )

        log_probs = []
        rewards = []

        state, _ = env.reset()
        done = False

        while not done:
            action, log_prob = self.policy.act(state)
            next_state, reward, terminated, truncated, _ = env.step(action)

            log_probs.append(log_prob)
            rewards.append(reward)

            state = next_state
            done = terminated or truncated

        return log_probs, rewards


    def compute_returns(self, rewards):
        G = 0
        returns = []
        for r in reversed(rewards):
            G = r + self.gamma * G
            returns.insert(0, G)
        returns = torch.tensor(returns, dtype=torch.float32)
        return (returns - returns.mean()) / (returns.std() + 1e-8)

    def update_policy(self, log_probs, returns):
        loss = []
        for lp, Gt in zip(log_probs, returns):
            loss.append(-lp * Gt)

        loss = torch.stack(loss).sum()

        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
        return loss.item()


    # Train over ALL train_pairs each epoch
    def train(self, epochs=200):
        for e in range(epochs):
            epoch_loss = 0
            epoch_reward = 0

            epoch_rewards = []
            # Train on ALL image/mask pairs each epoch
            for i, (img_path, mask_path) in enumerate(self.train_pairs):
                log_probs, rewards = self.run_episode(img_path, mask_path)
                returns = self.compute_returns(rewards)
                loss = self.update_policy(log_probs, returns)

                epoch_loss += loss
                epoch_rewards.append(sum(rewards))
                if i % 10 == 0:
                    print(f"  episode = {i} | Episode Reward={sum(rewards):.2f} | Loss={loss:.4f}")
            
            avg_reward = np.mean(epoch_rewards)
            # Save best model so far
            if avg_reward > self.best_reward:
                self.best_reward = avg_reward
                torch.save(self.policy.state_dict(), self.save_path)
                print(f"  -> Saved new best model (Reward {avg_reward:.2f})")

            if e % 10 == 0:
                print(f"[Epoch {e+1}] Avg Reward per Episode={avg_reward:.2f} | Loss={epoch_loss:.4f}")

        torch.save(self.policy.state_dict(), "reinforce_final.pt")
        print("Training finished. Final model saved to reinforce_final.pt")

# TESTING:

In [20]:
# load model to test:
LR = 1e-4 #From paper
CURRENT_CONFIG = {
    'grid_size': 4,
    'rewards': [5.0, -1.0, -0.2], 
    'action_space': spaces.Discrete(3)
}


agent = REINFORCEAgent(
    env_class=Glioblastoma,
    train_pairs=test_pairs,
    env_config=CURRENT_CONFIG,
    gamma=0.99,
    lr=1e-4
)

agent.policy.load_state_dict(torch.load("reinforce_final.pt"))


<All keys matched successfully>

In [23]:
overall_results = testing(agent, test_pairs, agent_type="reinforce", num_episodes=len(test_pairs), env_config=CURRENT_CONFIG, save_gifs=True, gif_folder="TEST_GIFS")

Saved GIF for episode 0 at TEST_GIFS/episode_0_002_58.gif
Saved GIF for episode 10 at TEST_GIFS/episode_10_013_86.gif
Saved GIF for episode 20 at TEST_GIFS/episode_20_024_49.gif
Saved GIF for episode 30 at TEST_GIFS/episode_30_038_84.gif
Saved GIF for episode 40 at TEST_GIFS/episode_40_052_98.gif
Saved GIF for episode 50 at TEST_GIFS/episode_50_104_74.gif
Saved GIF for episode 60 at TEST_GIFS/episode_60_176_99.gif
Saved GIF for episode 70 at TEST_GIFS/episode_70_204_52.gif
Saved GIF for episode 80 at TEST_GIFS/episode_80_260_62.gif
Saved GIF for episode 90 at TEST_GIFS/episode_90_300_107.gif

TEST RESULTS (REINFORCE Agent)
Success Rate: 49.00%
Average Episode Reward: 23.78
Average Steps to Find Tumor: 10.85
Average Tumor Rewards per Episode: 6.98
Tumor Size Statistics:
  Biggest Tumor: 4910 pixels (8.52%)
  Smallest Tumor: 296 pixels (0.51%)
  Average Tumor: 1873 pixels (3.25%)
Overall Action Distribution: [0.788 0.111 0.101]
  Successful Episodes: [0.87755102 0.06632653 0.05612245]
  

In [None]:
# # load model to test:
# LR = 1e-4 #From paper
# CURRENT_CONFIG = {
#     'grid_size': 4,
#     'rewards': [5.0, -1.0, -0.2], 
#     'action_space': spaces.Discrete(3)
# }

# env = Glioblastoma(*test_pairs[0], **CURRENT_CONFIG)

# model = DQN(env, learning_rate=LR, device='cpu')
# #model.load_state_dict(torch.load("/home/martina/codi2/4year/tfg/grid_search/Trial124.dat"))
# model.load_state_dict(torch.load("/home/martina/codi2/4year/tfg/other_models/current/Extension020.dat"))

# agent = DQNAgent(env_config=CURRENT_CONFIG, dnnetwork=model, buffer_class=ReplayBuffer, train_pairs=test_pairs,
#                  env_class=Glioblastoma,
#                  epsilon=0.00)  # very low epsilon for testing
