In [3]:
import os
import json
import random
import argparse
import itertools
import math
import torch
import numpy as np
from torch import nn, optim
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt

import utils

from data_utils import WBCdataset

from transformers import ViTForImageClassification, ViTMAEConfig

from torch.utils.tensorboard import SummaryWriter

WBC_mean = np.array([0.7049, 0.5392, 0.5885])
WBC_std = np.array([0.1626, 0.1902, 0.0974])

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
def get_WBC_transform(is_train):
    data_transforms = []
    data_transforms.append(transforms.Resize((224, 224)))
    if is_train:
        data_transforms.append(transforms.RandomHorizontalFlip())
    data_transforms.append(transforms.ToTensor())
    data_transforms.append(transforms.Normalize(WBC_mean, WBC_std, inplace=True))
    return transforms.Compose(data_transforms)

def run(device, hps):
    wbc_subset = "wbc50"
    pretrain_options = "pRCC"
    use_mask=True
    
    out_dir = os.path.join(hps.out_dir, f'{wbc_subset}', f'{pretrain_options}')
    if use_mask:
        out_dir = os.path.join(out_dir, 'mask')
    os.makedirs(out_dir, exist_ok = True)
    writer = SummaryWriter(out_dir)
    
    if wbc_subset == "wbc1":
        training_files = hps.WBCdata.training_files_1
    elif wbc_subset == "wbc10":
        training_files = hps.WBCdata.training_files_10
    elif wbc_subset == "wbc50":
        training_files = hps.WBCdata.training_files_50
    else:
        training_files = hps.WBCdata.training_files_100
    
    train_data = WBCdataset(hps.WBCdata.training_files_10, hps.WBCdata.label_dict, transform=get_WBC_transform(True))
    valid_data = WBCdataset(hps.WBCdata.validation_files, hps.WBCdata.label_dict, transform=get_WBC_transform(False))
    
    label2id = {}
    id2label = {}

    for label in hps.WBCdata.label_dict.keys():
        label2id[label] = hps.WBCdata.label_dict[label]
        id2label[hps.WBCdata.label_dict[label]] = label
    
    if pretrain_options == "pRCC":
        model = ViTForImageClassification.from_pretrained("Mo0310/vitmae_pRCC_80epochs", 
            label2id=label2id,
            id2label=id2label,
            ignore_mismatched_sizes = True,
        ).to(device)
    elif pretrain_options == "facebook":
        model = ViTForImageClassification.from_pretrained("facebook/vit-mae-base", 
            label2id=label2id,
            id2label=id2label,
            ignore_mismatched_sizes = True,
        ).to(device)
    else:
        config = ViTMAEConfig.from_pretrained("facebook/vit-mae-base",
            label2id=label2id,
            id2label=id2label,
            ignore_mismatched_sizes = True)
        model = ViTForImageClassification(config).to(device)
    
    train_loader = DataLoader(dataset = train_data, batch_size=hps.finetune.batch_size, shuffle=True)
    valid_loader = DataLoader(dataset = valid_data, batch_size=hps.finetune.batch_size, shuffle=False)
    
    # loss function
    criterion = nn.CrossEntropyLoss()
    # finetune optimizer
    ft_optimizer = optim.AdamW(model.parameters(), lr=hps.finetune.learning_rate)
    # finetune scheduler
    #ft_scheduler = optim.lr_scheduler.MultiStepLR(ft_optimizer, milestones=[1, 2], gamma=hps.pretrain.lr_decay)
    ft_scheduler = StepLR(ft_optimizer, step_size=5, gamma=hps.finetune.lr_decay)
    
    for epoch in range(hps.finetune.epochs):
        train_and_evaluate(device, epoch, model, criterion, ft_optimizer, ft_scheduler, [train_loader, valid_loader], writer)
        
    return model
    

def train_and_evaluate(device, epoch, model, criterion, optimizer, scheduler, loaders, writer):
    train_loader, valid_loader = loaders
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        optimizer.zero_grad()
        
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output.logits, label)
        
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()

        acc = (output.logits.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        model.eval()
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output.logits, label)

            acc = (val_output.logits.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)
        model.train()
        
    scheduler.step()
    
    writer.add_scalar('./Loss/train', epoch_loss, epoch+1)
    writer.add_scalar('./ACC/train', epoch_accuracy, epoch+1)
    writer.add_scalar('./Loss/val', epoch_val_loss, epoch+1)
    writer.add_scalar('./ACC/val', epoch_val_accuracy, epoch+1)
    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

In [None]:
hps = utils.get_hparams_from_file('./configs/base.json')
seed_everything(hps.seed)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = run(device, hps)

In [None]:
token = "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"
model.push_to_hub("5242_w_pRCC_wbc50_mask", token=token)

In [None]:
print(model)