In [1]:
import os
import json
import random
import argparse
import itertools
import math
import torch
import numpy as np
from torch import nn, optim
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torchvision import transforms
from tqdm import tqdm

import utils

from data_utils import (
    WBCdataset
)
from models import (
    ViT
)

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
def get_WBC_transform(is_train):
    data_transforms = []
    data_transforms.append(transforms.Resize((256, 256)))
    if is_train:
        data_transforms.append(transforms.RandomHorizontalFlip())
    data_transforms.append(transforms.ToTensor())
    data_transforms.append(transforms.Normalize([0.7049, 0.5392, 0.5885], [0.1626, 0.1902, 0.0974], inplace=True))
    return transforms.Compose(data_transforms)

def get_pRCC_transform():
    data_transforms = []
    data_transforms.append(transforms.RandomCrop((256, 256)))
    data_transforms.append(transforms.RandomHorizontalFlip())
    data_transforms.append(transforms.ToTensor())
    data_transforms.append(transforms.Normalize([0.6843, 0.5012, 0.6436], [0.2148, 0.2623, 0.1969], inplace=True))
    return transforms.Compose(data_transforms)

def run(device, hps):
    train_data = WBCdataset(hps.WBCdata.training_files, hps.WBCdata.label_dict, transform=get_WBC_transform(True))
    valid_data = WBCdataset(hps.WBCdata.validation_files, hps.WBCdata.label_dict, transform=get_WBC_transform(False))
    
    train_loader = DataLoader(dataset = train_data, batch_size=hps.train.batch_size, shuffle=True)
    valid_loader = DataLoader(dataset = valid_data, batch_size=hps.train.batch_size, shuffle=True)
    
    vit = ViT(
        image_size = hps.WBCdata.image_size,
        patch_size = hps.WBCdata.patch_size,
        num_classes = hps.WBCdata.num_classes,
        **hps.ViTmodel
    ).cuda()
    
    # loss function
    criterion = nn.CrossEntropyLoss()
    # optimizer
    optimizer = optim.Adam(vit.parameters(), lr=hps.train.learning_rate)
    # scheduler
    scheduler = StepLR(optimizer, step_size=1, gamma=hps.train.lr_decay)
    
    for epoch in range(hps.train.epochs):
        train_and_evaluate(device, epoch, vit, criterion, optimizer, [train_loader, valid_loader])

def train_and_evaluate(device, epoch, model, criterion, optimizer, loaders):
    train_loader, valid_loader = loaders
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data, label = data.cuda(), label.cuda()

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.cuda()
            label = label.cuda()

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

In [2]:
hps = utils.get_hparams_from_file('./configs/base.json')
seed_everything(hps.train.seed)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

run(device, hps)

100%|██████████| 132/132 [01:42<00:00,  1.28it/s]


Epoch : 1 - loss : 0.9812 - acc: 0.6601 - val_loss : 0.4312 - val_acc: 0.8264



100%|██████████| 132/132 [01:43<00:00,  1.28it/s]


Epoch : 2 - loss : 0.3581 - acc: 0.8667 - val_loss : 0.3175 - val_acc: 0.8843



100%|██████████| 132/132 [01:44<00:00,  1.27it/s]


Epoch : 3 - loss : 0.2953 - acc: 0.8920 - val_loss : 0.2816 - val_acc: 0.9010



100%|██████████| 132/132 [01:43<00:00,  1.28it/s]


Epoch : 4 - loss : 0.2498 - acc: 0.9048 - val_loss : 0.2188 - val_acc: 0.9178



100%|██████████| 132/132 [01:43<00:00,  1.27it/s]


Epoch : 5 - loss : 0.2455 - acc: 0.9088 - val_loss : 0.2296 - val_acc: 0.9242



100%|██████████| 132/132 [01:43<00:00,  1.27it/s]


Epoch : 6 - loss : 0.2331 - acc: 0.9193 - val_loss : 0.1966 - val_acc: 0.9352



100%|██████████| 132/132 [01:43<00:00,  1.28it/s]


Epoch : 7 - loss : 0.2257 - acc: 0.9168 - val_loss : 0.2475 - val_acc: 0.9057



  4%|▍         | 5/132 [00:03<01:32,  1.37it/s]


KeyboardInterrupt: 