In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.models import resnet18, ResNet18_Weights
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import utils

torch.manual_seed(421)
# if gpu available else cpu
device = torch.device(0) if torch.cuda.is_available() else torch.device('cpu')
TRAIN_MEAN = [0.5036, 0.4719, 0.3897]
TRAIN_STD = [0.2623, 0.2577, 0.2671]
classes = ['butterfly','cat', 'chicken', 'cow', 'dog', 'elephant', 'horse', 'sheep', 'spider', 'squirrel']

## Przygotowanie danych

In [2]:
# zdefiniowanie operacji na kazdym obrazie w zbiorze
transform = transforms.Compose([
    transforms.RandomResizedCrop(256),  #  na 256x256
    transforms.RandomHorizontalFlip(),  # wycinamy losowy fragment 128x128
    transforms.ToTensor(),              # obrazy zamieniamy na tensory,
    # srednie i odchylenia po kanałach całęgo zbioru,
    #  wyliczone wczesniej za pomocą utils.data_normalize_values
    transforms.Normalize(TRAIN_MEAN, TRAIN_STD)
])

# loader danych z batchami
train_data = ImageFolder(root='dataset/train/', transform=transform)
test_data = ImageFolder(root='dataset/test/', transform=transform)
batch_size = 32
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size)

## Uczenie klasyfikatora

In [3]:
# RESNET18 z wagami przetrenowanymi na zbiorze IMAGENET
res_net_model = resnet18(weights=None)
# res_net_model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
# dostosowanie ostatniego layeru do problemu
res_net_model.fc = nn.Linear(res_net_model.fc.in_features, len(classes))
# inicjalizacja wag w warstwie wyjsciowej
nn.init.xavier_uniform_(res_net_model.fc.weight)


res_net_model = res_net_model.to(device)
metrics = utils.train_fine_tuning(
    model = res_net_model, 
    learning_rate= 5e-4, 
    train_loader=train_loader,
    test_loader=test_loader,
    device=device, num_epochs=5, param_group=False
)

Progress: 1/5 epochs
Epoch: 1, Loss 1.942, Train acc: 0.318, Test acc: 0.383
Progress: 2/5 epochs
Epoch: 2, Loss 1.615, Train acc: 0.442, Test acc: 0.464
Progress: 3/5 epochs
Epoch: 3, Loss 1.465, Train acc: 0.495, Test acc: 0.466
Progress: 4/5 epochs
Epoch: 4, Loss 1.355, Train acc: 0.536, Test acc: 0.547
Progress: 5/5 epochs
Epoch: 5, Loss 1.261, Train acc: 0.566, Test acc: 0.502
