In [None]:
import pandas as pd
import numpy as np
import random
import os
import cv2
import gc
import warnings
from sklearn.metrics import f1_score
from sklearn.exceptions import UndefinedMetricWarning
import scipy.optimize as opt
from collections import defaultdict, Counter
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet50, resnet34
import math
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from torch.nn import init
import torch

%matplotlib inline

In [None]:
__all__ = ['xception']

xception_url = 'https://www.dropbox.com/s/1hp'
xception_url += 'lpzet9d7dv29/xception-c0a72b38.pth.tar?dl=1'

model_urls = {
    'xception': xception_url
}


class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
                 padding=0, dilation=1, bias=False):
        super(SeparableConv2d, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size, stride,
                               padding, dilation,
                               groups=in_channels, bias=bias)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1,
                                   1, 0, 1, 1, bias=bias)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x


class Block(nn.Module):
    def __init__(self, in_filters, out_filters, reps,
                 strides=1, start_with_relu=True, grow_first=True):
        super(Block, self).__init__()

        if out_filters != in_filters or strides != 1:
            self.skip = nn.Conv2d(in_filters, out_filters,
                                  1, stride=strides, bias=False)
            self.skipbn = nn.BatchNorm2d(out_filters)
        else:
            self.skip = None
        
        self.relu = nn.ReLU(inplace=True)
        rep = []

        filters = in_filters
        if grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters, out_filters,
                                       3, stride=1, padding=1, bias=False))
            rep.append(nn.BatchNorm2d(out_filters))
            filters = out_filters

        for i in range(reps-1):
            rep.append(self.relu)
            rep.append(SeparableConv2d(filters, filters, 3,
                                       stride=1, padding=1, bias=False))
            rep.append(nn.BatchNorm2d(filters))
        
        if not grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters, out_filters, 3,
                                       stride=1, padding=1, bias=False))
            rep.append(nn.BatchNorm2d(out_filters))

        if not start_with_relu:
            rep = rep[1:]
        else:
            rep[0] = nn.ReLU(inplace=False)

        if strides != 1:
            rep.append(nn.MaxPool2d(3, strides, 1))
        self.rep = nn.Sequential(*rep)

    def forward(self, inp):
        x = self.rep(inp)

        if self.skip is not None:
            skip = self.skip(inp)
            skip = self.skipbn(skip)
        else:
            skip = inp

        x += skip
        return x


class Xception(nn.Module):
    """
    Xception optimized for the ImageNet dataset, as specified in
    https://arxiv.org/pdf/1610.02357.pdf
    """
    def __init__(self, num_classes=1000):
        """ Constructor
        Args:
            num_classes: number of classes
        """
        super(Xception, self).__init__()

        self.num_classes = num_classes

        self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
        self.bn2 = nn.BatchNorm2d(64)

        self.block1 = Block(64, 128, 2, 2,
                            start_with_relu=False, grow_first=True)
        self.block2 = Block(128, 256, 2, 2,
                            start_with_relu=True, grow_first=True)
        self.block3 = Block(256, 728, 2, 2,
                            start_with_relu=True, grow_first=True)

        self.block4 = Block(728, 728, 3, 1,
                            start_with_relu=True, grow_first=True)
        self.block5 = Block(728, 728, 3, 1,
                            start_with_relu=True, grow_first=True)
        self.block6 = Block(728, 728, 3, 1,
                            start_with_relu=True, grow_first=True)
        self.block7 = Block(728, 728, 3, 1,
                            start_with_relu=True, grow_first=True)

        self.block8 = Block(728, 728, 3, 1,
                            start_with_relu=True, grow_first=True)
        self.block9 = Block(728, 728, 3, 1,
                            start_with_relu=True, grow_first=True)
        self.block10 = Block(728, 728, 3, 1,
                             start_with_relu=True, grow_first=True)
        self.block11 = Block(728, 728, 3, 1,
                             start_with_relu=True, grow_first=True)

        self.block12 = Block(728, 1024, 2, 2,
                             start_with_relu=True, grow_first=False)

        self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
        self.bn3 = nn.BatchNorm2d(1536)

        self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
        self.bn4 = nn.BatchNorm2d(2048)

        self.fc = nn.Linear(2048, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)
        x = self.block8(x)
        x = self.block9(x)
        x = self.block10(x)
        x = self.block11(x)
        x = self.block12(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.relu(x)

        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def xception(pretrained=False, **kwargs):
    """
    Construct Xception.
    """

    model = Xception(**kwargs)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['xception']))
    return model

In [None]:
TRAIN = '../input/human-protein-atlas-image-classification/train/'
TEST = '../input/human-protein-atlas-image-classification/test/'
LABELS = '../input/human-protein-atlas-image-classification/train.csv'
SUBMIT = '../input/human-protein-atlas-image-classification/'
SUBMIT += 'sample_submission.csv'

In [None]:
name_label_dict = {
    0: 'Nucleoplasm',
    1: 'Nuclear membrane',
    2: 'Nucleoli',   
    3: 'Nucleoli fibrillar center',
    4: 'Nuclear speckles',
    5: 'Nuclear bodies',
    6: 'Endoplasmic reticulum',   
    7: 'Golgi apparatus',
    8: 'Peroxisomes',
    9: 'Endosomes',
    10: 'Lysosomes',
    11: 'Intermediate filaments',
    12: 'Actin filaments',
    13: 'Focal adhesion sites',
    14: 'Microtubules',
    15: 'Microtubule ends',
    16: 'Cytokinetic bridge',
    17: 'Mitotic spindle',
    18: 'Microtubule organizing center',
    19: 'Centrosome',
    20: 'Lipid droplets',
    21: 'Plasma membrane',
    22: 'Cell junctions',
    23: 'Mitochondria',
    24: 'Aggresome',
    25: 'Cytosol',
    26: 'Cytoplasmic bodies',
    27: 'Rods & rings'
}

N_CLASSES = len(name_label_dict)
print('The number of classes: {}'.format(N_CLASSES))

In [None]:
df = pd.read_csv(LABELS)
sub_df = pd.read_csv(SUBMIT)

# Data Research

In [None]:
cls_counts = Counter(cls for classes in df['Target'].str.split() for cls in classes)
counts_x = [i[1] for i in cls_counts.most_common(N_CLASSES)]
counts_y = [name_label_dict[int(i[0])] for i in cls_counts.most_common(N_CLASSES)]
plt.figure(figsize=(8,8))
sns.barplot(y=counts_y, x=counts_x)

## KFold

In [None]:
# Fold
n_folds = 10
fold_cls_counts = defaultdict(int)
folds = [-1] * len(df)
for item in tqdm(df.sample(frac=1, random_state=42).itertuples(),total=len(df)):
    cls = min(item.Target.split(), key=lambda cls: cls_counts[cls])
    fold_counts = [(f, fold_cls_counts[f, cls]) for f in range(n_folds)]
    min_count = min([count for _, count in fold_counts])
    random.seed(item.Index)
    fold = random.choice([f for f, count in fold_counts if count == min_count])
    folds[item.Index] = fold
    for cls in item.Target.split():
        fold_cls_counts[fold, cls] += 1
df['fold'] = folds

In [None]:
valid_idx = 0
train_df = df[df['fold']!=valid_idx][['Id', 'Target']].reset_index(drop=True)
valid_df = df[df['fold']==valid_idx][['Id', 'Target']].reset_index(drop=True)
print('There are {} samples in the training set.'.format(len(train_df)))
print('There are {} samples in the validation set.'.format(len(valid_df)))

In [None]:
def open_rgby(path, id, size=None): # RGBY image
    colors = ['red','green','blue','yellow']
    flags = cv2.IMREAD_GRAYSCALE
    if size is None:
        img = [cv2.imread(os.path.join(path, id+'_'+color+'.png'), flags).astype(np.float32)/255 
               for color in colors]
    else:
        img = []
        for color in colors:
            src_img = cv2.imread(os.path.join(path, id+'_'+color+'.png'), flags)
            tar_img = cv2.resize(src_img, (2*size, size), interpolation=cv2.INTER_CUBIC).astype(np.float32)/255
            img.append(tar_img)

    return np.stack(img, axis=0)

In [None]:
class AtlasDataset(Dataset):
    def __init__(self, df, path, size=None, label=True):        
        self.df = df.copy()
        self.path = path
        self.size = size
        self.label = label
        if self.label:
            self.df['Target'] = [[int(i) for i in s.split()] for s in self.df['Target']] 
        
    def __getitem__(self, index):        
        img = open_rgby(self.path, self.df['Id'].iloc[index], self.size)
        if self.label:
            target = np.eye(N_CLASSES,dtype=np.float)[self.df['Target'].iloc[index]].sum(axis=0) 
        else:
            target = np.zeros(N_CLASSES,dtype=np.int)
        return img, target
    
    def __len__(self):
        return len(self.df)

In [None]:
size= 512
bs = 10 # batch_size

In [None]:
train_loader = DataLoader(AtlasDataset(train_df, TRAIN, size), batch_size=bs, shuffle=True, num_workers=2)
valid_loader = DataLoader(AtlasDataset(valid_df, TRAIN, size), batch_size=bs, shuffle=True, num_workers=2)

In [None]:
test_loader  = DataLoader(AtlasDataset(sub_df, TEST, size, False), batch_size=bs, shuffle=False, num_workers=2)

# MODEL

## resnet50

In [None]:
class AvgPool(nn.Module):
    def forward(self, x):
        return torch.squeeze(F.avg_pool2d(x, x.shape[2:]))

class ResNet50(nn.Module):
    def __init__(self, num_classes, pretrained=True):
        super().__init__()
        encoder = resnet50(pretrained=False)
        if pretrained:
            path="../input/pytorch-pretrained-models/resnet50-19c8e357.pth"
            encoder.load_state_dict(torch.load(path))
        
        self.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False) 
        
        if (pretrained):
            w = encoder.conv1.weight
            self.conv1.weight = nn.Parameter(torch.cat((w,0.5*(w[:,:1,:,:]+w[:,2:,:,:])),dim=1)) 
        
        self.bn1 = encoder.bn1
        self.relu = nn.ReLU(inplace=True) 
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer0 = nn.Sequential(self.conv1,self.relu,self.bn1,self.maxpool)
        self.layer1 = encoder.layer1
        self.layer2 = encoder.layer2
        self.layer3 = encoder.layer3
        self.layer4 = encoder.layer4
        self.avgpool = AvgPool()
        self.fc = nn.Sequential(nn.Dropout(p=0.5), nn.Linear(encoder.fc.in_features, num_classes)) 
        
    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = self.fc(x)        
        return x

## Xception

In [None]:
class myXception(nn.Module):
    
    def __init__(self, num_classes, pretrained=True):
        """ Constructor
        Args:
            num_classes: number of classes
        """
        super().__init__()
        encoder=xception(pretrained=pretrained)
        
        self.num_classes = num_classes

        self.conv1 = nn.Conv2d(4, 32, 3,2, 0, bias=False)
        self.bn1 = encoder.bn1
        self.relu =encoder.relu

        self.conv2 = encoder.conv2
        self.bn2 = encoder.bn2
        #do relu here

        self.block1=encoder.block1
        self.block2=encoder.block2
        self.block3=encoder.block3

        self.block4=encoder.block4
        self.block5=encoder.block5
        self.block6=encoder.block6
        self.block7=encoder.block7

        self.block8=encoder.block8
        self.block9=encoder.block9
        self.block10=encoder.block10
        self.block11=encoder.block11

        self.block12=encoder.block12

        self.conv3 = encoder.conv3
        self.bn3 = encoder.bn3

        #do relu here
        self.conv4 = encoder.conv4
        self.bn4 = encoder.bn4

        self.fc = nn.Linear(2048, num_classes)





    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)
        x = self.block8(x)
        x = self.block9(x)
        x = self.block10(x)
        x = self.block11(x)
        x = self.block12(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        
        x = self.conv4(x)
        x = self.bn4(x)
        x = self.relu(x)

        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

In [None]:
#model = ResNet50(num_classes=N_CLASSES)
model = myXception(num_classes=N_CLASSES)


# Focal Loss

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2):
        super().__init__()
        self.gamma = gamma

    def forward(self, logit, target):
        target = target.float()
        max_val = (-logit).clamp(min=0)
        loss = logit - logit * target + max_val + \
               ((-max_val).exp() + (-logit - max_val).exp()).log()

        invprobs = F.logsigmoid(-logit * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
        if len(loss.size())==2:
            loss = loss.sum(dim=1)
        return loss.mean()


# Training

## prepare

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
os.makedirs('models')

In [None]:
model.load_state_dict(torch.load("../input/xception/xception_400_th40_acc39.pth"))

In [None]:
criterion = FocalLoss()
model = model.to(device)

lr = 0.0002
optimizer = optim.Adam(model.parameters(), lr=lr)

## define training

In [None]:
def train_model(epoch, history=None):
    model.train() 
    t = tqdm(train_loader)
    
    for batch_idx, (img_batch, label_batch) in enumerate(t):
        img_batch = img_batch.to(device)
        label_batch = label_batch.to(device)
        
        optimizer.zero_grad()
        output = model(img_batch)
        loss = criterion(output, label_batch)
        t.set_description(f'train_loss (l={loss:.4f})')
        
        if history is not None:
            history.loc[epoch + batch_idx / len(train_loader), 'train_loss'] = loss.data.cpu().numpy()
        
        loss.backward()    
        optimizer.step()
    
    torch.save(model.state_dict(), 'models/epoch{}.pth'.format(epoch))

In [None]:
def binarize_prediction(probabilities, threshold: float, argsorted=None,
                        min_labels=1, max_labels=10):
    """ Return matrix of 0/1 predictions, same shape as probabilities.
    """
    assert probabilities.shape[1] == N_CLASSES
    if argsorted is None:
        argsorted = probabilities.argsort(axis=1)
    max_mask = _make_mask(argsorted, max_labels)
    min_mask = _make_mask(argsorted, min_labels)
    prob_mask = probabilities > threshold
    return (max_mask & prob_mask) | min_mask

def _make_mask(argsorted, top_n: int):
    mask = np.zeros_like(argsorted, dtype=np.uint8)
    col_indices = argsorted[:, -top_n:].reshape(-1)
    row_indices = [i // top_n for i in range(len(col_indices))]
    mask[row_indices, col_indices] = 1
    return mask

## evaluate threshold

In [None]:
def evaluate(epoch, history=None): 
    model.eval()
    valid_loss = 0.
    all_predictions, all_targets = [], []
    
    with torch.no_grad():
        for batch_idx, (img_batch, label_batch) in enumerate(valid_loader):
            all_targets.append(label_batch.numpy().copy())
            img_batch = img_batch.to(device)
            label_batch = label_batch.to(device)

            output = model(img_batch)
            loss = criterion(output, label_batch)
            valid_loss += loss.data
            predictions = torch.sigmoid(output)
            all_predictions.append(predictions.cpu().numpy())
    all_predictions = np.concatenate(all_predictions)
    all_targets = np.concatenate(all_targets)
    
    valid_loss /= (batch_idx+1)
    
    if history is not None:
        history.loc[epoch, 'valid_loss'] = valid_loss.cpu().numpy()
    
    print('Epoch: {}\tLR: {:.6f}\tValid Loss: {:.4f}'.format(
        epoch, optimizer.state_dict()['param_groups'][0]['lr'], valid_loss))
    
    def get_score(y_pred):
        with warnings.catch_warnings():
            warnings.simplefilter('ignore', category=UndefinedMetricWarning)
            return f1_score(all_targets, y_pred, average='macro')
    
    metrics = {}
    argsorted = all_predictions.argsort(axis=1)
    for threshold in [0.05,0.10,0.15,0.2,0.25,0.3,0.35,0.4,0.45]: 
        metrics[threshold] = get_score(
            binarize_prediction(all_predictions, threshold, argsorted))
    best_thr = max(metrics, key=metrics.get)
    print(' | '.join(f'thr_{k:.2f} {v:.3f}' for k, v in sorted(
        metrics.items(), key=lambda kv: -kv[1])[:5]))
    
    return valid_loss, best_thr, metrics[best_thr]

## start training

In [None]:
history_train = pd.DataFrame()
history_valid = pd.DataFrame()

n_epochs = 100
init_epoch = 0
max_lr_changes = 0
valid_losses = []
threshold = {}
macro_f1 = {}
lr_reset_epoch = init_epoch
patience = 2
lr_changes = 1
best_valid_loss = 1000.

for epoch in range(init_epoch, n_epochs):
    torch.cuda.empty_cache()
    gc.collect()
    train_model(epoch, history_train)
    valid_loss, best_thr, best_f1 = evaluate(epoch, history_valid)
    valid_losses.append(valid_loss)
    threshold[epoch] = best_thr
    macro_f1[epoch] = best_f1

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
    elif (patience and epoch - lr_reset_epoch > patience and
          min(valid_losses[-patience:]) > best_valid_loss):
        lr_changes +=1
        if lr_changes > max_lr_changes: 
            break
        lr /= 5 
        print(f'lr updated to {lr}')
        lr_reset_epoch = epoch
        optimizer.param_groups[0]['lr'] = lr

## Loss figure

In [None]:
fig, ax = plt.subplots(1,2,figsize=(20,5))

ax[0].plot(history_train['train_loss'].iloc[100:])
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Train Loss')

ax[1].plot(history_valid['valid_loss'])
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Valid Loss')


# Submit

In [None]:
best_epoch = max(macro_f1, key=macro_f1.get)
print('The best epoch is epoch {}'.format(best_epoch))

#model.load_state_dict(torch.load("../input/xception/epoch6.pth"))
model.load_state_dict(torch.load('models/epoch{}.pth'.format(best_epoch)))
model.eval();

In [None]:
# Inference
outputlist = []
for img_batch, _ in tqdm(test_loader):
    with torch.no_grad():
        output = torch.sigmoid(model(img_batch.to(device)))
    output = output.data.cpu().numpy()
    for i in output: 
        outputlist.append(i)

In [None]:
thr = 0.40
print('The best threshold is {:.3f}'.format(thr))

prediction = [' '.join([str(i) for i in np.argwhere((j > thr).astype(int)==1).reshape(-1)]) for j in outputlist]
sub_df['Predicted'] = prediction
sub_df.to_csv('submission_512_xception.csv', index=False)