# Imports

In [None]:
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import mlflow.pytorch

from src.datasets import SurveillanceAnomalyDataset
from src.hstgcnn import HSTGCNN

# Data loaders

In [None]:
dataset_path = "./data/processed/UCSD_Anomaly_Dataset.v1p2/UCSDped1/Train"
BATCH_SIZE = 32
num_workers = 4

# Create dataset and dataloaders
dataset = SurveillanceAnomalyDataset(dataset_path)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])

train_loader = DataLoader(
    train_dataset,
    BATCH_SIZE=BATCH_SIZE,
    shuffle=True,
    num_workers=num_workers,
    persistent_workers=True,
)
val_loader = DataLoader(
    val_dataset,
    BATCH_SIZE=BATCH_SIZE,
    shuffle=False,
    num_workers=num_workers,
    persistent_workers=True,
)

# Training loop

In [None]:
DEVICE = "mps"
EPOCHS = 20
learning_rate = 1e-3

experiment_name = "HSTGCNN_UCSD_Ped1"
model_version = "1.0.0"
run_name = f"HSTGCNN@{model_version}"

In [None]:
mlflow.set_experiment(experiment_name)
with mlflow.start_run(run_name=run_name) as run:
    run_id = run.info.run_id
    # Initialize trainer
    trainer = pl.Trainer(
        max_epochs=EPOCHS,
        accelerator=DEVICE,
        log_every_n_steps=10,
        enable_checkpointing=True,
        callbacks=[
            pl.callbacks.ModelCheckpoint(dirpath='checkpoints', filename='hstgcnn-{epoch:02d}-{val_score:.2f}',
                                         save_top_k=3,monitor='val_loss'),
            pl.callbacks.EarlyStopping(monitor='val_loss', patience=5, mode='min')
        ]
    )

    model = HSTGCNN().to(DEVICE)
    # Log hyperparameters to MLflow
    mlflow.log_params({
        'model_name': model.__class__.__name__,
        'learning_rate': learning_rate,
        'BATCH_SIZE': BATCH_SIZE,
        'max_epochs': EPOCHS,
    })
    # Train the model
    trainer.fit(model, train_loader, val_loader)
    # Log model to MLflow
    mlflow.pytorch.log_model(model, "models")