### Step 1: Choose the Dataset

- All datasets are provided in a format of `Lightning.LightningDataModule` Class, including training, validation and test subsets.
- Find the dataset you want in `data/` and simply import it.
- here we use a bearing fault type classification dataset:
    - input shape: `(b_size, seq_len=4096, num_features=1)`
    - output shape: `(b_size, num_classes=4)`

In [None]:
import torch
from data.bearing_fault_prediction.raw.fault_prediction_datamodule import FaultPredictionDataModule

data_module = FaultPredictionDataModule()   

### Step 2: Choose the Model
- Find the model you want in:
    - `Modules/classification_models.py`  for classification tasks
    - `Modules/regression_models.py`      for regression tasks

In [None]:
from Modules.classification_models import SimpleConv1dClassificationModel

model = SimpleConv1dClassificationModel(
    in_features=1,
    num_classes=4,
    hidden_features=64,
    kernel_size=16,
    stride=8,
    padding=4,
    pool_size=64,
    activation='relu',

    # training params intagrated in models:
    lr=1e-3,
    max_epochs=50,
)

# for a higher performance experiment, try using PatchTST:

# from Modules.classification_models import PatchTSTClassificationModel
# model = PatchTSTClassificationModel(
#     in_features=1,
#     d_model=64,
#     num_classes=4,
#     patch_size=64,
#     patch_stride=32,
#     dropout=0.1,
#     nhead=2,
#     num_layers=2,
#     norm_first=True,
#     activation='gelu',
    
#     lr=1e-3,
#     max_epochs=50,
# )

### Step 3: Run Training and Testing
- All training, validation and testing steps intagrated
    - just run `model.fit(datamodule)` to train itself
    - run `model.test(datamodule)` to test itself
    - callbacks, loss functions, metrics all implemented for every type of tasks, e.g. classification, regression and others
- tensorboard logger also intagrated

In [3]:
import subprocess
# open localhost:6006 in your browser to view training logs
pid = subprocess.Popen(["tensorboard", "--logdir=lightning_logs"]).pid
model.fit(data_module)
model.test(data_module)

In [None]:
subprocess.Popen(["kill", str(pid)])

### Step 4: Model Inference
- use `model.forward(input_tensor)` or `model(input_tensor)`, just the same as a standard nn.Module.

In [None]:
y = model(torch.randn(32, 4096, 1))  # batch size 32, 4096 timesteps, 1 feature
# or simply
# y = model(torch.randn(32, 4096, 1))
print(torch.softmax(y, dim=-1))