In [52]:
import torch
import torchvision.transforms as transforms
from torchvision import datasets
from torcheval.metrics import BinaryAccuracy, BinaryRecall, BinaryPrecision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from model import Net

if torch.cuda.is_available():
    torch.set_default_device('cuda')
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    torch.set_default_device('mps')
    device = torch.device('mps')
else:
    torch.set_default_device('cpu')
    device = torch.device('cpu')

torch.manual_seed(0)
batch_size = 64

In [53]:
affine_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize((256, 256)),
     transforms.RandomAffine(degrees=(180,180), translate=(0.1, 0.1), scale=(0.8, 1.2)),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize((256, 256)),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [54]:
train = datasets.OxfordIIITPet("", split="trainval", transform=affine_transform, target_types="binary-category", download=True).__add__(datasets.OxfordIIITPet("", split="trainval", transform=transform, target_types="binary-category", download=True))
test = datasets.OxfordIIITPet("", split="test", transform=transform, target_types="binary-category", download=True)

In [55]:
model = torch.load('model.pth', map_location=device, weights_only=False).to(device)

In [56]:
criterion = nn.BCELoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
metrics = {"Accuracy":BinaryAccuracy(device=device), "Recall": BinaryRecall(device=device), "Precision":BinaryPrecision(device=device)}

In [57]:
trainloader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=device).manual_seed(0))

In [58]:
batches = len(trainloader)
for epoch in range(1):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device).to(torch.float32)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        for _,metric in metrics.items():
            metric.update(outputs, labels.to(torch.long))
        if i % 10 == 9 or i == batches - 1:
            print(f'[{epoch + 1}, {i + 1}/{batches}] loss: {running_loss / (batches%10 if i == batches-1 else 10)}', end=' ')

            for name,metric in metrics.items():
                print(f'{name}: {metric.compute():.3f}', end=' ')
            print()
            running_loss = 0.0

[1, 10/115] loss: 0.27667738646268847 Accuracy: 0.880 Recall: 0.922 Precision: 0.898 
[1, 20/115] loss: 0.21343095153570174 Accuracy: 0.898 Recall: 0.938 Precision: 0.913 
[1, 30/115] loss: 0.24370730966329573 Accuracy: 0.896 Recall: 0.940 Precision: 0.908 
[1, 40/115] loss: 0.19823341965675353 Accuracy: 0.898 Recall: 0.935 Precision: 0.916 
[1, 50/115] loss: 0.2171710893511772 Accuracy: 0.902 Recall: 0.940 Precision: 0.917 
[1, 60/115] loss: 0.19047725349664688 Accuracy: 0.905 Recall: 0.944 Precision: 0.918 
[1, 70/115] loss: 0.1857640914618969 Accuracy: 0.908 Recall: 0.944 Precision: 0.922 
[1, 80/115] loss: 0.15513026118278503 Accuracy: 0.912 Recall: 0.947 Precision: 0.924 
[1, 90/115] loss: 0.16650660187005997 Accuracy: 0.914 Recall: 0.949 Precision: 0.925 
[1, 100/115] loss: 0.14292916655540466 Accuracy: 0.916 Recall: 0.950 Precision: 0.929 
[1, 110/115] loss: 0.18078498542308807 Accuracy: 0.917 Recall: 0.950 Precision: 0.929 
[1, 115/115] loss: 0.153315232694149 Accuracy: 0.917 R

In [59]:
testloader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False)
for name,metric in metrics.items():
    metric.reset()
batches = len(testloader)
for i, data in enumerate(testloader, 0):
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device).to(torch.float32)
    outputs = model(inputs)
    for _,metric in metrics.items():
        metric.update(outputs, labels.to(torch.int64))
    if i % 10 == 9 or i == batches - 1:
        print(f'[{i + 1}/{batches}]', end=' ')
        for name,metric in metrics.items():
            print(f'{name}: {metric.compute():.3f}', end=' ')
        print()

[10/58] Accuracy: 0.670 Recall: 0.827 Precision: 0.700 
[20/58] Accuracy: 0.666 Recall: 0.871 Precision: 0.642 
[30/58] Accuracy: 0.706 Recall: 0.831 Precision: 0.767 
[40/58] Accuracy: 0.716 Recall: 0.837 Precision: 0.772 
[50/58] Accuracy: 0.712 Recall: 0.836 Precision: 0.767 
[58/58] Accuracy: 0.698 Recall: 0.830 Precision: 0.750 
