# Training Example utilising hugging-face datasets and torchvision models

## Prepare Datasets & Dataloaders (cifar10 & resnet50 example)

### Load raw dataset

In [1]:
import torchvision
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import random_split

cifar10 = load_dataset('cifar10')

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform)

valid_size = int(0.2 * len(train_dataset))
train_dataset, valid_dataset = random_split(train_dataset, [len(train_dataset) - valid_size, valid_size])

test_dataset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


### Prepare dataloaders

In [2]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

## Prepare, train & test model

### Prepare model, optimizer and loss function

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet50

from kiss.utils.configs import CONFIGS

model = resnet50(num_classes=10)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.001)

device = CONFIGS.torch.device
model.to(device);

### Train loop

In [4]:
from tqdm import tqdm
from kiss.utils.strings import Format

num_epochs = 10
best_valid_acc = 0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}", unit=" batch") as pbar:
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            pbar.update(1)
            pbar.set_postfix(loss=f"{running_loss / len(train_loader):.4f}")
    
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in valid_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    
    if accuracy > best_valid_acc:
        with Format(Format.BOLD, Format.CYAN):
            print(f"Best valid accuracy improved from {best_valid_acc * 100:.2f}% to {accuracy * 100:.2f}%. Saving checkpoint...")
        best_valid_acc = accuracy
        torch.save(model.state_dict(), "../checkpoints/resnet50_prototype.pth")

Epoch 1/10: 100%|██████████| 625/625 [01:37<00:00,  6.42 batch/s, loss=1.9096]


[1m[36mBest valid accuracy improved from 0.00% to 33.56%. Saving checkpoint...
[0m

Epoch 2/10: 100%|██████████| 625/625 [01:36<00:00,  6.45 batch/s, loss=1.7142]


[1m[36mBest valid accuracy improved from 33.56% to 43.85%. Saving checkpoint...
[0m

Epoch 3/10: 100%|██████████| 625/625 [01:36<00:00,  6.48 batch/s, loss=1.5894]


[1m[36mBest valid accuracy improved from 43.85% to 44.18%. Saving checkpoint...
[0m

Epoch 4/10: 100%|██████████| 625/625 [01:35<00:00,  6.51 batch/s, loss=1.5328]


[1m[36mBest valid accuracy improved from 44.18% to 51.80%. Saving checkpoint...
[0m

Epoch 5/10: 100%|██████████| 625/625 [01:35<00:00,  6.51 batch/s, loss=1.3929]
Epoch 6/10: 100%|██████████| 625/625 [01:36<00:00,  6.49 batch/s, loss=1.5160]
Epoch 7/10: 100%|██████████| 625/625 [01:35<00:00,  6.51 batch/s, loss=1.3796]


[1m[36mBest valid accuracy improved from 51.80% to 57.46%. Saving checkpoint...
[0m

Epoch 8/10: 100%|██████████| 625/625 [01:35<00:00,  6.54 batch/s, loss=1.4863]
Epoch 9/10: 100%|██████████| 625/625 [01:35<00:00,  6.52 batch/s, loss=1.2527]


[1m[36mBest valid accuracy improved from 57.46% to 57.54%. Saving checkpoint...
[0m

Epoch 10/10: 100%|██████████| 625/625 [01:36<00:00,  6.51 batch/s, loss=1.1694]


[1m[36mBest valid accuracy improved from 57.54% to 60.42%. Saving checkpoint...
[0m

### Test Model

In [5]:
from tqdm import tqdm

model = resnet50(num_classes=10)
model.load_state_dict(torch.load("../checkpoints/resnet50_prototype.pth"))
model.to(device)

model.eval()
correct = 0
total = 0

with torch.no_grad():
    with tqdm(total=len(test_loader), desc="Testing", unit=" batch") as pbar:
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            pbar.update(1)

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")

Testing: 100%|██████████| 157/157 [00:28<00:00,  5.59 batch/s]

Test Accuracy: 61.00%



