In [146]:
import os

import numpy as np

import torch
from torch import nn
from torch.nn import functional as F

import pytorch_lightning as pl

from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split

from pytorch_lightning.callbacks import EarlyStopping


## Three-class classification model

In [199]:
class SimpleClassifier(pl.LightningModule):

    def __init__(self):
        super().__init__()

        # Linear  
        self.layer_1 = torch.nn.Linear(4, 32)
        self.layer_2 = torch.nn.Linear(32, 64)
        self.layer_3 = torch.nn.Linear(64, 32)
        self.layer_4 = torch.nn.Linear(32, 2)

    def forward(self, x):

        # Layer 1 
        x = self.layer_1(x)
        x = torch.relu(x)

        # Layer 2
        x = self.layer_2(x)
        x = torch.relu(x)
        
        # Layer 3
        x = self.layer_3(x)
        x = torch.relu(x)
        
        # Layer 4
        x = self.layer_4(x)
        x = torch.relu(x)

        # Probability distribution over labels
        x = torch.log_softmax(x, dim=1)

        return x

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        logits = self.forward(x)
        loss = self.cross_entropy_loss(logits, y)
        self.log('val_loss', loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-2)
        return optimizer

In [200]:
class MyDataModule(pl.LightningDataModule):
    
    def __init__(self, X, y, train_size):
        super().__init__()
        self.X = X
        self.y = y
        self.train_size = train_size

    def setup(self, stage):
        
        X_train = torch.tensor(self.X[:self.train_size])
        y_train = torch.tensor(self.y[:self.train_size], dtype=torch.long)
        X_test = torch.tensor(self.X[self.train_size:])
        y_test = torch.tensor(self.y[self.train_size:], dtype=torch.long)
        
        self.train = TensorDataset(X_train, y_train)
        self.test = TensorDataset(X_test, y_test)

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=16)

    def val_dataloader(self):
        return DataLoader(self.test, batch_size=16)

In [201]:
X = np.array([
    [1, 1, 1, 1],
    [2, 4, 6, 24], 
    [-1, -2, -1, -192],
    [-191, -3, -2, -7],
    [102, 12, 16, 200],
    [7, 9, 13, 177],
] * 200, dtype='float32')

y = np.array([
    0, 
    0, 
    1, 
    1, 
    0, 
    0
] * 200, dtype='int32')
 


In [202]:
data_module = MyDataModule(X, y, train_size=16)

In [203]:
early_stopping = EarlyStopping('val_loss', patience=20)

model = SimpleClassifier()

trainer = pl.Trainer(max_epochs=1000, callbacks=[early_stopping])

trainer.fit(model, data_module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 160   
1 | layer_2 | Linear | 2.1 K 
2 | layer_3 | Linear | 2.1 K 
3 | layer_4 | Linear | 66    
-----------------------------------
4.4 K     Trainable params
0         Non-trainable params
4.4 K     Total params
0.018     Total estimated model params size (MB)


Epoch 0:   1%|▋                                                   | 1/75 [00:00<00:00, 117.14it/s, loss=0.51, v_num=32]
Validating: 0it [00:00, ?it/s][A
Epoch 0: 100%|██████████████████████████████████████████████████| 75/75 [00:00<00:00, 1079.89it/s, loss=0.51, v_num=32][A
Epoch 1:   1%|▋                                                  | 1/75 [00:00<00:00, 200.44it/s, loss=0.362, v_num=32][A
Validating: 0it [00:00, ?it/s][A
Epoch 1: 100%|█████████████████████████████████████████████████| 75/75 [00:00<00:00, 1193.08it/s, loss=0.362, v_num=32][A
Epoch 2:   1%|▋                                                  | 1/75 [00:00<00:00, 250.63it/s, loss=0.296, v_num=32][A
Validating: 0it [00:00, ?it/s][A
Epoch 2: 100%|█████████████████████████████████████████████████| 75/75 [00:00<00:00, 1089.33it/s, loss=0.296, v_num=32][A
Epoch 3:   1%|▋                                                  | 1/75 [00:00<00:00, 200.42it/s, loss=0.255, v_num=32][A
Validating: 0it [00:00, ?it/s][A
Epoch 

Epoch 58:   1%|▋                                              | 1/75 [00:00<00:00, 142.79it/s, loss=5.59e-09, v_num=32][A
Validating: 0it [00:00, ?it/s][A
Epoch 58: 100%|██████████████████████████████████████████████| 75/75 [00:00<00:00, 775.00it/s, loss=5.59e-09, v_num=32][A
Epoch 59:   1%|▋                                              | 1/75 [00:00<00:00, 250.69it/s, loss=4.47e-09, v_num=32][A
Validating: 0it [00:00, ?it/s][A
Epoch 59: 100%|█████████████████████████████████████████████| 75/75 [00:00<00:00, 1161.76it/s, loss=4.47e-09, v_num=32][A
Epoch 60:   1%|▋                                              | 1/75 [00:00<00:00, 125.32it/s, loss=3.35e-09, v_num=32][A
Validating: 0it [00:00, ?it/s][A
Epoch 60: 100%|█████████████████████████████████████████████| 75/75 [00:00<00:00, 1058.53it/s, loss=3.35e-09, v_num=32][A
Epoch 61:   1%|▋                                              | 1/75 [00:00<00:00, 199.56it/s, loss=2.24e-09, v_num=32][A
Validating: 0it [00:00, ?it/s][A
Epo

In [204]:
model(torch.tensor([[1, 1, 1, 79]], dtype=torch.float)).argmax()

tensor(0)