In [None]:
from utils.dataset import *
from utils.train import *

import torch
import torchvision
import torchvision.transforms as transforms

In [None]:
batch_size = 64

epochs = 50
lr = 0.0001
weight_decay = 0

In [None]:
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])


train_dataset = iMatDataset(data_type='train', transform=train_transform)
val_dataset = iMatDataset(data_type='val', transform=val_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
model = Baseline(model='resnet34', num_classes=train_dataset.n_labels, epoch_print=1)

In [None]:
model.train(train_loader, val_loader, epochs=epochs, lr=lr, weight_decay=weight_decay)

In [None]:
label_fontsize = 25

plt.figure(figsize=(20, 10))
train_lossline, = plt.plot(model.train_losses, label='Train')
test_lossline, = plt.plot(model.test_losses, color='red', label='Test')
plt.legend(handles=[train_lossline, test_lossline], fontsize=20)
plt.xlabel('Step', fontsize=label_fontsize)
plt.ylabel('Loss', fontsize=label_fontsize)
plt.show()

In [None]:
plt.figure(figsize=(20, 10))
train_f1line, = plt.plot(model.train_f1, label='Train')
test_f1line, = plt.plot(model.test_f1, color='red', label='Test')
plt.legend(handles=[train_f1line, test_f1line], fontsize=20)
plt.xlabel('Step', fontsize=label_fontsize)
plt.ylabel('F1', fontsize=label_fontsize)
plt.show()