In [1]:
import torch
import torchvision.transforms as transforms
from torchvision import datasets
from torcheval.metrics import BinaryAccuracy, BinaryRecall, BinaryPrecision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from model import Net

if torch.cuda.is_available():
    torch.set_default_device('cuda')
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    torch.set_default_device('mps')
    device = torch.device('mps')
else:
    torch.set_default_device('cpu')
    device = torch.device('cpu')

torch.manual_seed(0)
batch_size = 64

Using cache found in /Users/kiran/.cache/torch/hub/pytorch_vision_v0.10.0


In [24]:
erasing_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize((256, 256)),
     transforms.RandomErasing(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize((256, 256)),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

In [25]:
train = datasets.OxfordIIITPet("", split="trainval", transform=erasing_transform, target_types="binary-category", download=True).__add__(datasets.OxfordIIITPet("", split="trainval", transform=transform, target_types="binary-category", download=True))
test = datasets.OxfordIIITPet("", split="test", transform=transform, target_types="binary-category", download=True)

In [28]:
model = torch.load('model.pth', map_location=device, weights_only=False).to(device)

In [29]:
criterion = nn.BCELoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
metrics = {"Accuracy":BinaryAccuracy(device=device), "Recall": BinaryRecall(device=device), "Precision":BinaryPrecision(device=device)}

In [30]:
trainloader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=device).manual_seed(0))

In [31]:
batches = len(trainloader)
for epoch in range(1):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device).to(torch.float32)
        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        for _,metric in metrics.items():
            metric.update(outputs, labels.to(torch.long))
        if i % 10 == 9 or i == batches - 1:
            print(f'[{epoch + 1}, {i + 1}/{batches}] loss: {running_loss / (batches%10 if i == batches-1 else 10)}', end=' ')

            for name,metric in metrics.items():
                print(f'{name}: {metric.compute():.3f}', end=' ')
            print()
            running_loss = 0.0

[1, 10/115] loss: 0.10239905714988709 Accuracy: 0.964 Recall: 0.976 Precision: 0.969 
[1, 20/115] loss: 0.10004920810461045 Accuracy: 0.956 Recall: 0.968 Precision: 0.968 
[1, 30/115] loss: 0.13408401794731617 Accuracy: 0.953 Recall: 0.971 Precision: 0.960 
[1, 40/115] loss: 0.08397868014872074 Accuracy: 0.957 Recall: 0.970 Precision: 0.965 
[1, 50/115] loss: 0.0795778676867485 Accuracy: 0.959 Recall: 0.972 Precision: 0.968 
[1, 60/115] loss: 0.10301439184695482 Accuracy: 0.960 Recall: 0.974 Precision: 0.967 
[1, 70/115] loss: 0.11131671108305455 Accuracy: 0.960 Recall: 0.973 Precision: 0.968 
[1, 80/115] loss: 0.06792753376066685 Accuracy: 0.962 Recall: 0.974 Precision: 0.971 
[1, 90/115] loss: 0.1049001483246684 Accuracy: 0.962 Recall: 0.975 Precision: 0.968 
[1, 100/115] loss: 0.065762023255229 Accuracy: 0.963 Recall: 0.976 Precision: 0.970 
[1, 110/115] loss: 0.09676657989621162 Accuracy: 0.963 Recall: 0.975 Precision: 0.970 
[1, 115/115] loss: 0.09142168462276459 Accuracy: 0.963 R

In [32]:
testloader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=False)
for name,metric in metrics.items():
    metric.reset()
batches = len(testloader)
for i, data in enumerate(testloader, 0):
    inputs, labels = data
    inputs, labels = inputs.to(device), labels.to(device).to(torch.float32)
    outputs = model(inputs)
    for _,metric in metrics.items():
        metric.update(outputs, labels.to(torch.int64))
    if i % 10 == 9 or i == batches - 1:
        print(f'[{i + 1}/{batches}]', end=' ')
        for name,metric in metrics.items():
            print(f'{name}: {metric.compute():.3f}', end=' ')
        print()

[10/58] Accuracy: 0.700 Recall: 0.892 Precision: 0.706 
[20/58] Accuracy: 0.677 Recall: 0.924 Precision: 0.641 
[30/58] Accuracy: 0.735 Recall: 0.889 Precision: 0.767 
[40/58] Accuracy: 0.736 Recall: 0.886 Precision: 0.769 
[50/58] Accuracy: 0.733 Recall: 0.886 Precision: 0.763 
[58/58] Accuracy: 0.718 Recall: 0.881 Precision: 0.748 
