In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from gymnasium import spaces
from torch.distributions import Categorical


In [2]:
import sys
sys.path.append("/home/martina/codi2/4year/tfg")  # add parent folder of general.py

from general import prepare, Glioblastoma, GlioblastomaPositionalEncoding, testing


In [None]:
class CNNPolicy(nn.Module):
    def __init__(self, obs_shape, action_dim):
        """
        obs_shape: (C, H, W)
        action_dim: number of discrete actions
        """
        super().__init__()
        C, H, W = obs_shape

        self.conv = nn.Sequential(
            nn.Conv2d(C, 16, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2),
            nn.ReLU(),
        )

        # Dynamically compute the flatten size instead of hardcoding 64*5*5
        with torch.no_grad():
            dummy = torch.zeros(1, C, H, W)
            conv_out = self.conv(dummy)
            flat_dim = conv_out.view(1, -1).size(1)

        self.fc = nn.Sequential(
            nn.Linear(flat_dim, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim),
        )

    def forward(self, x):
        """
        x: (B, C, H, W)
        returns: probs: (B, action_dim)
        """
        x = self.conv(x)
        x = x.flatten(1)
        logits = self.fc(x)
        probs = torch.softmax(logits, dim=-1)
        return probs

    def act(self, state, device):
        """
        state: np.array
          - (H, W)          for grayscale
          - (C, H, W)       for multi-channel (positional encoding)
        returns:
          action (int), log_prob (scalar tensor on `device`)
        """
        state_t = torch.as_tensor(state, dtype=torch.float32, device=device)

        if state_t.ndim == 2:        # (H, W) -> (1, 1, H, W)
            state_t = state_t.unsqueeze(0).unsqueeze(0)
        elif state_t.ndim == 3:      # (C, H, W) -> (1, C, H, W)
            state_t = state_t.unsqueeze(0)
        else:
            raise ValueError(f"Unexpected state ndim={state_t.ndim}, shape={state_t.shape}")

        probs = self.forward(state_t)           # (1, action_dim)
        dist = Categorical(probs)
        action = dist.sample()                  # (1,)
        log_prob = dist.log_prob(action)        # (1,)

        return action.item(), log_prob.squeeze(0)  # scalar tensor


In [5]:
class REINFORCEAgent:
    def __init__(self, env_class, train_pairs, env_config,
                 gamma=0.99, lr=1e-4, save_path="reinforce_policy.pt"):

        self.env_class = env_class
        self.train_pairs = train_pairs
        self.env_config = env_config
        self.gamma = gamma
        self.save_path = save_path

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # --- Infer observation shape & action_dim from a sample env ---
        sample_img, sample_mask = train_pairs[0]
        sample_env = env_class(sample_img, sample_mask, **env_config)
        obs, _ = sample_env.reset()

        if obs.ndim == 2:
            C, H, W = 1, obs.shape[0], obs.shape[1]
        elif obs.ndim == 3:
            C, H, W = obs.shape
        else:
            raise ValueError(f"Unexpected obs ndim={obs.ndim}, shape={obs.shape}")

        obs_shape = (C, H, W)
        self.action_dim = env_config["action_space"].n

        # --- Policy network ---
        self.policy = CNNPolicy(obs_shape, self.action_dim).to(self.device)
        self.model = self.policy  # for compatibility with your testing() function
        self.optim = optim.Adam(self.policy.parameters(), lr=lr)

        self.best_reward = -1e9

    def make_env(self, img_path, mask_path):
        return self.env_class(img_path, mask_path, **self.env_config)

    def run_episode(self, img_path, mask_path):
        """
        Runs one full episode on a single image/mask pair.
        Returns list of log_probs and list of rewards.
        """
        env = self.make_env(img_path, mask_path)

        log_probs = []
        rewards = []

        state, _ = env.reset()
        done = False

        while not done:
            action, log_prob = self.policy.act(state, self.device)
            next_state, reward, terminated, truncated, _ = env.step(action)

            log_probs.append(log_prob)
            rewards.append(reward)

            state = next_state
            done = terminated or truncated

        return log_probs, rewards

    def compute_returns(self, rewards):
        """
        Compute discounted returns G_t for a single episode (no normalization here).
        """
        G = 0.0
        returns = []
        for r in reversed(rewards):
            G = r + self.gamma * G
            returns.insert(0, G)
        return returns  # plain Python list of floats

    def update_policy_batch(self, log_probs, returns):
        """
        log_probs: list of scalar tensors (already on device)
        returns:   1D tensor on device (same length as log_probs)
        """
        log_probs_t = torch.stack(log_probs)               # (N,)
        loss = -(log_probs_t * returns).sum()

        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
        return loss.item()

    def train(self, epochs=200):
        """
        One epoch = run 1 episode per train_pair, then do a single batch update
        using all (log_prob, G) pairs, with returns normalized across the epoch.
        """
        for e in range(1, epochs + 1):
            all_log_probs = []
            all_returns = []
            episode_returns = []  # sum of rewards per episode

            # ---- Collect trajectories on ALL train_pairs ----
            for i, (img_path, mask_path) in enumerate(self.train_pairs):
                log_probs, rewards = self.run_episode(img_path, mask_path)

                Gs = self.compute_returns(rewards)   # list of floats
                all_log_probs.extend(log_probs)
                all_returns.extend(Gs)

                ep_ret = sum(rewards)
                episode_returns.append(ep_ret)

                if i % 10 == 0:
                    print(f"[Epoch {e} | Episode {i}] "
                          f"Return={ep_ret:.2f} (len={len(rewards)})")

            # ---- Normalize returns across the WHOLE epoch ----
            all_returns_t = torch.tensor(all_returns, dtype=torch.float32, device=self.device)
            all_returns_t = (all_returns_t - all_returns_t.mean()) / (all_returns_t.std() + 1e-8)

            # ---- Single policy update for the epoch ----
            loss = self.update_policy_batch(all_log_probs, all_returns_t)
            avg_ep_return = float(np.mean(episode_returns))

            # ---- Save best model ----
            if avg_ep_return > self.best_reward:
                self.best_reward = avg_ep_return
                torch.save(self.policy.state_dict(), self.save_path)
                print(f"  -> New best model saved ({self.save_path}), "
                      f"Avg return={avg_ep_return:.2f}")

            if e % 5 == 0 or e == 1:
                print(f"[Epoch {e}] Avg Return per Episode={avg_ep_return:.2f} | Loss={loss:.4f}")

        torch.save(self.policy.state_dict(), "reinforce_final.pt")
        print("Training finished. Final model saved to reinforce_final.pt")


In [6]:
CURRENT_CONFIG = {
    'grid_size': 6,
    'rewards': [10.0, -2.0, 2.5, -0.1], # [staying on tumor, staying off tumor, moving into tumor, movement cost] #[3.0, -1.0, -0.2],
    'action_space': spaces.Discrete(5), 
    'max_steps': 0
    # 'stop': False
}

train_pairs = prepare()


âœ… Found 30 pairs out of 30 listed in CSV.


In [8]:
agent = REINFORCEAgent(
    env_class=GlioblastomaPositionalEncoding,
    train_pairs=train_pairs,
    env_config=CURRENT_CONFIG,
    gamma=0.99,
    lr=1e-4,
    save_path="reinforce_best.pt"
)

agent.train(epochs=200)


[Epoch 1 | Episode 0] Return=-7.10 (len=10)
[Epoch 1 | Episode 10] Return=-1.99 (len=2)
[Epoch 1 | Episode 20] Return=-8.00 (len=25)
  -> New best model saved (reinforce_best.pt), Avg return=-2.38
[Epoch 1] Avg Return per Episode=-2.38 | Loss=0.3314
[Epoch 2 | Episode 0] Return=-2.46 (len=5)
[Epoch 2 | Episode 10] Return=-2.28 (len=2)
[Epoch 2 | Episode 20] Return=-2.00 (len=1)
  -> New best model saved (reinforce_best.pt), Avg return=-1.61
[Epoch 3 | Episode 0] Return=-2.00 (len=1)
[Epoch 3 | Episode 10] Return=-2.19 (len=2)
[Epoch 3 | Episode 20] Return=-6.24 (len=13)
  -> New best model saved (reinforce_best.pt), Avg return=-1.54
[Epoch 4 | Episode 0] Return=-2.96 (len=7)
[Epoch 4 | Episode 10] Return=-2.01 (len=2)
[Epoch 4 | Episode 20] Return=-1.95 (len=2)
  -> New best model saved (reinforce_best.pt), Avg return=-0.85
[Epoch 5 | Episode 0] Return=10.00 (len=1)
[Epoch 5 | Episode 10] Return=-2.21 (len=2)
[Epoch 5 | Episode 20] Return=-1.96 (len=2)
  -> New best model saved (reinfo