# Imports

In [None]:
import os

from dotenv import load_dotenv
import torch
from torch.utils.data import DataLoader
import lightning as pl
from lightning.pytorch.loggers import MLFlowLogger
import mlflow.pytorch

from src.datasets import SurveillanceAnomalyDataset
from src.hstgcnn import HSTGCNN


load_dotenv()

# Data loaders

In [None]:
dataset_path = "./data/processed/UCSD_Anomaly_Dataset.v1p2/UCSDped1/Train"
BATCH_SIZE = 64
num_workers = 4

# Create dataset and dataloaders
dataset = SurveillanceAnomalyDataset(dataset_path)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=num_workers,
    persistent_workers=True,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=num_workers,
    persistent_workers=True,
)

# Training loop

In [None]:
DEVICE = "mps"
EPOCHS = 20
learning_rate = 1e-3

experiment_name = "HSTGCNN_UCSD_Ped1"
model_version = "1.0.0"
run_name = f"HSTGCNN@{model_version}"

model_definition_path = "./src/hstgcnn.py"

In [None]:
mlflow.set_experiment(experiment_name)
mlf_logger = MLFlowLogger(
    tracking_uri=os.getenv("MLFLOW_TRACKING_URI"),
    experiment_name=experiment_name,
    run_name=run_name,
    log_model=False,
)
# Initialize trainer
trainer = pl.Trainer(
    max_epochs=EPOCHS,
    accelerator=DEVICE,
    log_every_n_steps=1,
    logger=mlf_logger,
    callbacks=[
        pl.pytorch.callbacks.EarlyStopping(monitor='val_loss', patience=3, mode='min'),
        pl.pytorch.callbacks.ModelCheckpoint(dirpath='./checkpoints', filename='hstgcnn-epoch_{epoch:02d}',
                                     monitor='val_loss', mode='min', save_top_k=1,
                                     save_last=False, auto_insert_metric_name=False),
    ],
)

model = HSTGCNN(oa_weights=[0.2, 0.5, 0.3], lr=learning_rate).to(DEVICE)
# Log hyperparameters to MLflow
mlf_logger.log_hyperparams({
    'model_name': model.__class__.__name__,
    'learning_rate': learning_rate,
    'batch_size': BATCH_SIZE,
    'max_epochs': EPOCHS,
})
# Train the model
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
# Log artifacts
best_model_path = trainer.checkpoint_callback.best_model_path
if best_model_path:
    mlf_logger.experiment.log_artifact(mlf_logger.run_id, best_model_path, "models")
    mlf_logger.experiment.log_artifact(mlf_logger.run_id, model_definition_path, "models")