# Training SLAT

This is the routine for training SLAT on a given dataset. For logging the model performance W&B is used and for simplified training Lightning is used

## Import necessary modules

We will use W&B logger as Lightning integration. The Trainer will be loaded from Lightning. The dataset is available as ".mat" file and will be loaded using the scipy.io package, then reshaped and finally converted into a TensorDataset. For automatic batch creation the DataLoader class is used. 

In [None]:
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import Trainer
from torch.utils.data import DataLoader, TensorDataset
from lightning.pytorch.callbacks import ModelCheckpoint
import scipy.io as sio
import torch
import SLAT
import wandb
import huggingface_hub

## Makros

Here, we define our api tokens for W&B and HuggingFace.

In [None]:
WANDB_API_KEY = ""
HF_KEY = ""

## Login

In this section we log in to W&B and HuggingFace

In [None]:
wandb.login(key=WANDB_API_KEY)
huggingface_hub.login(token=HF_KEY)

## Setting up the data loaders

The datasets are loaded and prepared to be used in a training routine

In [None]:
timeWindowSize = 40
features = 34

X_train = sio.loadmat('./data/trainX_new.mat')
X_train = X_train['train1X_new']
X_train = X_train.reshape(len(X_train), timeWindowSize+2, features)
Y_train = sio.loadmat('./data/trainY.mat')
Y_train = Y_train['train1Y']
Y_train = Y_train.transpose()

X_test = sio.loadmat('./data/testX_new.mat')
X_test = X_test['test1X_new']
X_test = X_test.reshape(len(X_test), timeWindowSize+2, features)
Y_test = sio.loadmat('./data/testY.mat')
Y_test = Y_test['test1Y']
Y_test = Y_test.transpose()

training_set = TensorDataset(
    torch.tensor(X_train, dtype=torch.float),
    torch.tensor(Y_train, dtype=torch.float)
)
validation_set = TensorDataset(
    torch.tensor(X_test, dtype=torch.float),
    torch.tensor(Y_test, dtype=torch.float)
)

training_loader = DataLoader(
    training_set,
    batch_size=256,
    shuffle=True,
    num_workers=4
)
validation_loader = DataLoader(
    validation_set,
    batch_size=256,
    num_workers=4
)

## Training

In the first stage, the non-trained model is loaded. Second, the checkpoint is defined and the variable that is monitored. Third, the logger from W&B is instanciated. Afterwards, the trainier itself is instanciated and the model is fitted on the data. Then, the finishing of the logger is called and finally the trained model is pushed to HuggingFace.

In [None]:
model = SLAT.SLAT_LitModule()

checkpoint_callback = ModelCheckpoint(monitor='val_RMSE', mode='min')

wandb_logger = WandbLogger(
    project='SLAT',
    log_model='all'
)

trainer = Trainer(
    logger=wandb_logger,
    callbacks=[checkpoint_callback],
    accelerator="gpu",
    max_epochs=300
)

trainer.fit(model, training_loader, validation_loader)

wandb.finish()

model.push_to_hub(
    "dschneider96/SLAT",
    use_auth_token=True,
    commit_message="basic training",
    private=True
)

## Inference

This is a basic example for prediction using the pretrained model from HuggingFace.

In [None]:
model_pretrained = SLAT.SLAT_LitModule.from_pretrained("dschneider96/SLAT")

trainer = Trainer(
    accelerator="gpu"
)

test_loader = validation_loader
predictions = trainer.predict(model=model_pretrained, dataloaders=test_loader)