EfficientNet1dXS is written as a PyTorch module and can be trained using the [standard PyTorch training workflow](https://docs.pytorch.org/tutorials/beginner/basics/intro.html) or any other PyTorch-oriented methods. 

This particular example demonstrates a general traning process within the framework of the PyTorch Lightning library. A small mock dataset is used (15 records only). 

For more detailed information and actual training of your own model, check out the [documentation](https://lightning.ai/docs/pytorch/stable/). 

In [1]:
import os
import numpy as np
import torch
from pathlib import Path
from sklearn.model_selection import train_test_split
from torchinfo import summary

import pytorch_lightning as pl
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint
)

from modules.networks import EfficientNet1dXS

# 1. Traning dataset

Training data should be prepared using the Pytorch [Datasets and DataLoaders](https://docs.pytorch.org/tutorials/beginner/basics/data_tutorial.html). Here we make a mock training dataset specifically for our 4-channel time series saved in binary files. The training labels are assigned according to the file names: files "f*" are FRB events (class 1), all others are class 0. Note that the splitting into training and validation is simplified here: one should maintain a balance between different classes in actual training.

In [2]:
class MockDataset(torch.utils.data.Dataset):
    def __init__(self, file_list):
        # Instance initialization
        self.file_list = file_list  # input file list

    def __len__(self):
        # Defining dataset length
        return len(self.file_list)

    def __getitem__(self, idx):
        # Reading from a binary file: file_list[idx]
        record = np.fromfile(self.file_list[idx], dtype=np.float32)
        record = np.reshape(record, (4, -1))  # [channels, length]
        record = torch.tensor(record)
        
        # Setting a label based on the file name
        is_frb = self.file_list[idx].stem[0] == 'f'
        label = 1. if is_frb else 0.  # should be float for loss function
                       
        # Return the readout data instance and its label   
        return record, label
    
    
# File list from the "data" folder
file_list = sorted(Path('data').glob('*'))

# Split into train and validation (simplified)
train, val = train_test_split(file_list, test_size=0.2, random_state=13)
train_dataset = MockDataset(train)
val_dataset = MockDataset(val)

In [3]:
# First element in the training dataset: the record and its label

train_dataset[0]

(tensor([[-1.9084e-17,  0.0000e+00, -1.2854e+00,  ...,  2.1157e+00,
           7.9065e-01, -7.3887e-02],
         [-5.4010e-18, -4.6295e-18,  9.7796e-01,  ..., -2.1284e+00,
           2.4998e-01,  2.8782e-01],
         [-5.5152e-18, -4.4122e-18,  6.3552e-01,  ..., -1.1870e+00,
          -1.5253e+00, -7.6169e-01],
         [-1.1947e-17, -5.9734e-17, -5.4967e-02,  ...,  1.6960e+00,
           1.3684e+00, -7.7276e-01]]),
 1.0)

# 2. Neural network

Loading the neural network and checking its forward/backward pass. The number of time series at the input is 4 with a length of 4080 elements: the data shape is [4 x 4080]. The number of output classes is 1 (binary classification). The batch size = 2.

In [4]:
model = EfficientNet1dXS(inchan=4, out_classes=1)

# Check for batch size = 2. CPU is used; for GPU: device='cuda'
print(summary(model, input_size=(2, 4, 4080), device='cpu'))

# Clear memory
del model

Layer (type:depth-idx)                                  Output Shape              Param #
EfficientNet1dXS                                        [2, 1]                    --
├─Sequential: 1-1                                       [2, 640]                  --
│    └─Conv1d: 2-1                                      [2, 12, 2040]             144
│    └─BatchNorm1d: 2-2                                 [2, 12, 2040]             24
│    └─GELU: 2-3                                        [2, 12, 2040]             --
│    └─MBConv1d: 2-4                                    [2, 12, 2040]             --
│    │    └─Identity: 3-1                               [2, 12, 2040]             --
│    │    └─Sequential: 3-2                             [2, 12, 2040]             624
│    └─MBConv1d: 2-5                                    [2, 24, 1020]             --
│    │    └─Downsample1d: 3-3                           [2, 24, 1020]             336
│    │    └─Sequential: 3-4                             [

# 3. Training with PyTorch Lightning

Below, a Pytorch Lightning module contains all necessary functions to performs the traning:
* initialization
* network forward pass
* training step with calculation of the training loss
* setting an optimizer and a learning rate scheduler (the latter is optional)
* validation step with calculation of the validation loss
* training data loader
* validation data loader

In [5]:
# PyTorch Lightning module
class FRB_Lightning(pl.LightningModule):
    def __init__(self, model, batch_size, lr, train_dataset, val_dataset):
        super(FRB_Lightning, self).__init__()
        self.save_hyperparameters(ignore='model')  # for checkpoints
        self.model = model                  # model (neural network)
        self.batch_size = batch_size        # batch size
        self.lr = lr                        # learning rate
        self.train_dataset = train_dataset  # training dataset
        self.val_dataset = val_dataset      # validation dataset
        self.criterion = torch.nn.BCEWithLogitsLoss()  # loss function

    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        records, labels = batch
        out = self.forward(records)
        out = torch.flatten(out) # batch flattening (binary classification)
        loss = self.criterion(out, labels)
        self.log('train_loss', loss, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        # Optimizer and learning rate scheduler
        opt = torch.optim.Adamax(self.model.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=10)
        return [opt], [scheduler]

    def validation_step(self, batch, batch_idx):
        records, labels = batch
        out = self.forward(records)
        out = torch.flatten(out)  # batch flattening (binary classification)
        loss = self.criterion(out, labels)
        self.log('val_loss', loss, prog_bar=True)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=os.cpu_count(),
            drop_last=True
        )
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=os.cpu_count(),
        )

We now create an instance of the PyTorch Lightning module and assign a value to be monitored to save the best models (checkpoints), here this is the minimum of the validation loss. An instance of PyTorch Lightning Trainer is then created with the number of epochs, logging frequency, and callbacks. Finally, the model is trained. The best models are saved in the 'lightning_logs' folder. Logs of the training can be viewed in a browser via TensorBoard: 

<tt>tensorboard --logdir=lightning_logs</tt>

In [None]:
# Lightning module
pl_model = FRB_Lightning(
    EfficientNet1dXS(inchan=4, out_classes=1), 
    batch_size=2, lr=0.002, 
    train_dataset=train_dataset, 
    val_dataset=val_dataset)

# Checkpoint saving based on the val_loss minimum 
checkpoint_callback = ModelCheckpoint(monitor='val_loss', mode='min')

# Lightning Trainer
trainer = pl.Trainer(
    max_epochs=10,
    log_every_n_steps=2,
    callbacks=[
        checkpoint_callback,
        LearningRateMonitor('epoch'),
        EarlyStopping(monitor='val_loss', mode='min', patience=5)
    ]
)

# Model training
trainer.fit(pl_model)

# 4. Loading checkpoints and saving the models

The checkpoints saved in the 'lightning_logs' folder can be loaded. You may also want to save the model as a TorchScript to use further independently of PyTorch Lightning and other training dependencies. 

Loading a previously saved checkpoint (put an actual file name below):

In [None]:
best_model = FRB_Lightning.load_from_checkpoint(
    'lightning_logs/version_?/checkpoints/epoch=?-step=?.ckpt',
    model=EfficientNet1dXS(inchan=4, out_classes=1))

Saving the model for further usage in PyTorch 

In [8]:
best_model.to('cpu').eval()
script = best_model.to_torchscript()
torch.jit.save(script, 'filename.pt')

Loading the model

In [None]:
# Use GPU if available, else CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

best_model = torch.jit.load('filename.pt').to(device).eval()

# 5. Making a prediction

Making a prediction for a record. This can predict incorrectly, as the mock data are too scarce for proper training. The data format should be same as in the training.

In [None]:
# Reading our binary file
record = np.fromfile('data/noise_h00_r1_s0000', dtype=np.float32)
record = np.reshape(record, (4, -1))  # [channels, length]
record = torch.tensor(record)

# Adding the batch dimension and copying the data to the device (GPU or CPU)
record = record.unsqueeze(0).to(device)

# For our example of binary classification
with torch.no_grad():
    logit = best_model(record)
    if logit > 0:
        print('Detection')
    else:
        print('No detection')

No detection
