In [None]:
import os
import json
import random
import argparse
import itertools
import math
import torch
import numpy as np
from torch import nn, optim
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torchvision import transforms
from tqdm import tqdm

import utils

from data_utils import (
    WBCdataset
)
from models import (
    ViT
)

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
def get_transform(is_train):
    data_transforms = []
    data_transforms.append(transforms.Resize((512, 512)))
    if is_train:
        data_transforms.append(transforms.RandomHorizontalFlip())
    data_transforms.append(transforms.ToTensor())
    return transforms.Compose(data_transforms)

def run(device, hps):
    train_data = WBCdataset(hps.WBCdata.training_files, hps.WBCdata.label_dict, transform=get_transform(True))
    valid_data = WBCdataset(hps.WBCdata.validation_files, hps.WBCdata.label_dict, transform=get_transform(False))
    
    train_loader = DataLoader(dataset = train_data, batch_size=hps.train.batch_size, shuffle=True)
    valid_loader = DataLoader(dataset = valid_data, batch_size=hps.train.batch_size, shuffle=True)
    
    model = ViT(
        image_size = hps.WBCdata.image_size,
        patch_size = hps.WBCdata.patch_size,
        num_classes = hps.WBCdata.num_classes,
        **hps.ViTmodel
    ).to(device)
    
    # loss function
    criterion = nn.CrossEntropyLoss()
    # optimizer
    optimizer = optim.Adam(model.parameters(), lr=hps.train.learning_rate)
    # scheduler
    scheduler = StepLR(optimizer, step_size=1, gamma=hps.train.lr_decay)
    
    for epoch in range(hps.train.epochs):
        train_and_evaluate(device, epoch, model, criterion, optimizer, [train_loader, valid_loader])

def train_and_evaluate(device, epoch, model, criterion, optimizer, loaders):
    train_loader, valid_loader = loaders
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data, label = data.to(device), label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

In [None]:
hps = utils.get_hparams_from_file('./configs/base.json')
seed_everything(hps.train.seed)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

run(device, hps)