In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
from torchvision import models
import torchvision.transforms as transforms

from torch.optim import lr_scheduler

from torch.utils.data import random_split, DataLoader
from dataset_handlers.vgg16.vgg16_feature_dataset import FeatureDataset

from sklearn.model_selection import KFold

import os
from PIL import Image
from matplotlib import pyplot as plt

In [48]:
vit_b = models.vit_b_16(pretrained=True)

In [49]:
vit_b.heads = nn.Sequential(
    nn.Linear(768, 2)
)

In [50]:
for param in vit_b.conv_proj.parameters():
    param.requires_grad = True

for param in vit_b.encoder.parameters():
    param.requires_grad = False

In [51]:
for param in vit_b.heads.parameters():
    param.requires_grad = True

In [52]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [53]:
def accuracy(model, data_loader, transform):
    acc = 0
    for i, (image, label) in enumerate(data_loader, 1):
        image = image.to(device)
        label = label.to(device)

        image = transform(image)

        output = model(image).reshape(-1, 2)
        acc += (torch.argmax(output, dim=1) == label).sum().item()

    return acc / len(data_loader.dataset)

In [54]:
def train_model(num_epochs, model, criterion, optimizer, acc_training_set, acc_val_set, l1_factor, train_loader, val_loader):
    for epoch in range(num_epochs):
        for i, (image, label) in enumerate(train_loader, 1):
            image = image.to(device)
            label = label.to(device)

            output = model(image)

            loss = criterion(output, label)

            l1_regularization = torch.tensor(0., requires_grad=False)
            for param in model.parameters():
                l1_regularization += torch.norm(param, 1)

            l1_regularization.requires_grad_(True)
            loss += l1_factor * l1_regularization

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % 1 == 0:
                print('Epoch: {:2.0f}/{}, Batch: {:3.0f}, Loss: {:.6f}'
                      .format(epoch+1, num_epochs, i, loss.item()))

In [58]:
epochs = 5
l1_factor = 0.0001
l2_factor = 0.01

In [59]:
critereon = nn.CrossEntropyLoss()
optimizer = optim.AdamW(vit_b.parameters(), lr=0.0001, weight_decay=l2_factor)

In [63]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [64]:
dataset = torchvision.datasets.ImageFolder(
    root='data/splitted/train',
    transform=transform
)

In [66]:
acc_training_set = []
acc_val_set = []

for fold in range(1):
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(dataset=val_dataset, batch_size=32, shuffle=True)

    train_model(epochs, vit_b, critereon, optimizer, acc_training_set, acc_val_set, l1_factor, train_loader, val_loader)

Epoch:  1/5, Batch:   1, Loss: 182.879929
Epoch:  1/5, Batch:   2, Loss: 182.923660
Epoch:  1/5, Batch:   3, Loss: 183.078873
Epoch:  1/5, Batch:   4, Loss: 182.918396
Epoch:  1/5, Batch:   5, Loss: 182.949097
Epoch:  1/5, Batch:   6, Loss: 182.955841
Epoch:  1/5, Batch:   7, Loss: 182.913269
Epoch:  1/5, Batch:   8, Loss: 182.919449
Epoch:  1/5, Batch:   9, Loss: 182.944977
Epoch:  1/5, Batch:  10, Loss: 182.926666
Epoch:  1/5, Batch:  11, Loss: 182.929642
Epoch:  1/5, Batch:  12, Loss: 182.932648
Epoch:  1/5, Batch:  13, Loss: 182.921082
Epoch:  1/5, Batch:  14, Loss: 182.902588
Epoch:  1/5, Batch:  15, Loss: 182.933762
Epoch:  1/5, Batch:  16, Loss: 182.917984
Epoch:  1/5, Batch:  17, Loss: 182.903412
Epoch:  1/5, Batch:  18, Loss: 182.932938
Epoch:  1/5, Batch:  19, Loss: 182.937546
Epoch:  1/5, Batch:  20, Loss: 182.952896
Epoch:  1/5, Batch:  21, Loss: 182.963013
Epoch:  1/5, Batch:  22, Loss: 182.984543
Epoch:  1/5, Batch:  23, Loss: 182.883240
Epoch:  1/5, Batch:  24, Loss: 182

In [44]:
test_dataset = torchvision.datasets.ImageFolder(
    root='data/splitted/test',
    transform=transform
)

test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=True)