In [221]:
from tqdm import tqdm
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import numpy as np
import pandas as pd
from helper_functions import accuracy_fn
from torchsummary import summary
from safetensors.torch import save_model
import matplotlib.pyplot as plt
import cv2

In [222]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [224]:
data_transforms = transforms.Compose([
    transforms.Grayscale(3),
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [225]:
train_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=data_transforms
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=data_transforms
)

In [226]:
batch_size = 16

# put custom dataset to dataloader
train_dl = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_data, batch_size=batch_size, shuffle=True)

In [227]:
# test if data can be inserted into dataloader
x, y = next(iter(train_dl))

### Model Creation

In [247]:
pretrained = torch.hub.load('pytorch/vision:v0.10.0', "resnet18", pretrained=True)
pretrained.to(device)
summary(pretrained, (3, 28, 28)) # needs to be rgb; any image size works i think

Using cache found in C:\Users\paoma/.cache\torch\hub\pytorch_vision_v0.10.0


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 14, 14]           9,408
       BatchNorm2d-2           [-1, 64, 14, 14]             128
              ReLU-3           [-1, 64, 14, 14]               0
         MaxPool2d-4             [-1, 64, 7, 7]               0
            Conv2d-5             [-1, 64, 7, 7]          36,864
       BatchNorm2d-6             [-1, 64, 7, 7]             128
              ReLU-7             [-1, 64, 7, 7]               0
            Conv2d-8             [-1, 64, 7, 7]          36,864
       BatchNorm2d-9             [-1, 64, 7, 7]             128
             ReLU-10             [-1, 64, 7, 7]               0
       BasicBlock-11             [-1, 64, 7, 7]               0
           Conv2d-12             [-1, 64, 7, 7]          36,864
      BatchNorm2d-13             [-1, 64, 7, 7]             128
             ReLU-14             [-1, 6

In [248]:
class ModifiedModel(nn.Module):
    def __init__(self, pretrained):
        super().__init__()
        self.pretrained = pretrained
        self.output = nn.Linear(1000, 10) # output only 10 classifications

    def forward(self, x):
        x = self.pretrained(x)
        return self.output(x)

model = ModifiedModel(pretrained).to(device)

In [252]:
# loss and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=model.parameters(), lr=0.01)

In [253]:
torch.manual_seed(20)

epochs = 5

for epoch in tqdm(range(epochs)):
    print(f"Epoch: {epoch}\n------")
    
    # TRAINING
    train_loss, train_acc = 0, 0
    model.train()
    for batch, (X, y) in enumerate(train_dl):

        X, y = X.to(device), y.to(device)

        # forward pass
        train_pred = model(X)

        # metrics
        loss = loss_fn(train_pred, y)
        train_loss += loss
        train_acc += accuracy_fn(y_true=y, y_pred=train_pred.argmax(dim=1))

        # backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # print metrics
    train_loss /= len(train_dl)
    train_acc /= len(train_dl)
    print(f"Train Loss: {train_loss:.4f} | Train Accuracy: {train_acc:.2f}%")

    
    # TESTING
    test_loss, test_acc = 0, 0
    model.eval()
    with torch.inference_mode():
        for X, y in test_dl:

            X, y = X.to(device), y.to(device)


            # forward pass
            test_pred = model(X)

            # metrics
            test_loss += loss_fn(test_pred, y)
            test_acc += accuracy_fn(y_true=y, y_pred=test_pred.argmax(dim=1))
        
        # print metrics
        test_loss /= len(test_dl)
        test_acc /= len(test_dl)
        print(f"Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.2f}%")

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 0
------
Train Loss: 0.0825 | Train Accuracy: 98.14%


 20%|██        | 1/5 [01:18<05:14, 78.68s/it]

Test Loss: 0.0545 | Test Accuracy: 98.71%
Epoch: 1
------
Train Loss: 0.0693 | Train Accuracy: 98.35%


 40%|████      | 2/5 [02:36<03:54, 78.30s/it]

Test Loss: 0.0501 | Test Accuracy: 98.64%
Epoch: 2
------
Train Loss: 0.0622 | Train Accuracy: 98.42%


 60%|██████    | 3/5 [03:56<02:37, 78.77s/it]

Test Loss: 0.0450 | Test Accuracy: 98.75%
Epoch: 3
------
Train Loss: 0.0586 | Train Accuracy: 98.53%


 80%|████████  | 4/5 [05:17<01:19, 79.69s/it]

Test Loss: 0.0430 | Test Accuracy: 98.77%
Epoch: 4
------
Train Loss: 0.0497 | Train Accuracy: 98.69%


100%|██████████| 5/5 [06:33<00:00, 78.71s/it]

Test Loss: 0.0407 | Test Accuracy: 98.79%





In [254]:
save_model(model, "resnet_mnist.safetensors")