In [1]:
# https://www.kaggle.com/code/awsaf49/uwmgi-unet-train-pytorch
!pip install -q wandb

!pip install -q ../input/monai/monai-0.9.1-202207251608-py3-none-any.whl

!pip install -q ../input/lib-pretrainedmodels/pretrainedmodels-0.7.4-py3-none-any.whl
!pip install -q ../input/tim-0412/timm-0.4.12-py3-none-any.whl
!pip install -q ../input/lib-efficientnet-pytorch/efficientnet_pytorch-0.7.1-py3-none-any.whl
!pip install -q ../input/segmentation-models-pytorch/segmentation_models_pytorch-0.3.0-py3-none-any.whl

# !pip install -q scikit-image
from kaggle_datasets import KaggleDatasets

[0m

In [2]:
import torch
from torch import nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, random_split
import albumentations as A

import segmentation_models_pytorch as smp
# import torchsummary

import pandas as pd
import numpy as np
import random, shutil, time, os

import sklearn
import cv2
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import albumentations as A

from glob import glob
from tqdm.notebook import tqdm
from sklearn.model_selection import KFold, GroupKFold
from sklearn.metrics import roc_auc_score
# from skimage import color
from IPython import display as ipd

import scipy
import pdb
import gc

import monai
import tifffile as tiff

from torch.cuda import amp

import warnings
warnings.filterwarnings('ignore')

print('done')

done


In [3]:
CFG = {
    'lr':6e-4,
    'shape':(320, 320),
}

reduce = 4


TRAIN = True

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
def seed_everything(seed=44):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
def clear_cache():
    torch.cuda.empty_cache()
    gc.collect()

In [5]:
BASE_DIR = '../input/hubmap-organ-segmentation'

if TRAIN:
    DATA_DIR = os.path.join(BASE_DIR, 'train_images')
else:
    DATA_DIR = os.path.join(BASE_DIR, 'test_images')
            
df = pd.read_csv(os.path.join(BASE_DIR, 'train.csv'))
df['path'] = df['id'].apply(lambda fname : os.path.join(DATA_DIR, str(fname) + '.tiff'))
organ_to_class = {
    'prostate':0,
    'spleen':1,
    'lung':2,
    'kidney':3,
    'largeintestine':4
}
df['classes'] = df['organ'].apply(lambda organ : organ_to_class[organ])
df.head(5)

Unnamed: 0,id,organ,data_source,img_height,img_width,pixel_size,tissue_thickness,rle,age,sex,path,classes
0,10044,prostate,HPA,3000,3000,0.4,4,1459676 77 1462675 82 1465674 87 1468673 92 14...,37.0,Male,../input/hubmap-organ-segmentation/train_image...,0
1,10274,prostate,HPA,3000,3000,0.4,4,715707 2 718705 8 721703 11 724701 18 727692 3...,76.0,Male,../input/hubmap-organ-segmentation/train_image...,0
2,10392,spleen,HPA,3000,3000,0.4,4,1228631 20 1231629 24 1234624 40 1237623 47 12...,82.0,Male,../input/hubmap-organ-segmentation/train_image...,1
3,10488,lung,HPA,3000,3000,0.4,4,3446519 15 3449517 17 3452514 20 3455510 24 34...,78.0,Male,../input/hubmap-organ-segmentation/train_image...,2
4,10610,spleen,HPA,3000,3000,0.4,4,478925 68 481909 87 484893 105 487863 154 4908...,21.0,Female,../input/hubmap-organ-segmentation/train_image...,1


In [6]:
# https://www.kaggle.com/paulorzp/rle-functions-run-length-encode-decode
def rle_encode(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.T.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)


def rle_decode(mask_rle, wid, hei):
    shape = (wid, hei)
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T


def img_read(path):
    img = tiff.imread(path)
    return img


class Dataset2D(torch.utils.data.Dataset):
    def __init__(self, df_sub, train=True):
        self.train = train
        
        self.paths = np.array(df_sub['path'])
        self.rles = np.array(df_sub['rle'])
        self.classes = np.array(df_sub['classes'])
        self.wid = np.array(df_sub['img_width'])
        self.hei = np.array(df_sub['img_height'])
        
        
    def __len__(self):
        return len(self.paths)
    
    def transform(self, img, mask):
        trans = A.Compose([
#             A.ToFloat(max_value=65535.0), # essential because albu requires 32 bits!!! ONLY THIS can force it work with 16 bits!!

            A.VerticalFlip(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.RandomRotate90(p=0.5),


            A.ShiftScaleRotate(
                scale_limit=0.12,  # 0
                shift_limit=0.02,  # 0.05
                rotate_limit=15,
                border_mode=cv2.BORDER_CONSTANT,
                value=(1,1,1),
                always_apply=True,
                p=1,
            ),

            A.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, always_apply=True),
#             A.OneOf([
#                     A.ElasticTransform(
#                         alpha=1, 
# #                         sigma=25, 
#                         always_apply=True,
#                     ),
#                     A.GridDistortion(
#                         always_apply=True,
#                     ),
#                     A.OpticalDistortion(
#                         distort_limit=0.05, 
#                         shift_limit=0.05, 
#                         always_apply=True,
#                     ),
#                 ], p=0.3
#             ),
        ])
        return trans(image=img, mask=mask)

     
    def data_prep_aug(self, img, mask, classes):
        shape = CFG['shape']
        img = (cv2.resize(img, shape, interpolation=cv2.INTER_AREA) / img.max()).astype('float32')
        mask = (cv2.resize(mask, shape, interpolation=cv2.INTER_AREA)).astype('float32')
        
        if self.train:
            trans = self.transform(img, mask)
            img = trans['image']
            mask = trans['mask']
        
        blank_mask = np.zeros((5, shape[0], shape[1]))
        blank_mask[classes, :, :] = mask
        mask = blank_mask

        img = img.transpose(2,0,1)

        return torch.tensor(img, dtype=torch.float16, device=device), torch.tensor(mask, dtype=torch.float16, device=device)
    
    def __getitem__(self, idx):
        shape = CFG['shape']
        img = img_read(self.paths[idx])
        mask = rle_decode(self.rles[idx], self.wid[idx], self.hei[idx])

        # data preprocessing and augmentation
        img, masks = self.data_prep_aug(img, mask, self.classes[idx])
        
        return img, masks

In [8]:
# # CHECK INPUT CORRECTNESS
# train_ds = Dataset2D(df, train=True)
# train_ds_loader = torch.utils.data.DataLoader(train_ds, batch_size=1, num_workers=0)

# for i, a in enumerate(train_ds_loader):
#     if i < 3:
#         print(a[1].shape)
# #         print((a[1][0]).dtype)
#         plt.figure(figsize=(10,10))
# #         plt.subplot(1,2,1)
#         plt.subplots()
#         plt.imshow(a[0][0].cpu().detach().numpy().astype('float32').transpose(1,2,0))
# #         plt.subplot(1,2,2)
#         plt.subplots()
#         plt.imshow(a[1][0][0].cpu().detach().numpy().astype('float32'))
#     else:
#         break



In [7]:
def imshow(img, return_only=False, pause=False, show_axis=True):
    if isinstance(img, np.ndarray):
        if len(img.shape) == 4:
            img = img[0]
        if img.shape[0] == 3:
            img = img.transpose(2,0,1)
    
    if isinstance(img, torch.Tensor):
        if len(img.shape) == 4:
            img = img.cpu().detach()[0].numpy().transpose(2,0,1)
        elif len(img.shape) == 3:
            if img.shape[0] == 3:
                img = img.cpu().detach().numpy().transpose(2,0,1)
        elif len(img.shape) == 2:
            img = img.cpu().detach().numpy()
    
    if return_only:
        return img
    else:
#         plt.figure(figsize=(5,5))
        plt.subplots()
        plt.imshow(img)
        if pause:
            plt.pause(1)
        if not show_axis:
            plt.axis('off')

In [10]:
# idx = 62

# def read(idx):
#     img = cv2.imread(df['path'][idx], cv2.IMREAD_ANYDEPTH)
#     shape = CFG['shape']
#     img = (cv2.resize(img, shape, interpolation=cv2.INTER_AREA) / img.max()).astype('float32')
    
#     blank_mask = np.zeros((df.img_width[idx],  df.img_height[idx], 3))
#     blank_mask[:, :, 0] = rle_decode(df.segmentation[idx][0], df.img_width[idx], df.img_height[idx])
#     blank_mask[:, :, 1] = rle_decode(df.segmentation[idx][1], df.img_width[idx], df.img_height[idx])
#     blank_mask[:, :, 2] = rle_decode(df.segmentation[idx][2], df.img_width[idx], df.img_height[idx])
#     mask = blank_mask
#     mask = cv2.resize(mask, shape, interpolation=cv2.INTER_AREA).astype('float32').transpose(2,0,1).reshape(1, 3, shape[0], shape[1])
    
#     blank_img = np.zeros((shape[0], shape[1], 3))
#     blank_img[:, :, 0] = img
#     blank_img[:, :, 1] = img
#     blank_img[:, :, 2] = img
#     img = blank_img.transpose(2,0,1).reshape((1, 3, shape[0], shape[1]))
    
#     return img, mask
    

# def predict(idx, model=model, to_numpy=True, log=True):
#     img, mask = read(idx)
    
#     model.eval()

#     mask = torch.tensor(mask, device=device,dtype=torch.float)
#     img = torch.tensor(img, device=device,dtype=torch.float)
#     pred = torch.sigmoid(model(img))

#     pred[pred < 0.5] = 0
#     pred[pred > 0.5] = 1
    
#     if log:
#         dl = monai.losses.DiceLoss()(mask, pred)
#         print((1-dl.cpu().detach().numpy()))


#     if to_numpy:
#         pred = pred.cpu().detach().numpy()
#         img = img.cpu().detach().numpy()
#         mask = mask.cpu().detach().numpy()

#     return img, mask, pred
# img, mask, pred = predict(idx)

In [8]:
USE_WANDB = True
if USE_WANDB:
    import wandb
    from wandb.keras import WandbCallback
    secret_value = '5526656efba1b3f066b08df55e01dd2c5101c5ba'
    wandb.login(key=secret_value)
    
    # wandb.init(project='unet_tract_tumor')

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [9]:
# UNet EFFB7
activation = None
model = smp.Unet(
    encoder_name='efficientnet-b7',
    decoder_channels= (512, 256, 128, 64, 32),
    decoder_use_batchnorm=True,
    activation=activation,
    in_channels=3,
    classes=5,
)
model = model.to(device)
print('finished')

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b7-dcc49843.pth


  0%|          | 0.00/254M [00:00<?, ?B/s]

finished


In [15]:
kf = KFold(n_splits=10)
for train_ind, val_ind in kf.split(df, df['classes'], groups=None):
    print(len(train_ind))
    break

315


In [16]:
clear_cache()

In [None]:
kf = KFold(n_splits=10)

# model = monai.networks.nets.UNet(
#     spatial_dims=2,
#     in_channels=3,
#     out_channels=3,

#     channels=(32, 64, 128, 256, 512),
# #     strides=(2, 2, 2, 2),
#     num_res_units=2,
# )

epochs = 30
train_bs = 8

num_epoch_2_skip = 0

n_accumulate = 2

optimizer = torch.optim.AdamW(model.parameters(), lr=CFG['lr'])
criterion = lambda y_pred, y_true : monai.losses.FocalLoss()(y_pred, y_true) + torch.log(torch.cosh(monai.losses.DiceLoss(sigmoid=(activation == None))(y_pred, y_true)))
# criterion = lambda y_pred, y_true : 1 - dice_coef(y_pred, y_true)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=int(epochs*39/n_accumulate), eta_min=5e-6)

# train_load * epoches * 
# 
for fold, (train_ind, val_ind) in tqdm(enumerate(kf.split(df, df['classes'], groups=None)), desc='Train '):

    EXP_NAME = f"PostRUNS_fold{fold}"
    
    if USE_WANDB:
        wandb.init(name = EXP_NAME, project="HACKING_v1", entity="kagglers", 
                    config = CFG, save_code = False, group = "UNet EffB7 LOGCOSH",notes='_6e-4_320_39, normal DA, normal scheduler, but 6e-4'
                  )
        wandb.watch(model, log=None)


    best_val_dice = -1
    for epoch in range(1, epochs+1):

        clear_cache()
        clear_cache()
        clear_cache()
        
        seed_everything(44)
        val_dice = 0
        if epoch <= num_epoch_2_skip:
            continue
            
        print(f'Epoch {epoch}/{epochs}', end='')

        model.train()
        scaler = torch.cuda.amp.GradScaler(init_scale=65536.0)

        dataset_size = 0
        running_loss = 0.0
        
        train_ds = Dataset2D(df.iloc[train_ind], train=True)
        train_ds_loader = torch.utils.data.DataLoader(train_ds, batch_size=train_bs)
        
        val_ds = Dataset2D(df.iloc[val_ind], train=False)
        val_ds_loader = torch.utils.data.DataLoader(val_ds, batch_size=2)

        
        # ========================================
        # TRAINING
        # ========================================

        pbar = tqdm(enumerate(train_ds_loader), total=len(train_ds_loader), desc='Train ')
        
        for step, (images, masks) in pbar:
            batch_size = images.size(0)

            with amp.autocast(enabled=True):
                y_pred = model(images)
                loss   = criterion(y_pred, masks)

            (scaler.scale(loss)/n_accumulate).backward()

            if (step + 1) % n_accumulate == 0:
#                 torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
                scaler.step(optimizer)
                scaler.update()

                optimizer.zero_grad()

                if scheduler is not None:
                    scheduler.step()

            running_loss += (loss.item() * batch_size)
            current_loss = (loss.item())
            dataset_size += batch_size

            epoch_loss = running_loss / dataset_size
            current_lr = optimizer.param_groups[0]['lr']
            

            if np.isnan(epoch_loss):
                print('NAN LOSS')
                break
            
            if USE_WANDB:
                wandb.log({
                    'train_current_loss':current_loss,
                    'lr':current_lr,
                })
                
            pbar.set_postfix(
                train_loss=f'{epoch_loss:0.4f}',
                current_loss=f'{current_loss:0.5f}',
                lr=f'{current_lr:0.6f}',
            )
        if USE_WANDB:
            wandb.log({
                'train_epoch_loss':epoch_loss,
            })
        torch.save(model.state_dict(), 'running_pt.pt')
        
        
        clear_cache()
        clear_cache()
        clear_cache()
        
        
        # ========================================
        # Validation
        # ========================================
        
        model.eval()
        
        dataset_size = 0
        running_loss = 0.0
        
        pbar = tqdm(enumerate(val_ds_loader), total=len(val_ds_loader), desc='Valid ')

        for step, (images, masks) in pbar:  
            images = images.to(dtype=torch.float, device=device)
            masks = masks.to(dtype=torch.float, device=device)
            
            batch_size = images.size(0)

            y_pred  = model(images)
            loss    = criterion(y_pred, masks)

            running_loss += (loss.item() * batch_size)
            dataset_size += batch_size

            epoch_loss = running_loss / dataset_size

            y_pred = nn.Sigmoid()(y_pred)
            
            y_pred[y_pred < 0.5] = 0
            y_pred[y_pred > 0.5] = 1
            y_pred[y_pred == 0.5] = 1
            
            dl = monai.losses.DiceLoss()(masks, y_pred)
            val_dice += (1-dl.cpu().detach().numpy())
#             val_dice = dice_coef(masks, y_pred).cpu().detach().numpy()
            
            current_lr = optimizer.param_groups[0]['lr']
            if USE_WANDB:
                wandb.log({
                    'running_valid_loss':epoch_loss,
                })
            pbar.set_postfix(
                valid_loss=f'{epoch_loss:0.4f}',
                dice_acc=f'{val_dice}'
            )
            
        if USE_WANDB:
            wandb.log({
                'val_dice':val_dice,
                'valid_loss':epoch_loss,
            })
                
        if val_dice > best_val_dice:
            best_val_dice = val_dice
            print('saving...')
            torch.save(model.state_dict(), f'{EXP_NAME}_{epoch}.pt')
                
    if USE_WANDB:
        wandb.finish()
    break

Train : 0it [00:00, ?it/s]

[34m[1mwandb[0m: Currently logged in as: [33mtonycfakename[0m ([33mkagglers[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1/30

Train :   0%|          | 0/40 [00:00<?, ?it/s]

Valid :   0%|          | 0/18 [00:00<?, ?it/s]

saving...
Epoch 2/30

Train :   0%|          | 0/40 [00:00<?, ?it/s]

Valid :   0%|          | 0/18 [00:00<?, ?it/s]

saving...
Epoch 3/30

Train :   0%|          | 0/40 [00:00<?, ?it/s]

Valid :   0%|          | 0/18 [00:00<?, ?it/s]

Epoch 4/30

Train :   0%|          | 0/40 [00:00<?, ?it/s]

Valid :   0%|          | 0/18 [00:00<?, ?it/s]

Epoch 5/30

Train :   0%|          | 0/40 [00:00<?, ?it/s]

Valid :   0%|          | 0/18 [00:00<?, ?it/s]

Epoch 6/30

Train :   0%|          | 0/40 [00:00<?, ?it/s]

Valid :   0%|          | 0/18 [00:00<?, ?it/s]

Epoch 7/30

Train :   0%|          | 0/40 [00:00<?, ?it/s]

Valid :   0%|          | 0/18 [00:00<?, ?it/s]

Epoch 8/30

Train :   0%|          | 0/40 [00:00<?, ?it/s]

Valid :   0%|          | 0/18 [00:00<?, ?it/s]

Epoch 9/30

Train :   0%|          | 0/40 [00:00<?, ?it/s]

Valid :   0%|          | 0/18 [00:00<?, ?it/s]

saving...
Epoch 10/30

Train :   0%|          | 0/40 [00:00<?, ?it/s]

Valid :   0%|          | 0/18 [00:00<?, ?it/s]

Epoch 11/30

Train :   0%|          | 0/40 [00:00<?, ?it/s]

Valid :   0%|          | 0/18 [00:00<?, ?it/s]

saving...
Epoch 12/30

Train :   0%|          | 0/40 [00:00<?, ?it/s]

Valid :   0%|          | 0/18 [00:00<?, ?it/s]

Epoch 13/30

Train :   0%|          | 0/40 [00:00<?, ?it/s]

Valid :   0%|          | 0/18 [00:00<?, ?it/s]

Epoch 14/30

Train :   0%|          | 0/40 [00:00<?, ?it/s]

Valid :   0%|          | 0/18 [00:00<?, ?it/s]

In [14]:
clear_cache()

In [18]:
l = []

val_ds = Dataset2D(df, train=False)
val_ds_loader = torch.utils.data.DataLoader(val_ds, batch_size=1)

# model.load_state_dict(torch.load('../input/hacking-uneteff7-baseline/PostRUNS_fold0_10.pt'))
model.eval()
clear_cache()
for images, labels in val_ds_loader:
    images = images.to(dtype=torch.float, device=device)

    with torch.no_grad():
        y_pred  = model(images)
    y_pred = nn.Sigmoid()(y_pred)
            
    y_pred[y_pred < 0.5] = 0
    y_pred[y_pred > 0.5] = 1
    
    l.append(y_pred)
    break
    
    


In [1]:
for i in range(len(l)):
    e = l[i][0].detach().cpu().numpy()

    c = e[np.argmax(np.sum(e, axis=(1,2)))]
    
    c[c < 0.5] = 0
    c[c > 0.5] = 1
    
    plt.imshow(c)
    c = cv2.resize(c, (df.iloc[i]['img_width'], df.iloc[i]['img_height']))
        
    if np.sum(c) < 100:
        rle = rle_encode(np.ones((df.iloc[i]['img_width'], df.iloc[i]['img_height'])))
    else:
        rle = rle_encode(c)

NameError: name 'l' is not defined