# **Cassava VGG11 fine-tuning**
2021/01/12 written by T.Yonezu

In [1]:
%load_ext autoreload
%autoreload 2

import torch
from torch.utils.data import DataLoader, Dataset

import numpy as np 
import matplotlib.pyplot as plt
import pandas as pd
import glob 
import os
from tqdm import tqdm

from cassava_dataset import *
from myvgg import *

import warnings
warnings.simplefilter('ignore')

In [4]:
input_dir = os.path.join('..',"..", 'input', 'cassava-leaf-disease-classification')

## **Fine-tuning**

In [5]:
x = pd.read_csv(os.path.join(input_dir, 'train.csv'))
x["image_path"] = os.path.join(input_dir,"train_images")
x["image_path"] = x["image_path"].str.cat(x["image_id"], sep=os.path.sep)

train_dict = dict( zip(x["image_path"],x["label"]) )

In [6]:
size = (224,224)
mean = [0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]

transform = ImageTransform(size,mean,std)

In [7]:
BATCH_SIZE = 10
NUM_WORKERS = 8

train_data = CassavaDataset(train_dict,transform=transform)
train_data = DataLoader(train_data,batch_size=BATCH_SIZE)

In [8]:
EPOCH_NUM = 100

OUT_DIR = os.path.join("..","..","input","cassava-models")
MODEL_NAME = "VGG16_cassava_finetuned_%dEpoch.mdl"
PATH = os.path.join(OUT_DIR,MODEL_NAME%EPOCH_NUM)

In [9]:
import torchvision.models as models

model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg11', pretrained=True)
model.classifier[6] = nn.Linear(in_features=4096, out_features=5, bias=True)
model

Using cache found in C:\Users\organ/.cache\torch\hub\pytorch_vision_v0.6.0


VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 

In [11]:
import torch.optim as optim
from torch import nn

model = model.cuda()

optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

# in your training loop:
for epoch in tqdm(range(EPOCH_NUM)):
    for batch in (train_data):
        
        X = batch[0].cuda()
        y = batch[1].cuda()

        pred = model(X)
        
        # zero the gradient buffers
        optimizer.zero_grad()
        loss = criterion(pred, y)
        loss.backward()
        optimizer.step() # Does the update
        

if not(os.path.exists(OUT_DIR)):
    os.makedirs(OUT_DIR)
torch.save(model.state_dict(),PATH)


  0%|                                                                                          | 0/100 [00:00<?, ?it/s]
  1%|▊                                                                              | 1/100 [02:53<4:46:28, 173.62s/it]
  2%|█▌                                                                             | 2/100 [05:46<4:43:16, 173.43s/it]
  3%|██▎                                                                            | 3/100 [08:39<4:40:03, 173.23s/it]
  4%|███▏                                                                           | 4/100 [11:31<4:36:51, 173.03s/it]
  5%|███▉                                                                           | 5/100 [14:25<4:33:59, 173.05s/it]
  6%|████▋                                                                          | 6/100 [17:18<4:31:07, 173.06s/it]
  7%|█████▌                                                                         | 7/100 [20:13<4:29:25, 173.83s/it]
  8%|██████▎                           