# Import the packages and modules

In [1]:
from gym_vrp.envs import SantaIRPEnv
from agents import IRPAgent

import torch


## Train the agent

In [2]:
# Original
# batch_size = 256
# seed = 69
# num_nodes = 20

# Quick Test
# batch_size = 10
# seed = 23
# num_nodes = 5

batch_size = 64
seed = 23
num_nodes = 6

num_epochs = 151
# num_epochs = 251

In [3]:
# Instantiate the SantaIRPEnv environment
env_santa_irp = SantaIRPEnv(num_nodes=num_nodes, batch_size=batch_size, seed=seed)

# Instantiate the IRPAgent (assuming it's compatible with SantaIRPEnv)
agent_santa_irp = IRPAgent(
    seed=seed, csv_path=f"./train_logs/loss_log_santa_irp_{num_nodes}_{seed}.csv",
)
# Train the agent
agent_santa_irp.train(
    env_santa_irp,
    epochs=num_epochs,
    check_point_dir=f"./check_points/santa_irp_{num_nodes}_{seed}/",
)

INFO:root:Start Training


!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

KeyboardInterrupt: 

## Visualise the actions of the agent in the environments

In [None]:
#env = SantaIRPEnv(num_nodes=num_nodes, batch_size=batch_size, seed=seed, num_draw=3)
env = env_santa_irp
TSPModel=f'./check_points/tsp_{num_nodes}_{seed}/model_epoch_{num_epochs-1}.pt'

In [None]:
# Setup for evaluation
env.enable_video_capturing(
    video_save_path=f"./videos/video_test_santa_irp_{num_nodes}_{seed}.mp4"
)

In [None]:
agent = IRPAgent(seed=seed)
agent.model.load_state_dict(torch.load(f"./check_points/santa_irp_{num_nodes}_{seed}/model_epoch_{num_epochs-1}.pt"))

In [None]:
# Evaluate the agent
loss_a = agent.evaluate(env)

In [None]:
# Close the video recorder
env.vid.close()

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Load the CSV file
csv_path = f"./train_logs/loss_log_santa_irp_{num_nodes}_{seed}.csv"
data = pd.read_csv(csv_path)

# Extract the 'Epoch' and 'Loss' columns
epochs = data['Epoch']
loss = data['Loss']

# Create the plot
plt.figure(figsize=(10, 6))
plt.plot(epochs, loss, label='Training Loss', color='blue', marker='o')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.legend()
plt.grid(True)
plt.show()
