# Model training pipeline

### Import of necessary python modules

In [None]:
import wandb
import h5py
import torch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
import sys
sys.path.append("../model")
import vanilla_transformer

### Weights & Biases

In [None]:
wandb.login()

### Define a run config

In [None]:
config = {
    "trainer": {
        "epochs": 300,
        "batch_size": 256
    },
    "architecture": "Vanilla Transformer",
    "dataset": "FD001"
}
cwd = "../data"

### Create a run

In [None]:
# Create a run instance
with wandb.init(
    project="RUL Prediction",
    job_type="training",
    notes="Training Vanilla Transformer for RUL prediction",
    tags=["baseline", "Vanilla", "RUL"],
    config=config
) as run:
    # Load the data
    save_dir = f"{cwd}/{wandb.config['dataset']}/{wandb.config['dataset']}.h5"
    database = h5py.File(save_dir, "r")

    # Prepare the data
    training_set = TensorDataset(
        torch.tensor(np.array(database["X_train"]), dtype=torch.float),
        torch.tensor(np.array(database["Y_train"]), dtype=torch.float)
    )
    validation_set = TensorDataset(
        torch.tensor(np.array(database["X_test"]), dtype=torch.float),
        torch.tensor(np.array(database["Y_test"]), dtype=torch.float)
    )

    print("Train set size X: {}".format(training_set.tensors[0].shape))
    print("Train set size y: {}".format(training_set.tensors[1].shape))
    print("Validation set size X: {}".format(validation_set.tensors[0].shape))
    print("Validation set size y: {}".format(validation_set.tensors[1].shape))

    # Create data loaders
    training_loader = DataLoader(
        training_set,
        batch_size=wandb.config["trainer"]["batch_size"],
        shuffle=True,
        num_workers=4
    )
    validation_loader = DataLoader(
        validation_set,
        batch_size=wandb.config["trainer"]["batch_size"],
        num_workers=4
    )

    # Create the model
    model = vanilla_transformer.VanTransLitModule()

    # Define the callbacks
    callbacks = [
        ModelCheckpoint(
            monitor="val_RMSE",
            mode="min"
        ),
        LearningRateMonitor(logging_interval="epoch")
    ]

    # Define the logger
    logger = WandbLogger(
        name="Vanilla Transformer",
        checkpoint_name="best_model",
        project="RUL Prediction",
        log_model=True
    )

    # Define the trainer
    trainer = Trainer(
        logger=logger,
        callbacks=callbacks,
        accelerator="auto",
        max_epochs=wandb.config["trainer"]["epochs"]
    )

    # Train the model
    trainer.fit(model, training_loader, validation_loader)

    wandb.finish()