### packages

In [1]:
from pathlib import Path
from PIL import Image
import pickle
import pdb
from fastprogress.fastprogress import master_bar, progress_bar
import random
import os

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cv2 

import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate
from torchvision import transforms as T
from torchvision import models
from efficientnet_pytorch import EfficientNet 

from sklearn.model_selection import GroupKFold, StratifiedKFold
from sklearn.metrics import accuracy_score
from scipy.special import softmax

import albumentations as A
from albumentations.pytorch.transforms import ToTensor, ToTensorV2
from albumentations import ImageOnlyTransform

### global configurations

Those are the core configurations of this notebook. 
Read this carefully: 

- paths do not need to change unless you changed folder structure of this project 
- if you add more data (especially add a folder full of images under ./data/train), you should only need to alter *label_encoding*
- each time you want to train and save a new model, please change *model_name* so you won't rewrite other's result
- *epoch* stands for the total epochs of training. 8-15 should be sufficient for this task. You can add up epochs but you'd better do not want to surpass 20


In [2]:
path = Path('./data/')
path_train = path/'train'
path_test = path/'test'

In [3]:
image_size, bs = 384, 8

In [4]:
# encode tree types to integer labels range from 0
# the string must be the same as the folder name the ./data/train/xxx
# which stands for the tree type that images belong to 
label_encoding = (
    ("Chinar", 0), 
    ("Gauva", 1),
    ("Jamun", 2),
)
num_labels = len(label_encoding)
label_encoding = dict(label_encoding)
num_labels, label_encoding

(3, {&#39;Chinar&#39;: 0, &#39;Gauva&#39;: 1, &#39;Jamun&#39;: 2})

In [None]:
# model name that will be saved after all epochs
model_name = 'baseline-model'

In [None]:
# total epochs
epoch = 12 

## helper functions to load data

In [5]:
# puts all your images under different folders, for example:
# - train 
# - - tree type 1
# - - - all tree type 1 images...
# - - tree type 2
# - - - ...
# - - tree type 3
# - - - ...
def list_all_train_files(path: Path):
    '''Return all image file paths in a list

    Returns:
        files: a list contains all image file paths

    Args:
        path: the path that holds all the images
    '''
    files = []
    for o in path.iterdir():
        files.extend([f for f in o.iterdir()])
    return files

# for all files in the same folder
# def list_all_train_files(path:Path):
#     return [f for f in path.iterdir()]

In [6]:
train_fnames = list_all_train_files(path_train)
train_fnames[:5], len(train_fnames)

([WindowsPath(&#39;data/train/Chinar/0011_0001.JPG&#39;),
  WindowsPath(&#39;data/train/Chinar/0011_0002.JPG&#39;),
  WindowsPath(&#39;data/train/Chinar/0011_0003.JPG&#39;),
  WindowsPath(&#39;data/train/Chinar/0011_0004.JPG&#39;),
  WindowsPath(&#39;data/train/Chinar/0011_0005.JPG&#39;)],
 1094)

In [7]:
test_fnames = list_all_train_files(path_test)
test_fnames[:5], len(test_fnames)

([], 0)

## Dataset

In [8]:
class TreeDataset(Dataset):
    def __init__(self, f_paths: list, transforms=None, is_test=False):
        self.f_paths = f_paths    
        self.transforms = transforms
        self.is_test = is_test

    def __getitem__(self, index):
        # 1.get image file
        img_path = self.f_paths[index]
        image = np.array(Image.open(img_path), dtype=np.float32)

        # transform?
        if self.transforms:
            image = self.transforms(**{'image': image})['image']

        # test?
        if self.is_test:
            return image

        # 2.get the corresponding label to this image
        tree_type = str(img_path).split('\\')[-2]
        label = label_encoding[tree_type]
        target = torch.tensor([label], dtype=torch.long)

        return image, target

    def __len__(self):
        return len(self.f_paths)

## model

In [9]:
class TreeEfficientNet(nn.Module):
    def __init__(self, model_name='efficientnet-b3', pool_type=F.adaptive_avg_pool2d):
        super(TreeEfficientNet, self).__init__()
        self.pool_type = pool_type
        self.backbone = EfficientNet.from_pretrained(model_name)
        
        image_in_features = getattr(self.backbone, '_fc').in_features
        self.efn_head = nn.Sequential(
            nn.Dropout(p=0.3),
            nn.Linear(image_in_features, 512),
            nn.Dropout(p=0.3),
            nn.Linear(512, 256),
            nn.Dropout(p=0.3),
            nn.Linear(256, 128),
        )
        self.classifer = nn.Linear(128, num_labels)

    def forward(self, x):
        cnn_features = self.pool_type(self.backbone.extract_features(x), 1)
        cnn_features = cnn_features.view(x.size(0), -1)
        cnn_features = self.efn_head(cnn_features)

        return self.classifer(cnn_features)

## Focal Loss

In [10]:
class FocalLoss(nn.Module):
    def __init__(self, alpha, gamma):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        
    def forward(self, preds, truth):
        criterion = nn.CrossEntropyLoss()
        pt = criterion(preds, truth.to(dtype=torch.long))
        log_pt = torch.log(pt)
        focal_loss = self.alpha * (1-pt)**self.gamma * log_pt
        return torch.mean(focal_loss)

## helper functions for forward&backward propagation

In [11]:
# get device

def get_device():
    return torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [12]:
# this is just for cleaning the possible None values

def collate_fn(batch):
    batch = [(data, target) for (data, target) in batch if data is not None]
    return default_collate(batch)

In [13]:
# shit ton of augmentations using albumentations

def get_augmentations(p=0.5, img_size=image_size):
    # give pretrained image_net stats
    imagenet_stats = {'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}
    
    # this is for training
    train_tfms = A.Compose([
        # simple cutout regularization
        A.Cutout(p=p),
        # rotation
        #A.RandomRotate90(p=p),
        #A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45, p=p),
        # flip
        A.Flip(p=p),
        # one of color augmentation
        A.OneOf([
            A.RandomBrightnessContrast(brightness_limit=0.2,
                                      contrast_limit=0.2,),
            A.HueSaturationValue(
                hue_shift_limit=20,
                sat_shift_limit=50,
                val_shift_limit=50)
        ], p=p),
        # one of noise augmentation
        A.OneOf([
            A.IAAAdditiveGaussianNoise(),
            A.GaussNoise()
        ], p=p),
        # one of blurring augmenation
        A.OneOf([
            A.MotionBlur(p=0.2),
            A.MedianBlur(blur_limit=3, p=0.1),
            A.Blur(blur_limit=3, p=0.1),
        ], p=p),
        # one of distortion
        A.OneOf([
            A.OpticalDistortion(p=0.3),
            A.GridDistortion(p=0.1),
            A.IAAPiecewiseAffine(p=0.3),
        ], p=p),
        A.Resize(img_size, img_size, always_apply=True),
        # must do: to tensor
        ToTensor(normalize=imagenet_stats),
    ])
    
    # this is for TTA
    test_tfms = A.Compose([
        A.RandomRotate90(p=p),
            A.Flip(p=p),
            A.OneOf([
                A.RandomBrightnessContrast(brightness_limit=0.2,
                                           contrast_limit=0.2,
                                           ),
                A.HueSaturationValue(
                    hue_shift_limit=20,
                    sat_shift_limit=50,
                    val_shift_limit=50)
            ], p=p),
            A.OneOf([
                A.IAAAdditiveGaussianNoise(),
                A.GaussNoise(),
            ], p=p),
        ToTensor(normalize=imagenet_stats)
        ])

    valid_tfms = A.Compose([
        ToTensor(normalize=imagenet_stats)
    ])

    return train_tfms, valid_tfms, test_tfms

In [14]:
def get_data(train_tfms, valid_tfms):
    train_ds = TreeDataset(train_fnames, train_tfms)
    valid_ds = TreeDataset(train_fnames, valid_tfms)
    train_dl = DataLoader(dataset=train_ds, batch_size=bs, shuffle=True, num_workers=0, collate_fn=collate_fn)
    valid_dl = DataLoader(dataset=valid_ds, batch_size=bs, shuffle=True, num_workers=0, collate_fn=collate_fn)
    return train_dl, valid_dl

## Set up our model

In [15]:
def get_model(model_name='efficientnet-b3', lr=1e-5, wd=0.01, freeze_backbone=False, opt_fn=torch.optim.AdamW, device=None):
    # 1. get device
    device = device if device else get_device()
    # 2.get our model
    pool_type = F.adaptive_avg_pool2d
    model = TreeEfficientNet(model_name=model_name, pool_type=pool_type)
    if freeze_backbone:
        for parameter in model.backbone.parameters():
            parameter.requires_grad = False
    # 3. get our optimizer for back propagation - AdamW tends to work better
    optimizer = opt_fn(model.parameters(), lr=lr, weight_decay=wd)
    # 4. move our model to device
    model.to(device)

    return model, optimizer

In [16]:
def training_step(xb, yb, model, loss_fn, opt, device, scheduler):
    # forward
    xb, yb = xb.to(device), yb.reshape(-1).to(device)
    out = model(xb)
    loss = loss_fn(out, yb)

    # backward
    opt.zero_grad()
    loss.backward()
    opt.step()
    scheduler.step()

    return loss.item()

In [17]:
def validation_step(xb, yb, model, loss_fn, device):
    xb, yb = xb.to(device), yb.reshape(-1).to(device)
    out = model(xb)
    loss = loss_fn(out, yb)
    
    out = torch.sigmoid(out)
    
    return loss.item(), out

In [18]:
# wrap up to a fit one cycle funcition
def fit(epochs, train_dl, valid_dl, model, loss_fn, opt, device=None):
    # set up device for data
    device = device if device else get_device()
    # set up scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, len(train_dl)*epochs)
    val_accuracy_scores = []
    
    # creating a progress bar
    mb = master_bar(range(epochs))
    mb.write(['epochs', 'train_loss', 'valid_loss', 'accuracy'], table=True)
    
    # iterate 10 epochs
    for epoch in mb:
        trn_loss, val_loss = 0., 0.
        val_preds = np.zeros((len(valid_dl.dataset), 3))
        val_targets = np.zeros((len(valid_dl.dataset), 1))
        val_scores = []
        
        # training mode
        model.train()
        
        # for every batch, we step and collect training loss
        for xb, yb in progress_bar(train_dl, parent=mb):
            trn_loss += training_step(xb, yb, model=model, loss_fn=loss_fn, opt=opt, device=device, scheduler=scheduler)
        trn_loss /= mb.child.total # 10
        
        # validation mode
        # now we need valid_loss and val_score from the validatin steps (witout gradients of course)
        with torch.no_grad():
            for i, (xb, yb) in enumerate(progress_bar(valid_dl, parent=mb)):
                loss, out = validation_step(xb, yb, model=model, loss_fn=loss_fn, device=device)
                val_loss += loss
                bs = xb.shape[0]
                val_preds[i*bs: i*bs+bs] = out.cpu().numpy()
                val_targets[i*bs: i*bs+bs]= yb.cpu().numpy()
                
        preds = np.argmax(softmax(val_preds, axis=1), axis=1)
        true = val_targets.reshape(-1)
        accuracy = accuracy_score(true, preds)
        val_accuracy_scores.append(accuracy)
        mb.write([epoch, f'{trn_loss:.6f}', f'{val_loss:.6f}', f'{accuracy:.6f}'], table=True)

    return model, val_accuracy_scores

## start training validation

In [19]:
train_tfms, valid_tfms, test_tfms = get_augmentations()
# get train and validation dataloader
train_dl, valid_dl = get_data(train_tfms=train_tfms, valid_tfms=valid_tfms)
# get our loss func
loss_fn = FocalLoss(alpha=0.25, gamma=2)

model, opt = get_model(model_name='efficientnet-b3', lr=1e-5, wd=1e-2)

model, accuracy_score = fit(epoch, train_dl, valid_dl, model, loss_fn, opt)

Loaded pretrained weights for efficientnet-b3


KeyboardInterrupt: 

In [None]:
print(f'Training finished, TOTAL epochs: {epoch}\nSaving model as :{model_name}')
torch.save(model.state_dict(), f'./models/{model_name}.pth')
print('You can check your model in ./models')