In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

import cv2
import datetime
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import pretrainedmodels.utils as utils
import torch
import torch.optim as optim
from apex import amp
from easydict import EasyDict
from numpy.random import choice, seed
from PIL import Image
from pretrainedmodels import xception
from torch import nn
from torch.utils.data import Dataset, SequentialSampler
from torchvision import transforms as tt
from torchvision.transforms.functional import to_tensor
from tqdm import tqdm_notebook as tqdm
from unet.unet_model import UNet
from custom_pytorch import custom_layers
from custom_pytorch.custom_samplers import SubsetRandomSampler

In [3]:
DATA_PATH = "../../input/pneumonothorax-data/"
TRAIN_PATH = os.path.join(DATA_PATH, 'train')
TEST_PATH = os.path.join(DATA_PATH, 'test')
MODEL_SAVE_DIR = "../../input/models/pneumonothorax"
try:
    os.makedirs(MODEL_SAVE_DIR)
except OSError:
    pass

In [4]:
train_folder_size = len(os.listdir(os.path.join(TRAIN_PATH, 'images')))
test_folder_size = len(os.listdir(os.path.join(TEST_PATH, 'images')))
print('Training folder data size:', train_folder_size)
print('Testing folder data size:', test_folder_size)

Training folder data size: 10712
Testing folder data size: 1377


In [5]:
class PneumothoraxDataset(Dataset):
    def __init__(self, path, train, feat_transforms=None, train_aug_transforms=None, metadata_encoder=None):
        super().__init__()
        self.train = train
        self.path = path
        self.aug_transforms = train_aug_transforms
        self.metadata = pd.read_csv(os.path.join(self.path, 'metadata.csv'), index_col='ImageId')
        self.feat_transforms = feat_transforms
        from sklearn.preprocessing import OneHotEncoder
        metadata_subset = self.metadata[['PatientSex', 'ViewPosition']]
        if metadata_encoder is None:
            self.encoder = OneHotEncoder(sparse=False)
            self.encoder.fit(metadata_subset.to_numpy().tolist())
        
        encoded = self.encoder.transform(metadata_subset.to_numpy().tolist())
        self.metadata = pd.DataFrame(encoded, index=list(metadata_subset.index.values))

    def __len__(self):
        return len(self.metadata)
    
    def __iter__(self):
        for fil in os.listdir(os.path.join(self.path, 'images')):
            img = Image.open(os.path.join(self.path, 'images', fil))
            metadata = self.metadata.loc[os.path.splitext(fil)[0]].values.astype(int)
            t_img = img
            if self.train:
                try:
                    mask = cv2.imread(os.path.join(self.path, 'masks', fil))
                except OSError:
                    mask = np.zeros(imgs[-1].size[::-1], np.uint8)
                if self.aug_transforms is not None:
                    transformed = self.aug_transforms(image=img, mask=mask)
                    img = transformed['image']
                    mask = transformed['mask']
                    t_img = img
                if self.feat_transforms is not None:
                    t_img = self.feat_transforms(Image.fromarray(img, 'L'))
                yield img,\
                    t_img, metadata, mask
            else:
                if self.feat_transforms is not None:
                    t_img = self.feat_transforms(Image.fromarray(img, 'L'))
                yield img,\
                    t_img, metadata
                
    def __getitem__(self, index):
        if not isinstance(index, list):
            index = [index]
        imgs = []
        metadata = []
        masks = []
        fils = os.listdir(os.path.join(self.path, 'images'))
        for fil_index in index:
            fil = fils[fil_index]
            imgs.append(Image.open(os.path.join(self.path, 'images', fil)).convert('L'))
            metadata.append(self.metadata.loc[os.path.splitext(fil)[0]].values.astype(int))
            if self.train:
                try:
                    masks.append(cv2.imread(os.path.join(self.path, 'masks', fil), 0))
                    if masks[-1] is None:
                        masks = masks[:-1]
                        raise OSError
                except OSError:
                    masks.append(np.zeros(imgs[-1].size[::-1], np.uint8))
        
        t_imgs = imgs
        if len(index) == 1:
            if self.feat_transforms is not None:
                t_imgs = self.feat_transforms(Image.fromarray(imgs[0], 'L'))
            
            metadata = metadata[0]
            if self.train:
                if self.aug_transforms is not None:
                    transformed = self.aug_transforms(
                        image=np.array(imgs[0]), mask=masks[0])
                    imgs = transformed['image']
                    masks = transformed['mask']
                else:
                    imgs = np.array(imgs[0])
                    masks = masks[0]
        else:

            if self.train:
                imgs = [np.array(im) for im in imgs]
                if self.aug_transforms is not None:
                    transformed = [self.aug_transforms(
                        image=image, mask=mask)
                                   for image, mask in zip(imgs, masks)]
                    imgs = [trans['image'] for trans in transformed]

                    masks = [trans['mask'] for trans in transformed]
            else:
                imgs = [np.array(img) for img in imgs]
            if self.feat_transforms is not None:
                t_imgs = torch.stack([self.feat_transforms(Image.fromarray(img), 'L') for img in imgs])
            metadata = metadata[0]
        if self.train:
            return imgs, t_imgs, metadata, masks
        return imgs, t_imgs, metadata
    
def collate_fn(batch): 
    images = [item[0] for item in batch]
    t_images = [item[1] for item in batch]
    metadata = [item[2] for item in batch]
    masks = [item[3] for item in batch]
    return [images, t_images, metadata, masks]

In [6]:

CONFIG = EasyDict(dict(TRAIN_SIZE=15000,
                       TRAIN_AUG_RATIO=3, #through augmentation
                       VALID_SIZE=50, # by random selection, will not participate at all in training
                       BATCH_SIZE=8,
                       RANDOM_SEED=42,
                       LR=0.0001,
                       MOMENTUM=0.9,
                       WEIGHT_DECAY=0.9
                       ))

In [7]:
seed(CONFIG.RANDOM_SEED)
from torch import functional as F
def pad_div32(tensor):
    """
    Pads last 2 dimensions of a tensor so they can be divisible with 32
    """
    h, w = tensor.size()[-2:]
    if h < 32:
        y_p = 32 - h
    else:
        y_p = 32 - (h % 32)
    if w < 32:
        x_p = 32 - w
    else:
        x_p = 32 - (w % 32)
    pads = (x_p // 2, x_p - x_p // 2, y_p // 2, y_p - y_p // 2)
    return nn.ZeroPad2d(pads)(tensor), pads

    
class NeuralNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.bbox_layer = custom_layers.BBoxLayer(1, (2048, 10, 10), 4)
        self.features_model = xception(pretrained='imagenet').cuda()
        for param in self.features_model.parameters():
            param.requires_grad = False
        self.segmentor = UNet(1, 1)

    def forward(self, inputs, t_inputs, metadata_inputs):
        
        feature_map_outputs = self.features_model.features(t_inputs)
        bboxes_mask = self.bbox_layer(inputs, feature_map_outputs, metadata_inputs)
        bboxes_group = self.bbox_layer.get_bboxes()
        masks = []
        for bboxes, inp in zip(bboxes_group, inputs):
            mask_output = torch.zeros_like(inp)
            for bbox in bboxes:
                mask_input, pads = pad_div32(
                    inp[..., int(bbox[0][1]): int(bbox[2][1]),
                        int(bbox[0][0]): int(bbox[2][0])])
                mask = self.segmentor(mask_input.unsqueeze(0).repeat(2, 1, 1, 1))
                mask_output[..., int(bbox[0][1]): int(bbox[2][1]),
                        int(bbox[0][0]): int(bbox[2][0])] += mask[0, 0, pads[2]: -pads[3], pads[0]: -pads[1]]
            masks.append(mask_output)
        if len(masks) == 1:
            masks = masks[0]
        else:
            masks = torch.stack(masks)
        return masks, bboxes_mask

net = NeuralNet().cuda()
optimizer = optim.SGD(
    net.parameters(), lr=CONFIG.LR, momentum=CONFIG.MOMENTUM, weight_decay=CONFIG.WEIGHT_DECAY)
amp.initialize(net, optimizer)

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic


(NeuralNet(
   (bbox_layer): BBoxLayer(
     (mapping_conv): Conv2d(2048, 16, kernel_size=(1, 1), stride=(1, 1))
     (mapping_linear): Linear(in_features=100, out_features=100, bias=True)
     (mapping_activation): Sigmoid()
     (metadata_linear): Linear(in_features=4, out_features=100, bias=True)
     (metadata_activation): Sigmoid()
     (inputs_conv): Sequential(
       (0): Conv2d(1, 4, kernel_size=(3, 3), stride=(1, 1))
       (1): ZeroPad2d(padding=(1, 1, 1, 1), value=0.0)
     )
     (inputs_activation): Sigmoid()
     (bboxes_layer): Conv2d(4, 4, kernel_size=(10, 10), stride=(1, 1))
     (padding): ZeroPad2d(padding=(4, 5, 4, 5), value=0.0)
   )
   (features_model): Xception(
     (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
     (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (relu): ReLU(inplace)
     (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
     (bn2): BatchNorm2d(64, eps=1e-

In [8]:
from albumentations import (
    PadIfNeeded,
    HorizontalFlip,
    VerticalFlip,
    CenterCrop,    
    Crop,
    Compose,
    Transpose,
    RandomRotate90,
    ElasticTransform,
    GridDistortion, 
    OpticalDistortion,
    RandomSizedCrop,
    OneOf,
    CLAHE,
    RandomBrightnessContrast,    
    RandomGamma    
)
from albumentations.pytorch.transforms import ToTensor as AlbToTensor

inputs_transformations = Compose([   
    VerticalFlip(p=0.3),
    OneOf([
        ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
        GridDistortion(p=0.5),
        OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5)                  
        ], p=0.8),
    CLAHE(p=0.5),
    RandomBrightnessContrast(p=0.5),    
    RandomGamma(p=0.5),
    AlbToTensor()])



In [9]:

dataset = PneumothoraxDataset(TRAIN_PATH, train=True)

valid_indices = choice(len(dataset), size=CONFIG.VALID_SIZE, replace=False)
train_indices = np.setdiff1d(range(len(dataset)), valid_indices)
train_sampler = SubsetRandomSampler(train_indices, replacement=True, num_samples=CONFIG.TRAIN_SIZE)
valid_sampler = SubsetRandomSampler(valid_indices, replacement=False)

In [10]:
train_loader = torch.utils.data.DataLoader(
    dataset, batch_size=CONFIG.BATCH_SIZE, sampler=train_sampler,
    collate_fn=collate_fn, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
valid_loader = torch.utils.data.DataLoader(
    dataset, batch_size=CONFIG.BATCH_SIZE, sampler=valid_sampler,
    collate_fn=collate_fn, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

In [11]:
from custom_pytorch.custom_layers.bbox_layer import create_bboxes_mask_from_mask

In [12]:
net.train()
loss_function = torch.nn.BCEWithLogitsLoss()
from torch import from_numpy
epochs_num = 20
feats_transforms = utils.TransformImage(net.features_model)

def get_model_name(epoch, train_loss, valid_loss):
    return "D_%s_DS_%d_Ep_%d_TL_%.2f_VL_%.2f.pkl"%(
        str(datetime.datetime.now()).replace(' ', '_'),
        CONFIG.TRAIN_SIZE, epoch, train_loss, valid_loss)

def create_augmented_images(image, t_image, mask):
    res_images = []
    res_masks = []
    res_t_images = []
    
    for cnt in range(CONFIG.TRAIN_AUG_RATIO):
        res = inputs_transformations(image=image, mask=mask)
        res_images.append(res['image'].unsqueeze(0))
        res_t_images.append(feats_transforms(Image.fromarray(res['image'].data.numpy(), 'RGB')))
        res_masks.append(res['mask'])
    return torch.stack(res_images), torch.stack(res_t_images), torch.stack(res_masks)
        
        

def perform_batch_operation(batch):
    images = batch[0]
    t_images = batch[1]
    metadatas = batch[2]
    masks = batch[3]
    loss = None
    
    for cnt, (image, t_image, mask, metadata) in enumerate(zip(images, t_images, masks, metadatas)):
        aug_images, aug_t_images, aug_masks = create_augmented_images(image, t_image, mask)
        aug_images = aug_images.float().cuda()
        aug_t_images = aug_t_images.float().cuda()
        bboxes_masks = torch.stack([from_numpy(
            create_bboxes_mask_from_mask(mask.data.numpy()[0, ...])) for mask in aug_masks]).float().cuda()
        aug_masks = aug_masks.cuda()
        metadata = torch.stack(
            [from_numpy(metadata).unsqueeze(0) for _ in range(CONFIG.TRAIN_AUG_RATIO)]).float().cuda()
        out_masks, out_bboxes_masks = net(aug_images, aug_t_images, metadata)
        l1 = 0.7 * loss_function(out_masks.view(1, -1).float(), aug_masks.view(1, -1).float())
        l2 = 0.3 * loss_function(out_bboxes_masks.view(1, -1).float(), bboxes_masks.view(1, -1).float())
        l = l1 + l2
        if loss is None:
            loss = l
        else:
            loss += l
    return loss
train_losses = {}
valid_losses = {}
step = 0
run_validation_every_n_steps = 100
for epoch in range(epochs_num):
    iterator = tqdm(train_loader, desc=f'Epoch {epoch + 1}')
    train_loss = 0
    valid_loss = None
    cnt = 0
    for batch in iterator:
        cnt += 1
        optimizer.zero_grad()
        loss = perform_batch_operation(batch)
        
        loss.backward()
        
        optimizer.step()
        loss = loss.cpu().data.numpy()
        train_losses[step] = loss
        train_loss += loss
        
        if cnt == 0 or step % run_validation_every_n_steps == 0:
            with torch.no_grad():
                valid_loss = 0
                v_cnt = 0
                for batch in tqdm(valid_loader, desc='Performing validation'):
                    v_cnt += 1
                    loss = perform_batch_operation(batch)
                    valid_loss += loss.cpu().data.numpy()
                valid_loss = valid_loss / min(
                    CONFIG.VALID_SIZE, CONFIG.BATCH_SIZE * v_cnt * CONFIG.TRAIN_AUG_RATIO)
                valid_losses[step] = valid_loss
        iterator.set_description(
            'Epoch: %d \n TLoss: %.3f, VLoss: %s'%(
                epoch + 1, train_loss / min(CONFIG.TRAIN_SIZE, CONFIG.BATCH_SIZE * CONFIG.TRAIN_AUG_RATIO * cnt),
                'None' if valid_loss is None else '%.3f'%(valid_loss)))
        step += 1
    train_loss = train_loss / CONFIG.TRAIN_SIZE
    torch.save({
            'epoch': epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            't_losses': train_losses,
            'v_losses': valid_losses,
            'config':CONFIG}, os.path.join(MODEL_SAVE_DIR, get_model_name(epoch, train_loss, valid_loss))
            )
    
            
    
    

HBox(children=(IntProgress(value=0, description='Epoch 1', max=1875), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))




HBox(children=(IntProgress(value=0, description='Epoch 2', max=1875), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))




HBox(children=(IntProgress(value=0, description='Epoch 3', max=1875), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, description='Performing validation', max=7), HTML(value='')))




HBox(children=(IntProgress(value=0, description='Epoch 4', max=1875), HTML(value='')))

KeyboardInterrupt: 

In [None]:
masks[0]

In [None]:
'1.2.276.0.7230010.3.1.4.8323329.14292.1517875250.605080' in dataset.metadata.index

In [None]:
len(dataset.metadata)

In [None]:
plt.plot(sorted(train_losses), [train_losses[key] for key in sorted(train_losses)])

In [None]:
from torchviz import make_dot
make_dot(out, params=dict(net.named_parameters()))

In [None]:
net.features_model

In [None]:
train_losses

In [None]:
valid_losses