In [None]:
import os
import sys
import yaml
import random
import numpy as np
import pandas as pd
from pathlib import Path
import skimage.io
import cv2
import glob

from sklearn.metrics import cohen_kappa_score

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from torch.nn.modules.batchnorm import _BatchNorm
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
import albumentations as album

sys.path.append('/kaggle/input/efficientnet-pytorch/EfficientNet-PyTorch/EfficientNet-PyTorch-master')
from efficientnet_pytorch import model as enet

## Utils

In [None]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
class EasyDict(dict):
    """
    Get attributes
    >>> d = EasyDict({'foo':3})
    >>> d['foo']
    3
    >>> d.foo
    3
    >>> d.bar
    Traceback (most recent call last):
    ...
    AttributeError: 'EasyDict' object has no attribute 'bar'
    Works recursively
    >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
    >>> isinstance(d.bar, dict)
    True
    >>> d.bar.x
    1
    Bullet-proof
    >>> EasyDict({})
    {}
    >>> EasyDict(d={})
    {}
    >>> EasyDict(None)
    {}
    >>> d = {'a': 1}
    >>> EasyDict(**d)
    {'a': 1}
    Set attributes
    >>> d = EasyDict()
    >>> d.foo = 3
    >>> d.foo
    3
    >>> d.bar = {'prop': 'value'}
    >>> d.bar.prop
    'value'
    >>> d
    {'foo': 3, 'bar': {'prop': 'value'}}
    >>> d.bar.prop = 'newer'
    >>> d.bar.prop
    'newer'
    Values extraction
    >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]})
    >>> isinstance(d.bar, list)
    True
    >>> from operator import attrgetter
    >>> map(attrgetter('x'), d.bar)
    [1, 3]
    >>> map(attrgetter('y'), d.bar)
    [2, 4]
    >>> d = EasyDict()
    >>> d.keys()
    []
    >>> d = EasyDict(foo=3, bar=dict(x=1, y=2))
    >>> d.foo
    3
    >>> d.bar.x
    1
    Still like a dict though
    >>> o = EasyDict({'clean':True})
    >>> o.items()
    [('clean', True)]
    And like a class
    >>> class Flower(EasyDict):
    ...     power = 1
    ...
    >>> f = Flower()
    >>> f.power
    1
    >>> f = Flower({'height': 12})
    >>> f.height
    12
    >>> f['power']
    1
    >>> sorted(f.keys())
    ['height', 'power']
    update and pop items
    >>> d = EasyDict(a=1, b='2')
    >>> e = EasyDict(c=3.0, a=9.0)
    >>> d.update(e)
    >>> d.c
    3.0
    >>> d['c']
    3.0
    >>> d.get('c')
    3.0
    >>> d.update(a=4, b=4)
    >>> d.b
    4
    >>> d.pop('a')
    4
    >>> d.a
    Traceback (most recent call last):
    ...
    AttributeError: 'EasyDict' object has no attribute 'a'
    """
    def __init__(self, d=None, **kwargs):
        if d is None:
            d = {}
        if kwargs:
            d.update(**kwargs)
        for k, v in d.items():
            setattr(self, k, v)
        # Class attributes
        for k in self.__class__.__dict__.keys():
            if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'):
                setattr(self, k, getattr(self, k))

    def __setattr__(self, name, value):
        if isinstance(value, (list, tuple)):
            value = [self.__class__(x)
                     if isinstance(x, dict) else x for x in value]
        elif isinstance(value, dict) and not isinstance(value, self.__class__):
            value = self.__class__(value)
        super(EasyDict, self).__setattr__(name, value)
        super(EasyDict, self).__setitem__(name, value)

    __setitem__ = __setattr__

    def update(self, e=None, **f):
        d = e or dict()
        d.update(f)
        for k in d:
            setattr(self, k, d[k])

    def pop(self, k, d=None):
        delattr(self, k)
        return super(EasyDict, self).pop(k, d)

## Dataset

In [None]:
def get_transforms(cfg):
    def get_object(transform):
        if hasattr(album, transform.name):
            return getattr(album, transform.name)
        else:
            return eval(transform.name)
    if cfg['transforms']:
        transforms = [get_object(transform)(**transform.params) for name, transform in cfg['transforms'].items()]
        return album.Compose(transforms)
    else:
        return None


class CustomDataset(Dataset):
    def __init__(self, df, labels, cfg):
        self.cfg = cfg
        self.image_ids = df['image_id'].values
        self.labels = labels
        self.transforms = get_transforms(self.cfg)
        self.is_train = cfg.is_train
        self.image_path = '/kaggle/test_images'

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image = cv2.imread(f'{self.image_path}/{image_id}.png')
        image = 255 - (image * (255.0/image.max())).astype(np.uint8)
        image = cv2.resize(image, dsize=(self.cfg.img_size.height, self.cfg.img_size.width))
        if self.transforms:
            image = self.transforms(image=image)['image']
        image = image.transpose(2, 0, 1).astype(np.float32)

        if self.is_train:
            label = self.labels[idx]
            return image, label
        else:
            return image

## layer

In [None]:
class AvgPool(nn.Module):
    def forward(self, x):
        return F.avg_pool2d(x, x.shape[2:])


class MaxPool(nn.Module):
    def forward(self, x):
        return F.max_pool2d(x, x.shape[2:])


class AdaptiveConcatPool2d(nn.Module):
    def __init__(self, sz=None):
        super().__init__()
        sz = sz or (1,1)
        self.ap = nn.AdaptiveAvgPool2d(sz)
        self.mp = nn.AdaptiveMaxPool2d(sz)
    def forward(self, x):
        return torch.cat([self.mp(x), self.ap(x)], 1)


# https://www.kaggle.com/c/bengaliai-cv19/discussion/123432
def gem(x, p=3, eps=1e-6):
    return F.avg_pool2d(x.clamp(min=eps).pow(p), (x.size(-2), x.size(-1))).pow(1./p)


class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM,self).__init__()
        self.p = Parameter(torch.ones(1)*p)
        self.eps = eps

    def forward(self, x):
        return gem(x, p=self.p, eps=self.eps)

    def __repr__(self):
        return self.__class__.__name__ + '(' + 'p=' + '{:.4f}'.format(self.p.data.tolist()[0]) + ', ' + 'eps=' + str(self.eps) + ')'
    
    
layer_encoder = {
    'AvgPool': AvgPool,
    'MaxPool': MaxPool,
    'AdaptiveConcatPool2d': AdaptiveConcatPool2d,
    'GeM': GeM,
}

## Model

In [None]:
def _efficientnet(model_name, pretrained):
    if pretrained:
        model = enet.EfficientNet.from_pretrained(model_name, advprop=True)
    else:
        model = enet.EfficientNet.from_name(model_name)
    
    return model


def efficientnet_b0(model_name='efficientnet-b0', pretrained=False):
    return _efficientnet(model_name, pretrained)


def efficientnet_b1(model_name='efficientnet-b1', pretrained=False):
    return _efficientnet(model_name, pretrained)


def efficientnet_b2(model_name='efficientnet-b2', pretrained=False):
    return _efficientnet(model_name, pretrained)


def efficientnet_b3(model_name='efficientnet-b3', pretrained=False):
    return _efficientnet(model_name, pretrained)


def efficientnet_b4(model_name='efficientnet-b4', pretrained=False):
    return _efficientnet(model_name, pretrained)


def efficientnet_b5(model_name='efficientnet-b5', pretrained=False):
    return _efficientnet(model_name, pretrained)


def efficientnet_b6(model_name='efficientnet-b6', pretrained=False):
    return _efficientnet(model_name, pretrained)


def efficientnet_b7(model_name='efficientnet-b7', pretrained=False):
    return _efficientnet(model_name, pretrained)

## factory

In [None]:
model_encoder = {
    'efficientnet_b0': efficientnet_b0,
    'efficientnet_b1': efficientnet_b1,
    'efficientnet_b2': efficientnet_b2,
    'efficientnet_b3': efficientnet_b3,
    'efficientnet_b4': efficientnet_b4,
    'efficientnet_b5': efficientnet_b5,
    'efficientnet_b6': efficientnet_b6,
    'efficientnet_b7': efficientnet_b7,
}


def set_channels(child, cfg):
    if cfg.model.n_channels < 3:
        child_weight = child.weight.data[:, :cfg.model.n_channels, :, :]
    else:
        child_weight = torch.cat([child.weight.data[:, :, :, :], child.weight.data[:, :int(cfg.model.n_channels - 3), :, :]], dim=1)
    setattr(child, 'in_channels', cfg.model.n_channels)

    if cfg.model.pretrained:
        setattr(child.weight, 'data', child_weight)


def replace_channels(model, cfg):
    if cfg.model.name.startswith('densenet'):
        set_channels(model.features[0], cfg)
    elif cfg.model.name.startswith('efficientnet'):
        set_channels(model._conv_stem, cfg)
    elif cfg.model.name.startswith('mobilenet'):
        set_channels(model.features[0][0], cfg)
    elif cfg.model.name.startswith('se_resnext'):
        set_channels(model.layer0.conv1, cfg)
    elif cfg.model.name.startswith('resnet') or cfg.model.name.startswith('resnex') or cfg.model.name.startswith('wide_resnet'):
        set_channels(model.conv1, cfg)
    elif cfg.model.name.startswith('resnest'):
        set_channels(model.conv1[0], cfg)


def replace_fc(model, cfg):
    if cfg.model.metric:
        classes = 1000
    else:
        classes = cfg.model.n_classes

    if cfg.model.name.startswith('densenet'):
        fc_input = getattr(model.classifier, 'in_features')
        model.classifier = nn.Linear(fc_input, classes)
    elif cfg.model.name.startswith('efficientnet'):
        fc_input = getattr(model._fc, 'in_features')
        model._fc = nn.Linear(fc_input, classes)
    elif cfg.model.name.startswith('mobilenet'):
        fc_input = getattr(model.classifier[1], 'in_features')
        model.classifier[1] = nn.Linear(fc_input, classes)
    elif cfg.model.name.startswith('se_resnext'):
        fc_input = getattr(model.last_linear, 'in_features')
        model.last_linear = nn.Linear(fc_input, classes)
    elif cfg.model.name.startswith('resnet') or cfg.model.name.startswith('resnex') or cfg.model.name.startswith('wide_resnet') or cfg.model.name.startswith('resnest'):
        fc_input = getattr(model.fc, 'in_features')
        model.fc = nn.Linear(fc_input, classes)
    return model


def replace_pool(model, cfg):
    avgpool = layer_encoder[cfg.model.avgpool.name](**cfg.model.avgpool.params)
    if cfg.model.name.startswith('efficientnet'):
        model._avg_pooling = avgpool
    elif cfg.model.name.startswith('se_resnext'):
        model.avg_pool = avgpool
    elif cfg.model.name.startswith('resnet') or cfg.model.name.startswith('resnex') or cfg.model.name.startswith('wide_resnet') or cfg.model.name.startswith('resnest'):
        model.avgpool = avgpool
    return model


def get_model(cfg):
    model = model_encoder[cfg.model.name](pretrained=False)
    if cfg.model.n_channels != 3:
        replace_channels(model, cfg)
    model = replace_fc(model, cfg)
    if cfg.model.avgpool:
        model = replace_pool(model, cfg)
    return model


def get_loss(cfg):
    loss = getattr(nn, cfg.loss.name)(**cfg.loss.params)
    return loss


def get_dataloader(df, labels, cfg):
    dataset = CustomDataset(df, labels, cfg)
    loader = DataLoader(dataset, **cfg.loader)
    return loader


def get_optim(cfg, parameters):
    optim = getattr(torch.optim, cfg.optimizer.name)(params=parameters, **cfg.optimizer.params)
    return optim


def get_scheduler(cfg, optimizer):
    if cfg.scheduler.name == 'ReduceLROnPlateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            **cfg.scheduler.params,
        )
    else:
        scheduler = getattr(torch.optim.lr_scheduler, cfg.scheduler.name)(
            optimizer,
            **cfg.scheduler.params,
        )
    return scheduler

## Metrics

In [None]:
def quadratic_weighted_kappa(y_hat, y):
    return cohen_kappa_score(y_hat, y, weights='quadratic')

# resize

In [None]:
def resize(df, data_dir, save_dir):
    if os.path.exists(data_dir):
        os.makedirs(save_dir, exist_ok=True)
        
        for img_id in df['image_id']:
            load_path = f'{data_dir}/{img_id}.tiff'
            save_path = f'{save_dir}/{img_id}.png'
            
            biopsy = skimage.io.MultiImage(load_path)
            img = cv2.resize(biopsy[-1], (512, 512))
            cv2.imwrite(save_path, img)

## submit

In [None]:
def submit(sample, data_dir, log_path):
    if os.path.exists(data_dir):
        print('run inference')
        
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        test_loader = get_dataloader(test_df, labels=None, cfg=cfg.data.test)
        model = get_model(cfg).to(device)
        model.load_state_dict(torch.load(log_path / 'weight_best.pt'))

        all_preds = []
        model.eval()
        for images in test_loader:
            images = Variable(images).to(device)

            preds = model(images.float())
            all_preds.append(preds.cpu().detach().numpy())

        preds_label = np.concatenate(all_preds).argmax(1)
        sample['isup_grade'] = preds_label.astype(int)
    return sample

## Main

In [None]:
root = Path('/kaggle/input/prostate-cancer-grade-assessment/')
log_path = Path(glob.glob('/kaggle/input/sub-*')[0])

test_df = pd.read_csv(root / 'test.csv')
data_dir = '/kaggle/input/prostate-cancer-grade-assessment/test_images'
save_dir = '/kaggle/test_images/'

with open(log_path / 'config.yml', 'r') as yf:
    cfg = EasyDict(yaml.safe_load(yf))

In [None]:
resize(test_df, data_dir, save_dir)

In [None]:
sample_df = pd.read_csv(root / 'sample_submission.csv')
sample_df = submit(sample_df, data_dir, log_path)
sample_df.to_csv('submission.csv', index=False)
sample_df.head()