In [1]:
import os
import json
import random
import argparse
import itertools
import math
import torch
import numpy as np
from torch import nn, optim
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt

import utils

from data_utils import WBCdataset_Mask

from transformers import ViTForImageClassification, ViTMAEConfig

from torch.utils.tensorboard import SummaryWriter

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    
def get_WBC_transform():
    data_transforms = []
    data_transforms.append(transforms.Resize((224, 224)))
    data_transforms.append(transforms.ToTensor())
    return transforms.Compose(data_transforms)

def run(device, hps):
    wbc_subset = "wbc1"
    pretrain_options = "pRCC"
    use_mask=True
    
    out_dir = os.path.join(hps.out_dir, f'{wbc_subset}', f'{pretrain_options}')
    if use_mask:
        out_dir = os.path.join(out_dir, 'mask')
    os.makedirs(out_dir, exist_ok = True)
    writer = SummaryWriter(out_dir)
    
    if wbc_subset == "wbc1":
        training_files = hps.WBCdata.training_files_1
    elif wbc_subset == "wbc10":
        training_files = hps.WBCdata.training_files_10
    elif wbc_subset == "wbc50":
        training_files = hps.WBCdata.training_files_50
    else:
        training_files = hps.WBCdata.training_files_100
    
    train_data = WBCdataset_Mask(training_files, hps.WBCdata.label_dict, transform=get_WBC_transform(), use_mask=use_mask, is_train=True)
    valid_data = WBCdataset_Mask(hps.WBCdata.validation_files, hps.WBCdata.label_dict, transform=get_WBC_transform())
    
    label2id = {}
    id2label = {}

    for label in hps.WBCdata.label_dict.keys():
        label2id[label] = hps.WBCdata.label_dict[label]
        id2label[hps.WBCdata.label_dict[label]] = label
    
    if pretrain_options == "pRCC":
        model = ViTForImageClassification.from_pretrained("Mo0310/vitmae_pRCC_80epochs", 
            label2id=label2id,
            id2label=id2label,
            ignore_mismatched_sizes = True,
        ).to(device)
    elif pretrain_options == "facebook":
        model = ViTForImageClassification.from_pretrained("facebook/vit-mae-base", 
            label2id=label2id,
            id2label=id2label,
            ignore_mismatched_sizes = True,
        ).to(device)
    else:
        config = ViTMAEConfig.from_pretrained("facebook/vit-mae-base",
            label2id=label2id,
            id2label=id2label,
            ignore_mismatched_sizes = True)
        model = ViTForImageClassification(config).to(device)
        
    masked_pixel = torch.rand(1).to(device)
    
    train_loader = DataLoader(dataset = train_data, batch_size=hps.finetune.batch_size, shuffle=True)
    valid_loader = DataLoader(dataset = valid_data, batch_size=hps.finetune.batch_size, shuffle=False)
    
    # loss function
    criterion = nn.CrossEntropyLoss()
    # finetune optimizer
    learnable_params = list(model.parameters())
    learnable_params.append(masked_pixel)
    ft_optimizer = optim.AdamW(learnable_params, lr=hps.finetune.learning_rate)
    #ft_optimizer.param_groups.append({'params': masked_pixel })
    # finetune scheduler
    #ft_scheduler = optim.lr_scheduler.MultiStepLR(ft_optimizer, milestones=[1, 2], gamma=hps.pretrain.lr_decay)
    ft_scheduler = StepLR(ft_optimizer, step_size=5, gamma=hps.finetune.lr_decay)
    
    for epoch in range(hps.finetune.epochs):
        train_and_evaluate(device, epoch, model, masked_pixel, criterion, ft_optimizer, ft_scheduler, [train_loader, valid_loader], writer)
        
    return model
    

def train_and_evaluate(device, epoch, model, masked_pixel, criterion, optimizer, scheduler, loaders, writer):
    train_loader, valid_loader = loaders
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label, mask in tqdm(train_loader):
        optimizer.zero_grad()
        
        data = data.to(device)
        mask = mask.to(device)
        label = label.to(device)
        data = data * mask + masked_pixel * (1 - mask)

        output = model(data)
        loss = criterion(output.logits, label)
        
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()

        acc = (output.logits.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        model.eval()
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label, mask in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output.logits, label)

            acc = (val_output.logits.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)
        model.train()
        
    scheduler.step()
    
    writer.add_scalar('./Loss/train', epoch_loss, epoch+1)
    writer.add_scalar('./ACC/train', epoch_accuracy, epoch+1)
    writer.add_scalar('./Loss/val', epoch_val_loss, epoch+1)
    writer.add_scalar('./ACC/val', epoch_val_accuracy, epoch+1)
    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

  from .autonotebook import tqdm as notebook_tqdm
2023-10-23 05:01:42.772232: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-10-23 05:01:42.820141: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
hps = utils.get_hparams_from_file('./configs/wbc1_40epochs.json')
seed_everything(hps.seed)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = run(device, hps)

You are using a model of type vit_mae to instantiate a model of type vit. This is not supported for all configurations of models and can yield errors.
Some weights of ViTForImageClassification were not initialized from the model checkpoint at Mo0310/vitmae_pRCC_80epochs and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 2/2 [00:04<00:00,  2.23s/it]


Epoch : 1 - loss : 1.2858 - acc: 0.3646 - val_loss : 1.0785 - val_acc: 0.6128



100%|██████████| 2/2 [00:03<00:00,  1.69s/it]


Epoch : 2 - loss : 0.9986 - acc: 0.6059 - val_loss : 1.1238 - val_acc: 0.4884



100%|██████████| 2/2 [00:03<00:00,  1.70s/it]


Epoch : 3 - loss : 0.9500 - acc: 0.5790 - val_loss : 1.0996 - val_acc: 0.6128



100%|██████████| 2/2 [00:03<00:00,  1.68s/it]


Epoch : 4 - loss : 0.8896 - acc: 0.6780 - val_loss : 0.8835 - val_acc: 0.6545



100%|██████████| 2/2 [00:03<00:00,  1.58s/it]


Epoch : 5 - loss : 0.6696 - acc: 0.7648 - val_loss : 0.7774 - val_acc: 0.6950



100%|██████████| 2/2 [00:03<00:00,  1.60s/it]


Epoch : 6 - loss : 0.5304 - acc: 0.8394 - val_loss : 0.5350 - val_acc: 0.8235



100%|██████████| 2/2 [00:03<00:00,  1.71s/it]


Epoch : 7 - loss : 0.4004 - acc: 0.8819 - val_loss : 0.4222 - val_acc: 0.8490



100%|██████████| 2/2 [00:03<00:00,  1.63s/it]


Epoch : 8 - loss : 0.2704 - acc: 0.9453 - val_loss : 0.3657 - val_acc: 0.8733



100%|██████████| 2/2 [00:03<00:00,  1.67s/it]


Epoch : 9 - loss : 0.3082 - acc: 0.9210 - val_loss : 0.3398 - val_acc: 0.8762



100%|██████████| 2/2 [00:03<00:00,  1.57s/it]


Epoch : 10 - loss : 0.1193 - acc: 0.9453 - val_loss : 0.3473 - val_acc: 0.8681



100%|██████████| 2/2 [00:03<00:00,  1.70s/it]


Epoch : 11 - loss : 0.1110 - acc: 0.9488 - val_loss : 0.2896 - val_acc: 0.8993



100%|██████████| 2/2 [00:03<00:00,  1.63s/it]


Epoch : 12 - loss : 0.0932 - acc: 0.9922 - val_loss : 0.3476 - val_acc: 0.8848



100%|██████████| 2/2 [00:03<00:00,  1.70s/it]


Epoch : 13 - loss : 0.0723 - acc: 0.9844 - val_loss : 0.2611 - val_acc: 0.9074



100%|██████████| 2/2 [00:03<00:00,  1.67s/it]


Epoch : 14 - loss : 0.0446 - acc: 0.9922 - val_loss : 0.2491 - val_acc: 0.9126



100%|██████████| 2/2 [00:03<00:00,  1.66s/it]


Epoch : 15 - loss : 0.0239 - acc: 1.0000 - val_loss : 0.2568 - val_acc: 0.9167



100%|██████████| 2/2 [00:03<00:00,  1.68s/it]


Epoch : 16 - loss : 0.0163 - acc: 1.0000 - val_loss : 0.2685 - val_acc: 0.9161



100%|██████████| 2/2 [00:03<00:00,  1.69s/it]


Epoch : 17 - loss : 0.0166 - acc: 1.0000 - val_loss : 0.2712 - val_acc: 0.9161



100%|██████████| 2/2 [00:03<00:00,  1.63s/it]


Epoch : 18 - loss : 0.0103 - acc: 1.0000 - val_loss : 0.2619 - val_acc: 0.9196



100%|██████████| 2/2 [00:03<00:00,  1.66s/it]


Epoch : 19 - loss : 0.0091 - acc: 1.0000 - val_loss : 0.2327 - val_acc: 0.9277



100%|██████████| 2/2 [00:03<00:00,  1.68s/it]


Epoch : 20 - loss : 0.0069 - acc: 1.0000 - val_loss : 0.2096 - val_acc: 0.9358



100%|██████████| 2/2 [00:01<00:00,  1.14it/s]


Epoch : 21 - loss : 0.0054 - acc: 1.0000 - val_loss : 0.2001 - val_acc: 0.9387



100%|██████████| 2/2 [00:03<00:00,  1.58s/it]


Epoch : 22 - loss : 0.0057 - acc: 1.0000 - val_loss : 0.1951 - val_acc: 0.9421



100%|██████████| 2/2 [00:03<00:00,  1.68s/it]


Epoch : 23 - loss : 0.0045 - acc: 1.0000 - val_loss : 0.1930 - val_acc: 0.9427



100%|██████████| 2/2 [00:03<00:00,  1.68s/it]


Epoch : 24 - loss : 0.0043 - acc: 1.0000 - val_loss : 0.1928 - val_acc: 0.9439



100%|██████████| 2/2 [00:03<00:00,  1.63s/it]


Epoch : 25 - loss : 0.0036 - acc: 1.0000 - val_loss : 0.1938 - val_acc: 0.9450



In [3]:
token = "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"
model.push_to_hub("5242_w_pRCC_wbc1_mask", token=token)

pytorch_model.bin: 100%|██████████| 343M/343M [02:11<00:00, 2.61MB/s]   


CommitInfo(commit_url='https://huggingface.co/Mo0310/5242_w_pRCC_wbc1_mask/commit/c911c5ca42a0de6e238ef6d9bdce8c671c59f6fc', commit_message='Upload ViTForImageClassification', commit_description='', oid='c911c5ca42a0de6e238ef6d9bdce8c671c59f6fc', pr_url=None, pr_revision=None, pr_num=None)

In [4]:
print(model)

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=7