# Bearing Fault Classification Using PatchTST Pretraining & Finetuning

-- Example usage of this repo.

In [None]:
from Modules.patchtst import PatchTSTRandomMaskedReconstructionModel, PatchTSTClassificationModel
from data.bearing_fault_prediction.raw.fault_prediction_datamodule import FaultPredictionDataModule
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import NeptuneLogger
from lightning.pytorch.callbacks import RichProgressBar

In [None]:
# 数据集 ===============================================
fault_data = FaultPredictionDataModule(
    train_val_test_split=(2800, 400, 800),
    batch_size=40,
    num_workers=4,
    pin_memory=True,
)

In [None]:
# callback ============================================
callbacks = [
    RichProgressBar(),
]

## Pretraining: Patch-wise Masked Autoencoding

- Lightning model, logger and trainer

In [None]:
patchTST = PatchTSTRandomMaskedReconstructionModel(
    in_features=1,
    d_model=128,
    patch_size=64,
    patch_stride=64,
    num_layers=2,
    dropout=0.1,
    nhead=4,
    activation='relu',
    norm_first=True,

    mask_ratio=0.4,
    learnable_mask=False
)

logger = NeptuneLogger(
    project='bearing-fault-classification',
    name='pretrain',
    api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJhNDgzNmZlMC02ZDgyLTQyZDAtOWI4Zi0yMzdiOGU4OTk2N2IifQ=='
)

trainer = Trainer(
    max_epochs=100,
    accelerator='auto',
    logger=logger,   # type: ignore
    callbacks=callbacks  # type: ignore
)

- Run fitting and get pretrained backbone

In [None]:
trainer.fit(patchTST, fault_data)
pretrained_backbone = patchTST.backbone

## Down-stream Task Finetuing: Classification

- Lightning Model, Logger and Trainer

In [None]:
patchTST = PatchTSTClassificationModel(
    in_features=1,
    d_model=128,
    patch_size=64,
    patch_stride=64,
    num_layers=2,
    dropout=0.1,
    nhead=4,
    activation='relu',
    norm_first=True,

    num_classes=4,
    lr=1e-4
)
patchTST.backbone = pretrained_backbone

logger = NeptuneLogger(
    project='bearing-fault-classification',
    name='finetune',
    api_token='eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJhNDgzNmZlMC02ZDgyLTQyZDAtOWI4Zi0yMzdiOGU4OTk2N2IifQ=='
)

trainer = Trainer(
    max_epochs=20,
    accelerator='auto',
    logger=logger,   # type: ignore
    callbacks=callbacks  # type: ignore
)

- Run finetuning and test

In [None]:
trainer.fit(patchTST, fault_data)
trainer.test(patchTST, fault_data)