In [1]:
import os
import json
import random
import argparse
import itertools
import math
import torch
import numpy as np
from torch import nn, optim
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torchvision import transforms
from tqdm import tqdm

import utils

from data_utils import (
    CS4243dataset
)
from models import (
    ViT
)

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
def get_transform(is_train):
    data_transforms = []
    data_transforms.append(transforms.Resize((512, 512)))
    if is_train:
        data_transforms.append(transforms.RandomHorizontalFlip())
    data_transforms.append(transforms.ToTensor())
    return transforms.Compose(data_transforms)

def run(device, hps):
    train_data = CS4243dataset(hps.CS4243dataset.dataset_path, hps.CS4243dataset.label_dict, is_train=True, transform=get_transform(True))
    valid_data = CS4243dataset(hps.CS4243dataset.dataset_path, hps.CS4243dataset.label_dict, is_train=False, transform=get_transform(False))
    
    train_loader = DataLoader(dataset = train_data, batch_size=hps.train.batch_size, shuffle=True)
    valid_loader = DataLoader(dataset = valid_data, batch_size=hps.train.batch_size, shuffle=False)
    
    vit = ViT(
        image_size = hps.CS4243dataset.image_size,
        patch_size = hps.CS4243dataset.patch_size,
        num_classes = hps.CS4243dataset.num_classes,
        **hps.ViTmodel
    ).to(device)
    
    # loss function
    criterion = nn.CrossEntropyLoss()
    # optimizer
    optimizer = optim.Adam(vit.parameters(), lr=hps.train.learning_rate)
    # scheduler
    scheduler = StepLR(optimizer, step_size=1, gamma=hps.train.lr_decay)
    
    for epoch in range(hps.train.epochs):
        train_and_evaluate(device, hps, epoch, vit, criterion, optimizer, [train_loader, valid_loader])

def train_and_evaluate(device, hps, epoch, model, criterion, optimizer, loaders):
    train_loader, valid_loader = loaders
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data, label = data.to(device), label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)
    
    if (epoch+1) % 2 == 0:
        epoch_val_accuracy, epoch_val_loss = evaluate(device, model, criterion, valid_loader)
        utils.save_checkpoint(model, optimizer, epoch, os.path.join(hps.save.model_save_path, "model_{}.pth".format(epoch)))

        print(
            f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
        )
    else:
        print(
            f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f}\n"
        )

def evaluate(device, model, criterion, valid_loader):
    model.eval()
    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data, label = data.to(device), label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)
    model.train()
    return epoch_val_accuracy, epoch_val_loss

In [2]:
hps = utils.get_hparams_from_file('./configs/base.json')
seed_everything(hps.train.seed)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

run(device, hps)

100%|██████████| 56/56 [03:24<00:00,  3.65s/it]


Epoch : 1 - loss : 0.8439 - acc: 0.5878



100%|██████████| 56/56 [03:23<00:00,  3.64s/it]


Epoch : 2 - loss : 0.5870 - acc: 0.6857 - val_loss : 0.6475 - val_acc: 0.6262



100%|██████████| 56/56 [03:24<00:00,  3.66s/it]


Epoch : 3 - loss : 0.4548 - acc: 0.7865



100%|██████████| 56/56 [03:22<00:00,  3.62s/it]


Epoch : 4 - loss : 0.3659 - acc: 0.8371 - val_loss : 0.8126 - val_acc: 0.5641



100%|██████████| 56/56 [03:22<00:00,  3.61s/it]


Epoch : 5 - loss : 0.2784 - acc: 0.8882



100%|██████████| 56/56 [03:21<00:00,  3.60s/it]


Epoch : 6 - loss : 0.2536 - acc: 0.8973 - val_loss : 0.9501 - val_acc: 0.5780



100%|██████████| 56/56 [03:23<00:00,  3.63s/it]


Epoch : 7 - loss : 0.1931 - acc: 0.9226



100%|██████████| 56/56 [03:22<00:00,  3.62s/it]


Epoch : 8 - loss : 0.1767 - acc: 0.9272 - val_loss : 1.2290 - val_acc: 0.6293



100%|██████████| 56/56 [03:27<00:00,  3.70s/it]


Epoch : 9 - loss : 0.1456 - acc: 0.9432



100%|██████████| 56/56 [03:24<00:00,  3.65s/it]


Epoch : 10 - loss : 0.1317 - acc: 0.9523 - val_loss : 1.3329 - val_acc: 0.6030



  9%|▉         | 5/56 [00:20<03:28,  4.09s/it]


KeyboardInterrupt: 