In [1]:
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import os
import cv2
import gc
import random
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import albumentations as A
import segmentation_models_pytorch as smp
from albumentations import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as D

import time

import warnings
warnings.filterwarnings("ignore")

from tqdm import tqdm

## config

In [2]:
bs =  16
SEED = 2020

TRAIN = '/data/game/cancer/data/train/1/train'
MASKS = '/data/game/cancer/data/train/1/mask'

NUM_WORKERS = 0
DEVICE = "cuda:0"
EPOCHES = 60
model_name = "se_resnext50_32x4d"

nfolds = 4

## seed

In [3]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(SEED)

## MODEL

In [4]:
def get_model():
    model = smp.Unet(
        encoder_name=model_name ,
        encoder_weights="imagenet",
        in_channels=3,     
        classes=1,   
    )
    return model

## Dataset

In [5]:
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

def img2tensor(img,dtype:np.dtype=np.float32):
    if img.ndim==2 : img = np.expand_dims(img,2)
    img = np.transpose(img,(2,0,1))
    return torch.from_numpy(img.astype(dtype, copy=False))

class CancerDataset(Dataset):
    def __init__(self, fold=0, train=True, tfms=None):
        self.fnames = [fname for fname in os.listdir(TRAIN)]
        self.train = train
        self.tfms = tfms
        
    def __len__(self):
        return len(self.fnames)
    
    def __getitem__(self, idx):
        fname = self.fnames[idx]
        img = cv2.cvtColor(cv2.imread(os.path.join(TRAIN,fname)), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(os.path.join(MASKS,fname),cv2.IMREAD_GRAYSCALE)
        mask = np.divide(mask, 255.0)
        img = cv2.resize(img, (256,256))
        mask = cv2.resize(mask, (256,256))
        if self.tfms is not None:
            augmented = self.tfms(image=img,mask=mask)
            img,mask = augmented['image'],augmented['mask']
        return img2tensor((img/255.0 - mean)/std),img2tensor(mask)
    
def get_aug(p=1.0):
    return Compose([
        HorizontalFlip(),
        VerticalFlip(),
        RandomRotate90(),
        ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.9, 
                         border_mode=cv2.BORDER_REFLECT),
        OneOf([
            OpticalDistortion(p=0.3),
            GridDistortion(p=.1),
            IAAPiecewiseAffine(p=0.3),
        ], p=0.3),
        OneOf([
            HueSaturationValue(10,15,10),
            CLAHE(clip_limit=2),
            RandomBrightnessContrast(),            
        ], p=0.3),
    ], p=p)

## SoftDiceLoss

In [6]:
class SoftDiceLoss(nn.Module):
    def __init__(self, smooth=1., dims=(-2,-1)):
        super(SoftDiceLoss, self).__init__()
        self.smooth = smooth
        self.dims = dims

    def forward(self, x, y):
        tp = (x * y).sum(self.dims)
        fp = (x * (1 - y)).sum(self.dims)
        fn = ((1 - x) * y).sum(self.dims)
        dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)
        dc = dc.mean()

        return 1 - dc

## metrics

In [7]:
def pixel_accuracy(preds, label):
    
    valid = (label >= 0)
    acc_sum = (valid * (preds == label)).sum()
    print(valid, acc_sum)
    valid_sum = valid.sum()
    acc = float(acc_sum) / (valid_sum + 1e-10)
    return acc

def np_dice_score(probability, mask):
    p = probability.reshape(-1)
    t = mask.reshape(-1)

    p = p>0.5
    t = t>0.5
    uion = p.sum() + t.sum()

    overlap = (p*t).sum()
    dice = 2*overlap/(uion+0.001)
    return dice

In [8]:
def train(model, train_loader, criterion, optimizer):
    losses = []
    for i, (image, target) in tqdm(enumerate(train_loader)):
        image, target = image.to(DEVICE), target.float().to(DEVICE)
        optimizer.zero_grad()
        output = model(image)
        
        loss = criterion(output, target)
        
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
        
    return np.array(losses).mean()

def validation(model, val_loader, criterion):
    val_probability, val_mask = [], []
    model.eval()
    with torch.no_grad():
        for image, target in tqdm(val_loader):
            image, target = image.to(DEVICE), target.float().to(DEVICE)
            output = model(image)

            output_ny = output.sigmoid().data.cpu().numpy()
            target_np = target.data.cpu().numpy()

            val_probability.append(output_ny)
            val_mask.append(target_np)

    val_probability = np.concatenate(val_probability)
    val_mask = np.concatenate(val_mask)

#     return pixel_accuracy(val_probability, val_mask)
    return np_dice_score(val_probability, val_mask)

In [9]:
for fold in range(nfolds):
    train_ds = CancerDataset(fold=fold, train=True, tfms=get_aug())
    valid_ds = CancerDataset(fold=fold, train=False)

    print(len(train_ds), len(valid_ds))
    
    train_loader = D.DataLoader(
        train_ds, batch_size=bs, shuffle=True, num_workers=0)

    val_loader = D.DataLoader(
        valid_ds, batch_size=bs, shuffle=False, num_workers=0)

    model = get_model()
    model.to(DEVICE)

    
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-3)
    lr_step = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=5)
    loss_fn = SoftDiceLoss()

    header = r'''
            Train | Valid
    Epoch |  Loss |  pa (Best) | Time
    '''
    print(header)
    #          Epoch         metrics            time
    raw_line = '{:6d}' + '\u2502{:7.4f}'*3 + '\u2502{:6.2f}'

    best_pa = 0
    for epoch in range(1, EPOCHES+1):
        start_time = time.time()
        model.train()
        train_loss = train(model, train_loader, loss_fn, optimizer)
        val_pa = validation(model, val_loader, loss_fn)
        lr_step.step(val_pa)

        if val_pa >= best_pa:
            best_pa = val_pa
            torch.save(model.state_dict(), f'{model_name}_fold.pth')

        print(raw_line.format(epoch, train_loss, val_pa, best_pa, (time.time()-start_time)/60**1))


    del train_loader, val_loader, train_ds, valid_ds

    gc.collect()
    break

403 403


0it [00:00, ?it/s]


            Train | Valid
    Epoch |  Loss |  pa (Best) | Time
    


26it [01:01,  2.35s/it]
100%|██████████| 26/26 [00:49<00:00,  1.91s/it]
0it [00:00, ?it/s]

     1│ 0.2576│ 0.6606│ 0.6606│  1.87


26it [00:57,  2.21s/it]
100%|██████████| 26/26 [00:48<00:00,  1.86s/it]
0it [00:00, ?it/s]

     2│ 0.1911│ 0.6411│ 0.6606│  1.76


26it [00:56,  2.18s/it]
100%|██████████| 26/26 [00:47<00:00,  1.84s/it]
0it [00:00, ?it/s]

     3│ 0.3304│ 0.6705│ 0.6705│  1.76


26it [00:56,  2.18s/it]
100%|██████████| 26/26 [00:47<00:00,  1.84s/it]
0it [00:00, ?it/s]

     4│ 0.5398│ 0.6142│ 0.6705│  1.74


26it [00:57,  2.21s/it]
100%|██████████| 26/26 [00:47<00:00,  1.84s/it]
0it [00:00, ?it/s]

     5│ 0.6391│ 0.6131│ 0.6705│  1.76


26it [00:57,  2.21s/it]
100%|██████████| 26/26 [00:57<00:00,  2.20s/it]
0it [00:00, ?it/s]

     6│ 0.3931│ 0.6043│ 0.6705│  1.91


21it [00:49,  2.36s/it]


KeyboardInterrupt: 