In [10]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm

# Plant Pathology 2020 - FGVC7

In [11]:
BATCH_SIZE = 8
INPUT_SIZE = 500
EPOCH_TIME = 100

DEVICE = torch.device('cpu')

## Data Loading

The code above shows how I load the data sets. 80% of the training set will be used for training, and the other 20% will be used for validation.

Besides, I use these methods to augment training sets:

* RandomGreyscale
* RandomHorizontalFlip
* RandomVerticalFlip
* RandomRotation
* RandomPerspective


In [28]:
from torch.utils.data import Dataset, random_split
from torchvision.transforms import RandomHorizontalFlip, RandomVerticalFlip, RandomGrayscale, RandomPerspective, RandomRotation
from torchvision.transforms import Compose, Resize, Normalize, ToTensor
from PIL import Image


class PlantDataset(Dataset):
    
    def __init__(self, train = True):
        
        self.train = train
        
        self.dataset = pd.read_csv('./dataset/{}.csv'.format('train' if train else 'test'))
        
        if self.train:
            
            self.augment = Compose([
                RandomGrayscale(),
                RandomHorizontalFlip(),
                RandomVerticalFlip(),
                RandomRotation((-180, +180)),
                RandomPerspective(),
            ])
            
        self.transform = Compose([
            Resize((INPUT_SIZE, INPUT_SIZE)),
            ToTensor(),
            Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])
        
    def __len__(self):
        
        return self.dataset.shape[0]
    
    def __getitem__(self, index):
        
        image = Image.open('./dataset/images/{}.jpg'.format(self.dataset.loc[index, 'image_id']))
        
        if self.train:
            image = self.augment(image)
        
        image = self.transform(image)
        
        if self.train:
            
            label = np.argmax(self.dataset.loc[index, ['healthy', 'multiple_diseases', 'rust', 'scab']].values)
            
            return image, label
        else:
            return image


train_raw = PlantDataset()

train_set, test_set = random_split(train_raw, [int(0.8 * len(train_raw)), len(train_raw) - int(0.8 * len(train_raw))])

## CNN Model

In [13]:
from torch.nn import Module, Linear, Sequential, Conv2d, BatchNorm2d, ReLU, MaxPool2d


class CNN(Module):
    
    def __init__(self):
        super(CNN, self).__init__()
        
        # (BATCH_SIZE, 3, INPUT_SIZE x INPUT_SIZE), (BATCH_SIZE, 32, (INPUT_SIZE / 4), (INPUT_SIZE x 4))
        self.layer = Sequential(
            
            Conv2d(3, 16, kernel_size = 5, stride = 1, padding = 2),
            BatchNorm2d(16),
            ReLU(),
            MaxPool2d(kernel_size = 2, stride = 2),
            
            Conv2d(16, 32, kernel_size = 5, stride = 1, padding = 2),
            BatchNorm2d(32),
            ReLU(),
            MaxPool2d(kernel_size = 2, stride = 2))
        
        # (BATCH_SIZE, 32 x (INPUT_SIZE / 4) x (INPUT_SIZE x 4)) -> (BATCH_SIZE, 4)
        self.fc = Linear(32 * (INPUT_SIZE // 4) * (INPUT_SIZE // 4), 4)
        
    def forward(self, x):
        
        y = self.layer(x.reshape(-1, 3, INPUT_SIZE, INPUT_SIZE))
        
        return self.fc(y.reshape(-1, 32 * (INPUT_SIZE // 4) * (INPUT_SIZE // 4)))


cnn = CNN().to(DEVICE)

## Model Training

In [None]:
import os
import json
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss


def validate(model, test_loader, pbar):

    correct, total = 0, 0
    
    with torch.no_grad():
        for images, labels in test_loader:
            
            outputs = model(images.to(DEVICE).float())
            
            _, predictions = torch.max(outputs.data, 1)
            
            predictions = predictions.cpu().numpy()
            
            for idx, prediction in enumerate(predictions):
                
                if prediction == labels[idx]:
                    correct = correct + 1
            
                total = total + 1
                
            pbar.update(BATCH_SIZE)
                
    return correct / total


def train(model, train_set, test_set, ax = None, label = None):

    if (lable is not None) and os.path.isfile('./checkpoints/{}.ckpt'.format(label)):
        model.load_state_dict('./checkpoints/{}.ckpt'.format(label))
    
    
    criterion = CrossEntropyLoss()
    optimizer = Adam(model.parameters())
    
    
    scores, losses = [], []
    
    if (lable is not None) and os.path.isfile('./checkpoints/{}.scores.json'.format(label)):
        with open('./checkpoints/{}.scores.json'.format(label), 'r') as file:
            scores = json.load(file)
            
    if (lable is not None) and os.path.isfile('./checkpoints/{}.losses.json'.format(label)):
        with open('./checkpoints/{}.losses.json'.format(label), 'r') as file:
            losses = json.load(file)
    
    
    train_loader = DataLoader(train_set, batch_size = BATCH_SIZE)
    test_loader = DataLoader(test_set, batch_size = BATCH_SIZE)
    
    
    with tqdm(total = EPOCH_TIME * (len(train_loader) + len(test_loader)) * BATCH_SIZE) as pbar:
        for epoch in range(0, EPOCH_TIME):

            for idx, (images, labels) in enumerate(train_loader):

                outputs = model(images.to(DEVICE).float())
                loss = criterion(outputs, labels.to(DEVICE))

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                pbar.update(BATCH_SIZE)

            loss = loss.cpu().item()
            score = validate(model, test_loader, pbar)
            
            pbar.set_postfix(loss = loss, score = score)
            
            losses.append(loss)
            scores.append(score)
            
    if ax is not None:
        
        ax.plot(scores, label = 'score')
        ax.plot(losses, label = 'loss')

        ax.legend()


train(cnn, train_set, test_set, ax = plt.subplot())

  5%|▍         | 8928/182400 [31:51<8:11:31,  5.88it/s, loss=1.6, score=0.425]  