In [1]:
import os
import json
import random
import argparse
import itertools
import math
import torch
import numpy as np
from torch import nn, optim
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torchvision import transforms
from tqdm import tqdm
from einops import rearrange
import matplotlib.pyplot as plt

import utils

from data_utils import WBCdataset

from transformers import ViTForImageClassification

WBC_mean = np.array([0.7049, 0.5392, 0.5885])
WBC_std = np.array([0.1626, 0.1902, 0.0974])

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
def get_WBC_transform(is_train):
    data_transforms = []
    data_transforms.append(transforms.Resize((224, 224)))
    if is_train:
        data_transforms.append(transforms.RandomHorizontalFlip())
    data_transforms.append(transforms.ToTensor())
    data_transforms.append(transforms.Normalize(WBC_mean, WBC_std, inplace=True))
    return transforms.Compose(data_transforms)

def run(device, hps):
    train_data = WBCdataset(hps.WBCdata.training_files_10, hps.WBCdata.label_dict, transform=get_WBC_transform(True))
    valid_data = WBCdataset(hps.WBCdata.validation_files, hps.WBCdata.label_dict, transform=get_WBC_transform(False))
    
    label2id = {}
    id2label = {}

    for label in hps.WBCdata.label_dict.keys():
        label2id[label] = hps.WBCdata.label_dict[label]
        id2label[hps.WBCdata.label_dict[label]] = label
    
    model = ViTForImageClassification.from_pretrained("Mo0310/vitmae_pRCC_80epochs", 
        label2id=label2id,
        id2label=id2label,
        ignore_mismatched_sizes = True,
     ).to(device)
    
    train_loader = DataLoader(dataset = train_data, batch_size=hps.finetune.batch_size, shuffle=True)
    valid_loader = DataLoader(dataset = valid_data, batch_size=hps.finetune.batch_size, shuffle=False)
    
    # loss function
    criterion = nn.CrossEntropyLoss()
    # finetune optimizer
    ft_optimizer = optim.Adam(model.parameters(), lr=hps.finetune.learning_rate)
    # finetune scheduler
    ft_scheduler = StepLR(ft_optimizer, step_size=10, gamma=hps.finetune.lr_decay)
    
    for epoch in range(hps.finetune.epochs):
        train_and_evaluate(device, epoch, model, criterion, ft_optimizer, ft_scheduler, [train_loader, valid_loader])
        
    return model
    

def train_and_evaluate(device, epoch, model, criterion, optimizer, scheduler, loaders):
    train_loader, valid_loader = loaders
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        optimizer.zero_grad()
        
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output.logits, label)
        
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        scheduler.step()

        acc = (output.logits.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        model.eval()
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output.logits, label)

            acc = (val_output.logits.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)
        model.train()

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
hps = utils.get_hparams_from_file('./configs/base.json')
seed_everything(hps.seed)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = run(device, hps)

You are using a model of type vit_mae to instantiate a model of type vit. This is not supported for all configurations of models and can yield errors.
Some weights of ViTForImageClassification were not initialized from the model checkpoint at Mo0310/vitmae_pRCC_80epochs and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 14/14 [00:13<00:00,  1.02it/s]


Epoch : 1 - loss : 0.8815 - acc: 0.6386 - val_loss : 0.8786 - val_acc: 0.7187



100%|██████████| 14/14 [00:12<00:00,  1.14it/s]


Epoch : 2 - loss : 0.3162 - acc: 0.8897 - val_loss : 0.2308 - val_acc: 0.9329



100%|██████████| 14/14 [00:12<00:00,  1.09it/s]


Epoch : 3 - loss : 0.1111 - acc: 0.9699 - val_loss : 0.1520 - val_acc: 0.9525



100%|██████████| 14/14 [00:12<00:00,  1.08it/s]


Epoch : 4 - loss : 0.0476 - acc: 0.9888 - val_loss : 0.1331 - val_acc: 0.9572



100%|██████████| 14/14 [00:13<00:00,  1.03it/s]


Epoch : 5 - loss : 0.0228 - acc: 0.9967 - val_loss : 0.1159 - val_acc: 0.9664



100%|██████████| 14/14 [00:13<00:00,  1.04it/s]


Epoch : 6 - loss : 0.0131 - acc: 0.9989 - val_loss : 0.1143 - val_acc: 0.9676



100%|██████████| 14/14 [00:13<00:00,  1.03it/s]


Epoch : 7 - loss : 0.0077 - acc: 1.0000 - val_loss : 0.1109 - val_acc: 0.9676



100%|██████████| 14/14 [00:12<00:00,  1.08it/s]


Epoch : 8 - loss : 0.0071 - acc: 0.9989 - val_loss : 0.1131 - val_acc: 0.9670



100%|██████████| 14/14 [00:13<00:00,  1.06it/s]


Epoch : 9 - loss : 0.0066 - acc: 1.0000 - val_loss : 0.1104 - val_acc: 0.9676



100%|██████████| 14/14 [00:13<00:00,  1.06it/s]


Epoch : 10 - loss : 0.0057 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.06it/s]


Epoch : 11 - loss : 0.0054 - acc: 1.0000 - val_loss : 0.1098 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.05it/s]


Epoch : 12 - loss : 0.0057 - acc: 1.0000 - val_loss : 0.1099 - val_acc: 0.9682



100%|██████████| 14/14 [00:13<00:00,  1.06it/s]


Epoch : 13 - loss : 0.0055 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.07it/s]


Epoch : 14 - loss : 0.0056 - acc: 1.0000 - val_loss : 0.1099 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.04it/s]


Epoch : 15 - loss : 0.0056 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:12<00:00,  1.09it/s]


Epoch : 16 - loss : 0.0053 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.07it/s]


Epoch : 17 - loss : 0.0055 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:12<00:00,  1.09it/s]


Epoch : 18 - loss : 0.0052 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:12<00:00,  1.09it/s]


Epoch : 19 - loss : 0.0056 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.08it/s]


Epoch : 20 - loss : 0.0056 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.07it/s]


Epoch : 21 - loss : 0.0054 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:12<00:00,  1.09it/s]


Epoch : 22 - loss : 0.0053 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.07it/s]


Epoch : 23 - loss : 0.0055 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.07it/s]


Epoch : 24 - loss : 0.0055 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.05it/s]


Epoch : 25 - loss : 0.0061 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.04it/s]


Epoch : 26 - loss : 0.0053 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.06it/s]


Epoch : 27 - loss : 0.0053 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.05it/s]


Epoch : 28 - loss : 0.0056 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.05it/s]


Epoch : 29 - loss : 0.0067 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.07it/s]


Epoch : 30 - loss : 0.0053 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.07it/s]


Epoch : 31 - loss : 0.0053 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.03it/s]


Epoch : 32 - loss : 0.0054 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.07it/s]


Epoch : 33 - loss : 0.0052 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.06it/s]


Epoch : 34 - loss : 0.0055 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.04it/s]


Epoch : 35 - loss : 0.0055 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.06it/s]


Epoch : 36 - loss : 0.0051 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.06it/s]


Epoch : 37 - loss : 0.0054 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.07it/s]


Epoch : 38 - loss : 0.0055 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.07it/s]


Epoch : 39 - loss : 0.0056 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687



100%|██████████| 14/14 [00:13<00:00,  1.06it/s]


Epoch : 40 - loss : 0.0054 - acc: 1.0000 - val_loss : 0.1100 - val_acc: 0.9687

