In [None]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.cuda import amp
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau, MultiStepLR, CosineAnnealingWarmRestarts

import segmentation_models_pytorch as smp
import albumentations as A
from empatches import EMPatches

import cv2
import matplotlib.pyplot as plt
import os 
import random
from tqdm import tqdm
import gc

In [None]:
class CFG:
    data_root_path = '/path/to/data/folder'

    seed = 42

    epochs = 6
    train_batch_size = 16
    valid_batch_size = 16
    n_accumulate = 2
    workers = 8
    accelerator = "gpu"
    patch_size = 256
    train_overlap = 0.4
    valid_overlap = 0.1

    seg_model = "Unet" 
    encoder_name = 'tu-maxvit_tiny_tf_512' 
    lr = 1.0e-4 
    weight_decay = 0.001
    eps = 0.0001
    min_lr = 1.0e-6 
    T_max =  100000 

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

df = pd.read_csv(f'{CFG.data_root_path}/train_rles.csv')
df['image'] = df['id'].apply(lambda x: x.split('_')[-1])

In [None]:
def set_seed(seed = 42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
set_seed(CFG.seed)

In [None]:
train_transform = A.Compose([
    A.RandomRotate90(p=1),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomBrightness(p=1),
    A.OneOf(
        [
            A.Blur(blur_limit=3, p=1),
            A.MotionBlur(blur_limit=3, p=1),
        ],
        p=0.9,
    ),
])
valid_transform = A.Compose([
])

In [None]:
df = df[df['id'].str.contains('kidney_1_dense|kidney_2|kidney_3_dense')].reset_index(drop=True)
df['kidney'] = df.id.apply(lambda x: x.rsplit('_',1)[0])

def create_image_path(row):
    if row.kidney == 'kidney_3_dense':
        image_path = f'{CFG.data_root_path}/train/kidney_3_sparse/images/{row.image}.tif'
    else:
        image_path = f'{CFG.data_root_path}/train/{row.kidney}/images/{row.image}.tif'
    return image_path
def create_mask_path(row):
    if row.kidney == 'kidney_3_dense':
        mask_path = f'{CFG.data_root_path}/train/kidney_3_dense/labels/{row.image}.tif'
    else:
        mask_path = f'{CFG.data_root_path}/train/{row.kidney}/labels/{row.image}.tif'
    return mask_path

df['image_path'] =  df.apply(create_image_path, axis=1)
df['mask_path'] =  df.apply(create_mask_path, axis=1)

In [None]:
def create_kidney_volume(kidney, df):
    df = df[df['kidney'].str.contains(kidney)].sort_values('image', ascending=True).reset_index(drop=True)
    all_images = []
    all_masks = []
    if kidney == 'kidney_2':
        df = df.iloc[900:]
    for i in tqdm(range(len(df))):
        row = df.iloc[i]

        image = cv2.imread(row.image_path, cv2.IMREAD_GRAYSCALE)
        image = torch.from_numpy(image.copy())
        image = image.to(torch.uint8)
        all_images.append(image)

        mask = cv2.imread(row.mask_path, cv2.IMREAD_GRAYSCALE)
        mask = torch.from_numpy(mask.copy())
        mask = mask.to(torch.uint8)
        all_masks.append(mask)
    all_images = torch.stack(all_images)
    all_masks = torch.stack(all_masks)
    return all_images, all_masks



In [None]:
train_data = {}
for kidney in ['kidney_1_dense', 'kidney_3_dense']:
    all_images, all_masks = create_kidney_volume(kidney, df)
    train_data[kidney] = [all_images, all_masks]

valid_images, valid_masks = create_kidney_volume('kidney_2', df)
valid_data = {'kidney_2': [valid_images, valid_masks]}

In [None]:
def create_grid(images):
    row1 = np.concatenate([images[:, :, 0],images[:, :, 1]], axis=1)
    row2 = np.concatenate([images[:, :, 2],images[:, :, 3]], axis=1)
    image = row2 = np.concatenate([row1, row2], axis=0)
    return image

def perc_normalize(image, percentile_dict, kidney):
    image = image.to(torch.float32)
    lo = percentile_dict[kidney][0]
    hi = percentile_dict[kidney][1]
    image = (image - lo) / (hi - lo)
    image = torch.clamp(image, min=0.5)
    return image

def preprocess_mask(mask):
    mask = mask.to(torch.float32)
    mask /= 255.0
    return mask

def get_image_ids(data, truncate=3, train=True):
    emp = EMPatches()
    if train:
        overlap = CFG.train_overlap
    else:
        overlap = CFG.valid_overlap
    ids = []
    for kidney in data.keys():
        img = data[kidney][0][0]
        img_patches, indices = emp.extract_patches(img, patchsize=CFG.patch_size, overlap=overlap)
        print('axis:0', kidney, len(img_patches))
        for i in range(data[kidney][0].shape[0]-truncate):
            for patch in range(len(img_patches)):
                ids.append(f'{kidney}-axis0-{i}_{patch}')

        if train:
            img = data[kidney][0].permute(1,2,0)[0]
            img_patches, indices = emp.extract_patches(img, patchsize=CFG.patch_size, overlap=overlap)
            print('axis:1',kidney, len(img_patches))
            for i in range(data[kidney][0].permute(1,2,0).shape[0]-truncate):
                for patch in range(len(img_patches)):
                    ids.append(f'{kidney}-axis1-{i}_{patch}')

            img = data[kidney][0].permute(2,0,1)[0]
            img_patches, indices = emp.extract_patches(img, patchsize=CFG.patch_size, overlap=overlap)
            print('axis:2',kidney, len(img_patches))
            for i in range(data[kidney][0].permute(2,0,1).shape[0]-truncate):
                for patch in range(len(img_patches)):
                    ids.append(f'{kidney}-axis2-{i}_{patch}')
    return ids

def get_patch(emp, kidney_volume, mask_volume, image_id, patch_id, overlap, percentile_dict, kidney):
    img = kidney_volume[image_id]
    mask = mask_volume[image_id]

    img = perc_normalize(img, percentile_dict, kidney)
    mask = preprocess_mask(mask)

    img_patches, indices = emp.extract_patches(img, patchsize=CFG.patch_size, overlap=overlap)
    mask_patches, indices = emp.extract_patches(mask, patchsize=CFG.patch_size, overlap=overlap)

    return img_patches[patch_id], mask_patches[patch_id]

def get_percentile_dict():
    percentile_dict = {}
    for kidney in ['kidney_1_dense', 'kidney_2', 'kidney_3_dense']:
        if kidney == 'kidney_2':
            lo, hi = np.percentile(valid_data[kidney][0].numpy(), (2, 98))
        else:
            lo, hi = np.percentile(train_data[kidney][0].numpy(), (2, 98))
        percentile_dict[kidney] = [lo, hi]
    return percentile_dict

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, train=True):
        self.train = train
        self.data = data
        self.image_ids = get_image_ids(self.data, train=self.train)
        self.emp = EMPatches()
        self.percentile_dict = get_percentile_dict()
        if self.train:
            self.overlap = CFG.train_overlap
        else:
            self.overlap = CFG.valid_overlap

    def __getitem__(self, index):
        kidney, axis, orig_image_id = self.image_ids[index].split('-')
        orig_image_id, patch_id = orig_image_id.split('_')
        orig_image_id, patch_id = int(orig_image_id), int(patch_id)
        kidney_volume = self.data[kidney][0]
        mask_volume = self.data[kidney][1]
        if axis == 'axis1':
            kidney_volume = kidney_volume.permute(1,2,0)
            mask_volume = mask_volume.permute(1,2,0)
        elif axis == 'axis2':
            kidney_volume = kidney_volume.permute(2,0,1)
            mask_volume = mask_volume.permute(2,0,1)

        images = []
        masks = []
        for i in range(4):
            image_id = orig_image_id+i
            img, mask = get_patch(self.emp, kidney_volume, mask_volume, image_id, patch_id, self.overlap, self.percentile_dict, kidney)
            images.append(img)
            masks.append(mask)
        images = torch.stack(images)
        masks = torch.stack(masks)
        images = images.numpy()
        masks = masks.numpy()
        if self.train:
            data = train_transform(image=images.transpose(1,2,0), mask=masks.transpose(1,2,0))
        else:
            data = valid_transform(image=images.transpose(1,2,0), mask=masks.transpose(1,2,0))
        images, masks = data['image'], data['mask']
        image = create_grid(images)
        mask = create_grid(masks)

        mask = (mask>0).astype(np.int8).astype(np.float32)
        image = torch.tensor(image) 
        mask = torch.tensor(mask)
        orig_image_id = torch.tensor(int(orig_image_id), dtype=torch.int16)
        patch_id = torch.tensor(int(patch_id), dtype=torch.int8)

        return image.unsqueeze(0), mask.unsqueeze(0), orig_image_id.unsqueeze(0),  patch_id.unsqueeze(0)
        
    def __len__(self):
        return len(self.image_ids)

In [None]:
seg_models = {
    "Unet": smp.Unet,
    "Unet++": smp.UnetPlusPlus,
    "MAnet": smp.MAnet,
    "Linknet": smp.Linknet,
    "FPN": smp.FPN,
    "PSPNet": smp.PSPNet,
    "PAN": smp.PAN,
    "DeepLabV3": smp.DeepLabV3,
    "DeepLabV3+": smp.DeepLabV3Plus,
}

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model = seg_models[CFG.seg_model](
            encoder_name=CFG.encoder_name,
            encoder_weights="imagenet", 
            in_channels=1,
            classes=1,
            activation=None,
        )

    def forward(self, images):
        preds = self.model(images)
        return preds

In [None]:
DiceLoss = smp.losses.DiceLoss(mode="binary")
TverskyLoss = smp.losses.TverskyLoss(mode="binary", alpha=0.7, beta=0.3)
BCELoss = smp.losses.SoftBCEWithLogitsLoss()
JaccardLoss = smp.losses.JaccardLoss(mode="binary")
FocalLoss = smp.losses.FocalLoss(mode="binary")
# https://smp.readthedocs.io/en/latest/losses.html

def calculate_loss(preds, masks):
    return DiceLoss(preds, masks)


In [None]:
def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch):
    model.train()
    scaler = amp.GradScaler()
    dataset_size = 0
    running_loss = 0.0

    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Train ')
    for step, (images, masks, image_id,  patch_id) in pbar:         
        images = images.to(device, dtype=torch.float)
        masks  = masks.to(device, dtype=torch.float)
        batch_size = images.size(0)

        with amp.autocast(enabled=True):
            preds = model(images)
            loss   = calculate_loss(preds, masks)
            loss   = loss / CFG.n_accumulate
            
        scaler.scale(loss).backward()
    
        if (step + 1) % CFG.n_accumulate == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            if scheduler is not None:
                scheduler.step()
        
        running_loss += (loss.detach().item() * batch_size)
        dataset_size += batch_size
        del loss, preds
        epoch_loss = running_loss / dataset_size
        
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix( epoch=f'{epoch}',
                          train_loss=f'{epoch_loss:0.4f}',
                          lr=f'{current_lr:0.5f}',
                          gpu_mem=f'{mem:0.2f} GB')
    torch.cuda.empty_cache()
    gc.collect()

    return epoch_loss

def valid_one_epoch(model, optimizer, dataloader):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Valid ')
    for step, (images, masks, image_id,  patch_id) in pbar:        
        images  = images.to(CFG.device, dtype=torch.float)
        masks   = masks.to(CFG.device, dtype=torch.float)
        
        batch_size = images.size(0)
        with torch.cuda.amp.autocast(enabled=True):
            with torch.no_grad():
                preds  = model(images)
                loss    = calculate_loss(preds, masks)
        
        running_loss += (loss.detach().item() * batch_size)
        dataset_size += batch_size
        epoch_loss = running_loss / dataset_size
        del loss, preds
  
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix(valid_loss=f'{epoch_loss:0.4f}',
                        lr=f'{current_lr:0.5f}',
                        gpu_memory=f'{mem:0.2f} GB')
    torch.cuda.empty_cache()
    gc.collect()

    return epoch_loss

def run_training(model, train_dataloader, valid_dataloader, optimizer, scheduler, num_epochs):
    best_loss      = np.inf

    for epoch in range(1, num_epochs + 1): 
        gc.collect()
        print(f'Epoch {epoch}/{num_epochs}', end='')
        train_loss = train_one_epoch(model, optimizer, scheduler, 
                                           dataloader=train_dataloader, 
                                           device=CFG.device, epoch=epoch)
        torch.cuda.empty_cache()
        gc.collect()
        val_loss = valid_one_epoch(model,optimizer, valid_dataloader)
        torch.cuda.empty_cache()
        gc.collect()
        torch.save(model.state_dict(), f'./results/model/{CFG.encoder_name}-{CFG.seg_model}_last_{epoch}epochs.pth')
        if val_loss <= best_loss:
            print(f'Loss improved! val_loss:{val_loss}, previous best:{best_loss}  ')
            best_loss = val_loss
            # best_dice = val_dice
            best_epoch = epoch
            torch.save(model.state_dict(), f'./results/model/{CFG.encoder_name}-{CFG.seg_model}_best.pth')
    print(f'Best model || best_epoch {best_epoch} | best_loss:{best_loss} |')
    

In [None]:
train_dataset = Dataset(train_data, train=True)
valid_dataset = Dataset(valid_data, train=False)
train_dataloader = DataLoader(train_dataset, batch_size=CFG.train_batch_size, shuffle=True, num_workers=CFG.workers)
valid_dataloader = DataLoader(valid_dataset, batch_size=CFG.valid_batch_size, shuffle=False, num_workers=CFG.workers)

In [None]:
model = Model()
model.to(CFG.device)
optimizer = AdamW(model.parameters(), lr =  CFG.lr, eps =  CFG.eps)      
scheduler = CosineAnnealingLR(optimizer,T_max=CFG.T_max, eta_min=CFG.min_lr)
run_training(model, train_dataloader, valid_dataloader, optimizer, scheduler, num_epochs=CFG.epochs)
torch.cuda.empty_cache()
gc.collect()