In [None]:
import pandas as pd
import json

from model import SpotGRU
from procedures import Training
from dataset import SpotDataset
from utils import ResultPlotter

# Data loading

In [None]:
train_df = pd.read_pickle("data/train_df.pkl")
val_df = pd.read_pickle("data/val_df.pkl")
test_df = pd.read_pickle("data/test_df.pkl")
instance_info_df = pd.read_pickle("data/instance_info_df.pkl")

print(f"Number of different id_instances: {test_df['id_instance'].nunique()}")
test_df.info()

In [None]:
train_dataset = SpotDataset(train_df, "config.yaml")
val_dataset = SpotDataset(val_df, "config.yaml")

In [None]:
# # Add this code cell after loading the datasets
# from utils.profiler import profile_dataloader, find_optimal_batch_size

# model = SpotGRU("config.yaml")

# # Profile DataLoader performance
# batch_sizes = [32, 64, 128, 256]
# num_workers_list = [0, 2, 4, 8]

# profiling_results = profile_dataloader(
#     train_dataset,
#     batch_sizes=batch_sizes,
#     num_workers_list=num_workers_list
# )

# # Print results
# for key, metrics in profiling_results.items():
#     print(f"\n{key}:")
#     print(f"Average batch time: {metrics['avg_batch_time']:.4f}s")
#     print(f"Throughput: {metrics['throughput']:.2f} samples/s")
#     print(f"Memory usage: {metrics['avg_memory_mb']:.2f} MB")
#     print(f"Device memory: {metrics['avg_device_mb']:.2f} MB")

# # Find optimal batch size
# optimal_batch_size = find_optimal_batch_size(
#     model=model,
#     dataset=train_dataset,
#     start_size=32,
#     max_size=512,
#     target_memory_usage=0.8
# )

# print(f"\nOptimal batch size: {optimal_batch_size}")

# Hyperparameter Tuning

In [None]:
def lr():
    from procedures import find_lr

    # Model configuration
    model_config = {
        "window_size": 20,
        "batch_size": 128,  # Smaller for better generalization
        "shuffle_buffer": 1000,
        "epochs": 150,  # More training time
        "steps_per_epoch": len(train_dataset),
        "init_learning_rate": 6e-7,
        "final_learning_rate": 1.2e-6,
        "weight_decay": 1.5e-5,
        "mse_weight": 0.8,
    }
    model = SpotGRU("config.yaml")

    log_lrs, losses = find_lr(model, train_loader, model_config)

    ResultPlotter().plot_learning_rate_finder(log_lrs, losses)


# lr()

# Model Training

In [None]:
model = SpotGRU("config.yaml")

modelTraining = Training(model, "config.yaml")
modelTraining.train_model(train_dataset, val_dataset)

In [None]:
with open("output/training_history.json", "r") as f:
    history = json.load(f)

ResultPlotter().plot_training_history(history)