In [None]:
import os

from torch.utils.data import DataLoader
import torch

from model import SAE_wrapper
from torch_dataset import SAE_dataset

from pytorch_lightning.loggers import MLFlowLogger
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
    ModelSummary,
    StochasticWeightAveraging
)

import mlflow


In [None]:
# Initialize MLFlow Logger

seed_everything(42) # Reproducibility 
experiment_name = "emb256"
run_name = f"{experiment_name}_bs64"
tracking_uri = "file:./sae_experiments"

mlflow_logger = MLFlowLogger(
    experiment_name=experiment_name,
    run_name=run_name,
    tracking_uri=tracking_uri
)

# Hyperparameters
hparams = {
    "lr": 1e-3,
    "batch_size": 64,
    "epochs": 20,
    "patience": 3,
    "dimension_list": [11, 16, 32, 64, 128, 256]
}
mlflow_logger.log_hyperparams(hparams)

In [None]:
# Data source 
root = "./data"
split_root_sae = "sae_data"

In [None]:
# Initialize dataloader
train_ds = SAE_dataset(df_path=os.path.join(root, split_root_sae, "train.csv"))
val_ds   = SAE_dataset(df_path=os.path.join(root, split_root_sae, "val.csv"))
test_ds  = SAE_dataset(df_path=os.path.join(root, split_root_sae, "test.csv"))

batch_size = hparams["batch_size"]
train_ld = DataLoader(dataset=train_ds, batch_size=batch_size, shuffle=True, num_workers=4, persistent_workers=True)
val_ld   = DataLoader(dataset=val_ds, batch_size=batch_size, shuffle=False, num_workers=4, persistent_workers=True)
test_ld  = DataLoader(dataset=test_ds, batch_size=batch_size, shuffle=False, num_workers=4)

In [None]:
# Initialize model
dim_list = hparams['dimension_list']
lr = hparams['lr']
sparse_autoencoder = SAE_wrapper(dim_list=dim_list, lr=lr)


In [None]:
# Callbacks 
check_point_name = f"best-checkpoint_{experiment_name}"

training_callbacks = [
        EarlyStopping(monitor="val_loss", mode="min", patience=hparams['patience']),
        StochasticWeightAveraging(swa_lrs=1e-2),
        LearningRateMonitor(logging_interval="step"),
        ModelCheckpoint(
            monitor="val_loss",
            mode="min",
            save_top_k=1,
            dirpath="sae_checkpoints/",
            filename=f"best-checkpoint_{experiment_name}"
        ),
        ModelSummary(-1)    
    ]

In [None]:
# Training
torch.cuda.empty_cache()  
trainer = Trainer(
    logger=mlflow_logger,
    callbacks=training_callbacks,
    max_epochs=hparams["epochs"],
    log_every_n_steps=1,
)

trainer.fit(model=sparse_autoencoder,
            train_dataloaders=train_ld,
            val_dataloaders= val_ld, 
            ckpt_path=None)

In [None]:
# Testing 
trainer.test(model=sparse_autoencoder, 
        dataloaders=test_ld, 
        ckpt_path=None)