In [1]:
import pandas as pd
import numpy as np
import json
import gc
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
import pytorch_lightning as pl
import albumentations
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from albumentations.core.composition import Compose
from albumentations.pytorch import ToTensorV2
import torchvision
from typing import Any, Dict, List, Union, Optional
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from sklearn.model_selection import StratifiedKFold
import cv2
import torch
from efficientnet_pytorch import EfficientNet
import numpy as np
import timm
from PIL import Image
from PIL import ImageFile
from collections import OrderedDict


In [2]:
import pytorch_lightning
pytorch_lightning.__version__

'1.1.0'

In [3]:
base_dir = r'C:\Users\Kaggle\Leaf_Classification\input'

In [4]:
### read csv and label mapping

labels = pd.read_csv(os.path.join(base_dir,'train.csv'))

with open(f'{base_dir}/label_num_to_disease_map.json') as f:
    label_mapper = json.load(f)

In [5]:
weights = (labels['label'].value_counts(normalize=True,ascending=False)).to_dict()

In [6]:
weights

{3: 0.6149460204701593,
 4: 0.12043744450156564,
 2: 0.11151095948030097,
 1: 0.10230406131700706,
 0: 0.05080151423096696}

#### Dataset definition

In [7]:
class ImageClassificationDataset(Dataset):
    def __init__(
        self,
        image_names: List,
        transforms: Compose,
        labels: Optional[List[int]],
        img_path: str = '',
        mode: str = 'train',
        labels_to_ohe: bool = False,
        n_classes: int = 5,
    ):
        """
        Image classification dataset.

        Args:
            df: dataframe with image id and bboxes
            mode: train/val/test
            img_path: path to images
            transforms: albumentations
        """

        self.mode = mode
        self.transforms = transforms
        self.img_path = img_path
        self.image_names = image_names
        if labels is not None:
            if not labels_to_ohe:
                self.labels = np.array(labels)
            else:
                self.labels = np.zeros((len(labels), n_classes))
                self.labels[np.arange(len(labels)), np.array(labels)] = 1

    def __getitem__(self, idx: int) -> Dict[str, np.array]:
        image_path = self.img_path + self.image_names[idx]
        image = cv2.imread(f'{image_path}', cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if image is None:
            raise FileNotFoundError(image_path)
        target = self.labels[idx]

        img = self.transforms(image=image)['image']
        sample = {'image_path': image_path, 'image': img, 'target': np.array(target).astype('int64')}

        return sample

    def __len__(self) -> int:
        return len(self.image_names)


### Define augmentations

In [8]:
sz = 512
sz

512

In [9]:
# augmentations taken from: https://www.kaggle.com/khyeh0719/pytorch-efficientnet-baseline-train-amp-aug
train_augs = albumentations.Compose([
            albumentations.RandomResizedCrop(sz, sz),
            albumentations.Transpose(p=0.5),
            albumentations.HorizontalFlip(p=0.5),
            albumentations.VerticalFlip(p=0.5),
            albumentations.ShiftScaleRotate(p=0.5),
            albumentations.HueSaturationValue(
                hue_shift_limit=0.2, 
                sat_shift_limit=0.2, 
                val_shift_limit=0.2, 
                p=0.5
            ),
            albumentations.RandomBrightnessContrast(
                brightness_limit=(-0.1,0.1), 
                contrast_limit=(-0.1, 0.1), 
                p=0.5
            ),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0
            ),
            albumentations.CoarseDropout(p=0.5),
            albumentations.Cutout(p=0.5), 
            ToTensorV2()],
            p=1.)
  
        
valid_augs = albumentations.Compose([
            albumentations.CenterCrop(sz, sz, p=1.),
            albumentations.Resize(sz, sz),
            albumentations.Normalize(
                mean=[0.485, 0.456, 0.406], 
                std=[0.229, 0.224, 0.225], 
                max_pixel_value=255.0, 
                p=1.0)
            ,ToTensorV2()], p=1.)

In [10]:
class conf:
    def __init__(self,size,prob,beta):
        super().__init__()
        self.cropsize = size
        self.prob = prob
        self.beta = beta
        
conf = conf(sz,1,5)
conf.prob

1

### Snapmix

In [11]:
def rand_bbox(size, lam,center=False,attcen=None):
    if len(size) == 4:
        W = size[2]
        H = size[3]
    elif len(size) == 3:
        W = size[1]
        H = size[2]
    elif len(size) == 2:
        W = size[0]
        H = size[1]
    else:
        raise Exception

    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    if attcen is None:
        # uniform
        cx = 0
        cy = 0
        if W>0 and H>0:
            cx = np.random.randint(W)
            cy = np.random.randint(H)
        if center:
            cx = int(W/2)
            cy = int(H/2)
    else:
        cx = attcen[0]
        cy = attcen[1]

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

def get_bbox(imgsize=(224,224),beta=1.0):

    r = np.random.rand(1)
    lam = np.random.beta(beta, beta)
    bbx1, bby1, bbx2, bby2 = rand_bbox(imgsize, lam)

    return [bbx1,bby1,bbx2,bby2]

In [12]:
import torch
import torch.nn as nn
import imp
import numpy as np
import os
import torch.nn.functional as F
import random
import copy


def get_spm(input,target,conf,model,classifier,classifier_weight,classifier_bias):

    imgsize = (conf.cropsize,conf.cropsize)
    bs = input.size(0)
    with torch.no_grad():
        output,fms,_,_,_,_ = model(input,target)
        clsw = classifier
        weight = classifier_weight
        bias = classifier_bias
        weight = weight.view(weight.size(0),weight.size(1),1,1)
        #fms = F.relu(fms)   ### this was already relu'ed from the conv net output
        poolfea = F.adaptive_avg_pool2d(fms,(1,1)).squeeze()
        clslogit = F.softmax(clsw.forward(poolfea))
        #print("logit shape",clslogit.shape)
        logitlist = []
        for i in range(bs):
            logitlist.append(clslogit[i,target[i]])
        clslogit = torch.stack(logitlist)

        out = F.conv2d(fms, weight, bias=bias)

        outmaps = []
        for i in range(bs):
            evimap = out[i,target[i]]
            outmaps.append(evimap)

        outmaps = torch.stack(outmaps)
        if imgsize is not None:
            outmaps = outmaps.view(outmaps.size(0),1,outmaps.size(1),outmaps.size(2))
            outmaps = F.interpolate(outmaps,imgsize,mode='bilinear',align_corners=False)

        outmaps = outmaps.squeeze()

        for i in range(bs):
            outmaps[i] -= outmaps[i].min()
            outmaps[i] /= outmaps[i].sum()
        
        del clsw,weight,bias,poolfea,out,evimap,logitlist,target,input
        gc.collect()

    
    return outmaps,clslogit



def snapmix(input,target,conf,model,classifier,classifier_weight,classifier_bias):

    r = np.random.rand(1)
    lam_a = torch.ones(input.size(0))
    lam_b = 1 - lam_a
    target_b = target.clone()

    if r < conf.prob:
        wfmaps,_ = get_spm(input,target,conf,model,classifier,classifier_weight,classifier_bias)
        bs = input.size(0)
        lam = np.random.beta(conf.beta, conf.beta)
        lam1 = np.random.beta(conf.beta, conf.beta)
        rand_index = torch.randperm(bs).cuda()
        wfmaps_b = wfmaps[rand_index,:,:]
        target_b = target[rand_index]

        same_label = target == target_b
        bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam)
        bbx1_1, bby1_1, bbx2_1, bby2_1 = rand_bbox(input.size(), lam1)

        area = (bby2-bby1)*(bbx2-bbx1)
        area1 = (bby2_1-bby1_1)*(bbx2_1-bbx1_1)

        if  area1 > 0 and  area>0:
            ncont = input[rand_index, :, bbx1_1:bbx2_1, bby1_1:bby2_1].clone()
            ncont = F.interpolate(ncont, size=(bbx2-bbx1,bby2-bby1), mode='bilinear', align_corners=True)
            input[:, :, bbx1:bbx2, bby1:bby2] = ncont
            lam_a = 1 - wfmaps[:,bbx1:bbx2,bby1:bby2].sum(2).sum(1)/(wfmaps.sum(2).sum(1)+1e-8)
            lam_b = wfmaps_b[:,bbx1_1:bbx2_1,bby1_1:bby2_1].sum(2).sum(1)/(wfmaps_b.sum(2).sum(1)+1e-8)
            tmp = lam_a.clone()
            lam_a[same_label] += lam_b[same_label]
            lam_b[same_label] += tmp[same_label]
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2]))
            lam_a[torch.isnan(lam_a)] = lam
            lam_b[torch.isnan(lam_b)] = 1-lam

    return input,target,target_b,lam_a,lam_b


def as_cutmix(input,target,conf,model=None):

    r = np.random.rand(1)
    lam_a = torch.ones(input.size(0))
    lam_b = 1 - lam_a
    target_b = target.clone()

    if r < conf.prob:
        bs = input.size(0)
        lam = np.random.beta(conf.beta, conf.beta)
        rand_index = torch.randperm(bs).cuda()
        target_b = target[rand_index]

        bbx1, bby1, bbx2, bby2 = rand_bbox(input.size(), lam)
        bbx1_1, bby1_1, bbx2_1, bby2_1 = rand_bbox(input.size(), lam)

        if (bby2_1-bby1_1)*(bbx2_1-bbx1_1) > 4 and  (bby2-bby1)*(bbx2-bbx1)>4:
            ncont = input[rand_index, :, bbx1_1:bbx2_1, bby1_1:bby2_1].clone()
            ncont = F.interpolate(ncont, size=(bbx2-bbx1,bby2-bby1), mode='bilinear', align_corners=True)
            input[:, :, bbx1:bbx2, bby1:bby2] = ncont
            # adjust lambda to exactly match pixel ratio
            lam_a = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (input.size()[-1] * input.size()[-2]))
            lam_a *= torch.ones(input.size(0))
    lam_b = 1 - lam_a

    return input,target,target_b,lam_a.cuda(),lam_b.cuda()

### PL Data Module - https://www.kaggle.com/artgor/cassava-disease-identification-with-lightning

In [13]:
class CassavaDataModule(pl.LightningDataModule):
    def __init__(self,
                 df,
                 train_augs,
                 valid_augs,
                 path,
                bs=8,
                fold=0):
        super().__init__()
        self.df = df
        self.train_augs = train_augs
        self.valid_augs = valid_augs
        self.path = path
        self.bs = bs
        self.fold = fold

    def prepare_data(self):
        pass

    def setup(self, stage=None):
        
        folds = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
        
        train_indexes, valid_indexes = list(folds.split(self.df, self.df['label']))[self.fold]
        
        train_df = self.df.iloc[train_indexes]
        valid_df = self.df.iloc[valid_indexes]

        
        self.train_dataset = ImageClassificationDataset(image_names=train_df['image_id'].values,
                                                        transforms=train_augs,
                                                        labels=train_df['label'].values,
                                                        img_path=self.path,
                                                        mode='train',
                                                        labels_to_ohe=False,
                                                        n_classes=5)
        self.valid_dataset = ImageClassificationDataset(image_names=valid_df['image_id'].values,
                                                        transforms=valid_augs,
                                                        labels=valid_df['label'].values,
                                                        img_path=self.path,
                                                        mode='valid',
                                                        labels_to_ohe=False,
                                                        n_classes=5)

    def train_dataloader(self):
        train_loader = torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.bs,
            num_workers=0,
            shuffle=True,
        )
        return train_loader

    def val_dataloader(self):
        valid_loader = torch.utils.data.DataLoader(
            self.valid_dataset,
            batch_size=self.bs,
            num_workers=0,
            shuffle=False,
        )

        return valid_loader

    def test_dataloader(self):
        return None

### Define Net - https://www.kaggle.com/artgor/cassava-disease-identification-with-lightning

In [14]:
timm.list_models()

['adv_inception_v3',
 'cspdarknet53',
 'cspdarknet53_iabn',
 'cspresnet50',
 'cspresnet50d',
 'cspresnet50w',
 'cspresnext50',
 'cspresnext50_iabn',
 'darknet53',
 'densenet121',
 'densenet121d',
 'densenet161',
 'densenet169',
 'densenet201',
 'densenet264',
 'densenet264d_iabn',
 'densenetblur121d',
 'dla34',
 'dla46_c',
 'dla46x_c',
 'dla60',
 'dla60_res2net',
 'dla60_res2next',
 'dla60x',
 'dla60x_c',
 'dla102',
 'dla102x',
 'dla102x2',
 'dla169',
 'dpn68',
 'dpn68b',
 'dpn92',
 'dpn98',
 'dpn107',
 'dpn131',
 'eca_vovnet39b',
 'ecaresnet18',
 'ecaresnet50',
 'ecaresnet50d',
 'ecaresnet50d_pruned',
 'ecaresnet101d',
 'ecaresnet101d_pruned',
 'ecaresnetlight',
 'ecaresnext26tn_32x4d',
 'efficientnet_b0',
 'efficientnet_b1',
 'efficientnet_b1_pruned',
 'efficientnet_b2',
 'efficientnet_b2_pruned',
 'efficientnet_b2a',
 'efficientnet_b3',
 'efficientnet_b3_pruned',
 'efficientnet_b3a',
 'efficientnet_b4',
 'efficientnet_b5',
 'efficientnet_b6',
 'efficientnet_b7',
 'efficientnet_b8',


In [15]:
class Net_upd(nn.Module):
    def __init__(self) -> None:
        """
        Model class.

        Args:
            cfg: main config
        """
        super().__init__()
        self.net = timm.create_model('resnet50', pretrained=True)
        output_dimension = list(self.net.children())[-1].in_features
        self.net.fc = nn.Linear(output_dimension , 5)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x, targets):
        lastConv = nn.Sequential(*list(self.net.children())[:-2])(x)
#         print("last conv shape",lastConv.shape)
        logits = self.net(x)
        loss = self.loss(logits, targets).view(1)
        return logits,lastConv,loss,self.net.fc,self.net.fc.weight.data.detach(),self.net.fc.bias.data.detach()

In [18]:
class LitCassava(pl.LightningModule):
    def __init__(self, model):
        super(LitCassava, self).__init__()
        self.model = model
        self.metric = pl.metrics.Accuracy()
        self.learning_rate = 1e-4

    def forward(self, x, targets, *args, **kwargs):
        return self.model(x, targets)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=0.001)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=2)
        #scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0 = 100,T_mult=2,eta_min=1e-6)

        return (
            [optimizer],
            [{'scheduler': scheduler, 'interval': 'epoch', 'monitor': 'valid_loss'}],
        )

    def training_step(
        self, batch: torch.Tensor, batch_idx: int
    ) -> Union[int, Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]]:
        image = batch['image']
        target = batch['target']
        logits,lastConv,loss,classifier,classifier_weight,classifier_bias = self(image, target)
        
#         print("weight",classifier_weight)
#         print("--------------------------------------------------------------------------------------")
#         print("bias",classifier_bias)
#         print("--------------------------------------------------------------------------------------")
        ### snapmix
        const = np.random.randint(10)
        if const<3:
            # snapmix
            #print("Snapmix envoked")
            mixed_input,target_a,target_b,lam_a,lam_b = snapmix(image,target,conf,model=self.model,
                                                                classifier=classifier,classifier_weight=classifier_weight,
                                                                classifier_bias=classifier_bias)
            _,_,loss_a,_,_,_ = self(mixed_input,target_a)
            _,_,loss_b,_,_,_ = self(mixed_input,target_b)
            loss = torch.mean(loss_a* lam_a + loss_b* lam_b)
            del mixed_input,target_a,target_b,lam_a,lam_b
            
        try:
            loss = loss[0]
            #rint("Is list",loss)
        except:
            pass
            #print("Not list",loss)
            #print("Snapmix loss")
        
        score = self.metric(logits.argmax(1), target)
        
        del lastConv,classifier_weight,classifier_bias
        gc.collect()
        
        logs = {'train_loss': loss, f'train_accuracy': score}

        return {
            'loss': loss,
            'log': logs,
            'progress_bar': logs,
            'logits': logits,
            'target': target,
            f'train_accuracy': score,
        }


    def training_epoch_end(self, outputs):
        #rint("Output shape:",outputs)
        #print([x['loss'] for x in outputs])
        _ = ([x['loss'] for x in outputs])
        avg_loss = torch.stack(_).mean()
        del _; gc.collect()
        avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        y_true = torch.cat([x['target'] for x in outputs])
        y_pred = torch.cat([x['logits'] for x in outputs])
        score = self.metric(y_pred.argmax(1), y_true)
        
        logs = {'train_loss': avg_loss, 'train_accuracy': score}
        return {'log': logs, 'progress_bar': logs}

    def validation_step(
        self, batch: torch.Tensor, batch_idx: int
    ) -> Union[int, Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]]:
        image = batch['image']
        target = batch['target']
        logits,lastConv, loss,_,_,_ = self(image, target)
        score = self.metric(logits.argmax(1), target)
        logs = {'valid_loss': loss, f'valid_accuracy': score}
        
        if isinstance(loss, list):
            #print("Here",loss)
            loss = loss[0]
                
        return {
            'loss': loss,
            'log': logs,
            'progress_bar': logs,
            'logits': logits,
            'target': target,
            f'valid_accuracy': score,
        }

    def validation_epoch_end(self, outputs):
        _ = ([x['loss'] for x in outputs])
        avg_loss = torch.stack(_).mean()
        del _; gc.collect()
        #avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        y_true = torch.cat([x['target'] for x in outputs])
        y_pred = torch.cat([x['logits'] for x in outputs])
        score = self.metric(y_pred.argmax(1), y_true)

        # score = torch.tensor(1.0, device=self.device)
        logs = {'valid_loss': avg_loss, f'valid_accuracy': score, 'accuracy': score}
        return {'valid_loss': avg_loss, 'log': logs, 'progress_bar': logs}

In [None]:
import time
for f in range(5):
    print(f"Running Fold:{f}")
    model = Net_upd()
    dm = CassavaDataModule(labels, train_augs, valid_augs, f'{base_dir}/train_images/',bs=8,fold=f)
    
    modelSavePath = f'C:/Users/Kaggle/Leaf_Classification/saved_models/{f}'
    if not os.path.exists(modelSavePath):
        os.makedirs(modelSavePath)
        
    trainer = pl.Trainer(
        checkpoint_callback=ModelCheckpoint(monitor='valid_accuracy',
                                            save_top_k=2, dirpath = modelSavePath,filename='{epoch}_{valid_loss:.4f}_{valid_accuracy:.4f}', mode='min'),
        accumulate_grad_batches = 1,
        gpus=1,
        max_epochs=20,
        num_sanity_val_steps=0,
        weights_summary='top',
#         precision = 16,
#         amp_backend = 'native',
        callbacks = [EarlyStopping(monitor='valid_accuracy', patience=4, mode='min')]
    )
    
    lit_model = LitCassava(model)
                              
    trainer.fit(lit_model, dm)

    del trainer,model,lit_model
    gc.collect()
    time.sleep(60*5)

Running Fold:0


GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type     | Params
------------------------------------
0 | model  | Net_upd  | 23.5 M
1 | metric | Accuracy | 0     
------------------------------------
23.5 M    Trainable params
0         Non-trainable params
23.5 M    Total params


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

Please use self.log(...) inside the lightningModule instead.

# log on a step or aggregate epoch metric to the logger and/or progress bar
# (inside LightningModule)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
Please use self.log(...) inside the lightningModule instead.

# log on a step or aggregate epoch metric to the logger and/or progress bar
# (inside LightningModule)
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
