In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
from collections import Counter
from torch.utils.data.dataset import random_split

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_ds = datasets.MNIST(
    root='./mnist',train=True,transform=transforms.ToTensor(),download=True
)

test_ds = datasets.MNIST(
    root='./mnist',train=False,transform=transforms.ToTensor(),download=True
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist\MNIST\raw\train-images-idx3-ubyte.gz


100%|████████████████████████████████████████████████████████████████████| 9912422/9912422 [00:35<00:00, 276154.79it/s]


Extracting ./mnist\MNIST\raw\train-images-idx3-ubyte.gz to ./mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist\MNIST\raw\train-labels-idx1-ubyte.gz


100%|████████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 898879.47it/s]


Extracting ./mnist\MNIST\raw\train-labels-idx1-ubyte.gz to ./mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|████████████████████████████████████████████████████████████████████| 1648877/1648877 [00:03<00:00, 417398.79it/s]


Extracting ./mnist\MNIST\raw\t10k-images-idx3-ubyte.gz to ./mnist\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|█████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<00:00, 2324369.05it/s]

Extracting ./mnist\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./mnist\MNIST\raw






In [4]:
len(train_ds),len(test_ds)

(60000, 10000)

In [5]:
train_ds,val_ds = random_split(train_ds,lengths=[55000,5000])

In [6]:
len(train_ds),len(val_ds),len(test_ds)

(55000, 5000, 10000)

In [55]:
train_loader = DataLoader(
    train_ds,
    shuffle=True,
    num_workers=4,
    batch_size=512,
    drop_last=True,
    persistent_workers=True
)

val_loader = DataLoader(
    val_ds,
    shuffle=False,
    batch_size=512,
    num_workers=4,
    persistent_workers=True,
)

test_loader = DataLoader(
    test_ds,
    shuffle=False,
    batch_size=512
)

In [56]:
train_counter = Counter()
for images,labels in train_loader:
    train_counter.update(labels.tolist())
print(train_counter.items())

dict_items([(1, 6153), (9, 5424), (3, 5599), (4, 5343), (5, 4961), (2, 5457), (7, 5727), (6, 5401), (0, 5385), (8, 5334)])


In [57]:
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(784,64),
            nn.ReLU(),
            nn.Linear(64,64),
            nn.ReLU(),
            nn.Linear(64,10),
        )
    def forward(self,x):
        x=x.flatten(1)
        out = self.net(x)
        return out

In [58]:
# Pure PyTorch

In [59]:
def compute_accuracy(model,loader,device):
    model.eval()
    running_acc = 0.0
    for batch_idx,(images,labels) in enumerate(loader):
        logits = model(images.to(device))
        preds = logits.argmax(dim=1).cpu()
        acc = (preds==labels).type(torch.float32).mean()
        running_acc+=acc.item()
    running_acc /= len(loader)
    return running_acc

def train(model,optim,criterion,train_loader,val_loader,num_epochs,device):
    for i in range(1,num_epochs+1):
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        for batch_idx,(images,labels) in enumerate(train_loader):
            logits = model(images.to(device))
            loss = criterion(logits,labels.to(device))
            running_loss+=loss.item()
            preds = logits.argmax(dim=1).cpu()
            acc = (preds==labels).type(torch.float32).mean()
            running_acc+=acc.item()
            loss.backward()
            opt.step()
            opt.zero_grad()
        acc = compute_accuracy(model,val_loader,device)
        print(f"Epoch : {i} : train_loss = {running_loss/len(train_loader)} | training_acc = {running_acc/len(train_loader)} | val_acc = {acc}")

In [37]:
train_ds[0][0].flatten(1).shape

torch.Size([1, 784])

In [38]:
device = torch.device('cuda')
model = Classifier().to(device)
opt = torch.optim.AdamW(model.parameters(),lr=3e-4)
criterion = nn.CrossEntropyLoss()
train(model,opt,criterion,train_loader,val_loader,10,device)

Epoch : 1 : train_loss = 1.6965621169482437 | training_acc = 0.6094662675233645 | val_acc = 0.8179727375507355
Epoch : 2 : train_loss = 0.5984015325519526 | training_acc = 0.8554139894859814 | val_acc = 0.8812779009342193
Epoch : 3 : train_loss = 0.398023304538192 | training_acc = 0.8916471962616822 | val_acc = 0.9051339268684387
Epoch : 4 : train_loss = 0.33049944586285923 | training_acc = 0.9080570969626168 | val_acc = 0.9142219364643097
Epoch : 5 : train_loss = 0.29261082577928205 | training_acc = 0.9173663843457944 | val_acc = 0.9233258903026581
Epoch : 6 : train_loss = 0.26602522346460933 | training_acc = 0.9249415887850467 | val_acc = 0.9309430778026581
Epoch : 7 : train_loss = 0.2457373313536154 | training_acc = 0.9303811331775701 | val_acc = 0.9326092183589936
Epoch : 8 : train_loss = 0.22918911250395196 | training_acc = 0.9356929030373832 | val_acc = 0.9393694221973419
Epoch : 9 : train_loss = 0.2145322722809337 | training_acc = 0.9395991530373832 | val_acc = 0.940676820278167

In [41]:
import lightning as L
import torch.nn.functional as F

In [60]:
class LightningModel(L.LightningModule):
    def __init__(self,model,lr):
        super().__init__()
        self.model = model
        self.lr = lr
    def forward(self,x):
        out = self.model(x)
        return out
    def training_step(self,batch,batch_idx):
        imgs,labels = batch
        logits = self(imgs)
        loss = F.cross_entropy(logits,labels)
        self.log('train loss',loss)
        return loss
    def validation_step(self,batch,batch_idx):
        imgs,labels = batch
        logits = self(imgs)
        loss = F.cross_entropy(logits,labels)
        self.log('validation loss',loss)
    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(),lr=self.lr)
        return opt

In [61]:
model = Classifier()
lightning_model = LightningModel(model=model,lr=1e-3)
trainer = L.Trainer(
    accelerator='gpu',
    devices=1,
    max_epochs=10,
    
)
trainer.fit(
    model=lightning_model,
    train_dataloaders=train_loader,
    val_dataloaders=val_loader,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type       | Params
-------------------------------------
0 | model | Classifier | 55.1 K
-------------------------------------
55.1 K    Trainable params
0         Non-trainable params
55.1 K    Total params
0.220     Total estimated model params size (MB)


Epoch 0: 100%|██████████████████████████████████████████████████████████████| 107/107 [00:05<00:00, 20.81it/s, v_num=4]
Validation: |                                                                                    | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                                                               | 0/10 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                  | 0/10 [00:00<?, ?it/s][A
Validation DataLoader 0:  10%|█████▊                                                    | 1/10 [00:00<00:00, 58.26it/s][A
Validation DataLoader 0:  20%|███████████▌                                              | 2/10 [00:00<00:00, 66.10it/s][A
Validation DataLoader 0:  30%|█████████████████▍                                        | 3/10 [00:00<00:00, 70.62it/s][A
Validation DataLoader 0:  40%|███████████████████████▏                                  | 4/10 [00:00<00:00, 72.75it/s][A
Validation DataLoad

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████████████████████████████████████████████████████████| 107/107 [00:05<00:00, 19.62it/s, v_num=4]


In [64]:
compute_accuracy(model.to(device),train_loader,device)

0.974609375

In [65]:
import torchmetrics

In [70]:
def compute_accuracy(model,dataloader,device):
    model.eval()
    acc = torchmetrics.Accuracy(task='multiclass',num_classes=10).to(device)
    for batch_idx,(images,labels) in enumerate(dataloader):
        with torch.inference_mode():
            logits = model(images.to(device))
        pred = logits.argmax(dim=1)
        acc(pred,labels.to(device))
    return acc

In [72]:
acc = compute_accuracy(model.to(device),test_loader,device)
acc.compute()

tensor(0.9667, device='cuda:0')