In [1]:
import numpy as np
import pandas as pd
import json
import matplotlib.pyplot as plt
from PIL import Image

from pathlib import Path
import os
from tqdm import tqdm
import datetime
import pretrainedmodels
import timm

In [2]:
%load_ext autoreload

In [3]:
%autoreload 2

In [4]:
# path to data
data_path = Path(os.getcwd()) / "cassava-leaf-disease-classification"

In [5]:
from data_utils import DatasetConstructor, reshape_model
from experimenter import Experimenter
from loss import BiTemperedLogisticLoss

from models import replace_activations, MishActivation

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import time
import copy

import albumentations as A
from albumentations.pytorch import ToTensorV2
from efficientnet_pytorch import EfficientNet
from ranger.ranger2020 import Ranger
import pytorch_warmup as warmup

In [6]:
# A bit of hyperparameter setup

# resize the images to be square
img_size = 384

# number of classes in the dataset
num_classes = 5

# batch size for training 
data_batch_size = 4
training_batch_size = 32

# max number of epochs to train for
num_epochs = 25

# set GPU/CPU 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# a parameter to keep learning rate constant for first x epochs before applying annealing
flatness = 4

In [7]:
timm.list_models('*resnext*')

['cspresnext50',
 'cspresnext50_iabn',
 'ecaresnext26tn_32x4d',
 'gluon_resnext50_32x4d',
 'gluon_resnext101_32x4d',
 'gluon_resnext101_64x4d',
 'gluon_seresnext50_32x4d',
 'gluon_seresnext101_32x4d',
 'gluon_seresnext101_64x4d',
 'ig_resnext101_32x8d',
 'ig_resnext101_32x16d',
 'ig_resnext101_32x32d',
 'ig_resnext101_32x48d',
 'legacy_seresnext26_32x4d',
 'legacy_seresnext50_32x4d',
 'legacy_seresnext101_32x4d',
 'resnext50_32x4d',
 'resnext50d_32x4d',
 'resnext101_32x4d',
 'resnext101_32x8d',
 'resnext101_64x4d',
 'seresnext26_32x4d',
 'seresnext26d_32x4d',
 'seresnext26t_32x4d',
 'seresnext26tn_32x4d',
 'seresnext50_32x4d',
 'seresnext101_32x4d',
 'seresnext101_32x8d',
 'skresnext50_32x4d',
 'ssl_resnext50_32x4d',
 'ssl_resnext101_32x4d',
 'ssl_resnext101_32x8d',
 'ssl_resnext101_32x16d',
 'swsl_resnext50_32x4d',
 'swsl_resnext101_32x4d',
 'swsl_resnext101_32x8d',
 'swsl_resnext101_32x16d',
 'tv_resnext50_32x4d']

In [8]:
# grab pretrained resnext50 model (a pretty high performing image classification model)
#model = torch.hub.load('pytorch/vision:v0.6.0', 'resnext50_32x4d', pretrained=True
model = timm.create_model('seresnext50_32x4d', pretrained=True)

In [9]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): ReLU(inplace=True)
      (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (se): SEModule(
        (fc1): Conv2

In [10]:
# features passed to final layer
features = model.fc.in_features

from collections import OrderedDict

# reshape final layer
model.fc = nn.Sequential(OrderedDict([
    ('dropout', nn.Dropout(0.5)),
    ('linear', nn.Linear(in_features = features, out_features=num_classes, bias=True))]
))

In [11]:
for module in model.modules():
    if hasattr(module, 'act1'):
        module.act1 = MishActivation()
    if hasattr(module, 'act2'):
        module.act2 = MishActivation()
    if hasattr(module, 'se'):
        module.se.act = MishActivation()

In [12]:
from models import WSConv2d

In [13]:
for module in model.modules():
    if hasattr(module, 'conv1'):
        in_channels = module.conv1.in_channels
        out_channels= module.conv1.out_channels
        kernel_size = module.conv1.kernel_size
        stride = module.conv1.stride
        padding = module.conv1.padding
        dilation = module.conv1.dilation
        groups = module.conv1.groups
        bias = module.conv1.bias
        module.conv1 = WSConv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
    if hasattr(module, 'conv2'):
        in_channels = module.conv2.in_channels
        out_channels= module.conv2.out_channels
        kernel_size = module.conv2.kernel_size
        stride = module.conv2.stride
        padding = module.conv2.padding
        dilation = module.conv2.dilation
        groups = module.conv2.groups
        bias = module.conv2.bias
        module.conv2 = WSConv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
    if hasattr(module, 'conv3'):
        in_channels = module.conv3.in_channels
        out_channels= module.conv3.out_channels
        kernel_size = module.conv3.kernel_size
        stride = module.conv3.stride
        padding = module.conv3.padding
        dilation = module.conv3.dilation
        groups = module.conv3.groups
        bias = module.conv3.bias
        module.conv3 = WSConv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
    if hasattr(module, 'downsample'):
        if module.downsample is not None:
            in_channels = module.downsample[0].in_channels
            out_channels= module.downsample[0].out_channels
            kernel_size = module.downsample[0].kernel_size
            stride = module.downsample[0].stride
            padding = module.downsample[0].padding
            dilation = module.downsample[0].dilation
            groups = module.downsample[0].groups
            bias = module.downsample[0].bias
            module.downsample[0] = WSConv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)

In [14]:
for module in model.modules():
    if hasattr(module, 'bn1'):
        num_channels = module.bn1.num_features
        module.bn1 = torch.nn.GroupNorm(num_groups=32, num_channels = num_channels)
    if hasattr(module, 'bn2'):
        num_channels = module.bn2.num_features
        module.bn2 = torch.nn.GroupNorm(num_groups=32, num_channels = num_channels)
    if hasattr(module, 'bn3'):
        num_channels = module.bn3.num_features
        module.bn3 = torch.nn.GroupNorm(num_groups=32, num_channels = num_channels)
    if hasattr(module, 'downsample'):
        if module.downsample is not None:
            num_channels = module.downsample[1].num_features
            module.downsample[1] = torch.nn.GroupNorm(num_groups=32, num_channels = num_channels)

In [13]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): MishActivation()
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): MishActivation()
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act2): MishActivation()
      (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (se): SEModule(
        (fc1): Conv2d(256,

In [14]:
class SnapMixNet(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.backbone = nn.Sequential(*model.children())[:-2]
        self.pool = model.global_pool
        self.fc = model.fc

    def forward_features(self, x):
        x = self.backbone(x)
        return x

    def forward(self, x):
        feats = self.forward_features(x)
        x = self.pool(feats).view(x.size(0), -1)
        x = self.fc(x)
        return x, feats

In [15]:
model = SnapMixNet(model)

In [16]:
model


SnapMixNet(
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): MishActivation()
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): MishActivation()
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): MishActivation()
        (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (s

In [17]:
from loss import TaylorCrossEntropyLoss

#reshape_model(model, num_classes)

model = model.to(device)

#optimiser = optim.Adam(model.parameters(), lr=2e-4)
optimiser_config = {
    'lr': 1e-4,
    'weight_decay': 1e-6
}
optimiser = Ranger(model.parameters(), **optimiser_config)
warmup_scheduler = warmup.RAdamWarmup
lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts
lr_config = {
    'T_0': 2000,
    'T_mult': 2,
    'eta_min': 1e-6,
    'last_epoch': -1
}
loss_function = BiTemperedLogisticLoss(t1=0.8, t2=1.4)
#loss_function = TaylorCrossEntropyLoss()

Ranger optimizer loaded. 
Gradient Centralization usage = True
GC applied to both conv and fc layers


In [18]:
transform = {
    'train_transform': A.Compose(
        [
            A.RandomResizedCrop(img_size, img_size),
            A.Transpose(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            #A.ShiftScaleRotate(p=0.5),
            #A.HueSaturationValue(hue_shift_limit=0.2, sat_shift_limit=0.2, val_shift_limit=0.2, p=0.5),
            #A.RandomBrightnessContrast(brightness_limit=(-0.1,0.1), contrast_limit=(-0.1, 0.1), p=0.5),
            #A.CoarseDropout(p=0.5),
            #A.Cutout(p=0.5),
            #A.FancyPCA(),
            #A.GaussNoise(),
            #A.OpticalDistortion(),
            A.Normalize(mean=[0.4303, 0.4967, 0.3134], std=[0.2330, 0.2359, 0.2237], max_pixel_value=255.0, p=1.0),
            ToTensorV2(p=1.0),
        ],
    p=1.),
    'eval_transform': A.Compose(
        [
            A.Resize(img_size, img_size),
            A.Normalize(mean=(0.4303, 0.4967, 0.3134), std=(0.2330, 0.2359, 0.2237)),
            ToTensorV2(),
        ],
    p=1.),
}

In [19]:
constructor = DatasetConstructor(
    data_path / "train_images", 
    data_batch_size, 
    k=5, 
    transform=transform
)

In [20]:
snapmix_config = {
    'prob': 0.5,
    'alpha': 5.0
}

In [21]:
experiment = Experimenter(
    model, 
    "wideresnet_test", 
    device, 
    loss_function, 
    optimiser, 
    optimiser_config, 
    lr_scheduler, 
    lr_config, 
    constructor,
    data_batch_size,
    training_batch_size,
    num_epochs, 
    num_classes,
    img_size,
    warmup_scheduler=None,
    flatness=flatness, 
    patience=5, 
    tta=10,
    snapmix=snapmix_config
    
)

one hot


In [None]:
experiment.train()

  0%|                                                                                      | 0/4280 [00:00<?, ?batch/s]

Epoch 1/25



100%|███████████████████████████████████████████████████████████████| 4280/4280 [21:33<00:00,  3.31batch/s, loss=0.251]
  0%|                                                                   | 1/1070 [00:00<02:21,  7.58batch/s, loss=1.15]

Training loss - Epoch 1/25: 1.1058278656847853
Training accuracy - Epoch 1/25: 0.48343751825670384


100%|████████████████████████████████████████████████████████████████| 1070/1070 [01:50<00:00,  9.71batch/s, loss=1.12]
  0%|                                                                                      | 0/4280 [00:00<?, ?batch/s]

Validation loss - Epoch 1/25: 0.6738356157998058
Validation accuracy - Epoch 1/25: 0.6691588759422302
Epoch 2/25



100%|████████████████████████████████████████████████████████████████| 4280/4280 [21:53<00:00,  3.26batch/s, loss=2.04]
  0%|                                                                   | 1/1070 [00:00<01:47,  9.90batch/s, loss=1.19]

Training loss - Epoch 2/25: 0.7211252148619265
Training accuracy - Epoch 2/25: 0.6604545188993398


100%|█████████████████████████████████████████████████████████████████| 1070/1070 [01:49<00:00,  9.75batch/s, loss=1.2]
  0%|                                                                                      | 0/4280 [00:00<?, ?batch/s]

Validation loss - Epoch 2/25: 0.521926538059575
Validation accuracy - Epoch 2/25: 0.7495327591896057
Epoch 3/25



 48%|███████████████████████████████▍                                 | 2066/4280 [10:37<10:47,  3.42batch/s, loss=1.6]