## Load libraries

In [1]:
!pip install -r requirements.txt

[31mfloyd-cli 0.11.17 has requirement click<7,>=6.7, but you'll have click 7.0 which is incompatible.[0m
[33mYou are using pip version 10.0.1, however version 19.2.3 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


In [1]:
import sys
import os
import numpy as np
import pandas as pd

from PIL import Image

import torch
import torch.nn as nn
import torch.utils.data as D
from torch.optim.lr_scheduler import ExponentialLR
import torch.nn.functional as F
from torch.autograd import Variable

from torchvision import transforms
from torchvision import models

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
# from scripts.ignite import create_supervised_evaluator, create_supervised_trainer
from ignite.metrics import Loss, Accuracy
from ignite.contrib.handlers.tqdm_logger import ProgressBar
from ignite.handlers import  EarlyStopping, ModelCheckpoint
from ignite.contrib.handlers import LinearCyclicalScheduler, CosineAnnealingScheduler

import random

from tqdm import tqdm_notebook

from sklearn.model_selection import train_test_split

from efficientnet_pytorch import EfficientNet, utils as enet_utils

from scripts.evaluate import eval_model
from scripts.plates_leak import apply_plates_leak

import gc

import warnings
warnings.filterwarnings('ignore')

In [2]:
!ls /storage/rxrxai

pixel_stats.csv		       test.csv		  train.zip
pixel_stats.csv.zip	       test.zip		  train384
pixel_stats_agg.csv	       test_controls.csv  train_controls.csv
recursion_dataset_license.pdf  train		  training_aug.csv
sample_submission.csv	       train.csv	  validation.csv
test			       train.csv.zip


## Define dataset and model

In [3]:
img_dir = '/storage/rxrxai'
path_data = '/storage/rxrxai'
stats_df = pd.read_csv(path_data + f'/pixel_stats_agg.csv')
model_name = 'efficientnet-b1'
device = 'cuda'
batch_size = 16
torch.manual_seed(0)
init_lr = 3e-4
end_lr = 1e-7
classes = 1108

In [4]:
std_mean = stats_df[(stats_df['cell'] == 'ALL') & (stats_df['channel'] == 1.)][['std', 'mean']]
std_mean

Unnamed: 0,std,mean
0,6.905682,5.845692


In [5]:
std_mean.iloc[0]['std']

6.905682371466988

In [6]:
std_mean.iloc[0]['mean']

5.845691592341334

In [7]:
channel_transforms = {}
for channel in range(1, 7):
    std_mean = stats_df[(stats_df['cell'] == 'ALL') & (stats_df['channel'] == float(channel))][['std', 'mean']]        
    mean = std_mean.iloc[0]['mean']
    std = std_mean.iloc[0]['std']

    channel_transforms[channel] = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[mean], std=[std])
    ])

class ImagesDS(D.Dataset):
    def __init__(self, df, img_dir=img_dir, mode='train', validation=False, channels=[1,2,3,4,5,6]):
        self.records = df.to_records(index=False)
        self.mode = mode
        self.img_dir = img_dir
        self.len = df.shape[0]
        self.validation = validation
        self.channels = channels

    def _get_img_path(self, index, channel, site):
        experiment, well, plate = self.records[index].experiment, self.records[index].well, self.records[index].plate
        return '/'.join([self.img_dir,self.mode,experiment,f'Plate{int(plate)}',f'{well}_s{site}_w{channel}.png'])
        
    @staticmethod
    def _load_img_as_tensor(file_name, channel):
        with Image.open(file_name) as img:    
            return channel_transforms[channel](img)
        
    def __getitem__(self, index):        
        img1 = torch.cat([self._load_img_as_tensor(self._get_img_path(index, ch, 1), ch) for ch in self.channels])
#         img2 = torch.cat([self._load_img_as_tensor(self._get_img_path(index, ch, 2), ch) for ch in self.channels])
        
#         if random.random() > 0.5 and not self.validation:
#             img1, img2 = img2, img1
        
        if self.mode == 'train':
            return img1, int(self.records[index].sirna)
        else:
            return img1, self.records[index].id_code
    
    def __len__(self):
        return self.len

In [8]:
# dataframes for training, cross-validation, and testing
df_train = pd.read_csv(path_data+'/train.csv')
df_val = pd.read_csv(path_data+'/validation.csv')
df_val = df_val.drop(['ds', 'cell', 'aug'], axis=1)
df_train = df_train[~df_train.isin(df_val)].dropna()
df_test = pd.read_csv(path_data+'/test.csv')

# pytorch training dataset & loader
ds = ImagesDS(df_train, mode='train', validation=False)
loader = D.DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=8)

# pytorch cross-validation dataset & loader
ds_val = ImagesDS(df_val, mode='train', validation=True)
val_loader = D.DataLoader(ds_val, batch_size=batch_size, shuffle=True, num_workers=8)

# pytorch test dataset & loader
ds_test = ImagesDS(df_test, mode='test', validation=True)
tloader = D.DataLoader(ds_test, batch_size=1, shuffle=False, num_workers=8)

In [9]:
# class DenseNetTwoInputs(nn.Module):
#     def __init__(self):
#         super(DenseNetTwoInputs, self).__init__()
#         self.classes = 1108
        
#         model = models.densenet121(pretrained=True)
#         num_ftrs = model.classifier.in_features
#         model.classifier = nn.Identity()

#         # let's make our model work with 6 channels
#         trained_kernel = model.features.conv0.weight
#         new_conv = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)
#         with torch.no_grad():
#             new_conv.weight[:,:] = torch.stack([torch.mean(trained_kernel, 1)]*6, dim=1)
#         model.features.conv0 = new_conv
        
#         self.densenet = model
#         self.fc = nn.Linear(num_ftrs * 2, self.classes)

#     def forward(self, x1, x2):
#         x1_out = self.densenet(x1)
#         x2_out = self.densenet(x2)
   
#         N, _, _, _ = x1.size()
#         x1_out = x1_out.view(N, -1)
#         x2_out = x2_out.view(N, -1)
        
#         out = torch.cat((x1_out, x2_out), 1)
#         out = self.fc(out)

#         return out 
    
# model = DenseNetTwoInputs()
# model.train()

In [None]:
class EfficientNetTwoInputs(nn.Module):
    def __init__(self):
        super(EfficientNetTwoInputs, self).__init__()
        self.classes = 1108
        
        model = EfficientNet.from_pretrained(model_name, num_classes=1108) 
        num_ftrs = model._fc.in_features
        model._fc = nn.Identity()
        
        # accept 6 channels
        trained_kernel = model._conv_stem.weight
        new_conv = enet_utils.Conv2dStaticSamePadding(6, 32, kernel_size=(3, 3), stride=(2, 2), bias=False, image_size=512)
        with torch.no_grad():
            new_conv.weight[:,:] = torch.stack([torch.mean(trained_kernel, 1)]*6, dim=1)
        model._conv_stem = new_conv
        
        self.resnet = model
        self.fc = nn.Linear(num_ftrs * 2, self.classes)

    def forward(self, x1, x2):
        x1_out = self.resnet(x1)
        x2_out = self.resnet(x2)
   
        N, _, _, _ = x1.size()
        x1_out = x1_out.view(N, -1)
        x2_out = x2_out.view(N, -1)
        
        out = torch.cat((x1_out, x2_out), 1)
        out = self.fc(out)
        
        del N, _, x1_out, x2_out

        return out 
   
# model = EfficientNetTwoInputs()
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = torch.nn.Linear(num_ftrs, classes)

# let's make our model work with 6 channels
trained_kernel = model.conv1.weight
new_conv = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)
with torch.no_grad():
    new_conv.weight[:,:] = torch.stack([torch.mean(trained_kernel, 1)]*6, dim=1)
model.conv1 = new_conv
model.train()

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=init_lr)

In [12]:
metrics = {
    'loss': Loss(criterion),
    'accuracy': Accuracy(),
}

trainer = create_supervised_trainer(model, optimizer, criterion, device=device)
val_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device)

#### EarlyStopping

In [13]:
# handler = EarlyStopping(patience=30, score_function=lambda engine: engine.state.metrics['accuracy'], trainer=trainer)
# val_evaluator.add_event_handler(Events.COMPLETED, handler)

#### LR Scheduler

In [14]:
scheduler = CosineAnnealingScheduler(optimizer, 'lr', init_lr, end_lr, len(loader))
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

# @trainer.on(Events.ITERATION_COMPLETED)
# def print_lr(engine):
#     epoch = engine.state.epoch
#     iteration = engine.state.iteration
    
#     if epoch < 2 and iteration % 100 == 0:
#         print(f'Iteration {iteration} | LR {optimizer.param_groups[0]["lr"]}')

#### Compute and display metrics

In [15]:
@trainer.on(Events.EPOCH_COMPLETED)
def compute_and_display_val_metrics(engine):
    epoch = engine.state.epoch
    metrics = val_evaluator.run(val_loader).metrics
    print("Validation Results - Epoch: {} | Average Loss: {:.4f} | Accuracy: {:.4f} "
          .format(engine.state.epoch, metrics['loss'], metrics['accuracy']))

#### Save best epoch only

In [16]:
!mkdir -p models

In [17]:
def get_saved_model_path(epoch):
    return f'models/Model_{model_name}_{epoch + 49}.pth'

best_acc = 0.
best_epoch = 1
best_epoch_file = ''

@trainer.on(Events.EPOCH_COMPLETED)
def save_best_epoch_only(engine):
    epoch = engine.state.epoch

    global best_acc
    global best_epoch
    global best_epoch_file
    best_acc = 0. if epoch == 1 else best_acc
    best_epoch = 1 if epoch == 1 else best_epoch
    best_epoch_file = '' if epoch == 1 else best_epoch_file

    metrics = val_evaluator.run(val_loader).metrics

    if metrics['accuracy'] > best_acc:
        prev_best_epoch_file = get_saved_model_path(best_epoch)
        if os.path.exists(prev_best_epoch_file):
            os.remove(prev_best_epoch_file)
            
        best_acc = metrics['accuracy']
        best_epoch = epoch
        best_epoch_file = get_saved_model_path(best_epoch)
        print(f'\nEpoch: {best_epoch} - New best accuracy! Accuracy: {best_acc}\n\n\n')
        torch.save(model.state_dict(), best_epoch_file)

#### Progress bar - uncomment when testing in notebook

In [18]:
pbar = ProgressBar(bar_format='')
pbar.attach(trainer, output_transform=lambda x: {'loss': x})

#### Train

In [19]:
print('Training started\n')
trainer.run(loader, max_epochs=50)

Training started



HBox(children=(IntProgress(value=0, max=2251), HTML(value='')))

KeyboardInterrupt: 

#### Evaluate

In [None]:
all_preds, _ = eval_model(model, tloader, best_epoch_file, path_data)

In [None]:
apply_plates_leak(all_preds)