In [2]:
from __future__ import print_function
from __future__ import division

from PIL import Image

import os
import pandas as pd
import numpy as np

import torch
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader
# from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import random_split

import matplotlib.pyplot as plt
import itertools    # confusion matrix에서 사용

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.models as models

In [3]:
BASE_DIR = '../input/microsoft-rice-disease-classification-challenge'

train = pd.read_csv(os.path.join(BASE_DIR, 'Train.csv'))
print(train.shape)
train.head()

In [4]:
train_rgb = train.loc[~train['Image_id'].str.contains('_rgn')]
train_rgb.head()

In [5]:
test = pd.read_csv(os.path.join(BASE_DIR, 'Test.csv'))
test_rgb = test.loc[~test['Image_id'].str.contains('_rgn')]
test_rgb.head()

In [6]:
ss = pd.read_csv(os.path.join(BASE_DIR, 'SampleSubmission.csv'))
ss.head()

In [7]:
class Img_Dataset(Dataset):
    def __init__(self, file_path, transform, table, is_train=True):
        self.file_path = file_path
        self.transform = transform
        self.table = table
        self.is_train = is_train
  
    def __len__(self):
        return len(self.table)
  
    def __getitem__(self, index):
        img_name = self.table.iloc[index]['Image_id']
        img_path = os.path.join(self.file_path, img_name)
        img_path_rgn = os.path.join(self.file_path, img_name.replace('.jpg', '_rgn.jpg'))
        img = Image.open(img_path)
        img_rgn = Image.open(img_path_rgn)
        img_transformed = self.transform(img)
        img_rgn_transformed = self.transform(img_rgn)
        img_concat = torch.cat([img_transformed, img_rgn_transformed], dim=0)
        if self.is_train:  # label encoding
            if self.table.iloc[index]['Label'] == 'blast':
                y = 0
            elif self.table.iloc[index]['Label'] == 'brown':
                y = 1
            elif self.table.iloc[index]['Label'] == 'healthy':
                y = 2
            return img_concat, y
        else:
            return img_concat

In [8]:
class EarlyStopping:
    """주어진 patience 이후로 validation loss가 개선되지 않으면 학습을 조기 중지"""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        """
        Args:
            patience (int): validation loss가 개선된 후 기다리는 기간
                            Default: 7
            verbose (bool): True일 경우 각 validation loss의 개선 사항 메세지 출력
                            Default: False
            delta (float): 개선되었다고 인정되는 monitered quantity의 최소 변화
                            Default: 0
            path (str): checkpoint저장 경로
                            Default: 'checkpoint.pt'
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''validation loss가 감소하면 모델을 저장한다.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [9]:
def make_data_loader(batch_size=128, split=0.8):
    IMG_DIR = '../input/microsoft-rice-disease-classification-challenge/Images'
    transform=transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    # train data
    train_dataset = Img_Dataset(IMG_DIR, transform, train_rgb)

    train_size = int(len(train_dataset) * split)
    val_size = len(train_dataset) - train_size

    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=2)
    # whole_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    # test data
    test_dataset = Img_Dataset(IMG_DIR, transform, test_rgb, is_train=False)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    return train_loader, val_loader, test_loader

In [10]:
def train_model(device, model, train_loader, val_loader, criterion, optimizer, num_epochs=5, early_stopping=None):
    model = model.to(device)

    dl = {'train': train_loader,
          'val': val_loader}
    
    val_label = []
    val_pred = []

    val_loss = 0.0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dl[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients -> backward시 필요
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)
                    # print(preds)
                    # print(labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    if phase == 'val' and epoch == num_epochs - 1:
                        val_label += labels.tolist()
                        val_pred += preds.tolist()

                # statistics
                running_loss += loss.item() * inputs.size(0)  # size(0) : batch size (첫 번째 차원 개수)
                                                              # item() : tensor에서 저장된 값만 가져오기
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dl[phase].dataset)  # 이렇게 나누면 epoch당 평균 loss가 됨
            epoch_acc = running_corrects.double() / len(dl[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            val_loss = epoch_loss

        if early_stopping != None:
            early_stopping(val_loss, model) # epoch_loss에는 validation loss가 저장
            if early_stopping.early_stop:
                print("Early stopping")
                break

    # load best model weights
    model.load_state_dict(torch.load(early_stopping.path))
    return model, val_label, val_pred

In [None]:
# model의 parameter를 freeze
def set_parameter_requires_grad(model):
    for param in model.parameters():
        param.requires_grad = False

In [None]:
def initialize_model(num_classes):
    # Use ResNet50
    # model_ft = models.resnet50(weights="IMAGENET1K_V2")    # 코랩
    model_ft = models.resnet50(pretrained=True)
    set_parameter_requires_grad(model_ft)
    # 첫 번째 conv layer 구조 변경
    prev_conv_w = model_ft.conv1.weight  # 기존 모델 weight
    model_ft.conv1 = nn.Conv2d(6, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    model_ft.conv1.weight = nn.Parameter(torch.cat((prev_conv_w, torch.zeros(64, 3, 7, 7)), dim=1))  # 초기 3-dim weight에 3-dim짜리 zero weight를 붙임 
    model_ft.conv1.requires_grad = True
    # 마지막 output layer 구조 변경
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, num_classes)

    print("Params to learn:")
    params_to_update = []
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)

    return model_ft, params_to_update

In [11]:
# confusion matrix 시각화
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(cm, target_names=None, labels=True):
    accuracy = np.trace(cm) / float(np.sum(cm))

    cmap = plt.get_cmap('Blues')

    plt.figure(figsize=(9, 6))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.colorbar()
    thresh = cm.max() / 2

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names)
        plt.yticks(tick_marks, target_names)

    if labels:
        for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
            plt.text(j, i, "{:,}".format(cm[i, j]), horizontalalignment="center",
                     color="white" if cm[i,j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.show()

In [12]:
def test_model(device, model, test_loader):
    test_pred = []

    model.eval()
    model = model.to(device)
    with torch.set_grad_enabled(False):
        for features in test_loader:
            features = features.to(device)

            outputs = model(features.to(torch.float))
            probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
            test_pred.append(probabilities.tolist())

    return test_pred

In [None]:
resnet_ft, params_to_update = initialize_model(3)
print(resnet_ft)

In [None]:
# train
train_loader, val_loader, test_loader = make_data_loader()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = optim.SGD(params_to_update, lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
early_stopping = EarlyStopping(patience=3, verbose=True, delta=0.00001, path="resnet2_checkpoint.pt")

In [None]:
resnet_tf, val_label, val_pred = train_model(device, resnet_ft, train_loader, val_loader, criterion, optimizer, 30, early_stopping)

### 2. 6-3 layer를 앞에 붙임

In [None]:
class myResNet50(nn.Module):
    def __init__(self, num_classes):
        super(myResNet50, self).__init__()
        self.model_ft = models.resnet50(pretrained=True)
        # set_parameter_requires_grad(self.model_ft)
        self.first_conv = nn.Conv2d(6, 3, kernel_size=1)
        
        num_ftrs = self.model_ft.fc.in_features
        self.model_ft.fc = nn.Linear(num_ftrs, num_classes)
        '''
        print("Params to learn:")
        self.params_to_update = []
        for name,param in self.model_ft.named_parameters():
            if param.requires_grad == True:
                self.params_to_update.append(param)
                print("\t",name)
        '''
    
    def forward(self, x):
        x = self.first_conv(x)
        out = self.model_ft(x)
        return out

In [None]:
resnet_ft2 = myResNet50(num_classes=3)
# print(resnet_ft2)
print(resnet_ft2)

In [None]:
# train
train_loader, val_loader, test_loader = make_data_loader()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# optimizer = optim.Adam(resnet_ft2.parameters(), lr=3e-4)
optimizer = optim.SGD(resnet_ft2.parameters(), lr=0.0001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
early_stopping = EarlyStopping(patience=3, verbose=True, delta=0.00001, path="resnet3_checkpoint.pt")

In [None]:
resnet_tf3, val_label, val_pred = train_model(device, resnet_ft2, train_loader, val_loader, criterion, optimizer, 30, early_stopping)

In [None]:
resnet_tf3 = myResNet50(num_classes=3)
resnet_tf3.load_state_dict(torch.load(early_stopping.path))

In [None]:
test_pred = test_model(device, resnet_tf3, test_loader)

In [None]:
ss.loc[:, ['blast', 'brown', 'healthy']] = test_pred
ss.to_csv("result_resnet50_3.csv", index=False)
ss.head()

### 3. 충분한 epoch로 다시 시도.

In [None]:
resnet_ft4 = myResNet50(num_classes=3)

In [None]:
# train
train_loader, val_loader, test_loader = make_data_loader()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# optimizer = optim.Adam(resnet_ft2.parameters(), lr=3e-4)
optimizer = optim.SGD(resnet_ft4.parameters(), lr=3e-4, momentum=0.9)
criterion = nn.CrossEntropyLoss()
early_stopping = EarlyStopping(patience=3, verbose=True, delta=0.00001, path="resnet4_checkpoint.pt")

In [None]:
resnet_tf4, val_label, val_pred = train_model(device, resnet_ft4, train_loader, val_loader, criterion, optimizer, 30, early_stopping)

In [None]:
resnet_tf4 = myResNet50(num_classes=3)
resnet_tf4.load_state_dict(torch.load(early_stopping.path))

In [None]:
test_pred = test_model(device, resnet_tf4, test_loader)

In [None]:
ss.loc[:, ['blast', 'brown', 'healthy']] = test_pred
ss.to_csv("result_resnet50_4.csv", index=False)
ss.head()

score : 0.52..^^ 이 모델은 포기..

4. DeiT 사용

In [13]:
!pip install timm requests

In [14]:
class myDeiT(nn.Module):
    def __init__(self, num_classes):
        super(myDeiT, self).__init__()
        self.model_ft = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
        # set_parameter_requires_grad(self.model_ft)
        self.first_conv = nn.Conv2d(6, 3, kernel_size=1)
        
        num_ftrs = self.model_ft.head.in_features
        self.model_ft.head = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        x = self.first_conv(x)
        out = self.model_ft(x)
        return out

In [15]:
deit_tf = myDeiT(3)

In [None]:
print(deit_tf)

In [17]:
# train
train_loader, val_loader, test_loader = make_data_loader(batch_size=50)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = optim.Adam(deit_tf.parameters(), lr=3e-4)
# optimizer = optim.SGD(resnet_ft4.parameters(), lr=3e-4, momentum=0.9)
criterion = nn.CrossEntropyLoss()
early_stopping = EarlyStopping(patience=3, verbose=True, delta=0.00001, path="deit_checkpoint.pt")

In [18]:
deit_tf, val_label, val_pred = train_model(device, deit_tf, train_loader, val_loader, criterion, optimizer, 30, early_stopping)

In [19]:
test_pred = test_model(device, deit_tf, test_loader)

In [20]:
ss.loc[:, ['blast', 'brown', 'healthy']] = test_pred
ss.to_csv("result_deit.csv", index=False)
ss.head()