# –ö–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ü–∏—è –Ω–∞ PyTorch Lightning


In [1]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score


In [2]:
wine = load_wine()
X = wine.data
y = wine.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

X_train = torch.FloatTensor(X_train)
X_test = torch.FloatTensor(X_test)
y_train = torch.LongTensor(y_train)
y_test = torch.LongTensor(y_test)

print(f"Train: {X_train.shape}, Test: {X_test.shape}")
print(f"Classes: {len(np.unique(y))}")


Train: torch.Size([142, 13]), Test: torch.Size([36, 13])
Classes: 3


In [3]:
class WineDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

class WineDataModule(pl.LightningDataModule):
    def __init__(self, X_train, X_test, y_train, y_test, batch_size=16, val_size=0.2, seed=42):
        super().__init__()
        # TODO
        self.X_train = X_train
        self.X_test = X_test
        self.y_train = y_train
        self.y_test = y_test
        self.batch_size = batch_size
        self.val_size = val_size
        self.seed = seed

        # placeholders (—Å–æ–∑–¥–∞—é—Ç—Å—è –≤ setup)
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def setup(self, stage=None):
        # TODO
        # stage –º–æ–∂–µ—Ç –±—ã—Ç—å: 'fit', 'validate', 'test', 'predict'
        if stage == "fit" or stage is None:
            # –¥–µ–ª–∞–µ–º –≤–∞–ª–∏–¥–∞—Ü–∏–æ–Ω–Ω—É—é –≤—ã–±–æ—Ä–∫—É –∏–∑ train (—Å—Ç—Ä–∞—Ç–∏—Ñ–∏—Ü–∏—Ä–æ–≤–∞–Ω–Ω–æ)
            X_tr, X_val, y_tr, y_val = train_test_split(
                self.X_train.numpy(),
                self.y_train.numpy(),
                test_size=self.val_size,
                random_state=self.seed,
                stratify=self.y_train.numpy()
            )
            X_tr = torch.FloatTensor(X_tr)
            X_val = torch.FloatTensor(X_val)
            y_tr = torch.LongTensor(y_tr)
            y_val = torch.LongTensor(y_val)

            self.train_dataset = WineDataset(X_tr, y_tr)
            self.val_dataset = WineDataset(X_val, y_val)

        if stage == "test" or stage is None:
            self.test_dataset = WineDataset(self.X_test, self.y_test)

    def train_dataloader(self):
        # TODO
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        # TODO
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        # TODO
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)


In [4]:
class NeuralNetLightning(pl.LightningModule):
    def __init__(self, input_size, num_classes, lr=0.001, hidden_sizes=(32, 16)):
        super().__init__()
        # TODO
        self.save_hyperparameters()

        layers = []
        prev = input_size
        for hs in hidden_sizes:
            layers.append(nn.Linear(prev, hs))
            layers.append(nn.ReLU())
            prev = hs
        layers.append(nn.Linear(prev, num_classes))
        self.network = nn.Sequential(*layers)

        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        # TODO
        return self.network(x)

    def _shared_step(self, batch, stage: str):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()

        self.log(f"{stage}_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        self.log(f"{stage}_acc", acc, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def training_step(self, batch, batch_idx):
        # TODO
        return self._shared_step(batch, "train")

    def validation_step(self, batch, batch_idx):
        # TODO
        self._shared_step(batch, "val")

    def test_step(self, batch, batch_idx):
        # TODO
        self._shared_step(batch, "test")

    def configure_optimizers(self):
        # TODO
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)


In [5]:
data_module = WineDataModule(X_train, X_test, y_train, y_test, batch_size=16)
model = NeuralNetLightning(input_size=13, num_classes=3, lr=0.001)

print(model)


NeuralNetLightning(
  (network): Sequential(
    (0): Linear(in_features=13, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=16, bias=True)
    (3): ReLU()
    (4): Linear(in_features=16, out_features=3, bias=True)
  )
  (criterion): CrossEntropyLoss()
)


In [6]:
trainer = pl.Trainer(
    max_epochs=100,
    accelerator='auto',
    devices=1,
    log_every_n_steps=5,
    enable_progress_bar=True
)

trainer.fit(model, data_module)


üí° Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
C:\Anaconda3\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

  | Name      | Type             | Params | Mode  | FLOPs
---------------------------------------------------------------
0 | network   | Sequential       | 1.0 K  | train | 0    
1 | criterion | Cross

Sanity Checking: |                                                                               | 0/? [00:00<‚Ä¶

C:\Anaconda3\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:434: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=27` in the `DataLoader` to improve performance.
C:\Anaconda3\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:434: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=27` in the `DataLoader` to improve performance.


Training: |                                                                                      | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

Validation: |                                                                                    | 0/? [00:00<‚Ä¶

`Trainer.fit` stopped: `max_epochs=100` reached.


In [7]:
test_result = trainer.test(model, data_module)
print(f"Test Accuracy: {test_result[0]['test_acc']:.4f}")


C:\Anaconda3\lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:434: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=27` in the `DataLoader` to improve performance.


Testing: |                                                                                       | 0/? [00:00<‚Ä¶

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
       Test metric             DataLoader 0
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
        test_acc                    1.0
        test_loss          0.0023363062646239996
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚î

In [8]:
model.eval()
with torch.no_grad():
    logits = model(X_test)
    predicted = torch.argmax(logits, dim=1)
    
from sklearn.metrics import classification_report
print("\nClassification Report:")
print(classification_report(y_test, predicted, target_names=wine.target_names))



Classification Report:
              precision    recall  f1-score   support

     class_0       1.00      1.00      1.00        14
     class_1       1.00      1.00      1.00        14
     class_2       1.00      1.00      1.00         8

    accuracy                           1.00        36
   macro avg       1.00      1.00      1.00        36
weighted avg       1.00      1.00      1.00        36

