In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, WeightedRandomSampler
from torch.cuda.amp import autocast, GradScaler

import numpy as np

from torchvision import datasets, transforms


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import random
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x106e377d0>

In [3]:
transform = transforms.Compose([transforms.ToTensor()])

train_ds = datasets.MNIST('.', train=True,transform=transform,download=True)
test_ds = datasets.MNIST('.', train = False, download=True, transform=transform)

In [4]:
## for imbalance data
lables = np.array(train_ds.targets)
print(lables)
print(len(lables))

class_counts = np.bincount(lables)
print(class_counts)

class_weights = 1.0 / ( class_counts + 1e-6)
print(class_weights)

sample_weights = class_weights[lables]
print(sample_weights)

sampler = WeightedRandomSampler(sample_weights, num_samples= len(sample_weights), replacement=True)

[5 0 4 ... 5 6 8]
60000
[5923 6742 5958 6131 5842 5421 5918 6265 5851 5949]
[0.00016883 0.00014832 0.00016784 0.00016311 0.00017117 0.00018447
 0.00016898 0.00015962 0.00017091 0.0001681 ]
[0.00018447 0.00016883 0.00017117 ... 0.00018447 0.00016898 0.00017091]


  lables = np.array(train_ds.targets)


In [5]:
train_loader = DataLoader(train_ds, batch_size=64, sampler=sampler)
test_loader = DataLoader(test_ds,   batch_size=64, shuffle=False)

In [7]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1,out_channels=32,kernel_size=3,padding=1)
        self.rel = nn.ReLU()
        self.conv2 = nn.Conv2d(in_channels=32,out_channels=64,kernel_size=3,padding=1)
        self.rel = nn.ReLU()
        self.pooling = nn.MaxPool2d(2)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(64*14*14,10)



    def forward(self, x):
        print(x.shape)
        x = self.conv1(x)
        print(x.shape)
        x = self.rel(x) 
        print(x.shape)
        x = self.conv2(x)
        print(x.shape)
        x = self.rel(x)
        print(x.shape)
        x = self.pooling(x)
        print(x.shape)
        x = self.flatten(x)
        print(x.shape)
        x = self.linear(x)
        print(x.shape)
        return x





model = SimpleCNN().to(device)






class FocalLoss(nn.Module):
    def  __init__(self, gamma=2.0, weight = None , reduction = 'mean'):
        super().__init__()   
        self.gamma, self.weight, self.reduction = gamma , weight, reduction

    def forward(self,logits,target):
        ce = nn.functional.cross_entropy(logits,target,weight=self.weight,reduction=self.reduction)
        pt = torch.exp(-ce)
        loss = ((1-pt)**self.gamma)*ce
        return loss.mean() if self.reduction=='mean' else loss.sum
    






criterion = FocalLoss()

optimizer = optim.Adam(model.parameters(),lr = 0.001)
scaler = GradScaler()








for epoch in range(3):
    model.train()
    for x, gt in train_loader:
        
        x = x.to(device)
        gt = gt.to(device)

        optimizer.zero_grad()
    
        with autocast():
            logits = model(x)
            loss = criterion(logits, gt)
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update() 
        print(f"Epoch: {epoch}, Loss: {loss}")
        

        checkpoint = {
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scaler':scaler.state_dict(),
            'loss':loss
        }

        torch.save(checkpoint,'checkpoint.pth')

    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for x , gt in test_loader:
            x, gt = x.to(device), gt.to(device)
            preds = model(x).argmax(1)
            correct += (preds == gt).sum().item()
            total += gt.size(0)
    print(f"Epoch {epoch}: acc {correct/total:.4f}")

  scaler = GradScaler()
  with autocast():


torch.Size([64, 1, 28, 28])
torch.Size([64, 32, 28, 28])
torch.Size([64, 32, 28, 28])
torch.Size([64, 64, 28, 28])
torch.Size([64, 64, 28, 28])
torch.Size([64, 64, 14, 14])
torch.Size([64, 12544])
torch.Size([64, 10])
Epoch: 0, Loss: 1.8573946952819824
torch.Size([64, 1, 28, 28])
torch.Size([64, 32, 28, 28])
torch.Size([64, 32, 28, 28])
torch.Size([64, 64, 28, 28])
torch.Size([64, 64, 28, 28])
torch.Size([64, 64, 14, 14])
torch.Size([64, 12544])
torch.Size([64, 10])
Epoch: 0, Loss: 2.151101589202881
torch.Size([64, 1, 28, 28])
torch.Size([64, 32, 28, 28])
torch.Size([64, 32, 28, 28])
torch.Size([64, 64, 28, 28])
torch.Size([64, 64, 28, 28])
torch.Size([64, 64, 14, 14])
torch.Size([64, 12544])
torch.Size([64, 10])
Epoch: 0, Loss: 1.8608241081237793
torch.Size([64, 1, 28, 28])
torch.Size([64, 32, 28, 28])
torch.Size([64, 32, 28, 28])
torch.Size([64, 64, 28, 28])
torch.Size([64, 64, 28, 28])
torch.Size([64, 64, 14, 14])
torch.Size([64, 12544])
torch.Size([64, 10])
Epoch: 0, Loss: 1.703195

KeyboardInterrupt: 