In [None]:
import pandas as pd
import json

from model import SpotLSTM, Training
from dataset import SpotDataset
from utils import ResultPlotter

# Data loading

In [None]:
train_df = pd.read_pickle("data/train_df.pkl")
val_df = pd.read_pickle("data/val_df.pkl")
test_df = pd.read_pickle("data/test_df.pkl")
instance_info_df = pd.read_pickle("data/instance_info_df.pkl")

print(f"Number of different id_instances: {test_df['id_instance'].nunique()}")
test_df.info()

In [None]:
train_dataset = SpotDataset(train_df, "config.yaml")
train_loader = train_dataset.get_data_loader()

val_dataset = SpotDataset(val_df, "config.yaml")
val_loader = val_dataset.get_data_loader()

# Hyperparameter Tuning

In [None]:
def lr():
    from model import find_lr

    # Model configuration
    model_config = {
        "window_size": 20,
        "batch_size": 128,  # Smaller for better generalization
        "shuffle_buffer": 1000,
        "epochs": 150,  # More training time
        "steps_per_epoch": len(train_dataset),
        "init_learning_rate": 6e-7,
        "final_learning_rate": 1.2e-6,
        "weight_decay": 1.5e-5,
        "mse_weight": 0.8,
    }
    model = SpotLSTM("config.yaml")

    log_lrs, losses = find_lr(model, train_loader, model_config)

    ResultPlotter().plot_learning_rate_finder(log_lrs, losses)


# lr()

# Model Training

In [None]:
model = SpotLSTM("config.yaml")

modelTraining = Training(model, len(train_dataset), "config.yaml")
modelTraining.train_model(train_loader, val_loader)

In [None]:
with open("output/training_history.json", "r") as f:
    history = json.load(f)

ResultPlotter().plot_training_history(history)