We need to install (Hugging Face) 🤗 Transformers and 🤗 Datasets.

In [2]:
%%capture
! pip install "datasets" "pytorch-lightning" "wandb" "torcheval" "torchmetrics"

In [25]:
import datasets
import torch
import wandb
import numpy as np

import os
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader, random_split

import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress import TQDMProgressBar

from torchmetrics import Accuracy
from torchmetrics import Precision, Recall
from torchmetrics.classification import MulticlassF1Score
from torchmetrics import F1Score  #, BinaryF1Score

from torchmetrics.classification import BinaryAccuracy
from torchmetrics.classification import BinaryPrecision
from torchmetrics.classification import BinaryRecall
from torchmetrics.classification import BinaryF1Score

from pytorch_lightning import LightningDataModule, LightningModule, Trainer, seed_everything

from pytorch_lightning.loggers import WandbLogger
from datasets import load_dataset, load_metric


In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = torch.device("cpu")
print(f"Device: {device}")

Device: cpu


# Loading the Dataset

In [5]:
#PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
PATH_DATASETS = "."
BATCH_SIZE = 5

In [6]:
%%capture
train_ds = FashionMNIST(PATH_DATASETS, train=True, download=True, transform=transforms.ToTensor())

In [7]:
# making the dataset binary
for idx in range(len(train_ds)):
    if train_ds.targets[idx] > 5:
        train_ds.targets[idx] = 1
    else:
        train_ds.targets[idx] = 0

In [8]:
class_0_samples_train = sum(1 for label in train_ds.targets if label == 0)
class_1_samples_train = sum(1 for label in train_ds.targets if label == 1)

print("\nNumber of Samples", len(train_ds))
print("Total label: ", class_1_samples_train + class_0_samples_train)
print("Number of Samples for Class 1 in Training Set:", class_1_samples_train)
print("Number of Samples for Class 0 in Training Set:", class_0_samples_train)


Number of Samples 60000
Total label:  60000
Number of Samples for Class 1 in Training Set: 24000
Number of Samples for Class 0 in Training Set: 36000


In [9]:
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)

In [None]:
print(train_ds.data[0].shape)

torch.Size([28, 28])


# Lightning Model

In [30]:
class Net(pl.LightningModule):
    def __init__(self, num_classes=1):
        super().__init__()

        self.l1 = nn.Linear(in_features=28 * 28, out_features=64)
        self.l2 = nn.Linear(in_features=64, out_features=1)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
        self.loss_fn = nn.BCELoss()

        self.num_classes = num_classes
        self.accuracy_metric  = BinaryAccuracy() 
        self.precision_metric = BinaryPrecision() 
        self.recall_metric    = BinaryRecall() 
        self.f1score_metric   = BinaryF1Score()
        
    def forward(self, x):
        x = x.view(x.size(0), -1)  # [ batch_size, 28, 28] -> [ batch_size, 784] 
        x = self.relu(self.l1(x))
        x = self.sigmoid(self.l2(x))
        return x

    def training_step(self, batch):
        x, y = batch
       

        predictions = self.forward(x).float().squeeze()  # I needed to transpose the vector
        y = torch.tensor(y.clone(), dtype=torch.float32)   # using float to don't loose representation


        loss = self.loss_fn(predictions, y)

        acc       = self.accuracy_metric(predictions, y)
        precision = self.precision_metric(predictions, y)
        recall    = self.recall_metric(predictions, y)
        f1_score  = self.f1score_metric(predictions, y)


        wandb.log({"acc": acc,
                   "loss": loss,
                   "precision": precision,
                   "recall": recall,
                   "f1-score:":f1_score
                   })
        
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

# Training With and without WANDB

Trying with and without WANDB

In [12]:
%%capture
personal_key = '8cc162e60d761ed72f9c2298a0ab2b17f40d0f13'
wandb.login(key = personal_key)

#wandb.login()

In [31]:
wandb.finish()

wandb.init(project="Geometric_Algebra_Transformer",
           name="GATr - wandb template",
           config={
               "learning_rate": 0.001,
               "dataset": "MNIST",
               "epochs": 2,
               "train_batch_size": 256,
               "eval_batch_size": 256
           })


In [32]:
model = Net()

wandb_logger = WandbLogger(log_model="all")
trainer = Trainer(max_epochs=10, accelerator="auto", logger=wandb_logger)
trainer.fit(model, train_loader)

# mark the run as finished
wandb.finish()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
C:\Users\pc\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pytorch_lightning\loggers\wandb.py:389: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.

  | Name             | Type            | Params
-----------------------------------------------------
0 | l1               | Linear          | 50.2 K
1 | l2               | Linear          | 65    
2 | relu             | ReLU            | 0     
3 | sigmoid          | Sigmoid         | 0     
4 | loss_fn          | BCELoss         | 0     
5 | accuracy_metric  | BinaryAccuracy  | 0     
6 | precision_metric | BinaryPrecision | 0     
7 | recall_metric    | BinaryRecall    | 0     
8

Epoch 0:   0%|          | 5/12000 [00:00<05:49, 34.34it/s, v_num=o2g6]

  y = torch.tensor(y.clone(), dtype=torch.float32)   # using float to don't loose representation


Epoch 1:  97%|█████████▋| 11603/12000 [03:04<00:06, 62.95it/s, v_num=o2g6]

C:\Users\pc\AppData\Local\Packages\PythonSoftwareFoundation.Python.3.10_qbz5n2kfra8p0\LocalCache\local-packages\Python310\site-packages\pytorch_lightning\trainer\call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


0,1
acc,▆█▆█▃▁█▆█▆█▆█▃██▃█▆█▆███▆█▆██▆▆█▆▃▆███▆▆
f1-score:,▆█▁█▆▁█▇█▇█▁▁▆██▁█▇█▇██▁▆█▆██▇▇█▁▆▆███▇▇
loss,▃▁▂▁▂▃▁▂▁▃▁▂▁▃▁▁█▁▂▁▂▁▁▁▂▁▂▁▁▂▂▁▂▂▂▁▁▁▂▃
precision,██▁██▁███▇█▁▁███▁█▆████▁█████▆▆█▁▅█████▆
recall,▅█▁█▅▁█▇███▁▁▅██▁███▆██▁▅█▅█████▁█▅███▆█

0,1
acc,0.8
f1-score:,0.85714
loss,0.29259
precision,1.0
recall,0.75
