In [1]:
import torch
import torchvision.transforms as transforms
from torchvision import datasets
from torcheval.metrics import BinaryAccuracy, BinaryRecall, BinaryPrecision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from model import Net

if torch.cuda.is_available():
    torch.set_default_device('cuda')
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    torch.set_default_device('mps')
    device = torch.device('mps')
else:
    torch.set_default_device('cpu')
    device = torch.device('cpu')

torch.manual_seed(0)
batch_size = 64

Downloading: "https://github.com/pytorch/vision/zipball/v0.10.0" to /home/ptaech/.var/app/com.visualstudio.code/cache/torch/hub/v0.10.0.zip
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/ptaech/.var/app/com.visualstudio.code/cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100.0%


In [2]:
photometric_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize((256, 256)),
     transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize((256, 256)),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [3]:
train = datasets.OxfordIIITPet("", split="trainval", transform=photometric_transform, target_types="binary-category", download=True).__add__(datasets.OxfordIIITPet("", split="trainval", transform=transform, target_types="binary-category", download=True))
test = datasets.OxfordIIITPet("", split="test", transform=transform, target_types="binary-category", download=True)

In [4]:
model = torch.load('model.pth', map_location=device, weights_only=False).to(device)

In [5]:
criterion = nn.BCELoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
metrics = {"Accuracy":BinaryAccuracy(device=device), "Recall": BinaryRecall(device=device), "Precision":BinaryPrecision(device=device)}

In [6]:
trainloader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=device).manual_seed(0))

In [7]:
batches = len(trainloader)
for epoch in range(1):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device).to(torch.float32)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        for _,metric in metrics.items():
            metric.update(outputs, labels.to(torch.long))
        if i % 10 == 9 or i == batches - 1:
            print(f'[{epoch + 1}, {i + 1}/{batches}] loss: {running_loss / (batches%10 if i == batches-1 else 10)}', end=' ')

            for name,metric in metrics.items():
                print(f'{name}: {metric.compute():.3f}', end=' ')
            print()
            running_loss = 0.0

[1, 10/115] loss: 0.06899387463927269 Accuracy: 0.978 Recall: 0.993 Precision: 0.975 
[1, 20/115] loss: 0.0753303935751319 Accuracy: 0.974 Recall: 0.984 Precision: 0.978 
[1, 30/115] loss: 0.08010146953165531 Accuracy: 0.973 Recall: 0.983 Precision: 0.978 
[1, 40/115] loss: 0.051252171769738196 Accuracy: 0.976 Recall: 0.987 Precision: 0.979 
[1, 50/115] loss: 0.07394364848732948 Accuracy: 0.975 Recall: 0.987 Precision: 0.977 
[1, 60/115] loss: 0.0698537714779377 Accuracy: 0.975 Recall: 0.985 Precision: 0.978 
[1, 70/115] loss: 0.05094455098733306 Accuracy: 0.976 Recall: 0.986 Precision: 0.979 
[1, 80/115] loss: 0.06284597590565681 Accuracy: 0.976 Recall: 0.985 Precision: 0.980 
[1, 90/115] loss: 0.06458655744791031 Accuracy: 0.976 Recall: 0.986 Precision: 0.979 
[1, 100/115] loss: 0.05009535551071167 Accuracy: 0.977 Recall: 0.986 Precision: 0.980 
[1, 110/115] loss: 0.05237346701323986 Accuracy: 0.977 Recall: 0.986 Precision: 0.981 
[1, 115/115] loss: 0.06682261880487203 Accuracy: 0.97

In [8]:
testloader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False)
for name,metric in metrics.items():
    metric.reset()
batches = len(testloader)
for i, data in enumerate(testloader, 0):
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device).to(torch.float32)
    outputs = model(inputs)
    for _,metric in metrics.items():
        metric.update(outputs, labels.to(torch.int64))
    if i % 10 == 9 or i == batches - 1:
        print(f'[{i + 1}/{batches}]', end=' ')
        for name,metric in metrics.items():
            print(f'{name}: {metric.compute():.3f}', end=' ')
        print()

[10/58] Accuracy: 0.686 Recall: 0.877 Precision: 0.698 
[20/58] Accuracy: 0.666 Recall: 0.912 Precision: 0.635 
[30/58] Accuracy: 0.724 Recall: 0.878 Precision: 0.762 
[40/58] Accuracy: 0.731 Recall: 0.880 Precision: 0.767 
[50/58] Accuracy: 0.729 Recall: 0.879 Precision: 0.763 
[58/58] Accuracy: 0.715 Recall: 0.874 Precision: 0.748 
