# References

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import os
os.sys.path.append('../fastai/') #fastai version 1

from pathlib import Path
import numpy as np
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from fastai import *
from fastai.vision import *
from fastai.vision.image import *

from resnet import Resnet4Channel

# Global Variables

In [None]:
torch.cuda.set_device(1)
torch.backends.cudnn.benchmark=True

DP = Path('/home/Deep_Learner/work/datasets/human-protein-atlas-image-classification/')
STAGE_ONE_DATA = DP/'stage1_data'
TRAIN_PNGS = 'train_pngs'
TRAIN_LABELS = 'labels.csv'
TEST_PNGS = 'test_pngs'
TRAIN_CSV = 'train.csv'
SAMPLE_SUBMISSION_CSV = 'sample_submission.csv'


SUBMISSIONS = DP/STAGE_ONE_DATA/'submissions'
SUBMISSIONS.mkdir(exist_ok=True)


filter_colors = ['blue', 'green', 'red', 'yellow']

IdToCatDict = {0:'Nucleoplasm',
               1:'Nuclear_membrane',
               2:'Nucleoli',
               3:'Nucleoli_fibrillar_center',
               4:'Nuclear_speckles',
               5:'Nuclear_bodies',
               6:'Endoplasmic_reticulum',
               7:'Golgi_apparatus',
               8:'Peroxisomes',
               9:'Endosomes',
               10:'Lysosomes',
               11:'Intermediate_filaments',
               12:'Actin_filaments',
               13:'Focal_adhesion_sites',
               14:'Microtubules',
               15:'Microtubule_ends',
               16:'Cytokinetic_bridge',
               17:'Mitotic_spindle',
               18:'Microtubule_organizing_center',
               19:'Centrosome',
               20:'Lipid_droplets',
               21:'Plasma_membrane',
               22:'Cell_junctions',
               23:'Mitochondria',
               24:'Aggresome',
               25:'Cytosol',
               26:'Cytoplasmic_bodies',
               27:'Rods_&_rings'}

sz = 224
bs = 16
val_split = 0.2
np.random.seed(42)

# Data

In [None]:
def open_4_channel(fname):
    fname = str(fname)
    # strip extension before adding color
    if fname.endswith('.png'):
        fname = fname[:-4]
    colors = ['red','green','blue','yellow']
    flags = cv2.IMREAD_GRAYSCALE
    img = [cv2.imread(fname+'_'+color+'.png', flags).astype(np.float32)/255
           for color in colors]
    
    x = np.stack(img, axis=-1)
    return Image(pil2tensor(x, np.float32).float())


class ImageMulti4Channel(ImageMultiDataset):
    def __init__(self, fns, labels, classes=None, **kwargs):
        super().__init__(fns, labels, classes, **kwargs)
        self.image_opener = open_4_channel

In [None]:
df = pd.read_csv(STAGE_ONE_DATA/TRAIN_CSV)
fns = pd.Series([id + '.png' for id in df.Id])
labels = [targ.split(' ') for targ in df.Target]
trn_ds, val_ds = ImageMulti4Channel.from_folder(path=STAGE_ONE_DATA,
                                                folder=TRAIN_PNGS, 
                                                fns=fns, 
                                                labels=labels,
                                                valid_pct = val_split,
                                                classes=[str(i) for i in range(28)])

In [None]:
df_test = pd.read_csv(STAGE_ONE_DATA/SAMPLE_SUBMISSION_CSV)
fns_test = pd.Series([id + '.png' for id in df_test.Id])
labels_test = [str(targ).split(' ') for targ in df_test.Predicted]
test_ds,_ = ImageMulti4Channel.from_folder(
    path = STAGE_ONE_DATA, 
    folder=TEST_PNGS, 
    fns=fns_test, 
    labels=labels_test,
    valid_pct=0,
    classes=['0']
)

In [None]:
#test_ids = list(sorted({fname.split('_')[0] for fname in os.listdir(STAGE_ONE_DATA/TEST_PNGS)}))
#test_ds,_ = ImageMulti4Channel.from_folder(
#    path = STAGE_ONE_DATA, 
#    folder = TEST_PNGS, 
#    fns = pd.Series(test_ids), 
#    labels = [['0'] for _ in range(len(test_ids))],
#    valid_pct=0,
#    classes=['0'],
#)

In [None]:
trn_tfms,_ = get_transforms(do_flip=True, flip_vert=True, max_rotate=30., max_zoom=1,
                      max_lighting=0.05, max_warp=0.)

In [None]:
protein_stats = ([0.08069, 0.05258, 0.05487, 0.08282], [0.13704, 0.10145, 0.15313, 0.13814])

In [None]:
data = ImageDataBunch.create(trn_ds, val_ds, test_ds=test_ds, path=STAGE_ONE_DATA, bs=bs, ds_tfms=(trn_tfms, []), num_workers=8, size=sz).normalize(protein_stats)

# Model

# Loss Function

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, gamma=2):
        super().__init__()
        self.gamma = gamma
        
    def forward(self, input, target):
        if not (target.size() == input.size()):
            raise ValueError("Target size ({}) must be the same as input size ({})"
                             .format(target.size(), input.size()))

        max_val = (-input).clamp(min=0)
        loss = input - input * target + max_val + \
            ((-max_val).exp() + (-input - max_val).exp()).log()

        invprobs = F.logsigmoid(-input * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
        
        return loss.sum(dim=1).mean()

# Train

In [None]:
resnet50 = Resnet4Channel(encoder_depth=50)

In [None]:
f1_score = partial(fbeta, thresh=0.5, beta=1)

In [None]:
learn = ClassificationLearner(
    data=data,
    model=resnet50,
    loss_func=F.binary_cross_entropy_with_logits,
    #loss_func=FocalLoss(),
    path=STAGE_ONE_DATA,    
    metrics=[f1_score]
)

In [None]:
learn.unfreeze()

In [None]:
learn.lr_find()

In [None]:
learn.recorder.plot()

In [None]:
lr = 0.01

In [None]:
learn.fit_one_cycle(20, slice(lr))

In [None]:
learn.save('resnet50_basic')

# Predictions

In [None]:
preds,_ = learn.get_preds(DatasetType.Test)

In [None]:
pred_labels = [' '.join(list([str(i) for i in np.nonzero(row>0.5)[0]])) for row in np.array(preds)]
df = pd.DataFrame({'Id':test_ids,'Predicted':pred_labels})
df.to_csv(path/'protein_predictions.csv', header=True, index=False)