## Load libraries

In [1]:
!pip install -r requirements.txt

[31mfloyd-cli 0.11.17 has requirement click<7,>=6.7, but you'll have click 7.0 which is incompatible.[0m
[33mYou are using pip version 10.0.1, however version 19.2.3 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [1]:
import sys
import os
import numpy as np
import pandas as pd

from PIL import Image

import torch
import torch.nn as nn
import torch.utils.data as D
from torch.optim.lr_scheduler import ExponentialLR
import torch.nn.functional as F
from torch.autograd import Variable

from torchvision import transforms

from ignite.engine import Events
from scripts.ignite import create_supervised_evaluator, create_supervised_trainer
from ignite.metrics import Loss, Accuracy
from ignite.contrib.handlers.tqdm_logger import ProgressBar
from ignite.handlers import  EarlyStopping, ModelCheckpoint
from ignite.contrib.handlers import LinearCyclicalScheduler, CosineAnnealingScheduler

import random

from tqdm import tqdm_notebook

from sklearn.model_selection import train_test_split

from efficientnet_pytorch import EfficientNet, utils as enet_utils

from scripts.evaluate import eval_model_per_cell, eval_model_per_cell_10
from scripts.transforms import gen_transform_train, gen_transform_validation, gen_transform_test_multi
from scripts.plates_leak import apply_plates_leak

import gc

import warnings
warnings.filterwarnings('ignore')

## Define dataset and model

In [2]:
img_dir = '/storage/rxrxai'
path_data = '/storage/rxrxai'
device = 'cuda'
batch_size = 4
torch.manual_seed(0)
model_name = 'efficientnet-b4'
init_lr = 3e-4
end_lr = 1e-7

In [3]:
class ImagesDS(D.Dataset):
    transform_validation = gen_transform_test_multi(crop_size=448)
    
    def __init__(self, df, img_dir=img_dir, mode='train', validation=False, channels=[1,2,3,4,5,6]):
        self.records = df.to_records(index=False)
        self.mode = mode
        self.img_dir = img_dir
        self.len = df.shape[0]
        self.validation = validation
        self.channels = channels

    def _get_img_path(self, index, channel, site):
        experiment, well, plate = self.records[index].experiment, self.records[index].well, self.records[index].plate
        return '/'.join([self.img_dir,self.mode,experiment,f'Plate{plate}',f'{well}_s{site}_w{channel}.png'])
        
    @staticmethod
    def _load_img_as_tensor(file_name, transform):
        with Image.open(file_name) as img:
            im = transform(img)
            del img
            gc.collect()
            return im
        
    def __getitem__(self, index):        
        paths1 = [self._get_img_path(index, ch, 1) for ch in self.channels]
        paths2 = [self._get_img_path(index, ch, 2) for ch in self.channels]
        
        img1 = torch.cat([self._load_img_as_tensor(img_path, ImagesDS.transform_validation) for img_path in paths1])
        img2 = torch.cat([self._load_img_as_tensor(img_path, ImagesDS.transform_validation) for img_path in paths2])
        
        del paths1, paths2
        gc.collect()
        
        return img1, img2, self.records[index].id_code
    
    def __len__(self):
        return self.len

In [4]:
df_test = pd.read_csv(path_data+'/test.csv')
df_test['category'] = df_test['experiment'].apply(lambda x: x.split('-')[0])

In [5]:
class EfficientNetTwoInputs(nn.Module):
    def __init__(self):
        super(EfficientNetTwoInputs, self).__init__()
        self.classes = 1108
        
        model = EfficientNet.from_pretrained(model_name, num_classes=1108) 
        num_ftrs = model._fc.in_features
        model._fc = nn.Identity()
        
        # accept 6 channels
        trained_kernel = model._conv_stem.weight
        new_conv = enet_utils.Conv2dStaticSamePadding(6, 48, kernel_size=(3, 3), stride=(2, 2), bias=False, image_size=512)
        with torch.no_grad():
            new_conv.weight[:,:] = torch.stack([torch.mean(trained_kernel, 1)]*6, dim=1)
        model._conv_stem = new_conv
        
        self.resnet = model
        self.fc = nn.Linear(num_ftrs * 2, self.classes)

    def forward(self, x1, x2):
        x1_out = self.resnet(x1)
        x2_out = self.resnet(x2)
   
        N, _, _1, _2 = x1.size()
        x1_out = x1_out.view(N, -1)
        x2_out = x2_out.view(N, -1)
        x1_out = x1_out.detach()
        x2_out = x2_out.detach()
        
        out = torch.cat((x1_out, x2_out), 1)
        out = self.fc(out)
        
        del x1_out, x2_out, N, _, _1, _2
        gc.collect()

        return out 

In [6]:
!mkdir -p models

In [7]:
!mkdir -p /artifacts

In [8]:
all_preds = []
cells = df_test['category'].unique()

for cell in cells:
    cat_test_df = df_test[df_test['category'] == cell].copy()

    print('\n' + '=' * 40)
    print("CURRENT CATEGORY:", cell)
    print('-' * 40)

    model = EfficientNetTwoInputs()
    model.to(device)
    
    cat_test_ds = ImagesDS(cat_test_df, mode='test', validation=True)
    cat_test_loader = D.DataLoader(cat_test_ds, batch_size=1, shuffle=False, num_workers=8)
    cell_preds, preds = eval_model_per_cell_10(model, cat_test_loader, f'/storage/enet4{cell}.pth', path_data, cat_test_df.copy().drop(['category'], axis=1), sub_file=f'/artifacts/submission_{cell}.csv', n=2)
    all_preds += cell_preds


CURRENT CATEGORY: HEPG2
----------------------------------------
Loaded pretrained weights for efficientnet-b4

iteration 0



HBox(children=(IntProgress(value=0, max=4429), HTML(value='')))

KeyboardInterrupt: 

In [None]:
# aggregate submission files
submissions = []
for cell in cells:
    submissions += [pd.read_csv(f'/artifacts/submission_{cell}.csv')]

submissions = pd.concat(submissions)
submissions.to_csv(f'/artifacts/submission.csv', index=False, columns=['id_code','sirna'])

#### apply plates leak

In [None]:
apply_plates_leak(all_preds)