# Import the packages and modules

In [1]:
from gym_vrp.envs import SantaIRPEnv
from agents import TSPAgentFF

import torch


## Train the agent

In [2]:
# Original
# batch_size = 256
# seed = 69
# num_nodes = 20

# Quick Test
# batch_size = 10
# seed = 23
# num_nodes = 5

batch_size = 64
seed = 23
num_nodes = 10

num_epochs = 251
# num_epochs = 251

In [3]:
# Instantiate the SantaIRPEnv environment
env_santa_ff = SantaIRPEnv(num_nodes=num_nodes, batch_size=batch_size, seed=seed)

# Instantiate the TSPAgentFF (assuming it's compatible with SantaIRPEnv)
agent_santa_ff = TSPAgentFF(
    seed=seed, csv_path=f"./train_logs/loss_log_santa_ff_{num_nodes}_{seed}.csv",
)
# Train the agent
agent_santa_ff.train(
    env_santa_ff,
    episodes=num_epochs,
    check_point_dir=f"./check_points/santa_ff_{num_nodes}_{seed}/",
)

INFO:root:Start Training


tensor([[0.5173, 0.9470, 0.1018,  ..., 0.1836, 0.0000, 0.0000],
        [0.6342, 0.9204, 0.1827,  ..., 0.1996, 0.0000, 0.0000],
        [0.0438, 0.6962, 0.1392,  ..., 0.1118, 0.0000, 0.0000],
        ...,
        [0.2358, 0.5336, 0.3337,  ..., 0.3174, 0.0000, 0.0000],
        [0.3550, 0.6016, 0.1740,  ..., 0.0543, 0.0000, 0.0000],
        [0.2248, 0.7298, 0.0000,  ..., 0.3345, 0.0000, 0.0000]])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x50 and 8x512)

## Visualise the actions of the agent in the environments

In [None]:
#env = SantaIRPEnv(num_nodes=num_nodes, batch_size=batch_size, seed=seed, num_draw=3)
env = env_santa_ff
TSPModel=f'./check_points/tsp_{num_nodes}_{seed}/model_epoch_{num_epochs-1}.pt'

In [None]:
# Setup for evaluation
env.enable_video_capturing(
    video_save_path=f"./videos/video_test_santa_ff_{num_nodes}_{seed}.mp4"
)

In [None]:
agent = TSPAgentFF(seed=seed)
agent.model.load_state_dict(torch.load(f"./check_points/santa_ff_{num_nodes}_{seed}/model_epoch_{num_epochs-1}.pt"))

In [None]:
# Evaluate the agent
loss_a = agent.evaluate(env)

In [None]:
# Close the video recorder
env.vid.close()

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Load the CSV file
csv_path = f"./train_logs/loss_log_santa_ff_{num_nodes}_{seed}.csv"
data = pd.read_csv(csv_path)

# Extract the 'Epoch' and 'Loss' columns
epochs = data['Epoch']
loss = abs(data['Loss'])*-1

# Create the plot
plt.figure(figsize=(10, 6))
plt.plot(epochs, loss, label='Training Loss', color='blue', marker='o')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.legend()
plt.grid(True)
plt.show()
