# **Cassava EfficientNet fine-tuning**
2021/01/12 written by T.Yonezu

In [1]:
%load_ext autoreload
%autoreload 2

import torch
from torch.utils.data import DataLoader, Dataset

import numpy as np 
import matplotlib.pyplot as plt
import pandas as pd
import glob 
import os
from tqdm import tqdm

from cassava_dataset import *
from augmentation import *
from my_nn_module import *

import warnings
warnings.simplefilter('ignore')

In [2]:
input_dir = os.path.join('..',"..", 'input', 'cassava-leaf-disease-classification')

## **Fine-tuning**

In [3]:
x = pd.read_csv(os.path.join(input_dir, 'train.csv'))
x["image_path"] = os.path.join(input_dir,"train_images")
x["image_path"] = x["image_path"].str.cat(x["image_id"], sep=os.path.sep)


from sklearn.model_selection import train_test_split

train_df , valid_df = train_test_split(x,test_size=0.2, random_state=42)
#train_df = EqualizeLabels(train_df,NUM=100)

train_dict = dict( zip(train_df["image_path"],train_df["label"]) )
valid_dict = dict( zip(valid_df["image_path"],valid_df["label"]) )

In [4]:
BATCH_SIZE = 4
size = (512,512)
mean = [0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

transform = ImageTransform(size,mean,std)
train_data = CassavaDataset(train_dict,transform=transform,phase="train")
train_data = DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True)

valid_data = CassavaDataset(valid_dict,transform=transform,phase="val")
valid_data = DataLoader(valid_data,batch_size=BATCH_SIZE)

In [5]:
EPOCH_NUM = 150

OUT_DIR = os.path.join("..","..","input","cassava-models")
MODEL_NAME = "EfficientNet-b4_cassava(512x512)_EL_finetuned_%dEpoch.mdl"
PATH = os.path.join(OUT_DIR,MODEL_NAME%EPOCH_NUM)

In [6]:
from efficientnet_pytorch import EfficientNet

model = EfficientNet.from_pretrained("efficientnet-b4")
in_features = model._fc.in_features
model._fc = nn.Linear(in_features=in_features, out_features=5, bias=True)

#model

Loaded pretrained weights for efficientnet-b4


In [7]:
# #optimizer = optim.SGD(model.parameters(), lr=0.001)
# optimizer = optim.Adam(model.parameters(), lr=0.0001)
# criterion = nn.CrossEntropyLoss()

# trainer = NN_TRAINER(model=model,
#                      criterion=criterion,
#                      optimizer=optimizer,
#                      OUT_DIR=OUT_DIR,
#                      MODEL_NAME=MODEL_NAME)

# trainer.run(train_dataloader=train_data,
#             valid_dataloader=valid_data,
#             epoch_num=EPOCH_NUM,
#             device="cuda")

In [7]:
import torch.optim as optim
from torch import nn

model = model.cuda()
#optimizer = optim.SGD(model.parameters(), lr=0.001)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()

log = {}

best_acc = -np.inf

scaler = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True
for epoch in tqdm(range(EPOCH_NUM)):
    
    model.train()
    train_acc = 0
    train_loss = 0
    for batch in (train_data):
        X = batch[0].cuda()
        y = batch[1].cuda()
        
        # zero the gradient buffers
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            pred = model(X)
            loss = criterion(pred, y)
        
            scaler.scale(loss).backward()
            scaler.step(optimizer) # Does the update
            scaler.update()
        
        pred_label = ( pred.argmax(axis=1) ).cpu().numpy()
        y = y.cpu().numpy()
        train_acc += (pred_label == y).sum()
        train_loss += loss.item() * X.size(0)
        
    train_acc = train_acc/len(train_data.dataset)
    train_loss = train_loss/len(train_data.dataset)
    
    
    model.eval()
    valid_acc = 0
    valid_loss = 0
    for batch in (valid_data): 
        X = batch[0].cuda()
        y = batch[1].cuda()
        
        with torch.cuda.amp.autocast():
            pred = model(X)
            loss = criterion(pred, y)
        
        pred_label = ( pred.argmax(axis=1) ).cpu().numpy()
        y = y.cpu().numpy()
        
        valid_acc += (pred_label == y).sum()
        valid_loss += loss.item() * X.size(0)
    
    valid_acc = valid_acc/len(valid_data.dataset)
    valid_loss = valid_loss/len(valid_data.dataset)
    
    
    if valid_acc > best_acc:
        torch.save(model.state_dict(),PATH)
        best_acc = valid_acc
    
    
    log[epoch] = [train_acc, train_loss, valid_acc, valid_loss]
    
    if epoch%1 == 0:
        log_df = pd.DataFrame.from_dict(log,
                                        orient="index",
                                        columns=["train acc","train loss","valid acc","valid loss"])
        
        fig = plt.figure(figsize=(10,3))
        
        ax = fig.add_subplot(1,2,1)
        log_df[["train loss","valid loss"]].plot(marker="o",ax=ax)
        ax.set_xlim(0,EPOCH_NUM+1)
        ax.grid(True)
        
        ax2 = fig.add_subplot(1,2,2)
        log_df[["train acc","valid acc"]].plot(marker="o",ax=ax2)
        ax2.grid(True)
        ax2.set_xlim(0,EPOCH_NUM+1)
        
        fig.suptitle(MODEL_NAME%EPOCH_NUM)
        plt.savefig(MODEL_NAME+".png",format="png")
        plt.show()

  0%|                                                               | 0/150 [00:15<?, ?it/s]
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "C:\Users\organ\anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 3418, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-8-115453ed3374>", line 20, in <module>
    for batch in (train_data):
  File "C:\Users\organ\anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 435, in __next__
    data = self._next_data()
  File "C:\Users\organ\anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 475, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "C:\Users\organ\anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "C:\Users\organ\anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "C:\Users\organ\Google Drive\workspace\kagg

TypeError: object of type 'NoneType' has no len()