# Notebook to experiment with training:

https://docs.pytorch.org/tutorials/intermediate/reinforcement_q_learning.html

## Code:

In [None]:
import wandb
import numpy as np
from gymnasium import spaces
import random
import torch
import itertools
import json
import os

In [None]:
SEED = 42
# Python RNG
random.seed(SEED)

# NumPy RNG
np.random.seed(SEED)

# PyTorch RNG (CPU + GPU)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


In [None]:
# === Configuration ===
START_TRIAL = 19  # start numbering from Trial019
PROGRESS_FILE = "grid_progress.json"

RUN_NAME_BASE = "Trial"
ENVIRONMENT = "Glioblastoma"
NET = "DQN"
AGENT = "DQNAgent"
BUFFER = "ReplayBuffer"
NOTES = ""

LR = 1e-4
MEMORY_SIZE = 15000
MAX_EPISODES = 90
GAMMA = 0.99
DECAY_TYPE = "exponential"
BATCH_SIZE = 128
DNN_UPD = 1
DNN_SYNC = 20

# --- Grid parameters ---
reward_grid = [
    [10.0, -2.0, -0.5],
    [5.0, -1.0, -0.2],
    [8.0, -1.5, -0.3],
]

burnin_grid = [100, 150, 200, 500]
epsilon_grid = [1.0, 0.9, 0.7]
epsilon_min_grid = [0.1, 0.05]
epsilon_decay_grid = [0.95, 0.9, 0.85, 0.8]

In [None]:
# --- Grid Search Loop ---
all_combos = list(itertools.product(
    reward_grid,
    burnin_grid,
    epsilon_grid,
    epsilon_min_grid,
    epsilon_decay_grid
))

total = len(all_combos)
print(f"Total combinations: {total}")


In [None]:
if os.path.exists(PROGRESS_FILE):
    with open(PROGRESS_FILE, "r") as f:
        progress = json.load(f)
    completed = set(progress.get("completed", []))
    current_trial = progress.get("last_trial", START_TRIAL)
else:
    completed = set()
    current_trial = START_TRIAL

print(f"Resuming from trial {current_trial:03d}, {len(completed)} of {total} completed.")


In [None]:
# --- Training function ---
def train_grid(instance_gid, rewards, burnin, eps, eps_min, eps_decay):
    CURRENT_CONFIG = {
        'grid_size': 4,
        'rewards': rewards,
        'action_space': spaces.Discrete(3)
    }

    from training_environments import prepare, Glioblastoma
    from training_dqn import DQN
    from training_agents import DQNAgent
    from training_buffers import ReplayBuffer
    
    train_pairs = prepare()
    env = Glioblastoma(*train_pairs[0], **CURRENT_CONFIG)
    print(env.observation_space.shape)
    print(env.action_space.n)
    print(np.arange(env.action_space.n))

    net = DQN(env, learning_rate=LR, device='cpu')
    buffer = ReplayBuffer(capacity=MEMORY_SIZE)
    agent = DQNAgent(env_config=CURRENT_CONFIG, dnnetwork=net, buffer_class=ReplayBuffer,
                    train_pairs=train_pairs, env_class=Glioblastoma,
                    epsilon=eps, eps_decay=eps_decay, eps_decay_type=DECAY_TYPE, epsilon_min=eps_min,
                    batch_size=BATCH_SIZE, gamma=GAMMA,
                    memory_size=MEMORY_SIZE, buffer_initial=burnin,
                    save_name=instance_gid)

    wandb.login()
    wandb.init(
        project="TFG_Glioblastoma",
        name=instance_gid,
        id=instance_gid,
        config={
            "environment": ENVIRONMENT,
            "configuration": CURRENT_CONFIG,
            "model": NET,
            "agent": AGENT,
            "buffer": BUFFER,
            "notes": NOTES,
            "lr": LR,
            "MEMORY_SIZE": MEMORY_SIZE,
            "MAX_EPISODES": MAX_EPISODES,
            "EPSILON": eps,
            "EPSILON_DECAY": eps_decay,
            "Decay type": DECAY_TYPE,
            "EPSILON_MIN": eps_min,
            "GAMMA": GAMMA,
            "BATCH_SIZE": BATCH_SIZE,
            "BURN_IN": burnin,
            "DNN_UPD": DNN_UPD,
            "DNN_SYNC": DNN_SYNC,
            "rewards": rewards,
        }
    )

    agent.train(
        train_pairs=train_pairs,
        gamma=GAMMA,
        max_episodes=MAX_EPISODES,
        dnn_update_frequency=DNN_UPD,
        dnn_sync_frequency=DNN_SYNC
    )
    wandb.finish()

In [8]:
for i, combo in enumerate(all_combos):
    combo_id = str(i)
    if combo_id in completed:
        continue  # skip already finished runs

    rewards, burnin, eps, eps_min, eps_decay = combo
    run_id = f"{RUN_NAME_BASE}{current_trial:03d}"

    print(f"\n=== Running {run_id} ({i+1}/{total}) ===")

    try:
        train_grid(run_id, rewards, burnin, eps, eps_min, eps_decay)
        # --- Save progress after each successful run ---
        completed.add(combo_id)
        current_trial += 1
        with open(PROGRESS_FILE, "w") as f:
            json.dump({"completed": list(completed), "last_trial": current_trial}, f)
    except Exception as e:
        print(f"❌ Error on {run_id}: {e}")
        break

print("\n✅ Grid search finished or interrupted. Progress saved.")



✅ Grid search finished or interrupted. Progress saved.
