<a id="toc"></a>
# Table of Contents
1. [Install MLComp library (offline version)](#install_mlcomp_library)
1. [Import required libraries](#import_required_libraries)
1. [Load models](#load_models)
  1. [Load segmentation models](#load_segmentation_models)
  1. [Load classification models](#load_classification_models)
1. [Define models' mean aggregator](#define_models_mean_aggregator)
1. [Create TTA transforms, datasets, loaders](#create_tta_etc)
1. [Save predictions](#save_predictions)
1. [Draw histogram of predictions](#draw_histogram_of_predictions)

<a id="install_mlcomp_library"></a>
# Install MLComp library (offline version)
[Back to Table of Contents](#toc)

As the competition does not allow commit with the kernel that uses internet connection, we use offline installation

In [None]:
!python ../input/mlcomp/mlcomp/mlcomp/setup.py

In [None]:
!mkdir -p /tmp/pip/cache/
!cp ../input/segmentation-models-zip-003/efficientnet_pytorch-0.4.0.xyz /tmp/pip/cache/efficientnet_pytorch-0.4.0.tar.gz
!cp ../input/segmentation-models-zip-003/pretrainedmodels-0.7.4.xyz /tmp/pip/cache/pretrainedmodels-0.7.4.tar.gz
!cp ../input/segmentation-models-zip-003/segmentation_models_pytorch-0.0.3.xyz /tmp/pip/cache/segmentation_models_pytorch-0.0.3.tar.gz

In [None]:
!pip install --no-index --find-links /tmp/pip/cache/ efficientnet-pytorch
!pip install --no-index --find-links /tmp/pip/cache/ segmentation-models-pytorch

In [None]:
!mkdir -p /tmp/.cache/torch/checkpoints/
!cp ../input/efficientnet-pytorch-b0-b7/efficientnet-b0-355c32eb.pth /tmp/.cache/torch/checkpoints/
!cp ../input/efficientnet-pytorch-b0-b7/efficientnet-b4-6ed6700e.pth /tmp/.cache/torch/checkpoints/
!cp ../input/efficientnet-pytorch-b0-b7/efficientnet-b5-b6417697.pth /tmp/.cache/torch/checkpoints/
!cp ../input/efficientnet-pytorch-b0-b7/efficientnet-b7-dcc49843.pth /tmp/.cache/torch/checkpoints/

In [None]:
!pip install ../input/pretrainedmodels/pretrainedmodels-0.7.4/pretrainedmodels-0.7.4/ > /dev/null
package_path = '../input/senetunetmodelcode'

import sys
sys.path.append(package_path)

<a id="import_required_libraries"></a>
# Import required libraries
[Back to Table of Contents](#toc)

In [None]:
import os

import warnings
warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
import numpy as np
import cv2
import albumentations as A
import pandas as pd
from tqdm import tqdm_notebook

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.jit import load
import torch.utils.data as data

from mlcomp.contrib.transform.albumentations import ChannelTranspose
from mlcomp.contrib.dataset.classify import ImageDataset
from mlcomp.contrib.transform.rle import rle2mask, mask2rle
from mlcomp.contrib.transform.tta import TtaWrap

import segmentation_models_pytorch as smp
from senet_unet_model_code import Unet
from efficientnet_pytorch import EfficientNet

<a id="load_models"></a>
# Load models
[Back to Table of Contents](#toc)

<a id="load_segmentation_models"></a>
## Load segmentation models
[Back to Table of Contents](#toc)

In [None]:
unet_efficientnet = smp.Unet("efficientnet-b5", encoder_weights="imagenet", classes=4, activation="sigmoid").cuda()
ckpt_path = "../input/unet-efficientnet-baseline/unet_efficientnet_model.pth"
device = torch.device("cuda")
unet_efficientnet.to(device)
unet_efficientnet.eval()
state = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
unet_efficientnet.load_state_dict(state["state_dict"])

In [None]:
ckpt_path = "../input/senetmodels/senext50_30_epochs_high_threshold.pth"
device = torch.device("cuda")
senext50_30_epochs_high_threshold = Unet('se_resnext50_32x4d', encoder_weights=None, classes=4, activation=None).cuda()
senext50_30_epochs_high_threshold.to(device)
senext50_30_epochs_high_threshold.eval()
state = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
senext50_30_epochs_high_threshold.load_state_dict(state["state_dict"])

In [None]:
unet_se_resnext50_32x4d = \
    load('/kaggle/input/severstalmodels/unet_se_resnext50_32x4d.pth').cuda()
unet_mobilenet2 = load('/kaggle/input/severstalmodels/unet_mobilenet2.pth').cuda()
unet_resnet34 = load('/kaggle/input/severstalmodels/unet_resnet34.pth').cuda()

<a id="load_classification_models"></a>
## Load classification models
[Back to Table of Contents](#toc)

In [None]:
clf_resnet34 = load('/kaggle/input/severstalmodels/resnet34_classify.pth').cuda()

In [None]:
ckpt_path = "/kaggle/input/pytorch-multi-label-classification/efficientnet_b0_model.pth"
device = torch.device("cuda")
clf_efficientnet_b0 = EfficientNet.from_pretrained('efficientnet-b0', num_classes=4).cuda()
clf_efficientnet_b0.to(device)
clf_efficientnet_b0.eval()
state = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
clf_efficientnet_b0.load_state_dict(state["state_dict"])

In [None]:
ckpt_path = "/kaggle/input/clf-efficientnet-b4/efficientnet_b4_model.pth"
device = torch.device("cuda")
clf_efficientnet_b4 = EfficientNet.from_pretrained('efficientnet-b4', num_classes = 4).cuda()
clf_efficientnet_b4.to(device)
clf_efficientnet_b4.eval()
state = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
clf_efficientnet_b4.load_state_dict(state["state_dict"])

In [None]:
ckpt_path = "/kaggle/input/pytorch-multi-label-classification-effnet-b5/efficientnet_b5_model.pth"
device = torch.device("cuda")
clf_efficientnet_b5 = EfficientNet.from_pretrained('efficientnet-b5', num_classes=4).cuda()
clf_efficientnet_b5.to(device)
clf_efficientnet_b5.eval()
state = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
clf_efficientnet_b5.load_state_dict(state["state_dict"])

In [None]:
ckpt_path = "/kaggle/input/pytorch-multi-label-classification-effnet-b7/efficientnet_b7_model.pth"
device = torch.device("cuda")
clf_efficientnet_b7 = EfficientNet.from_pretrained('efficientnet-b7', num_classes=4).cuda()
clf_efficientnet_b7.to(device)
clf_efficientnet_b7.eval()
state = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
clf_efficientnet_b7.load_state_dict(state["state_dict"])

<a id="define_models_mean_aggregator"></a>
# Define models' mean aggregator
[Back to Table of Contents](#toc)

In [None]:
class Model:
    def __init__(self, models):
        self.models = models
    
    def __call__(self, x):
        res = []
        x = x.cuda()
        with torch.no_grad():
            for m in self.models:
                res.append(m(x))
        res = torch.stack(res)
        return torch.mean(res, dim=0)

model = Model([
    unet_se_resnext50_32x4d,
    unet_mobilenet2,
    unet_resnet34,
    unet_efficientnet,
    senext50_30_epochs_high_threshold
])
model_clf = Model([
    clf_efficientnet_b0,
    clf_efficientnet_b4,
    clf_efficientnet_b5,
    clf_efficientnet_b7,
    clf_resnet34
])

<a id="create_tta_etc"></a>
# Create TTA transforms, datasets, loaders
[Back to Table of Contents](#toc)

In [None]:
def create_transforms(additional):
    res = list(additional)
    # add necessary transformations
    res.extend([
        A.Normalize(
            mean=(0.485, 0.456, 0.406), std=(0.230, 0.225, 0.223)
        ),
        ChannelTranspose()
    ])
    res = A.Compose(res)
    return res

img_folder = '/kaggle/input/severstal-steel-defect-detection/test_images'
batch_size = 2
num_workers = 0

# Different transforms for TTA wrapper
transforms = [
    [],
    [A.HorizontalFlip(p=1)]
]

transforms = [create_transforms(t) for t in transforms]
datasets = [TtaWrap(ImageDataset(img_folder=img_folder, transforms=t), tfms=t) for t in transforms]
loaders = [DataLoader(d, num_workers=num_workers, batch_size=batch_size, shuffle=False) for d in datasets]

In [None]:
def close(mask):
    """
    Parameters:
        mask: Input mask.

    Returns:
        closing: Output mask.
    """

    kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(5,5))
    closing = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    
    return closing

### Loaders' mean aggregator

In [None]:
thresholds = [0.5, 0.75, 0.5, 0.75]
clf_thresholds = [0.5, 0.5, 0.75, 0.85]
min_area = [600, 1000, 1200, 2000]

res = []
# Iterate over all TTA loaders
total = len(datasets[0])//batch_size
for loaders_batch in tqdm_notebook(zip(*loaders), total=total):
    preds = []
    image_file = []
    features_no_tta = None
    for i, batch in enumerate(loaders_batch):
        features = batch['features'].cuda()
        p = torch.sigmoid(model(features))
        # inverse operations for TTA
        p = datasets[i].inverse(p)
        preds.append(p)
        image_file = batch['image_file']
        if i == 0:
            features_no_tta = features
    
    # TTA mean
    preds = torch.stack(preds)
    preds = torch.mean(preds, dim=0)
    preds = preds.detach().cpu().numpy()
    
    clf_preds = torch.sigmoid(model_clf(features_no_tta)).detach().cpu().numpy()
    
    # Batch post processing
    for p, clf_pred, file in zip(preds, clf_preds, image_file):
        file = os.path.basename(file)
        # Image postprocessing
        for i in range(4):
            p_channel = p[i]
            imageid_classid = file+'_'+str(i+1)
            p_channel = (p_channel>thresholds[i]).astype(np.uint8)
            if p_channel.sum() < min_area[i]:
                p_channel = np.zeros(p_channel.shape, dtype=p_channel.dtype)
                
            # Remove false positives with classifier
            if clf_pred[i] <= clf_thresholds[i]:
                p_channel = np.zeros(p_channel.shape, dtype=p_channel.dtype)
            else:
                p_channel = close(p_channel)   
                
            res.append({
                'ImageId_ClassId': imageid_classid,
                'EncodedPixels': mask2rle(p_channel)
            })

<a id="save_predictions"></a>
# Save predictions
[Back to Table of Contents](#toc)

In [None]:
df = pd.DataFrame(res)
df = df.fillna('')
df.to_csv('submission.csv', index=False)

<a id="draw_histogram_of_predictions"></a>
# Draw histogram of predictions
[Back to Table of Contents](#toc)

In [None]:
df['Image'] = df['ImageId_ClassId'].map(lambda x: x.split('_')[0])
df['Class'] = df['ImageId_ClassId'].map(lambda x: x.split('_')[1])
df['empty'] = df['EncodedPixels'].map(lambda x: not x)
df[df['empty'] == False]['Class'].value_counts()

### Visualization

In [None]:
%matplotlib inline

df = pd.read_csv('submission.csv')[:40]
df['Image'] = df['ImageId_ClassId'].map(lambda x: x.split('_')[0])
df['Class'] = df['ImageId_ClassId'].map(lambda x: x.split('_')[1])

for row in df.itertuples():
    img_path = os.path.join(img_folder, row.Image)
    img = cv2.imread(img_path)
    mask = rle2mask(row.EncodedPixels, (1600, 256)) \
        if isinstance(row.EncodedPixels, str) else np.zeros((256, 1600))
    if mask.sum() == 0:
        continue
    
    fig, axes = plt.subplots(1, 2, figsize=(20, 60))
    axes[0].imshow(img/255)
    axes[1].imshow(mask*60)
    axes[0].set_title(row.Image)
    axes[1].set_title(row.Class)
    plt.show()