# Notebook to experiment with training:

https://docs.pytorch.org/tutorials/intermediate/reinforcement_q_learning.html

## Code:

In [1]:
import wandb
import numpy as np
from gymnasium import spaces

In [2]:
from training_environments import prepare, Glioblastoma
from training_dqn import DQN
from training_agents import DQNAgent
from training_buffers import ReplayBuffer

In [None]:
#RUN_NAME = "Extended002"
RUN_NAME = "Trial001"

ENVIRONMENT = "Glioblastoma"

CURRENT_CONFIG = {
    'grid_size': 4,
    'rewards': [1.0, -2.0, -0.5], # [10.0, -2.0, -0.5],
    'action_space': spaces.Discrete(3)
}

NET = "DQN"
AGENT = "DQNAgent"
BUFFER = "ReplayBuffer"

NOTES = "First trial default rewards, 500 burnin."

LR = 1e-4 #From paper
MEMORY_SIZE = 15000 #From paper
MAX_EPISODES = 90 #From paper

EPSILON = 1.0 #From paper
EPSILON_MIN = 0.1 #From paper
DECAY_TYPE = "subtraction"
# DECAY_TYPE = "exponential"
if DECAY_TYPE == "exponential":
    EPSILON_DECAY = 0.85 #Let's try exponential decay
else:
    EPSILON_DECAY = (EPSILON - EPSILON_MIN) / MAX_EPISODES
print(f"Starting at {EPSILON}, decaying {EPSILON_DECAY}, will reach {EPSILON_MIN} after {MAX_EPISODES} episodes")

GAMMA = 0.99
BATCH_SIZE = 128 #From paper
BURN_IN = 500 #150
DNN_UPD = 1
DNN_SYNC = 20

In [None]:
train_pairs = prepare()
env=Glioblastoma(*train_pairs[0], **CURRENT_CONFIG)
print(env.observation_space.shape)
print(env.action_space.n)
print(np.arange(env.action_space.n))

net = DQN(env, learning_rate=LR, device='cpu')
buffer = ReplayBuffer(capacity=MEMORY_SIZE)
agent = DQNAgent(env_config=CURRENT_CONFIG, net=net, buffer_class=ReplayBuffer, train_pairs=train_pairs, env_class=Glioblastoma,
                 epsilon=EPSILON, eps_decay=EPSILON_DECAY, eps_decay_type=DECAY_TYPE, epsilon_min=EPSILON_MIN,
                 batch_size=BATCH_SIZE, gamma=GAMMA, 
                 memory_size=MEMORY_SIZE, buffer_initial=BURN_IN,
                 save_name=RUN_NAME)

In [None]:
wandb.login()
wandb.Settings(quiet=True)

wandb.init(project="TFG_Glioblastoma",
           name=RUN_NAME,
           id=RUN_NAME,
           config={
            "environment": ENVIRONMENT,
            "configuration": CURRENT_CONFIG,
            "model": NET,
            "agent": AGENT,
            "buffer": BUFFER,
            "notes": NOTES,
            "lr": LR,
            "MEMORY_SIZE": MEMORY_SIZE,
            "MAX_EPISODES": MAX_EPISODES,
            "EPSILON": EPSILON,
            "EPSILON_DECAY": EPSILON_DECAY,
            "Decay type": DECAY_TYPE,
            "EPSILON_MIN": EPSILON_MIN,
            "GAMMA": GAMMA,
            "BATCH_SIZE": BATCH_SIZE,
            "BURN_IN": BURN_IN,
            "DNN_UPD": DNN_UPD,
            "DNN_SYNC": DNN_SYNC, 
})

In [None]:
agent.train(
    train_pairs=train_pairs,
    gamma=GAMMA,
    max_episodes=MAX_EPISODES,
    dnn_update_frequency=DNN_UPD,
    dnn_sync_frequency=DNN_SYNC
)
wandb.finish()

In [None]:
wandb.finish()