In [None]:
!pip install segmentation-models-pytorch
!pip install torchcontrib

In [None]:
import pandas as pd
import numpy as np
import zipfile
import plotly.express as px
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
import torch

import sys
sys.path.append('../code/')
from data.dataset import TileDataset, PredictFromImgDataset
from data.transforms import HE_preprocess
from model_zoo.models import define_model
from utils.plots import plot_global_pred, plot_thresh_scores
from training.predict import predict_entire_mask
from training.train import fit

%load_ext autoreload
%autoreload 2

In [None]:
tile_size = 256   #the size of tiles
reduce_fac = 4 #reduce the original images by 4 times


MASKS = '../input/hubmap-kidney-segmentation/train.csv'
DATA = '../input/hubmap-kidney-segmentation/train/'
ZIP_TRAIN = f'../input/hubmap-kidney-segmentation/train_{tile_size}_red_{reduce_fac}.zip'
ZIP_MASKS = f'../input/hubmap-kidney-segmentation/masks_{tile_size}_red_{reduce_fac}.zip'

IMG_FOLDER = f'../input/hubmap-kidney-segmentation/train_{tile_size}_red_{reduce_fac}/'
MSK_FOLDER = f'../input/hubmap-kidney-segmentation/masks_{tile_size}_red_{reduce_fac}/'

# Split by image id

In [None]:
df_mask = pd.read_csv(MASKS)
slide_ids = df_mask.id.unique().tolist()

with zipfile.ZipFile(ZIP_TRAIN, 'r') as img_arch, \
     zipfile.ZipFile(ZIP_MASKS, 'r') as msk_arch:
    fnames = img_arch.namelist()
    
df_images = pd.DataFrame()
df_images['tile_name'] = fnames
df_images['fold'] = df_images['tile_name'].apply(lambda f: slide_ids.index(f.split("_", 1)[0]))

## VIZ and Training

In [None]:
#example of train images with masks
viz_ds = TileDataset(df_images,
                     IMG_FOLDER,
                     MSK_FOLDER,
                     transforms=HE_preprocess(augment=True, visualize=True))
viz_dl = DataLoader(viz_ds, batch_size=64,shuffle=False,num_workers=1)
imgs,masks = next(iter(viz_dl))

plt.figure(figsize=(16,16))
for i,(img,mask) in enumerate(zip(imgs,masks)):
    img = img.permute(1,2,0).numpy()#.astype(np.uint8)
    plt.subplot(8,8,i+1)
    plt.imshow(img,vmin=0,vmax=255
              )
    plt.imshow(mask.squeeze().numpy(), alpha=0.5)
    plt.axis('off')
    plt.subplots_adjust(wspace=None, hspace=None)
    
# del viz_ds, viz_dl, imgs, masks

In [None]:
encoder = "resnet101"
decoder = "Unet"
plot_global = False

for fold_nb in range(df_images.fold.nunique()):
    print("FOLD :", fold_nb)
#     if fold_nb <3:
#         continue
    train_df = df_images[df_images.fold!=fold_nb].reset_index()
    val_df = df_images[df_images.fold==fold_nb].reset_index()
    mask_name = val_df.tile_name[0].split("_", 1)[0]
    
    train_dataset = TileDataset(train_df,
                                IMG_FOLDER, MSK_FOLDER,
                                transforms=HE_preprocess(augment=True, visualize=False))
    val_dataset = TileDataset(val_df,
                              IMG_FOLDER, MSK_FOLDER,
                              transforms=HE_preprocess(augment=False, visualize=False))

    model = define_model(decoder, encoder, num_classes=1, activation=None).to("cuda")
    meter, history = fit(model,
                         train_dataset=train_dataset,
                         val_dataset=val_dataset,
                         epochs=100,
                         loss_name="BCEWithLogitsLoss",
                         swa_first_epoch = 80,
                         batch_size=32
                        )
    px.line(history, x='epoch', y='dice').show()
    
    torch.save(model.state_dict(), f"../models/{encoder}_{decoder}_{tile_size}_{reduce_fac}_fold{fold_nb}.pt")
    
    if plot_global:
        print("predicting global image...")
        predict_dataset = PredictFromImgDataset(f'../input/hubmap-kidney-segmentation/train/{mask_name}.tiff',
                                            mask_name = mask_name,
                                            overlap_factor=4,
                                            reduce_factor=reduce_fac,
                                            transforms=HE_preprocess(augment=False, visualize=False))

        global_pred = predict_entire_mask(predict_dataset,
                                      model,
                                      batch_size = 32)
        plot_global_pred(mask=predict_dataset.mask, pred=global_pred>0)
        thresholds, scores = plot_thresh_scores(mask=predict_dataset.mask, pred=global_pred)
        max_score_pos = np.argmax(scores)
        print(f"Maximum dice: {scores[max_score_pos]} with thresh {thresholds[max_score_pos]}")