In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import sys
package_dir = '../input/early-stopping-pytorch'
sys.path.append(package_dir)
from pytorchtools import EarlyStopping
from PIL import Image
import numpy as np
import pandas as pd
from functools import partial
from sklearn import metrics
from sklearn.model_selection import KFold
from sklearn.metrics import cohen_kappa_score, confusion_matrix
import seaborn as sns
from collections import Counter, OrderedDict
from itertools import chain
import json
import math
import numbers
import time
import cv2
import gc
import torchvision
from torchvision import transforms
import torch.nn as nn
from tqdm import tqdm_notebook as tqdm
from torch.utils.data import Dataset
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.optim.optimizer import Optimizer, required
from torch.nn import functional as F
from torch.utils.data.sampler import SubsetRandomSampler
import albumentations
import matplotlib.pyplot as plt
import os
from pynverse import inversefunc

IMG_SIZE = 256

# To have reproducible results and compare them
seedValue = 2019
np.random.seed(seedValue)
torch.manual_seed(seedValue)
os.environ['PYTHONHASHSEED'] = str(seedValue)
torch.cuda.manual_seed(seedValue)
torch.cuda.manual_seed_all(seedValue) 
# 学習速度が遅くなるらしく、以下２行はコメントアウト
#torch.backends.cudnn.deterministic = True  
#torch.backends.cudnn.benchmark = False

# Specify GPU usage
os.environ["CUDA_VISIBLE_DEVICES"] = "6,7,8,9" 
device_ids = [0,1,2,3]
device = torch.device("cuda:0")

def quadratic_kappa(y_hat, y, coef):
    for pred in enumerate(y_hat):
            if pred[1] < coef[0]:
                y_hat[1] = 0
            elif pred[1] >= coef[0] and pred[1] < coef[1]:
                y_hat[1] = 1
            elif pred[1] >= coef[1] and pred[1] < coef[2]:
                y_hat[1] = 2
            elif pred[1] >= coef[2] and pred[1] < coef[3]:
                y_hat[1] = 3
            else:
                y_hat[1] = 4
                
    return torch.tensor(cohen_kappa_score(torch.round(y_hat), y, weights='quadratic'),device=device)

# confusion matrix
def plot_cmx(true, output):
    labels = [0,1,2,3,4]
    cmx = confusion_matrix(true, output,labels=labels)
    plt.figure(figsize=(6,4)) 
    plt.title("Confusion Matrix")
    sns.heatmap(cmx, annot = True)
    plt.show()
    
# calculate scale from ratio
inv = inversefunc(lambda x : (np.arcsin(2*x**2-1)+2*x*np.sqrt(1-x**2))/(2*x**2),domain=[0.7,1.0])

# pre-processing

In [None]:
def crop_image1(img,tol=7):
    # 'tol' is tolerance
    mask = img>tol
    return img[np.ix_(mask.any(1),mask.any(0))]

def crop_image_from_gray(img,tol=7):
    if img.ndim ==2:
        mask = img>tol
        return img[np.ix_(mask.any(1),mask.any(0))]
    elif img.ndim==3:
        gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
        mask = gray_img>tol
        
        check_shape = img[:,:,0][np.ix_(mask.any(1),mask.any(0))].shape[0]
        if (check_shape == 0): # image is too dark so that we crop out everything,
            return img 
        else:
            img1=img[:,:,0][np.ix_(mask.any(1),mask.any(0))]
            img2=img[:,:,1][np.ix_(mask.any(1),mask.any(0))]
            img3=img[:,:,2][np.ix_(mask.any(1),mask.any(0))]
            img = np.stack([img1,img2,img3],axis=-1)
        
        return img
    

def zoom_to_center(img, tol=7, th = 0.95 ,p= 1.0):
    img = crop_image_from_gray(img, tol = tol) 
    gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    s1 = gray_img.size
    mask = gray_img > 7
    s2 = mask.sum()
    # calculate Ratio of object to image
    ratio = s2/s1
    if ratio <= th:
        # The larger rest you have, the more a image is zoomed.
        rest = np.sqrt(2)*inv(ratio)-1 if (ratio>= np.pi/4) else np.sqrt(2)-1
        sl = max(rest-0.125,0)
        aug = albumentations.ShiftScaleRotate(shift_limit = 0.01, scale_limit=(sl,sl + 0.05),
                                              rotate_limit=5, p=p)
        img = aug(image=img)['image']

    return img

In [None]:
class RetinopathyDataset(Dataset):
    def __init__(self, csv_file, transform, datatype='train'):
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.datatype = datatype

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # 画像データの読み込み
        if self.datatype=='train':
            img_name = os.path.join('../input/aptos2019-blindness-detection/train_images',
                                self.data.loc[idx, 'id_code'] + '.png')
            label = self.data.loc[idx, 'diagnosis']
        elif self.datatype=='train_old':
            img_name = os.path.join('../input/diabetic-retinopathy-resized/resized_train',
                                self.data.loc[idx, 'image'] + '.jpeg')
            label = self.data.loc[idx, 'level']
        else:
            img_name =  os.path.join('../input/aptos2019-blindness-detection/test_images',
                                     self.data.loc[idx, 'id_code'] + '.png')
        # image preprocessing
        img = cv2.imread(img_name)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = zoom_to_center(img)
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))
        img = self.transform(image=img)['image']
        img = torch.from_numpy(img).permute(2,0,1)
       
        if self.datatype=='train':
            return {'image': img,
                    'labels': label
                    }
        elif self.datatype=='train_old':
            return {'image': img,
                    'labels': label
                    }
        else:
            return {'image': img}

In [None]:
transform = albumentations.Compose([
    albumentations.HorizontalFlip(),
    albumentations.VerticalFlip(),
    albumentations.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

In [None]:
def RMSELoss(yhat,y):
    return torch.sqrt(torch.mean((yhat-y)**2))
    
class AdaptiveConcatPool2d(nn.Module):
    def __init__(self,sz=1):
        super().__init__()
        self.output_size = sz
        self.ap = nn.AdaptiveAvgPool2d(sz)
        self.mp = nn.AdaptiveMaxPool2d(sz)
    
    def forward(self,x):
        return torch.cat([self.mp(x),self.ap(x)],1)
    
class RAdam(Optimizer):

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        self.buffer = [[None, None, None] for ind in range(10)]
        super(RAdam, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(RAdam, self).__setstate__(state)

    def step(self, closure=None):

        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:

            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data.float()
                if grad.is_sparse:
                    raise RuntimeError('RAdam does not support sparse gradients')

                p_data_fp32 = p.data.float()

                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
                else:
                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
                exp_avg.mul_(beta1).add_(1 - beta1, grad)

                state['step'] += 1
                buffered = self.buffer[int(state['step'] % 10)]
                if state['step'] == buffered[0]:
                    N_sma, step_size = buffered[1], buffered[2]
                else:
                    buffered[0] = state['step']
                    beta2_t = beta2 ** state['step']
                    N_sma_max = 2 / (1 - beta2) - 1
                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
                    buffered[1] = N_sma

                    # more conservative since it's an approximated value
                    if N_sma >= 5:
                        step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
                    else:
                        step_size = group['lr'] / (1 - beta1 ** state['step'])
                    buffered[2] = step_size

                if group['weight_decay'] != 0:
                    p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)

                # more conservative since it's an approximated value
                if N_sma >= 5:            
                    denom = exp_avg_sq.sqrt().add_(group['eps'])
                    p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
                else:
                    p_data_fp32.add_(-step_size, exp_avg)

                p.data.copy_(p_data_fp32)

        return loss

In [None]:
class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class Flatten(nn.Module):
    def forward(self, x):
        return x.reshape(x.shape[0], -1)
    
class SqueezeExcitation(nn.Module):
    
    def __init__(self, inplanes, se_planes):
        super(SqueezeExcitation, self).__init__()
        self.reduce_expand = nn.Sequential(
            nn.Conv2d(inplanes, se_planes, 
                      kernel_size=1, stride=1, padding=0, bias=True),
            Swish(),
            nn.Conv2d(se_planes, inplanes, 
                      kernel_size=1, stride=1, padding=0, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        x_se = torch.mean(x, dim=(-2, -1), keepdim=True)
        x_se = self.reduce_expand(x_se)
        return x_se * x
    
class MBConv(nn.Module):
    def __init__(self, inplanes, planes, kernel_size, stride, 
                 expand_rate=1.0, se_rate=0.25, 
                 drop_connect_rate=0.2):
        super(MBConv, self).__init__()

        expand_planes = int(inplanes * expand_rate)
        se_planes = max(1, int(inplanes * se_rate))

        self.expansion_conv = None        
        if expand_rate > 1.0:
            self.expansion_conv = nn.Sequential(
                nn.Conv2d(inplanes, expand_planes, 
                          kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(expand_planes, momentum=0.01, eps=1e-3),
                Swish()
            )
            inplanes = expand_planes

        self.depthwise_conv = nn.Sequential(
            nn.Conv2d(inplanes, expand_planes,
                      kernel_size=kernel_size, stride=stride, 
                      padding=kernel_size // 2, groups=expand_planes,
                      bias=False),
            nn.BatchNorm2d(expand_planes, momentum=0.01, eps=1e-3),
            Swish()
        )

        self.squeeze_excitation = SqueezeExcitation(expand_planes, se_planes)
        
        self.project_conv = nn.Sequential(
            nn.Conv2d(expand_planes, planes, 
                      kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(planes, momentum=0.01, eps=1e-3),
        )

        self.with_skip = stride == 1
        self.drop_connect_rate = torch.tensor(drop_connect_rate, requires_grad=False)
    
    def _drop_connect(self, x):        
        keep_prob = 1.0 - self.drop_connect_rate
        drop_mask = torch.rand(x.shape[0], 1, 1, 1) + keep_prob
        drop_mask = drop_mask.type_as(x)
        drop_mask.floor_()
        return drop_mask * x / keep_prob
        
    def forward(self, x):
        z = x
        if self.expansion_conv is not None:
            x = self.expansion_conv(x)

        x = self.depthwise_conv(x)
        x = self.squeeze_excitation(x)
        x = self.project_conv(x)
        
        # Add identity skip
        if x.shape == z.shape and self.with_skip:            
            if self.training and self.drop_connect_rate is not None:
                self._drop_connect(x)
            x += z
        return x
    
def init_weights(module):    
    if isinstance(module, nn.Conv2d):    
        nn.init.kaiming_normal_(module.weight, a=0, mode='fan_out')
    elif isinstance(module, nn.Linear):
        init_range = 1.0 / math.sqrt(module.weight.shape[1])
        nn.init.uniform_(module.weight, a=-init_range, b=init_range)
        
class EfficientNet(nn.Module):
        
    def _setup_repeats(self, num_repeats):
        return int(math.ceil(self.depth_coefficient * num_repeats))
    
    def _setup_channels(self, num_channels):
        num_channels *= self.width_coefficient
        new_num_channels = math.floor(num_channels / self.divisor + 0.5) * self.divisor
        new_num_channels = max(self.divisor, new_num_channels)
        if new_num_channels < 0.9 * num_channels:
            new_num_channels += self.divisor
        return new_num_channels

    def __init__(self, num_classes, 
                 width_coefficient=1.0,
                 depth_coefficient=1.0,
                 se_rate=0.25,
                 dropout_rate=0.2,
                 drop_connect_rate=0.2):
        super(EfficientNet, self).__init__()
        
        self.width_coefficient = width_coefficient
        self.depth_coefficient = depth_coefficient
        self.divisor = 8
                
        list_channels = [32, 16, 24, 40, 80, 112, 192, 320, 1280]
        list_channels = [self._setup_channels(c) for c in list_channels]
                
        list_num_repeats = [1, 2, 2, 3, 3, 4, 1]
        list_num_repeats = [self._setup_repeats(r) for r in list_num_repeats]        
        
        expand_rates = [1, 6, 6, 6, 6, 6, 6]
        strides = [1, 2, 2, 2, 1, 2, 1]
        kernel_sizes = [3, 3, 5, 3, 5, 5, 3]

        # Define stem:
        self.stem = nn.Sequential(
            nn.Conv2d(3, list_channels[0], kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(list_channels[0], momentum=0.01, eps=1e-3),
            Swish()
        )
        
        # Define MBConv blocks
        blocks = []
        counter = 0
        num_blocks = sum(list_num_repeats)
        for idx in range(7):
            
            num_channels = list_channels[idx]
            next_num_channels = list_channels[idx + 1]
            num_repeats = list_num_repeats[idx]
            expand_rate = expand_rates[idx]
            kernel_size = kernel_sizes[idx]
            stride = strides[idx]
            drop_rate = drop_connect_rate * counter / num_blocks
            
            name = "MBConv{}_{}".format(expand_rate, counter)
            blocks.append((
                name,
                MBConv(num_channels, next_num_channels, 
                       kernel_size=kernel_size, stride=stride, expand_rate=expand_rate, 
                       se_rate=se_rate, drop_connect_rate=drop_rate)
            ))
            counter += 1
            for i in range(1, num_repeats):                
                name = "MBConv{}_{}".format(expand_rate, counter)
                drop_rate = drop_connect_rate * counter / num_blocks                
                blocks.append((
                    name,
                    MBConv(next_num_channels, next_num_channels, 
                           kernel_size=kernel_size, stride=1, expand_rate=expand_rate, 
                           se_rate=se_rate, drop_connect_rate=drop_rate)                                    
                ))
                counter += 1
        
        self.blocks = nn.Sequential(OrderedDict(blocks))
        
        # Define head
        self.head = nn.Sequential(
            nn.Conv2d(list_channels[-2], list_channels[-1], 
                      kernel_size=1, bias=False),
            nn.BatchNorm2d(list_channels[-1], momentum=0.01, eps=1e-3),
            Swish(),
            nn.AdaptiveAvgPool2d(1),
            Flatten(),
            nn.Dropout(p=dropout_rate),
            nn.Linear(list_channels[-1], num_classes)
        )

        self.apply(init_weights)
        
    def forward(self, x):
        f = self.stem(x)
        f = self.blocks(f)
        y = self.head(f)
        return y
    


In [None]:
def load_EfficientNet():
    model = EfficientNet(num_classes=1000, 
                     width_coefficient=1.4, depth_coefficient=1.8,
                     dropout_rate=0.4)

    model_state = torch.load("../data/model/efficientnet-pytorch/efficientnet-b4-e116e8b3.pth")

    # A basic remapping is required
    mapping = {
        k: v for k, v in zip(model_state.keys(), model.state_dict().keys())
    }
    mapped_model_state = OrderedDict([
        (mapping[k], v) for k, v in model_state.items()
    ])

    model.load_state_dict(mapped_model_state, strict=False)
    
    
    model.head[6] = nn.Sequential(
                                  nn.Linear(in_features=1792, out_features=1024, bias=True),
                                  nn.ReLU(),
                                  nn.BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                                  nn.Dropout(p=0.5),
                                  nn.Linear(in_features=1024, out_features=6, bias=True),
                                 ) # classification +  kappa regressor
    
    # freeze
    for param in model.parameters():
            param.requires_grad = False
    # unfreeze head
    for param in model.head[6].parameters():
            param.requires_grad = True
    
    return model


In [None]:
from sklearn.utils import shuffle

df_train = pd.read_csv('../input/aptos2019-blindness-detection/train.csv')
df_test = pd.read_csv('../input/aptos2019-blindness-detection/test.csv')

x = df_train['id_code']
y = df_train['diagnosis']

x, y = shuffle(x, y)
_ = y.hist()

# get class stats
n_classes = int(y.max()+1)
class_weights = len(y) / df_train.groupby('diagnosis').count().values.ravel()  # we can use this to balance our loss function
class_weights *= n_classes / class_weights.sum()
print('class_weights:', class_weights.tolist())

In [None]:
import torchvision.utils as vutils

PLOT = False
if PLOT:
    data_dic = {
                'train':{'csv':'../input/aptos2019-blindness-detection/train.csv','datatype':'train'},
                'test':{'csv':'../input/aptos2019-blindness-detection/test.csv','datatype':'test'},
                '2015':{'csv':'../input/diabetic-retinopathy-resized/new_trainLabels.csv',
                        'datatype':'train_old'},      
               }
    select = '2015'
    sample_dataset = RetinopathyDataset(csv_file=data_dic[select]['csv'],
                                    transform = transform,
                                    datatype=data_dic[select]['datatype'])

    loader = torch.utils.data.DataLoader(sample_dataset, batch_size=64,
                                         num_workers=8, drop_last=True)
    batch = next(iter(loader))

    plt.figure(figsize=(16, 8))
    plt.axis("off")
    plt.title("Training Images")
    _ = plt.imshow( 
        vutils.make_grid(batch['image'][:16], padding=2, normalize=True).cpu().numpy().transpose((1, 2, 0))
    )

In [None]:
# activation = lambda y: (n_classes-1) * torch.sigmoid(y)
# activation = lambda y: (n_classes-1) * (0.5 + 0.5 * y / (1 + y.abs()))  # linear sigmoid
activation = lambda y: y  # no-op

def cont_kappa(input, targets, activation=None):
    ''' continuos version of quadratic weighted kappa '''
    n = len(targets)
    y = targets.float().unsqueeze(0)
    pred = input.float().squeeze(-1).unsqueeze(0)
    if activation is not None:
        pred = activation(pred)
    wo = (pred - y)**2
    we = (pred - y.t())**2
    return 1 - (n * wo.sum() / we.sum())
# adapted from keras version: https://www.kaggle.com/ryomiyazaki/keras-simple-implementation-of-qwk-for-regressor

In [None]:
kappa_loss = lambda pred, y: 1 - cont_kappa(pred, y)  # from 0 to 2 instead of 1 to -1

In [None]:
def smooth_one_hot(true_labels: torch.Tensor, classes = 5, smoothing=0.1):
    """
    if smoothing == 0, it's one-hot method
    if 0 < smoothing < 1, it's smooth method

    """
    assert 0 <= smoothing < 1
    confidence = 1.0 - smoothing
    label_shape = torch.Size((true_labels.size(0), classes))
    with torch.no_grad():
        true_dist = torch.empty(size=label_shape, device=true_labels.device)
        true_dist.fill_(smoothing / (classes - 1))
        true_dist.scatter_(1, true_labels.data.unsqueeze(1), confidence)
    return true_dist


pattern =torch.FloatTensor([
        [0.7, 0.15 ,0.075, 0.05, 0.025],
        [0.1125, 0.7, 0.1125, 0.05, 0.025],
        [0.025, 0.125, 0.7, 0.125, 0.025],
        [0.025, 0.05, 0.1125, 0.7, 0.1125],
        [0.025, 0.05, 0.075, 0.15, 0.7]
         ])

def smooth_pattern(true_labels, pattern=pattern):
    label_shape = torch.Size((true_labels.size(0), 5))
    pattern = pattern.to(true_labels.device)
    with torch.no_grad():
        true_dist = torch.zeros(size=label_shape, device=true_labels.device)
        true_dist = true_dist.scatter_(1, true_labels.data.unsqueeze(1), 1)
        true_dist = torch.einsum('bi,ik->bk', true_dist, pattern)
    return true_dist

class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2., reduction='mean'):
        super().__init__()
        self.alpha = torch.tensor(alpha)
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        self.alpha = self.alpha.type(inputs.type(), non_blocking=True) # fix type and device
        smooth_targets = smooth_pattern(targets)
        CE_loss = nn.BCEWithLogitsLoss()(inputs, smooth_targets)
        pt = torch.exp(-CE_loss)
        F_loss = self.alpha[targets] * (1-pt)**self.gamma * CE_loss

        if self.reduction == 'sum':
            return F_loss.sum()
        elif self.reduction == 'mean':
            return F_loss.mean()
        return F_loss
    
# balance between metric optimisation and classification accuracy
class MultiTaskLoss(FocalLoss):
    def __init__(self, alpha=None, gamma=2.0, second_loss=F.mse_loss, second_mult=0.1):
        super().__init__(alpha, gamma)
        self.second_loss = second_loss
        self.second_mult = second_mult

    def forward(self, inputs, targets):
        loss  = super().forward(inputs[...,:-1], targets)  # focal loss
        loss += self.second_mult * self.second_loss(inputs[...,-1], targets.float())
        return loss


# train

In [None]:
class train(object):
    def __init__(self, params):
        # params
        self.lr = params.get('lr')
        self.bs = params.get('batch_size')
        self.n_epochs = params.get('n_epochs')
        self.n_freeze = params.get('n_freeze')
        self.coef = params.get('coef')
        self.criterion = params.get('criterion')
        self.num_workers = params.get('num_workers')
        self.load_state = params.get('load_state')
        self.load_path = params.get('load_path')
        self.save_path = params.get('save_path')
        self.device = params.get('device')
        self.n_folds = params.get('n_folds')
        self.use_valid = False
        self.early_stop = params.get("early_stop")
        self.patience = params.get("patience")
        self.T_max = params.get("T_max")
    
    def get_train(self,data):
        self.train_data = data
        
    def get_valid(self,data):
        self.valid_data = data
        self.use_valid = True
    
    def fit(self, use_cv=False, train_idx=None, valid_idx=None):
        since = time.time()
        # Model
        model = load_EfficientNet()
        if self.load_state:
            # load params
            model_state = torch.load(self.load_path)
            # A basic remapping is required
            mapping = {
                k: v for k, v in zip(model_state.keys(), model.state_dict().keys())
            }
            mapped_model_state = OrderedDict([
                (mapping[k], v) for k, v in model_state.items()
            ])
            model.load_state_dict(mapped_model_state, strict=False)

        if torch.cuda.device_count() > 1: # train in parallel if multi GPUs available
            model = nn.DataParallel(model, device_ids)
        model = model.to(device)

        # Set learning rate by layer, if you like
        if torch.cuda.device_count() > 1: # multi GPUs 
            plist = [
                        {
                            "params": model.module.stem.parameters(),
                            "lr": self.lr * 0.1,
                        },
                        {
                            "params": model.module.blocks[:16].parameters(),
                            "lr": self.lr * 0.15,
                        },
                        {
                            "params": model.module.blocks[16:].parameters(),
                            "lr": self.lr * 0.2,
                        },
                        {
                            "params": model.module.head[:6].parameters(),
                            "lr": self.lr * 0.3,
                        },    
                        {
                            "params": model.module.head[6].parameters(), 
                            "lr": self.lr,
                        }
            ]
        else: # single GPU
            pass

        
        #optimizer = RAdam(plist,weight_decay=1e-4, lr=self.lr)
        optimizer = optim.Adam(plist,weight_decay=1e-4, lr=self.lr)

        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.T_max, eta_min=self.lr/10)
        
        if use_cv: # if you choose cv, split the data into train and valid
            train_sampler = SubsetRandomSampler(train_idx)
            valid_sampler = SubsetRandomSampler(valid_idx)
            data_loader_train = torch.utils.data.DataLoader(self.train_data, batch_size=self.bs,
                                                num_workers=self.num_workers,sampler=train_sampler, 
                                                            drop_last=True, shuffle=False)
            data_loader_valid = torch.utils.data.DataLoader(self.train_data, batch_size=self.bs,
                                                num_workers=self.num_workers,sampler=valid_sampler, 
                                                            drop_last=True, shuffle=False)
            
        # prepare train and valid data (e.g. train: 2015, valid: 2019)
        elif self.use_valid:
            data_loader_train = torch.utils.data.DataLoader(self.train_data, batch_size=self.bs,
                                                num_workers=self.num_workers, drop_last=True,
                                                            shuffle=True)
            data_loader_valid = torch.utils.data.DataLoader(self.valid_data, batch_size=self.bs,
                                                num_workers=self.num_workers, drop_last=True,
                                                            shuffle=False)
        else: # no validation
            data_loader_train = torch.utils.data.DataLoader(self.train_data, batch_size=self.bs,
                                                num_workers=self.num_workers, drop_last=True,
                                                            shuffle=True)
            
        if self.early_stop: 
            early_stopping = EarlyStopping(patience=self.patience, verbose=True)
        
        for epoch in range(self.n_epochs):
            # unfreeze layers if you like
            if epoch == self.n_freeze:            
                if torch.cuda.device_count() > 1: # multi GPUs
                    for param in model.module.parameters():
                        param.requires_grad = True
                else: # single GPU
                    for param in model.parameters():
                        param.requires_grad = True         

            print('Epoch {}/{}'.format(epoch, self.n_epochs - 1))
            print('-' * 10)
            scheduler.step()
            model.train()
            running_loss = 0.0
            kappa = 0
            steps = 0
            with tqdm(data_loader_train, total=int(len(data_loader_train))) as pbar:
                for bi, d in enumerate(pbar):
                    inputs = d["image"]
                    labels = d["labels"].view(-1, 1)
                    inputs = inputs.to(self.device, dtype=torch.float)
                    labels = labels.to(self.device, dtype=torch.long).view(-1)
                    
                    optimizer.zero_grad()
                    with torch.set_grad_enabled(True):
                        outputs = model(inputs)
                        loss =  self.criterion(outputs, labels)
                        loss.backward()
                        optimizer.step()

                    running_loss += loss.mean().item() 
                    
                    y_hat = torch.Tensor.cpu(outputs).detach()
                    y_hat = y_hat[:,:-1]
                    # calc softmax for submission
                    y_hat = F.softmax(y_hat, dim=-1)
                    _,y_hat = torch.max(y_hat,1)
                    y_hat = y_hat.float()
                    y = torch.Tensor.cpu(labels.view(-1))
                    kappa += quadratic_kappa(y_hat, y, self.coef).mean().item()
                    steps += 1
                    pbar.set_postfix(OrderedDict(loss = running_loss / steps,
                                                 kappa_score = kappa / steps))
                    
            epoch_loss = running_loss / steps
            print('Training Loss: {:.4f}'.format(epoch_loss))
            # calculate kappa score only to monitor training
            kappa = kappa / steps
            print('Training Kappa: {:.4f}'.format(kappa))
            
            if self.use_valid or use_cv:
                model.eval()
                running_loss = 0.0
                kappa = 0
                steps = 0

                true = np.zeros((len(self.valid_data), 1)) if not use_cv else np.zeros((len(valid_sampler), 1))
                preds = np.zeros((len(self.valid_data), 1)) if not use_cv else np.zeros((len(valid_sampler), 1))

                with tqdm(data_loader_valid , total=int(len(data_loader_valid))) as pbar:
                    for step, batch in enumerate(pbar):
                        inputs = batch["image"]
                        labels = batch["labels"].view(-1, 1)
                        inputs = inputs.to(self.device, dtype=torch.float)
                        labels = labels.to(self.device, dtype=torch.long).view(-1)
                    
                        with torch.no_grad():
                            outputs = model(inputs)
                            loss =  self.criterion(outputs, labels)

                        running_loss += loss.mean().item()     
                        
                        y_hat = torch.Tensor.cpu(outputs).detach()
                        y_hat = y_hat[:,:-1]
                        # calc softmax for submission
                        y_hat = F.softmax(y_hat, dim=-1)
                        _,y_hat = torch.max(y_hat,1)
                        y_hat = y_hat.float()
                        y = torch.Tensor.cpu(labels.view(-1))
                        true[step * self.bs:(step + 1) * self.bs] += labels.detach().cpu().squeeze().numpy().ravel().reshape(-1, 1)
                        preds[step * self.bs:(step + 1) * self.bs] += y_hat.numpy().ravel().reshape(-1, 1)
                        kappa += quadratic_kappa(y_hat, y, self.coef).mean().item()
                        steps += 1
                        pbar.set_postfix(OrderedDict(loss = running_loss / steps,
                                                     kappa_score = kappa / steps))
                        
                epoch_loss = running_loss / steps
                print('Validation Loss: {:.4f}'.format(epoch_loss))
                # calculate kappa score only to monitor training
                kappa = kappa / steps
                print('Validation Kappa: {:.4f}'.format(kappa))
                # plot Confusion Matrix
                plot_cmx(true.astype(int),preds.astype(int))
                
                # early stopping
                if self.early_stop:
                    eval_loss = epoch_loss
                    early_stopping(eval_loss, model)
                    if early_stopping.early_stop:
                        print("Early stopping")
                        if use_cv:
                            return early_stopping.val_loss_min
                        else:
                            break

        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        
        # save model except when cv
        if not use_cv:
            torch.save(model.state_dict(), self.save_path)

        del(model, data_loader_train)
        gc.collect()
        torch.cuda.empty_cache() 
        
    def cv(self): # cross validation
        num_train = len(self.train_data)
        indices = list(range(num_train))
        kf = KFold(n_splits=self.n_folds, random_state=1337, shuffle=True)

        train_idx = []
        valid_idx = []

        for t, v in kf.split(indices):
            train_idx.append(t)
            valid_idx.append(v)

        # Training
        scores = np.zeros(self.n_folds)
        for fold in np.arange(self.n_folds):
            print('Fold:',fold)
            scores[fold] = self.fit(use_cv=True,train_idx=train_idx[fold],valid_idx=valid_idx[fold])
        print('CV mean score: {0:.4f}, std: {1:.4f}.'.format(np.mean(scores), np.std(scores)))

In [None]:
#params
params = {
    "lr" : 1e-4,
    "batch_size" : 64,
    "n_epochs" : 20,
    "n_freeze" : 1,
    "num_workers" : 8,
    "coef" : [0.5, 1.5, 2.5, 3.5],
    "criterion" :  MultiTaskLoss(gamma=2., alpha=class_weights, second_loss=kappa_loss, second_mult=0.5),
    "load_state" : False,
    "load_path" : None, 
    'save_path' : None,
    "device" : device,
    "n_folds" : 5,
    "early_stop" : False,
    "patience" : 3,
    "T_max" : 7,
}

#training for the lazy, like me

# train:old & valid:new
if 1:
    params['n_freeze'] = 1
    params['n_epochs'] = 40
    params['early_stop'] = True
    params['patience'] = 10
    params['save_path'] =  "ENb4_ori_cls_old.bin"
    Mytrain = train(params)
    Mytrain.get_train(RetinopathyDataset(csv_file='../input/diabetic-retinopathy-resized/balanced_trainLabels.csv', 
                       transform = transform,
                       datatype='train_old'))
    Mytrain.get_valid(RetinopathyDataset(csv_file="../input/aptos2019-blindness-detection/train.csv", 
                       transform = transform,
                       datatype='train'))
    Mytrain.fit()
    !mv checkpoint.pt ENb4_ori_cls.pt
    
# train:new
if 0:
    params['n_freeze'] = 1
    params['n_epochs'] = 6
    params['lr'] = 1e-4
    params['T_max'] = 6
    params['load_state'] = True 
    params['load_path'] =  "ENb4_ori_cls.pt"
    params['save_path'] =  "ENb4_ori_cls_new.bin"        
    Mytrain = train(params)
    Mytrain.get_train(RetinopathyDataset(csv_file="../input/aptos2019-blindness-detection/train.csv", 
                       transform = transform,
                       datatype='train'))
    Mytrain.fit()
    
# cv: new
if 0:
    params['n_freeze'] = 1
    params['n_epochs'] = 10
    params['load_state'] = True
    params['load_path'] =  "ENb4_ori_cls.pt"
    Mytrain = train(params)
    Mytrain.get_train(RetinopathyDataset(csv_file="../input/aptos2019-blindness-detection/train.csv", 
                       transform = transform,
                       datatype='train'))
    Mytrain.cv()

# Inference

In [None]:
load_path = "ENb4_ori_cls_new.bin"

model = load_EfficientNet()
# load params
model_state = torch.load(load_path)
# A basic remapping is required
mapping = {
    k: v for k, v in zip(model_state.keys(), model.state_dict().keys())
}
mapped_model_state = OrderedDict([
    (mapping[k], v) for k, v in model_state.items()
])
model.load_state_dict(mapped_model_state, strict=False)

model = model.to(device)

model.eval()

In [None]:
class TestDataset(Dataset):
    def __init__(self, csv_file, transform = transform):
        self.data = pd.read_csv(csv_file)
        self.transform = transform
            
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name =  os.path.join('../input/aptos2019-blindness-detection/test_images',
                                 self.data.loc[idx, 'id_code'] + '.png')
        img = Image.open(img_name)
        # image preprocessing
        img = self.transform(img)
        
        return img

In [None]:
test_dataset = TestDataset(csv_file='../input/aptos2019-blindness-detection/sample_submission.csv',
                           transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32,num_workers=8,shuffle=False)

In [None]:
PLOT = False
if PLOT:
    batch = next(iter(test_loader))

    plt.figure(figsize=(16, 8))
    plt.axis("off")
    plt.title("Test Images")
    _ = plt.imshow( 
        vutils.make_grid(batch[:16], padding=2, normalize=True).cpu().numpy().transpose((1, 2, 0))
    )

In [None]:
def test_prediction(output):
    # calc softmax for submission
    pred = F.softmax(output, dim=-1)
    _,pred = torch.max(pred,1)
    
    return pred

In [None]:
test_bs = 32
test_preds = np.zeros((len(test_dataset), 1))
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_bs, shuffle=False, num_workers=8)

for i, x_batch in tqdm(enumerate(tqdm(test_data_loader))):
    x_batch = x_batch.to(device)
    with torch.no_grad():
        output1 = model(x_batch).detach().cpu()
        reg1 = output1[:,-1]
        cls1 = test_prediction(output1[:,:-1]).float()
        #Horizontal Flip
        output2 = model(x_batch.flip(1)).detach().cpu()
        reg2 = output2[:,-1]
        cls2 = test_prediction(output2[:,:-1]).float()
        #Vertical Flip
        output3 = model(x_batch.flip(2)).detach().cpu()
        reg3 = output3[:,-1]
        cls3 = test_prediction(output3[:,:-1]).float()
        
        pred_reg = (reg1 + reg2 + reg3) / 3.
        pred = (cls1 + cls2 + cls3 + pred_reg) / 4.
        
    test_preds[i * test_bs:(i + 1) * test_bs] += pred.squeeze().numpy().ravel().reshape(-1, 1)

In [None]:
del model 
torch.cuda.empty_cache()
gc.collect()

In [None]:
coef = [0.5, 1.5, 2.5, 3.5]

for i, pred in enumerate(test_preds):
    if pred < coef[0]:
        test_preds[i] = 0
    elif pred >= coef[0] and pred < coef[1]:
        test_preds[i] = 1
    elif pred >= coef[1] and pred < coef[2]:
        test_preds[i] = 2
    elif pred >= coef[2] and pred < coef[3]:
        test_preds[i] = 3
    else:
        test_preds[i] = 4

sample = pd.read_csv("../input/aptos2019-blindness-detection/sample_submission.csv")
sample.diagnosis = test_preds.astype(int)
sample.to_csv("submission.csv", index=False)

In [None]:
sample["diagnosis"].value_counts()