In [1]:
import os
import json
import random
import argparse
import itertools
import math
import torch
import numpy as np
from torch import nn, optim
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torchvision import transforms
from tqdm import tqdm
from einops import rearrange

import utils

from data_utils import (
    WBCdataset,
    pRCCdataset
)
from models import (
    ViT,
    MAE
)

WBC_mean = np.array([0.7049, 0.5392, 0.5885])
WBC_std = np.array([0.1626, 0.1902, 0.0974])

pRCC_mean = np.array([0.6843, 0.5012, 0.6436])
pRCC_std = np.array([0.2148, 0.2623, 0.1969])

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
def get_WBC_transform(is_train):
    data_transforms = []
    data_transforms.append(transforms.Resize((256, 256)))
    if is_train:
        data_transforms.append(transforms.RandomHorizontalFlip())
    data_transforms.append(transforms.ToTensor())
    data_transforms.append(transforms.Normalize(WBC_mean, WBC_std, inplace=True))
    return transforms.Compose(data_transforms)

def get_pRCC_transform():
    data_transforms = []
    data_transforms.append(transforms.RandomCrop((256, 256)))
    # data_transforms.append(transforms.RandomHorizontalFlip())
    data_transforms.append(transforms.ToTensor())
    data_transforms.append(transforms.Normalize(pRCC_mean, pRCC_std, inplace=True))
    return transforms.Compose(data_transforms)

def run(device, hps):
    pRCC_data = pRCCdataset(hps.pRCCdata.training_files, transform=get_pRCC_transform())
    train_data = WBCdataset(hps.WBCdata.training_files_1, hps.WBCdata.label_dict, transform=get_WBC_transform(True))
    valid_data = WBCdataset(hps.WBCdata.validation_files, hps.WBCdata.label_dict, transform=get_WBC_transform(False))
    
    vit = ViT(
        image_size = hps.WBCdata.image_size,
        patch_size = hps.WBCdata.patch_size,
        num_classes = hps.WBCdata.num_classes,
        **hps.ViTmodel
    ).to(device)
    
    mae = MAE(
        encoder = vit,
        **hps.MAEmodel
    ).to(device)
    
    pRCC_loader = DataLoader(dataset = pRCC_data, batch_size=hps.pretrain.batch_size, shuffle=True)
    
    # pretrain optimizer
    pt_optimizer = optim.Adam(mae.parameters(), lr=hps.pretrain.learning_rate)
    # pretrain scheduler
    pt_scheduler = StepLR(pt_optimizer, step_size=5, gamma=hps.pretrain.lr_decay)
    
    for epoch in range(hps.pretrain.epochs):
        pretrain(device, epoch, mae, pt_optimizer, pt_scheduler, pRCC_loader)
    
    train_loader = DataLoader(dataset = train_data, batch_size=hps.finetune.batch_size, shuffle=True)
    valid_loader = DataLoader(dataset = valid_data, batch_size=hps.finetune.batch_size, shuffle=True)
    
    # loss function
    criterion = nn.CrossEntropyLoss()
    # finetune optimizer
    ft_optimizer = optim.Adam(vit.parameters(), lr=hps.finetune.learning_rate)
    # finetune scheduler
    ft_scheduler = StepLR(ft_optimizer, step_size=1, gamma=hps.finetune.lr_decay)
    
    for epoch in range(hps.finetune.epochs):
        train_and_evaluate(device, epoch, vit, criterion, ft_optimizer, ft_scheduler, [train_loader, valid_loader])


        
def pretrain(device, epoch, model, optimizer, scheduler, loader):
    epoch_loss = 0
    
    for data, *_ in tqdm(loader):
        data = data.to(device)
        
        recon_loss, _ = model(data)
        
        optimizer.zero_grad()
        recon_loss.backward()
        optimizer.step()
        scheduler.step()
        
        epoch_loss += recon_loss.item() / len(loader)
        
    img = data.cpu()[0]
    model.eval()
    run_one_image(device, img, model, set_name='pRCC')
    model.train()
        
    print(
        f"Pretrain Epoch : {epoch+1} - Reconstruction loss : {epoch_loss:.4f}\n"
    )

@torch.no_grad()
def run_one_image(device, img, model, set_name='pRCC'):
    assert model.train is False
    
    if set_name=='pRCC':
        std = pRCC_std
        mean = pRCC_mean
    else:
        raise "NonImplemented."
    
    x = img
    img = rearrange(img, 'c h w -> h w c')
    utils.show_image(img, std, mean, "input")
    
    x = x.unsqueeze(dim=0).to(device)
    loss, y = model(x)
    y = y.detach().cpu()[0]
    
    y = rearrange(y, 'c h w -> h w c')
    utils.show_image(y, std, mean, "reconstructed")
    return
    

def train_and_evaluate(device, epoch, model, criterion, optimizer, scheduler, loaders):
    train_loader, valid_loader = loaders
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        model.eval()
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)
        model.train()

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

In [None]:
hps = utils.get_hparams_from_file('./configs/base.json')
seed_everything(hps.seed)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

run(device, hps)

  9%|▊         | 2/23 [00:35<06:13, 17.78s/it]