In [1]:
import numpy as np
import pandas as pd
import os.path as osp
import glob
import matplotlib.pyplot as plt
import random
from tqdm import tqdm
from sklearn.model_selection import StratifiedKFold

from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
from torchvision import models, transforms

In [2]:
num_folds = 5

rootpath = osp.join('..','input','cassava-leaf-disease-classification')

# size, mean, std
size = 224
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

# mini-batch size
batch_size = 32

# number of epochs
num_epochs = 2

# params to update
update_param_names = ['classifier.6.weight', 'classifier.6.bias']

In [3]:
train = pd.read_csv(osp.join(rootpath, 'train.csv'))
train['image_path'] = osp.join(rootpath, 'train_images')
train['image_path'] = train['image_path'].str.cat(train['image_id'], sep=osp.sep)
train.head()

Unnamed: 0,image_id,label,image_path
0,1000015157.jpg,0,..\input\cassava-leaf-disease-classification\t...
1,1000201771.jpg,3,..\input\cassava-leaf-disease-classification\t...
2,100042118.jpg,1,..\input\cassava-leaf-disease-classification\t...
3,1000723321.jpg,1,..\input\cassava-leaf-disease-classification\t...
4,1000812911.jpg,3,..\input\cassava-leaf-disease-classification\t...


In [4]:
submission = pd.read_csv(osp.join(rootpath, 'sample_submission.csv'))
submission.head()

Unnamed: 0,image_id,label
0,2216849948.jpg,4


# Preprocessing Class

In [5]:
class ImageTransform():
    def __init__(self, resize, mean, std):
        self.data_transform = {
            'train': transforms.Compose([
                transforms.RandomResizedCrop(resize, scale=(0.5, 1.0)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]),
            'val': transforms.Compose([
                transforms.Resize(resize),
                transforms.CenterCrop(resize),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]),
            'test': transforms.Compose([
                transforms.Resize(resize),
                transforms.CenterCrop(resize),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ])
        }

    def __call__(self, img, phase='train'):
        return self.data_transform[phase](img)


# Dataset Class

In [6]:
class CassavaDataset(data.Dataset):
    def __init__(self, filepath2label, transform=None, phase='train', output_label=True):
        self.file_list = list(filepath2label.keys())
        self.transform = transform
        self.filepath2label = filepath2label
        self.phase = phase
        self.output_label = output_label
        
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, index):
        img_path = self.file_list[index]
        img = Image.open(img_path)
        if self.transform:
            img = self.transform(img, self.phase)
        if self.output_label:
            label = self.filepath2label[img_path]
            return img, label
        else:
            return img        

In [7]:
# # 動作確認

# index = 0
# print(train_dataset.__getitem__(index)[0].size())

# DataLoader

In [8]:
def get_DataLoader(dataset, batch_size, shuffle=True):
    return torch.utils.data.DataLoader(dataset, batch_size, shuffle)

In [9]:
# # 動作確認

# batch_iterator = iter(dataloaders_dict['train'])
# inputs, labels = next(batch_iterator)
# print(inputs.size())
# print(labels)

# Make Network model

In [None]:
def get_network(model_name, use_pretrained=True):
    if model_name == 'vgg19':
        # load pretrained model
        net = models.vgg19(pretrained=use_pretrained)

        # change output layer
        net.classifier[6] = nn.Linear(in_features=4096, out_features=5, bias=True)
        return net
    elif model_name == 'resnet152':
        # load pretrained model
        net = models.resnet152(pretrained=use_pretrained)

        # change output layer
        net.fc[0] = nn.Linear(in_features=4096, out_features=5, bias=True)
        return net


# Define Loss function

In [11]:
def get_criterion():
    criterion = nn.CrossEntropyLoss()
    return criterion

# Set Optimization

In [12]:
def set_params(net, update_param_names=['classifier.6.weight', 'classifier.6.bias']):
    # add parameters to learn by fine-tuning to params_to_update
    params_to_update = []

    for name, param in net.named_parameters():
        if name in update_param_names:
            param.requires_grad = True
            params_to_update.append(param)
            print(name)
        else:
            param.requires_grad = False

    print("-----------------")
    print(params_to_update)
    
    return params_to_update

In [13]:
def get_optimizer(params_to_update):
    # set optimizer
    optimizer = optim.SGD(params=params_to_update, lr=0.001, momentum=0.9)
    
    return optimizer

# Train & Validation

In [14]:
def train_model(net, dataloaders_dict, criterion, optimizer, num_epochs):
    # LOOP: epoch
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('--------------------')
        
        # LOOP: train & valid at each epoch
        for phase in ['train', 'val']:
            if phase == 'train':
                net.train()   # set train mode
            else:
                net.eval()   # set validation mode
            
            epoch_loss = 0.0   # sum of epoch loss
            epoch_corrects = 0   # number of epoch correctness
            
            # training at epoch = 0 is omitted to confirm the verification performance when unlearned.
            if (epoch == 0) and (phase == 'train'):
                continue
            
            # LOOP: mini-batch
            for inputs, labels in tqdm(dataloaders_dict[phase]):
                # initialize optimizer
                optimizer.zero_grad()
                
                # forward
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = net(inputs)
                    loss = criterion(outputs, labels)   # calc loss
                    _, preds = torch.max(outputs, 1)   # get predicted label
                    
                    # when train, run back propagation
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    
                    # calc iteration result
                    epoch_loss += loss.item() * inputs.size(0)
                    epoch_corrects += torch.sum(preds == labels.data)
            
            # print loss & accuracy in each epoch
            epoch_loss = epoch_loss / len(dataloaders_dict[phase].dataset)
            epoch_acc = epoch_corrects.double() / len(dataloaders_dict[phase].dataset)
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

# Inference

# Utils

In [15]:
def set_random_seed(seed):
    # set random seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

In [16]:
def make_datapath_list(rootpath, phase='train'):
    target_path = osp.join(rootpath, phase+'_images', '*.jpg')

    path_list = []
    for path in glob.glob(target_path):
        path_list.append(path)
    return path_list

# main

In [17]:
if __name__ == "__main__":
    # CV loop
    folds = StratifiedKFold(n_splits=num_folds).split(np.arange(train.shape[0]), train.label.values)
    
    for fold, (train_idx, val_idx) in tqdm(enumerate(folds)):
        if fold > 0:
            break
        
        # Dataset
        train_filepath2label = dict(zip(train.loc[train_idx, :].image_path, train.loc[train_idx, :].label))
        val_filepath2label = dict(zip(train.loc[val_idx, :].image_path, train.loc[val_idx, :].label))

        train_dataset = CassavaDataset(train_filepath2label, ImageTransform(size, mean, std), 'train', True)
        val_dataset = CassavaDataset(val_filepath2label, ImageTransform(size, mean, std), 'val', True)
        
        # DataLoader
        train_dataloader = get_DataLoader(train_dataset, batch_size, True)
        val_dataloader = get_DataLoader(val_dataset, batch_size, True)
        dataloaders_dict = {'train': train_dataloader, 'val': val_dataloader}
        
        # get NetWork model
        model_name = 'resnet152'
        use_pretrained = True
        net = get_network(model_name, use_pretrained)
        
        # set train mode
        net.train()

        # criterion
        criterion = get_criterion()

        # get optimizer
        params_to_update = set_params(net, update_param_names)
        optimizer = get_optimizer(params_to_update)
        
        # train & valid
        train_model(net, dataloaders_dict, criterion, optimizer, num_epochs)


0it [00:00, ?it/s]
  0%|                                                                                          | 0/134 [00:00<?, ?it/s][A

classifier.6.weight
classifier.6.bias
-----------------
[Parameter containing:
tensor([[ 0.0005, -0.0150, -0.0109,  ..., -0.0105,  0.0110,  0.0077],
        [-0.0019,  0.0074,  0.0028,  ..., -0.0040,  0.0144, -0.0023],
        [-0.0115, -0.0058, -0.0071,  ...,  0.0054, -0.0118,  0.0132],
        [-0.0033,  0.0076, -0.0082,  ...,  0.0138,  0.0116,  0.0056],
        [-0.0049,  0.0091,  0.0150,  ...,  0.0032, -0.0045,  0.0088]],
       requires_grad=True), Parameter containing:
tensor([-0.0011, -0.0080,  0.0153,  0.0007, -0.0104], requires_grad=True)]
Epoch 1/2
--------------------



  1%|▌                                                                                 | 1/134 [00:04<10:09,  4.58s/it][A
  1%|█▏                                                                                | 2/134 [00:08<09:58,  4.53s/it][A
  2%|█▊                                                                                | 3/134 [00:13<09:42,  4.44s/it][A
  3%|██▍                                                                               | 4/134 [00:17<09:27,  4.36s/it][A
  4%|███                                                                               | 5/134 [00:21<09:17,  4.32s/it][A
  4%|███▋                                                                              | 6/134 [00:25<09:04,  4.25s/it][A
  5%|████▎                                                                             | 7/134 [00:29<08:59,  4.25s/it][A
  6%|████▉                                                                             | 8/134 [00:34<08:54,  4.24s/it][A
  7%|█████▌    

val Loss: 1.6298 Acc: 0.1456
Epoch 2/2
--------------------



  0%|▏                                                                                 | 1/535 [00:04<44:03,  4.95s/it][A
  0%|▎                                                                                 | 2/535 [00:09<43:27,  4.89s/it][A
  1%|▍                                                                                 | 3/535 [00:14<42:28,  4.79s/it][A
  1%|▌                                                                                 | 4/535 [00:19<42:53,  4.85s/it][A
  1%|▊                                                                                 | 5/535 [00:23<42:05,  4.77s/it][A
  1%|▉                                                                                 | 6/535 [00:28<41:42,  4.73s/it][A
  1%|█                                                                                 | 7/535 [00:33<41:21,  4.70s/it][A
  1%|█▏                                                                                | 8/535 [00:37<40:50,  4.65s/it][A
  2%|█▍        

train Loss: 0.9465 Acc: 0.6580



  1%|▌                                                                                 | 1/134 [00:03<08:16,  3.73s/it][A
  1%|█▏                                                                                | 2/134 [00:07<08:12,  3.73s/it][A
  2%|█▊                                                                                | 3/134 [00:11<08:10,  3.74s/it][A
  3%|██▍                                                                               | 4/134 [00:14<08:06,  3.74s/it][A
  4%|███                                                                               | 5/134 [00:18<08:03,  3.74s/it][A
  4%|███▋                                                                              | 6/134 [00:22<07:59,  3.75s/it][A
  5%|████▎                                                                             | 7/134 [00:26<07:54,  3.74s/it][A
  6%|████▉                                                                             | 8/134 [00:29<07:51,  3.75s/it][A
  7%|█████▌    

val Loss: 0.8177 Acc: 0.6991



