In [1]:
import os
import json
import random
import argparse
import itertools
import math
import torch
import numpy as np
from torch import nn, optim
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torchvision import transforms
from tqdm import tqdm

import utils

from data_utils import (
    WBCdataset,
    pRCCdataset
)
from models import (
    ViT,
    MAE
)

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
def get_WBC_transform(is_train):
    data_transforms = []
    data_transforms.append(transforms.Resize((256, 256)))
    if is_train:
        data_transforms.append(transforms.RandomHorizontalFlip())
    data_transforms.append(transforms.ToTensor())
    data_transforms.append(transforms.Normalize([0.7049, 0.5392, 0.5885], [0.1626, 0.1902, 0.0974], inplace=True))
    return transforms.Compose(data_transforms)

def get_pRCC_transform():
    data_transforms = []
    data_transforms.append(transforms.RandomCrop((256, 256)))
    # data_transforms.append(transforms.RandomHorizontalFlip())
    data_transforms.append(transforms.ToTensor())
    data_transforms.append(transforms.Normalize([0.6843, 0.5012, 0.6436], [0.2148, 0.2623, 0.1969], inplace=True))
    return transforms.Compose(data_transforms)

def run(device, hps):
    pRCC_data = pRCCdataset(hps.pRCCdata.training_files, transform=get_pRCC_transform())
    train_data = WBCdataset(hps.WBCdata.training_files_1, hps.WBCdata.label_dict, transform=get_WBC_transform(True))
    valid_data = WBCdataset(hps.WBCdata.validation_files, hps.WBCdata.label_dict, transform=get_WBC_transform(False))
    
    vit = ViT(
        image_size = hps.WBCdata.image_size,
        patch_size = hps.WBCdata.patch_size,
        num_classes = hps.WBCdata.num_classes,
        **hps.ViTmodel
    ).to(device)
    
    mae = MAE(
        encoder = vit,
        **hps.MAEmodel
    ).to(device)
    
    pRCC_loader = DataLoader(dataset = pRCC_data, batch_size=hps.pretrain.batch_size, shuffle=True)
    
    # pretrain optimizer
    pt_optimizer = optim.Adam(mae.parameters(), lr=hps.pretrain.learning_rate)
    # pretrain scheduler
    pt_scheduler = StepLR(pt_optimizer, step_size=5, gamma=hps.pretrain.lr_decay)
    
    for epoch in range(hps.pretrain.epochs):
        pretrain(device, epoch, mae, pt_optimizer, pt_scheduler, pRCC_loader)
    
    train_loader = DataLoader(dataset = train_data, batch_size=hps.finetune.batch_size, shuffle=True)
    valid_loader = DataLoader(dataset = valid_data, batch_size=hps.finetune.batch_size, shuffle=True)
    
    # loss function
    criterion = nn.CrossEntropyLoss()
    # finetune optimizer
    ft_optimizer = optim.Adam(vit.parameters(), lr=hps.finetune.learning_rate)
    # finetune scheduler
    ft_scheduler = StepLR(ft_optimizer, step_size=1, gamma=hps.finetune.lr_decay)
    
    for epoch in range(hps.finetune.epochs):
        train_and_evaluate(device, epoch, vit, criterion, ft_optimizer, ft_scheduler, [train_loader, valid_loader])


def pretrain(device, epoch, model, optimizer, scheduler, loader):
    epoch_loss = 0
    
    for data, *_ in tqdm(loader):
        data = data.to(device)
        
        recon_loss = model(data)
        
        optimizer.zero_grad()
        recon_loss.backward()
        optimizer.step()
        scheduler.step()
        
        epoch_loss += recon_loss.item() / len(loader)
        
    print(
        f"Pretrain Epoch : {epoch+1} - Reconstruction loss : {epoch_loss:.4f}\n"
    )
        

def train_and_evaluate(device, epoch, model, criterion, optimizer, scheduler, loaders):
    train_loader, valid_loader = loaders
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

In [2]:
hps = utils.get_hparams_from_file('./configs/base.json')
seed_everything(hps.seed)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

run(device, hps)

100%|██████████| 23/23 [00:58<00:00,  2.56s/it]


Pretrain Epoch : 1 - Reconstruction loss : 1.0406



100%|██████████| 23/23 [00:57<00:00,  2.52s/it]


Pretrain Epoch : 2 - Reconstruction loss : 0.9727



100%|██████████| 23/23 [00:58<00:00,  2.56s/it]


Pretrain Epoch : 3 - Reconstruction loss : 0.9569



100%|██████████| 23/23 [01:00<00:00,  2.64s/it]


Pretrain Epoch : 4 - Reconstruction loss : 0.9526



100%|██████████| 23/23 [00:59<00:00,  2.60s/it]


Pretrain Epoch : 5 - Reconstruction loss : 0.9576



100%|██████████| 23/23 [00:58<00:00,  2.52s/it]


Pretrain Epoch : 6 - Reconstruction loss : 0.9551



100%|██████████| 23/23 [00:58<00:00,  2.54s/it]


Pretrain Epoch : 7 - Reconstruction loss : 0.9700



100%|██████████| 23/23 [00:58<00:00,  2.56s/it]


Pretrain Epoch : 8 - Reconstruction loss : 0.9541



100%|██████████| 23/23 [00:56<00:00,  2.48s/it]


Pretrain Epoch : 9 - Reconstruction loss : 0.9619



100%|██████████| 23/23 [00:58<00:00,  2.53s/it]


Pretrain Epoch : 10 - Reconstruction loss : 0.9617



100%|██████████| 23/23 [00:58<00:00,  2.53s/it]


Pretrain Epoch : 11 - Reconstruction loss : 0.9544



100%|██████████| 23/23 [00:57<00:00,  2.50s/it]


Pretrain Epoch : 12 - Reconstruction loss : 0.9509



100%|██████████| 23/23 [00:58<00:00,  2.55s/it]


Pretrain Epoch : 13 - Reconstruction loss : 0.9601



100%|██████████| 23/23 [00:59<00:00,  2.59s/it]


Pretrain Epoch : 14 - Reconstruction loss : 0.9537



100%|██████████| 23/23 [00:59<00:00,  2.59s/it]


Pretrain Epoch : 15 - Reconstruction loss : 0.9567



100%|██████████| 23/23 [01:00<00:00,  2.61s/it]


Pretrain Epoch : 16 - Reconstruction loss : 0.9559



100%|██████████| 23/23 [00:59<00:00,  2.60s/it]


Pretrain Epoch : 17 - Reconstruction loss : 0.9710



100%|██████████| 23/23 [00:59<00:00,  2.58s/it]


Pretrain Epoch : 18 - Reconstruction loss : 0.9485



100%|██████████| 23/23 [00:59<00:00,  2.57s/it]


Pretrain Epoch : 19 - Reconstruction loss : 0.9506



100%|██████████| 23/23 [01:00<00:00,  2.65s/it]


Pretrain Epoch : 20 - Reconstruction loss : 0.9614



100%|██████████| 2/2 [00:00<00:00,  2.60it/s]


Epoch : 1 - loss : 1.5688 - acc: 0.2925 - val_loss : 1.1891 - val_acc: 0.6128



100%|██████████| 2/2 [00:00<00:00,  2.59it/s]


Epoch : 2 - loss : 1.0156 - acc: 0.6979 - val_loss : 1.1292 - val_acc: 0.6128



100%|██████████| 2/2 [00:00<00:00,  2.85it/s]


Epoch : 3 - loss : 1.1918 - acc: 0.6181 - val_loss : 1.1108 - val_acc: 0.6128



100%|██████████| 2/2 [00:00<00:00,  2.80it/s]


Epoch : 4 - loss : 1.0097 - acc: 0.6181 - val_loss : 1.1016 - val_acc: 0.6128



100%|██████████| 2/2 [00:00<00:00,  2.80it/s]


Epoch : 5 - loss : 0.9418 - acc: 0.6580 - val_loss : 1.0938 - val_acc: 0.6128



100%|██████████| 2/2 [00:00<00:00,  2.50it/s]


Epoch : 6 - loss : 1.1078 - acc: 0.6380 - val_loss : 1.0928 - val_acc: 0.6128



100%|██████████| 2/2 [00:00<00:00,  2.75it/s]


Epoch : 7 - loss : 1.0413 - acc: 0.6181 - val_loss : 1.0928 - val_acc: 0.6128



100%|██████████| 2/2 [00:00<00:00,  2.79it/s]


Epoch : 8 - loss : 1.1384 - acc: 0.5582 - val_loss : 1.0891 - val_acc: 0.6128



100%|██████████| 2/2 [00:00<00:00,  2.79it/s]


Epoch : 9 - loss : 1.0392 - acc: 0.6181 - val_loss : 1.0904 - val_acc: 0.6128



100%|██████████| 2/2 [00:00<00:00,  2.65it/s]


Epoch : 10 - loss : 1.1091 - acc: 0.5781 - val_loss : 1.0892 - val_acc: 0.6128



100%|██████████| 2/2 [00:00<00:00,  2.72it/s]


Epoch : 11 - loss : 1.0143 - acc: 0.6181 - val_loss : 1.0874 - val_acc: 0.6128



100%|██████████| 2/2 [00:00<00:00,  2.81it/s]


Epoch : 12 - loss : 1.0474 - acc: 0.6380 - val_loss : 1.0926 - val_acc: 0.6128



100%|██████████| 2/2 [00:00<00:00,  2.80it/s]


Epoch : 13 - loss : 0.9608 - acc: 0.6580 - val_loss : 1.0888 - val_acc: 0.6128



100%|██████████| 2/2 [00:00<00:00,  2.75it/s]


Epoch : 14 - loss : 1.0384 - acc: 0.6181 - val_loss : 1.0890 - val_acc: 0.6128



100%|██████████| 2/2 [00:00<00:00,  2.69it/s]


Epoch : 15 - loss : 1.1136 - acc: 0.5781 - val_loss : 1.0891 - val_acc: 0.6128



100%|██████████| 2/2 [00:00<00:00,  2.78it/s]


Epoch : 16 - loss : 1.0750 - acc: 0.6181 - val_loss : 1.0867 - val_acc: 0.6128



100%|██████████| 2/2 [00:00<00:00,  2.68it/s]


Epoch : 17 - loss : 1.1120 - acc: 0.5582 - val_loss : 1.0904 - val_acc: 0.6128



100%|██████████| 2/2 [00:00<00:00,  2.56it/s]


Epoch : 18 - loss : 0.9712 - acc: 0.6380 - val_loss : 1.0882 - val_acc: 0.6128



100%|██████████| 2/2 [00:00<00:00,  2.61it/s]


Epoch : 19 - loss : 1.1407 - acc: 0.5981 - val_loss : 1.0910 - val_acc: 0.6128



100%|██████████| 2/2 [00:00<00:00,  2.57it/s]


Epoch : 20 - loss : 1.0888 - acc: 0.6380 - val_loss : 1.0896 - val_acc: 0.6128

