# **Cassava VisionTransformer fine-tuning**
2021/01/25 written by T.Yonezu

In [1]:
%load_ext autoreload
%autoreload 2

import torch
from torch.utils.data import DataLoader, Dataset
import timm

import numpy as np 
import matplotlib.pyplot as plt
import pandas as pd
import glob 
import os
from tqdm import tqdm

from cassava_dataset import *
from augmentation import *

import warnings
warnings.simplefilter('ignore')

In [2]:
input_dir = os.path.join('..',"..", 'input', 'cassava-leaf-disease-classification')

## **Fine-tuning**

In [3]:
x = pd.read_csv(os.path.join(input_dir, 'train.csv'))
x["image_path"] = os.path.join(input_dir,"train_images")
x["image_path"] = x["image_path"].str.cat(x["image_id"], sep=os.path.sep)


from sklearn.model_selection import train_test_split

train_df , valid_df = train_test_split(x,test_size=0.33, random_state=42)
train_df = EqualizeLabels(train_df,NUM=15000)

train_dict = dict( zip(train_df["image_path"],train_df["label"]) )
valid_dict = dict( zip(valid_df["image_path"],valid_df["label"]) )

In [4]:
BATCH_SIZE = 32
size = (224,224)
mean = [0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

transform = ImageTransform(size,mean,std)
train_data = CassavaDataset(train_dict,transform=transform,phase="train")
train_data = DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)

valid_data = CassavaDataset(valid_dict,transform=transform,phase="val")
valid_data = DataLoader(valid_data,batch_size=BATCH_SIZE)

In [5]:
EPOCH_NUM = 40

OUT_DIR = os.path.join("..","..","input","cassava-models")
MODEL_NAME = "vit_base_resnet50d_224_cassava(224x224)_EL_finetuned_%dEpoch.mdl"
PATH = os.path.join(OUT_DIR,MODEL_NAME%EPOCH_NUM)

In [6]:
print("AVAILABLE VisionTransformer Models:")
timm.list_models("vit*")

AVAILABLE VisionTransformer Models:


['vit_base_patch16_224',
 'vit_base_patch16_384',
 'vit_base_patch32_384',
 'vit_base_resnet26d_224',
 'vit_base_resnet50d_224',
 'vit_huge_patch16_224',
 'vit_huge_patch32_384',
 'vit_large_patch16_224',
 'vit_large_patch16_384',
 'vit_large_patch32_384',
 'vit_small_patch16_224',
 'vit_small_resnet26d_224',
 'vit_small_resnet50d_s3_224']

In [7]:
import timm

model = timm.create_model('vit_base_resnet50d_224',pretrained=True)
in_features = model.head.in_features
model.head = nn.Linear(in_features=in_features, out_features=5, bias=True)

#model

In [7]:
import torch.optim as optim
from torch import nn

model = model.cuda()
optimizer = optim.SGD(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()

log = {}

best_acc = -np.inf

scaler = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True
for epoch in tqdm(range(EPOCH_NUM)):
    
    model.train()
    train_acc = 0
    train_loss = 0
    for batch in (train_data):
        X = batch[0].cuda()
        y = batch[1].cuda()
        
        # zero the gradient buffers
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            pred = model(X)
            loss = criterion(pred, y)
        
            scaler.scale(loss).backward()
            scaler.step(optimizer) # Does the update
            scaler.update()
        
        pred_label = ( pred.argmax(axis=1) ).cpu().numpy()
        y = y.cpu().numpy()
        train_acc += (pred_label == y).sum()
        train_loss += loss.item() * X.size(0)
        
    train_acc = train_acc/len(train_data.dataset)
    train_loss = train_loss/len(train_data.dataset)
    
    
    model.eval()
    valid_acc = 0
    valid_loss = 0
    for batch in (valid_data): 
        X = batch[0].cuda()
        y = batch[1].cuda()
        
        with torch.cuda.amp.autocast():
            pred = model(X)
            loss = criterion(pred, y)
        
        pred_label = ( pred.argmax(axis=1) ).cpu().numpy()
        y = y.cpu().numpy()
        
        valid_acc += (pred_label == y).sum()
        valid_loss += loss.item() * X.size(0)
    
    valid_acc = valid_acc/len(valid_data.dataset)
    valid_loss = valid_loss/len(valid_data.dataset)
    
    
    if valid_acc > best_acc:
        torch.save(model.state_dict(),PATH)
        best_acc = valid_acc
    
    
    log[epoch] = [train_acc, train_loss, valid_acc, valid_loss]
    
    if epoch%1 == 0:
        log_df = pd.DataFrame.from_dict(log,
                                        orient="index",
                                        columns=["train acc","train loss","valid acc","valid loss"])
        
        fig = plt.figure(figsize=(10,3))
        
        ax = fig.add_subplot(1,2,1)
        log_df[["train loss","valid loss"]].plot(marker="o",ax=ax)
        ax.set_xlim(0,EPOCH_NUM+1)
        ax.grid(True)
        
        ax2 = fig.add_subplot(1,2,2)
        log_df[["train acc","valid acc"]].plot(marker="o",ax=ax2)
        ax2.grid(True)
        ax2.set_xlim(0,EPOCH_NUM+1)
        
        fig.suptitle(MODEL_NAME%EPOCH_NUM)
        plt.savefig(MODEL_NAME+".png",format="png")
        plt.show()

  0%|                                                                | 0/40 [00:27<?, ?it/s]
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "C:\Users\organ\anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 3418, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-8-e6f9f2ce7bac>", line 30, in <module>
    scaler.step(optimizer) # Does the update
  File "C:\Users\organ\anaconda3\lib\site-packages\torch\cuda\amp\grad_scaler.py", line 320, in step
    if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
  File "C:\Users\organ\anaconda3\lib\site-packages\torch\cuda\amp\grad_scaler.py", line 320, in <genexpr>
    if not sum(v.item() for v in optimizer_state["found_inf_per_device"].values()):
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:\Users\organ\anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 2045, in showtraceback
    stb = value._render_traceback_()
AttributeError: 'KeyboardInterrupt' object has

TypeError: object of type 'NoneType' has no len()