In [1]:
import torchvision
import torchvision.transforms as transforms
import os
import yaml
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

os.chdir('../')
from utils import Config
from models.Resnet50 import resnet50

In [15]:
if torch.cuda.is_available():
    device = torch.device('cuda:0')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print('Using device:', device)

Using device: mps


In [2]:
with open("configs/cifar10_base.yml") as f:
        yml_dict = yaml.load(f, Loader=yaml.FullLoader)
config = Config(yml_dict)

In [8]:
dataset_original = torchvision.datasets.CIFAR10(
    root=config.train_dataset_aug_path,
    train=False,
    download=True,
    transform=transforms.ToTensor()
)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar10/train_aug/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:16<00:00, 10317331.50it/s]


Extracting data/cifar10/train_aug/cifar-10-python.tar.gz to data/cifar10/train_aug


### Testing ViT on classification task

In [16]:
clf_model = resnet50(num_classes=10, pretrained=False)
clf_model.to(device)



ResNet50(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (

In [17]:
train_size = int(0.8 * len(dataset_original))
test_size = len(dataset_original) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset_original, [train_size, test_size])

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=config.batch_size,
    shuffle=True
)

test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=config.batch_size,
    shuffle=False
)

In [18]:
epochs = 20
optimizer = optim.Adam(clf_model.parameters())
criterion = nn.CrossEntropyLoss()

for epoch in range(epochs):
    print("Epoch:", epoch)
    running_loss = 0.0
    for i, data in enumerate(train_dataloader, 0):
        inputs, label = data
        inputs, label = inputs.to(device), label.to(device)
        output = clf_model(inputs)

        optimizer.zero_grad()
        loss = F.cross_entropy(output, label)
        loss.backward()

        optimizer.step()

        running_loss += loss.item()
        if i % 25 == 24:
            print("Loss:", running_loss / 25)
            running_loss = 0.0

print("Finished Training")

Epoch: 0
Epoch: 1
Epoch: 2
Epoch: 3
Epoch: 4
Epoch: 5
Epoch: 6
Epoch: 7
Epoch: 8
Epoch: 9
Epoch: 10
Epoch: 11
Epoch: 12
Epoch: 13
Epoch: 14
Epoch: 15
Epoch: 16
Epoch: 17
Epoch: 18
Epoch: 19
Finished Training


In [19]:
correct = 0
total = 0

with torch.no_grad():
    for data in test_dataloader:
        input, label = data
        input, label = input.to(device), label.to(device)
        output = clf_model(input)
        _, predicted = torch.max(output.data, 1)
        total += label.size(0)
        correct += (predicted == label).sum().item()

print(f"Accuracy of the network on the test images: {100 * correct / total}%")

Accuracy of the network on the test images: 43.95%
