In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from gymnasium import spaces

In [None]:
import sys
sys.path.append("/home/martina/codi2/4year/tfg")  # add parent folder of general.py

from general import prepare, Glioblastoma


In [21]:
class CNNPolicy(nn.Module):
    def __init__(self, action_dim, channels=1):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(channels, 16, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=5, stride=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2),
            nn.ReLU(),
        )

        self.fc = nn.Sequential(
            nn.Linear(64 * 5 * 5, 256),   # flatten size for 60x60 input
            nn.ReLU(),
            nn.Linear(256, action_dim),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        x = self.cnn(x)
        x = x.flatten(1)
        return self.fc(x), None

    def act(self, state):
        # state shape: (60,60) or (1,60,60)
        state = torch.tensor(state, dtype=torch.float32)

        if state.ndim == 2:
            state = state.unsqueeze(0)  # (1,60,60)
        if state.ndim == 3:
            state = state.unsqueeze(0)  # (1,1,60,60)

        probs, _ = self.forward(state)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()

        return action.item(), dist.log_prob(action)

class REINFORCEAgent:
    def __init__(self, env_class, train_pairs, env_config,
                 gamma=0.99, lr=1e-4):

        self.env_class = env_class
        self.train_pairs = train_pairs
        self.env_config = env_config
        self.gamma = gamma

        # Create sample env to read obs shape
        sample_img, sample_mask = train_pairs[0]

        sample_env = env_class(
            sample_img,
            sample_mask,
            grid_size=env_config["grid_size"],
            rewards=env_config["rewards"],
            action_space=env_config["action_space"]
        )

        obs, _ = sample_env.reset()

        channels = 1 if obs.ndim == 2 else obs.shape[0]

        self.action_dim = env_config["action_space"].n
        self.policy = CNNPolicy(self.action_dim, channels)
        self.model = self.policy
        self.optim = optim.Adam(self.policy.parameters(), lr=lr)
        self.best_reward = -1e9
        self.save_path = "reinforce_policy.pt"


    def run_episode(self, img_path, mask_path):
        env = self.env_class(
            img_path,
            mask_path,
            grid_size=self.env_config["grid_size"],
            rewards=self.env_config["rewards"],
            action_space=self.env_config["action_space"]
        )

        log_probs = []
        rewards = []

        state, _ = env.reset()
        done = False

        while not done:
            action, log_prob = self.policy.act(state)
            next_state, reward, terminated, truncated, _ = env.step(action)

            log_probs.append(log_prob)
            rewards.append(reward)

            state = next_state
            done = terminated or truncated

        return log_probs, rewards


    def compute_returns(self, rewards):
        G = 0
        returns = []
        for r in reversed(rewards):
            G = r + self.gamma * G
            returns.insert(0, G)
        returns = torch.tensor(returns, dtype=torch.float32)
        return (returns - returns.mean()) / (returns.std() + 1e-8)

    def update_policy(self, log_probs, returns):
        loss = []
        for lp, Gt in zip(log_probs, returns):
            loss.append(-lp * Gt)

        loss = torch.stack(loss).sum()

        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
        return loss.item()


    # Train over ALL train_pairs each epoch
    def train(self, epochs=200):
        for e in range(epochs):
            epoch_loss = 0
            epoch_reward = 0

            epoch_rewards = []
            # Train on ALL image/mask pairs each epoch
            for i, (img_path, mask_path) in enumerate(self.train_pairs):
                log_probs, rewards = self.run_episode(img_path, mask_path)
                returns = self.compute_returns(rewards)
                loss = self.update_policy(log_probs, returns)

                epoch_loss += loss
                epoch_rewards.append(sum(rewards))
                if i % 10 == 0:
                    print(f"  episode = {i} | Episode Reward={sum(rewards):.2f} | Loss={loss:.4f}")
            
            avg_reward = np.mean(epoch_rewards)
            # Save best model so far
            if avg_reward > self.best_reward:
                self.best_reward = avg_reward
                torch.save(self.policy.state_dict(), self.save_path)
                print(f"  -> Saved new best model (Reward {avg_reward:.2f})")

            if e % 10 == 0:
                print(f"[Epoch {e+1}] Avg Reward per Episode={avg_reward:.2f} | Loss={epoch_loss:.4f}")

        torch.save(self.policy.state_dict(), "reinforce_final.pt")
        print("Training finished. Final model saved to reinforce_final.pt")

In [14]:
CURRENT_CONFIG = {
    'grid_size': 4,
    'rewards': [5.0, -1.0, -0.2], 
    'action_space': spaces.Discrete(3)
}

train_pairs = prepare()


âœ… Found 30 pairs out of 30 listed in CSV.


In [22]:
agent = REINFORCEAgent(
    env_class=Glioblastoma,
    train_pairs=train_pairs,
    env_config=CURRENT_CONFIG,
    gamma=0.99,
    lr=1e-4
)

agent.train(epochs=300)


  episode = 0 | Episode Reward=-4.00 | Loss=-0.0924
  episode = 10 | Episode Reward=7.20 | Loss=-0.0075
  episode = 20 | Episode Reward=-15.20 | Loss=0.1712
  -> Saved new best model (Reward -8.17)
[Epoch 1] Avg Reward per Episode=-8.17 | Loss=-0.1538
  episode = 0 | Episode Reward=-15.20 | Loss=0.1591
  episode = 10 | Episode Reward=-10.00 | Loss=0.0599
  episode = 20 | Episode Reward=-4.80 | Loss=-0.3426
  episode = 0 | Episode Reward=-10.00 | Loss=0.0475
  episode = 10 | Episode Reward=2.00 | Loss=0.4708
  episode = 20 | Episode Reward=1.20 | Loss=-0.1603
  -> Saved new best model (Reward -7.55)
  episode = 0 | Episode Reward=-4.00 | Loss=0.3678
  episode = 10 | Episode Reward=-15.20 | Loss=0.0611
  episode = 20 | Episode Reward=-15.20 | Loss=-0.0530
  -> Saved new best model (Reward -3.29)
  episode = 0 | Episode Reward=-4.00 | Loss=0.9734
  episode = 10 | Episode Reward=7.20 | Loss=-0.3104
  episode = 20 | Episode Reward=-15.20 | Loss=0.4595
  -> Saved new best model (Reward 0.91)