# Imports

In [1]:
import os

from dotenv import load_dotenv
import torch
from torch.utils.data import DataLoader
import lightning as pl
from lightning.pytorch.loggers import MLFlowLogger
import mlflow.pytorch

from src.datasets import SurveillanceAnomalyDataset, load_and_process_ontology
from src.hstgcnn import KnowledgeHSTGCNN


load_dotenv()

True

# Data loaders

In [2]:
dataset_path = "../data/processed/UCSD_Anomaly_Dataset.v1p2/UCSDped1/Train"
BATCH_SIZE = 32
num_workers = 4

# Create dataset and dataloaders
dataset = SurveillanceAnomalyDataset(dataset_path)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.8, 0.2])

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=num_workers,
    persistent_workers=True,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=num_workers,
    persistent_workers=True,
)

# Training loop

In [3]:
DEVICE = "cuda"
EPOCHS = 20
learning_rate = 1e-3

experiment_name = "HSTGCNN_UCSD_Ped1"
model_version = "3.5.0"
run_name = f"KnowledgeHSTGCNN@{model_version}"

ontology_path = "../data/TrafficEnvironment.rdf"
sentence_transformer = "all-MiniLM-L6-v2"
model_definition_path = "../src/hstgcnn.py"

In [4]:
mlflow.set_experiment(experiment_name)
mlf_logger = MLFlowLogger(
    tracking_uri=os.getenv("MLFLOW_TRACKING_URI"),
    experiment_name=experiment_name,
    run_name=run_name,
    log_model=False,
)
# Initialize trainer
trainer = pl.Trainer(
    max_epochs=EPOCHS,
    accelerator=DEVICE,
    log_every_n_steps=1,
    logger=mlf_logger,
    callbacks=[
        pl.pytorch.callbacks.EarlyStopping(monitor='val_loss', patience=3, mode='min'),
        pl.pytorch.callbacks.ModelCheckpoint(dirpath='../checkpoints/UCSD_Ped1/',
                                             filename='KnowledgeHSTGCNN_v' + ''.join(model_version.split('.')) + '-epoch_{epoch:02d}',
                                             monitor='val_loss', mode='min', save_top_k=1,
                                             save_last=False, auto_insert_metric_name=False),
    ],
)
pyg_data = load_and_process_ontology(ontology_path, sentence_transformer, device=DEVICE)
model = KnowledgeHSTGCNN(pyg_data, oa_weights=[0.2, 0.5, 0.3], lr=learning_rate).to(DEVICE)
# Log hyperparameters to MLflow
mlf_logger.log_hyperparams({
    'model_name': model.__class__.__name__,
    'learning_rate': learning_rate,
    'batch_size': BATCH_SIZE,
    'max_epochs': EPOCHS,
    'ontology': ontology_path.split('/')[-1].split('.')[0],
    'sentence_transformer': sentence_transformer,
})
# Train the model
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
# Log artifacts
best_model_path = trainer.checkpoint_callback.best_model_path
if best_model_path:
    mlf_logger.experiment.log_artifact(mlf_logger.run_id, best_model_path, "models")
    mlf_logger.experiment.log_artifact(mlf_logger.run_id, model_definition_path, "models")

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                   | Params | Mode 
---------------------------------------------------------------------
0 | high_level_stgcnn | STGCNN                 | 97     | train
1 | low_level_stgcnn  | STGCNN                 | 193    | train
2 | feature_fusion    | KnowledgeFeatureFusion | 37.2 K | train
3 | ffp               | FFP                    | 199 K  | train
4 | oa                | OA                     | 3      | train
5 | loss_fn           | MSELoss                | 0      | train
---------------------------------------------------------------------
237 K     Trainable params
0         Non-trainable params
237 K     Total params
0.949     Total estimated model params size

Sanity Checking: |                                                   | 0/? [00:00<?, ?it/s]

Training: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.
