# Model inference pipeline

### Import of necessary python modules

In [None]:
import wandb
import h5py
import os
import torch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger
import sys
sys.path.append("../model")
import vanilla_transformer
import matplotlib.pyplot as plt

### Weights & Biases

In [None]:
wandb.login()

### Defining hyperparameters

In [None]:
cwd = "../data"
subdataset = "FD001"
unit = 24
RUL_max = 125
ci = 0.1

### Loading the test dataset

In [None]:
save_dir = f"{cwd}/{subdataset}/RTF.h5"
database = h5py.File(save_dir, "r")

rtf_set = TensorDataset(
    torch.tensor(np.array(database["RTF_X"]), dtype=torch.float),
    torch.tensor(np.array(database["RTF_Y"]), dtype=torch.float)
)

rtf_loader = DataLoader(
    rtf_set,
    batch_size=256,
    num_workers=4
)

In [None]:
# Configuration for the run
config = {
    "model": "best_model:latest"
}

# Initialize a new run
with wandb.init(
    project="RUL Prediction",
    job_type="inference",
    notes="Testing Vanilla Transformer for RUL prediction",
    tags=["baseline", "Vanilla", "RUL"],
    config=config
) as run:
    # Download the model as W&B artifact
    path = run.use_artifact(wandb.config["model"]).download()

    # Call logger
    wandb_logger = WandbLogger(
        project="RUL Prediction",
        job_type="inference",
        notes="Testing Vanilla Transformer for RUL prediction",
        tags=["baseline", "Vanilla", "RUL"]
    )

    # Call a trainer for inference
    trainer = Trainer(
        logger=wandb_logger,
        accelerator="gpu"
    )

    # Load the model from Lightning checkpoint
    model = vanilla_transformer.VanTransLitModule.load_from_checkpoint(os.path.join(path, "model.ckpt"))

    # Trigger inference
    predictions = trainer.predict(model=model, dataloaders=rtf_loader)

### Plotting the result

In [None]:
# Create new figure
fig, ax = plt.subplots(figsize=(8, 4))
plt.rcParams["font.family"] = "Times New Roman"

# Get the true and predicted RUL and renormalize values
y_true = rtf_loader.dataset.tensors[1].numpy() * RUL_max
y_pred = predictions[0].numpy() * RUL_max

# Plot the true RUL
plt.plot(
    y_true,
    color="tab:blue",
    label=f"True RUL for unit {unit}"
)

# Plot a confidence interval
ci_lower = np.squeeze(y_true - ci*RUL_max)
ci_upper = np.squeeze(y_true + ci*RUL_max)
t = np.arange(len(y_true))
ax.fill_between(t, ci_lower, ci_upper, color='grey', alpha=.3)

# Plot the predicted RUL
plt.plot(
    y_pred,
    color="tab:orange",
    label="Predicted RUL"
)

# Format the plot
plt.legend(fontsize=12)
plt.grid()
plt.xlabel("Inspection intervals", fontsize=12)
plt.ylabel("RUL", fontsize=12)
title_str = "Vanilla Transformer for RUL Prediction"
plt.title(title_str, fontsize=12)
fig.tight_layout()

# Save the plot
save_path = f"./RULfTimeSteps_unit_{subdataset}_{unit}.svg"
plt.savefig(save_path, format='svg', dpi=1200)
save_path = f"./RULfTimeSteps_unit_{subdataset}_{unit}.png"
plt.savefig(save_path, format='png', dpi=1200)