In [None]:
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms as T
from torchvision import io
import torchutils as tu
import json
import numpy as np
import matplotlib.pyplot as plt
from torchvision.models import resnet50, ResNet50_Weights


device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [None]:
model = resnet50(weights=ResNet50_Weights.DEFAULT)

trnsfrms = T.Compose(
    [
        T.Resize((224, 224)),
        T.ToTensor()
    ]
)

train_dataset = torchvision.datasets.ImageFolder(
    'data/train',
    transform=trnsfrms
)
valid_dataset = torchvision.datasets.ImageFolder(
    'data/test',
    transform=trnsfrms
)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=True)

model.fc = nn.Linear(2048, 1)

idx2class= {j: i for i, j in train_dataset.class_to_idx.items()}

In [None]:
for param in model.parameters():
    param.requires_grad = False

model.fc.weight.requires_grad = True
model.fc.bias.requires_grad = True

model.to(device);

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()

def compute_batch_accuracy(preds, labels):
    preds_reshaped = preds.squeeze(-1)
    preds_sigmoid  = preds_reshaped.sigmoid()
    preds_classes  = torch.round(preds_sigmoid)
    bool_array     = preds_classes == labels
    number_of_true = bool_array.sum()
    accuracy       = number_of_true / len(labels)
    accuracy_float = accuracy.item()

    return accuracy_float

In [None]:
train_epoch_acc = []
train_epoch_losses = []
valid_epoch_losses = []
valid_epoch_acc =[]
for epoch in range(10):
    model.train()
    loss_batch = []
    acc_batch  = []

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        preds = model(images)
        loss = criterion(preds.squeeze(-1), labels.float())

        loss_batch.append(loss.item())
        acc_batch.append(compute_batch_accuracy(preds, labels))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_epoch_losses.append(np.mean(loss_batch))
    train_epoch_acc.append(np.mean(acc_batch))

    model.eval()
    loss_batch = []
    acc_batch  = []
    for images, labels in valid_loader:

        
        images = images.to(device)
        labels = labels.to(device)
        preds = model(images)

        loss = criterion(preds.squeeze(-1), labels.float())

        loss_batch.append(loss.item())
        acc_batch.append(compute_batch_accuracy(preds, labels))

    valid_epoch_losses.append(np.mean(loss_batch))
    valid_epoch_acc.append(np.mean(acc_batch))

    print(f'Epoch: {epoch}, loss_train: {train_epoch_losses[-1]:.3f}, loss_valid: {valid_epoch_losses[-1]:.3f}')
    print(f'\t metrics_train: {train_epoch_acc[-1]:.3f}, metrics_valid: {valid_epoch_acc[-1]:.3f}')