<a href="https://colab.research.google.com/github/TomohiroYazaki/Hacking_the_Human_Body/blob/main/Hacking_the_Human_Body_0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
DRIVE_PATH = '/content/drive/MyDrive/Kaggle/202207_Hacking_the_Human_Body/Data/hubmap-organ-segmentation.zip'
CURRENT_PATH = '/content/hubmap-organ-segmentation.zip'

In [None]:
!rm -f -r /content/OUTPUT
!rm -f -r /content/OUTPUT/MODEL

In [None]:
![ -r /content/DATA ] || mkdir DATA
!mkdir OUTPUT
!mkdir OUTPUT/MODEL

In [None]:
%%time
![ -f /content/DATA/train.csv ] || cp $DRIVE_PATH $CURRENT_PATH
![ -f $CURRENT_PATH ] && unzip -qq $CURRENT_PATH -d /content/DATA
![ -f $CURRENT_PATH ] && rm $CURRENT_PATH

In [None]:
!rm -f -r /content/sample_data

In [None]:
!pip install timm

In [None]:
!pip install --upgrade albumentations

In [None]:
!pip install colorama

In [None]:
!pip install -q segmentation_models_pytorch

In [None]:
!pip install -qU wandb

In [None]:
!pip install --upgrade opencv-python

In [None]:
!pip install staintools
!pip install spams

**---------- Import ----------**

In [None]:
import numpy as np
import pandas as pd
import random
import os
import gc
from tqdm import tqdm
import copy
from collections import defaultdict
import pickle

import time
from datetime import datetime
import pytz

# sklearn
from sklearn.model_selection import StratifiedKFold, KFold, StratifiedGroupKFold

# file
import json
import tifffile as tiff

# visualization
import matplotlib.pyplot as plt
from IPython import display as ipd

# PyTorch 
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

# Albumentations for augmentations
import cv2
import albumentations as A

# For colored terminal text
from colorama import Fore, Back, Style
b_ = Fore.BLUE
r_ = Fore.RED
sr_ = Style.RESET_ALL
g_ = Fore.GREEN
br_ = Back.LIGHTRED_EX

# staintools
import staintools

In [None]:
import wandb
wandb_key = open('/content/drive/MyDrive/Kaggle/wandb_key', 'r')
wandb.login(key=wandb_key.read())

In [None]:
run_start_time = datetime.now(pytz.timezone('Asia/Tokyo')).strftime('%Y%m%d-%H:%M:%S')
print(run_start_time)

**---------- Utilities ----------**

In [None]:
train_bs = 16
epochs = 100#20

CFG = {
    'seed'            : 21,
    'debug'           : False, # set debug:False for Full Training
    'in_channels'     : 3,
    'backbone'        : 'timm-efficientnet-b0',#'timm-efficientnet-b6'
    'encoder_weights' : 'imagenet',# use `imagenet` pre-trained weights for encoder initialization  
    'train_bs'        : train_bs,
    'valid_bs'        : train_bs*2,
    'img_size'        : [224, 224],
    'epochs'          : epochs,
    'lr'              : 2e-3,
    'scheduler'       : 'CosineAnnealingLR',#'CosineLRScheduler'
    'min_lr'          : 1e-6,
    'T_max'           : int(30000/train_bs*epochs)+50,
    'T_0'             : 25,
    'warmup_epochs'   : 0,
    'wd'              : 1e-6,
    'n_accumulate'    : max(1, 32//train_bs),
    'n_fold'          : 5,#5
    'learning_fold'   : [0],
    'num_classes'     : 1,
    'device'          : torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
    'model'           : 'Unet',
    'organs'          : ['Prostate','spleen','lung','kidney','largeintestine'], #'Prostate','Spleen','Lung','Kidney','largeintestine']
    'stain'           : True,
}

CFG['tags'] = ['baseline']

if CFG['debug']:
    CFG['epochs'] = 3
    CFG['learning_fold'] = [0]

In [None]:
def set_seed(seed = 21):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    print('> SEEDING DONE')
    
set_seed(CFG['seed'])

In [None]:
# https://www.kaggle.com/paulorzp/rle-functions-run-length-encode-decode
def mask2rle(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels= img.T.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)
 
def rle2mask(mask_rle, shape=(1600,256)):
    '''
    mask_rle: run-length as string formated (start length)
    shape: (width,height) of array to return 
    Returns numpy array, 1 - mask, 0 - background

    '''
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T

In [None]:
def load_img(id):
    img = tiff.imread('/content/DATA/train_images/' + str(id) +'.tiff')
    return img

def load_msk(df, id):
    row = df[df['id']==id]
    msk = rle2mask(row['rle'][0], (row['img_height'][0], row['img_width'][0]))
    return msk

In [None]:
test_path = '/content/DATA/test_images/10078.tiff'
target = staintools.read_image(test_path)

# Standardize brightness (optional, can improve the tissue mask calculation)
target = staintools.LuminosityStandardizer.standardize(target)

# Stain normalize
normalizer = staintools.StainNormalizer(method='vahadane')
normalizer.fit(target)

def load_stain_img(id):
    #img = tiff.imread('/content/DATA/train_images/' + str(id) +'.tiff')
    to_transform = staintools.read_image('/content/DATA/train_images/' + str(id) +'.tiff')
    to_transform = staintools.LuminosityStandardizer.standardize(to_transform)
    transformed1 = normalizer.transform(to_transform)
    return transformed1

**---------- Data Processing ----------**

In [None]:
df = pd.read_csv('/content/DATA/train.csv')
df.head()

In [None]:
df = df[df['organ'].isin(CFG['organs'])].reset_index(drop=True)
df.head()

In [None]:
#df.info()

In [None]:
df_test = pd.read_csv('/content/DATA/test.csv')
df_test.head()

In [None]:
train_path = '/content/DATA/train_images/' + str(df['id'].sample().iloc[-1]) + '.tiff'
test_path = '/content/DATA/test_images/10078.tiff'
train_img = tiff.imread(train_path)
test_img = tiff.imread(test_path)

to_transform = staintools.read_image(train_path)
target = staintools.read_image(test_path)

In [None]:
%%time
# Standardize brightness (optional, can improve the tissue mask calculation)
target = staintools.LuminosityStandardizer.standardize(target)

# Stain normalize
normalizer = staintools.StainNormalizer(method='vahadane')
normalizer.fit(target)

In [None]:
%%time
to_transform = staintools.LuminosityStandardizer.standardize(to_transform)
transformed1 = normalizer.transform(to_transform)

In [None]:
plt.figure(figsize=(20, 20))
plt.subplot(131)
plt.title("train")
plt.imshow(train_img)
plt.subplot(132)
plt.title('test')
plt.imshow(test_img)
plt.subplot(133)
plt.title('stain')
plt.imshow(transformed1)

In [None]:
skf = StratifiedKFold(n_splits=CFG['n_fold'], shuffle=True, random_state=CFG['seed'])
for fold, (train_idx, val_idx) in enumerate(skf.split(df, df['organ'])):
    df.loc[val_idx, 'fold'] = fold

In [None]:
data_transforms = {
    'train': A.Compose([
        A.Resize(*CFG['img_size'], interpolation=cv2.INTER_NEAREST),
        #A.HorizontalFlip(p=0.5),
        #A.VerticalFlip(p=0.5),
        #A.RandomRotate90(p=0.5),

        #A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.05, rotate_limit=10, p=0.5),
        #A.OpticalDistortion(distort_limit=0.05, shift_limit=0.05, p=0.5),
        #A.GridDistortion(num_steps=5, distort_limit=0.03, p=0.5),
        #A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.5),
        #A.OneOf([
        #    A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0),
        #    A.OpticalDistortion(distort_limit=0.05, shift_limit=0.05, p=1.0),
        #    A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0)
        #], p=0.25),

        #A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=1),#20,30,20 ランダムに色相、彩度、輝度を変える
        #A.RandomGamma(gamma_limit=(60, 140), p=1),#80, 120 ランダムにガンマ変換をかける
        #A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=1),#0.2, 0.2, ランダムに明るさとコントラストを変える

        #A.OpticalDistortion(distort_limit=0.5, shift_limit=0.5, p=1),#0.05 0.05 光学的な歪みを再現する
        #A.GaussNoise(var_limit=(10.0, 200.0), p=1.0),#10.0 50.0
        #A.MotionBlur(blur_limit=(3, 7), p=1.0),#3 7

        #A.CoarseDropout(max_holes=8, max_height=CFG['img_size'][0]//20, max_width=CFG['img_size'][1]//20,
        #                 min_holes=5, fill_value=0, mask_fill_value=0, p=0.5),
        ], p=1.0),
    
    'valid': A.Compose([
        A.Resize(*CFG['img_size'], interpolation=cv2.INTER_NEAREST),
        ], p=1.0)
}

data_transforms_dict = {}
for transform in data_transforms['train']:
    s = str(transform)
    data_transforms_dict['A_'+s[:s.find('(')]] = s[s.find('(')+1:-1]

In [None]:
class BuildDataset(torch.utils.data.Dataset):
    def __init__(self, df, training_mode=True, transforms=None):
        self.df             = df
        self.training_mode  = training_mode
        self.transforms     = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        img_id  = self.df.iloc[index]['id']
        if CFG['stain']:
            img = load_stain_img(img_id)
        else:
            img = load_img(img_id)
        
        if self.training_mode:
            #msk = load_msk(self.df, img_id)
            msk = rle2mask(self.df.iloc[index]['rle'], (self.df.iloc[index]['img_height'], self.df.iloc[index]['img_width']))
            if self.transforms:
                data = self.transforms(image=img, mask=msk)
                img  = data['image']
                msk  = data['mask']
            img = np.transpose(img, (2, 0, 1))
            #msk = np.transpose(msk, (2, 0, 1))
            msk = np.expand_dims(msk, 0)
            return torch.tensor(img), torch.tensor(msk)
        else:
            if self.transforms:
                data = self.transforms(image=img)
                img  = data['image']
            img = np.transpose(img, (2, 0, 1))
            return torch.tensor(img)

In [None]:
def prepare_loaders(fold, debug=False):
    train_df = df.query('fold!=@fold').reset_index(drop=True)
    valid_df = df.query('fold==@fold').reset_index(drop=True)
    if debug:
        train_df = train_df.head(8)
        valid_df = valid_df.head(8)
    train_dataset = BuildDataset(train_df, transforms=data_transforms['train'])
    valid_dataset = BuildDataset(valid_df, transforms=data_transforms['valid'])

    train_loader = DataLoader(train_dataset, batch_size=CFG['train_bs'] if not debug else 8, 
                              num_workers=os.cpu_count(), shuffle=True, pin_memory=True, drop_last=False)
    valid_loader = DataLoader(valid_dataset, batch_size=CFG['valid_bs'] if not debug else 8, 
                              num_workers=os.cpu_count(), shuffle=False, pin_memory=True)
    
    return train_loader, valid_loader

In [None]:
train_loader, valid_loader = prepare_loaders(fold=0, debug=True)

In [None]:
imgs, msks = next(iter(train_loader))
imgs.size(), msks.size()

**---------- Model ----------**

In [None]:
import segmentation_models_pytorch as smp

def build_model():
    if CFG['backbone']=='UnetPlusPlus':
        model = smp.UnetPlusPlus(
            encoder_name=CFG['backbone'],      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
            encoder_weights=CFG['encoder_weights'],
            in_channels=CFG['in_channels'],
            classes=CFG['num_classes'],        # model output channels (number of classes in your dataset)
            activation=None,
        )
    else:
        model = smp.Unet(
            encoder_name=CFG['backbone'],      # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
            encoder_weights=CFG['encoder_weights'],
            in_channels=CFG['in_channels'],
            classes=CFG['num_classes'],        # model output channels (number of classes in your dataset)
            activation=None,
        )
    
    model.to(CFG['device'])
    return model

def load_model(path):
    model = build_model()
    model.load_state_dict(torch.load(path))
    model.eval()
    return model

In [None]:
loss_def ={
    'JaccardLoss'   : None,
    'DiceLoss'      : 1.0,
    'BCELoss'       : None,
    'LovaszLoss'    : None,
    'TverskyLoss'   : None,
    'FocalLoss'     : None,
}

CFG['loss'] = loss_def

JaccardLoss = smp.losses.JaccardLoss(mode='multilabel')
DiceLoss    = smp.losses.DiceLoss(mode='multilabel')
BCELoss     = smp.losses.SoftBCEWithLogitsLoss()
LovaszLoss  = smp.losses.LovaszLoss(mode='multilabel', per_image=False)
TverskyLoss = smp.losses.TverskyLoss(mode='multilabel', log_loss=False)
FocalLoss = smp.losses.FocalLoss(mode="multilabel")

def dice_coef(y_true, y_pred, thr=0.5, dim=(2,3), epsilon=0.001):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred>thr).to(torch.float32)
    inter = (y_true*y_pred).sum(dim=dim)
    den = y_true.sum(dim=dim) + y_pred.sum(dim=dim)
    dice = ((2*inter+epsilon)/(den+epsilon)).mean(dim=(1,0))
    return dice

def iou_coef(y_true, y_pred, thr=0.5, dim=(2,3), epsilon=0.001):
    y_true = y_true.to(torch.float32)
    y_pred = (y_pred>thr).to(torch.float32)
    inter = (y_true*y_pred).sum(dim=dim)
    union = (y_true + y_pred - y_true*y_pred).sum(dim=dim)
    iou = ((inter+epsilon)/(union+epsilon)).mean(dim=(1,0))
    return iou

def criterion(y_pred, y_true):
    loss = 0
    for k, v in loss_def.items():
        if v != None:
            if k == 'JaccardLoss':
                loss += v*JaccardLoss(y_pred, y_true)
            elif k == 'DiceLoss':
                loss += v*DiceLoss(y_pred, y_true)
            elif k == 'BCELoss':
                loss += v*BCELoss(y_pred, y_true)
            elif k == 'LovaszLoss':
                loss += v*LovaszLoss(y_pred, y_true)
            elif k == 'TverskyLoss':
                loss += v*TverskyLoss(y_pred, y_true)
            elif k == 'FocalLoss':
                loss += v*FocalLoss(y_pred, y_true)

    return loss

#def criterion(y_pred, y_true):
#    return 0.5*BCELoss(y_pred, y_true) + 0.5*TverskyLoss(y_pred, y_true)

**---------- Learning ----------**

In [None]:
gc.collect()

In [None]:
from timm.scheduler import CosineLRScheduler
def fetch_scheduler(optimizer):
    if CFG["scheduler"] == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max=CFG["T_max"], 
                                                   eta_min=CFG["min_lr"])
    elif CFG["scheduler"] == 'CosineAnnealingWarmRestarts':
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=CFG["T_0"], 
                                                             eta_min=CFG["min_lr"])
    elif CFG["scheduler"] == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   mode='min',
                                                   factor=0.1,
                                                   patience=7,
                                                   threshold=0.0001,
                                                   min_lr=CFG["min_lr"],)
    elif CFG["scheduler"] == 'ExponentialLR':
        scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.85)
    elif CFG["scheduler"] == 'CosineLRScheduler':
        scheduler = CosineLRScheduler(optimizer, t_initial=CFG["epochs"], lr_min=1e-4, 
                                      warmup_t=round(CFG["epochs"]*0.2), warmup_lr_init=5e-5, warmup_prefix=True)
    elif CFG["scheduler"] == None:
        return None
        
    return scheduler

In [None]:
model = torch.nn.Linear(1, 1) ## 適当なモデル
optimizer = torch.optim.Adam(model.parameters())

scheduler = fetch_scheduler(optimizer)

lrs = []
for i in range(CFG["epochs"]):
    #lrs.append(scheduler.get_epoch_values(i))
    lrs.append(optimizer.param_groups[0]['lr'])
    if CFG["scheduler"] == 'CosineLRScheduler':
        scheduler.step(i+1)
    else:
        scheduler.step()

plt.plot(lrs)
plt.show()

In [None]:
def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch):
    model.train()
    scaler = amp.GradScaler()
    
    dataset_size = 0
    running_loss = 0.0
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Train ')
    for step, (images, masks) in pbar:         
        images = images.to(device, dtype=torch.float)
        masks  = masks.to(device, dtype=torch.float)
        
        batch_size = images.size(0)
        
        with amp.autocast(enabled=True):
            y_pred = model(images)
            loss   = criterion(y_pred, masks)
            loss   = loss / CFG["n_accumulate"]
            
        scaler.scale(loss).backward()
    
        if (step + 1) % CFG["n_accumulate"] == 0:
            scaler.step(optimizer)
            scaler.update()

            # zero the parameter gradients
            optimizer.zero_grad()

            if scheduler is not None:
                if CFG["scheduler"] == 'CosineLRScheduler':
                    pass
                else:
                    scheduler.step()
                
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix(train_loss=f'{epoch_loss:0.4f}',
                        lr=f'{current_lr:0.5f}',
                        gpu_mem=f'{mem:0.2f} GB')
    torch.cuda.empty_cache()
    gc.collect()
    
    return epoch_loss

In [None]:
@torch.no_grad()
def valid_one_epoch(model, dataloader, device, epoch):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    
    val_scores = []
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc='Valid ')
    for step, (images, masks) in pbar:        
        images  = images.to(device, dtype=torch.float)
        masks   = masks.to(device, dtype=torch.float)
        
        batch_size = images.size(0)
        
        y_pred  = model(images)
        loss    = criterion(y_pred, masks)
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        
        y_pred = nn.Sigmoid()(y_pred)
        val_dice = dice_coef(masks, y_pred).cpu().detach().numpy()
        val_jaccard = iou_coef(masks, y_pred).cpu().detach().numpy()
        val_scores.append([val_dice, val_jaccard])
        
        mem = torch.cuda.memory_reserved() / 1E9 if torch.cuda.is_available() else 0
        current_lr = optimizer.param_groups[0]['lr']
        pbar.set_postfix(valid_loss=f'{epoch_loss:0.4f}',
                        lr=f'{current_lr:0.5f}',
                        gpu_memory=f'{mem:0.2f} GB')
    val_scores  = np.mean(val_scores, axis=0)
    torch.cuda.empty_cache()
    gc.collect()
    
    return epoch_loss, val_scores

In [None]:
def run_training(model, optimizer, scheduler, device, num_epochs):
    # To automatically log gradients
    wandb.watch(model, log_freq=100)
    
    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_dice      = -np.inf
    best_epoch     = -1
    history = defaultdict(list)
    
    for epoch in range(1, num_epochs + 1): 
        gc.collect()
        print(f'Epoch {epoch}/{num_epochs}', end='\n')
        #print()
        train_loss = train_one_epoch(model, optimizer, scheduler, 
                                           dataloader=train_loader, 
                                           device=CFG["device"], epoch=epoch)
        
        val_loss, val_scores = valid_one_epoch(model, valid_loader, 
                                                 device=CFG["device"], 
                                                 epoch=epoch)
        val_dice, val_jaccard = val_scores
    
        history['Train Loss'].append(train_loss)
        history['Valid Loss'].append(val_loss)
        history['Valid Dice'].append(val_dice)
        history['Valid Jaccard'].append(val_jaccard)
        
        # Log the metrics
        wandb.log({"Train Loss": train_loss, 
                   "Valid Loss": val_loss,
                   "Valid Dice": val_dice,
                   "Valid Jaccard": val_jaccard,
                   "LR":optimizer.param_groups[0]['lr']})
                   #"LR":scheduler.get_last_lr()[0]})
        
        print(f"{b_}Valid Dice: {val_dice:0.4f} | Valid Jaccard: {val_jaccard:0.4f}{sr_}")
        
        # deep copy the model
        if val_dice >= best_dice:
            print(f"{r_}Valid Score Improved ({best_dice:0.4f} ---> {val_dice:0.4f})")
            best_dice    = val_dice
            best_jaccard = val_jaccard
            best_epoch   = epoch
            run.summary["Best Dice"]    = best_dice
            run.summary["Best Jaccard"] = best_jaccard
            run.summary["Best Epoch"]   = best_epoch
            best_model_wts = copy.deepcopy(model.state_dict())
            #PATH = f"best_epoch-{fold:02d}.bin"
            PATH = f"/content/OUTPUT/MODEL/best_epoch-{fold:02d}.bin"
            torch.save(model.state_dict(), PATH)
            # Save a model file from the current directory
            wandb.save(PATH, base_path="/content/OUTPUT/MODEL")
            print(f"Model Saved{sr_}")

        if CFG["scheduler"] == 'CosineLRScheduler':
            scheduler.step(epoch+1)
            
        last_model_wts = copy.deepcopy(model.state_dict())
        #PATH = f"last_epoch-{fold:02d}.bin"
        PATH = f"/content/OUTPUT/MODEL/last_epoch-{fold:02d}.bin"
        torch.save(model.state_dict(), PATH)
            
        print()

    with open("/content/OUTPUT/CFG.pickle", 'wb') as f:
        pickle.dump(CFG, f)
    wandb.save("/content/OUTPUT/CFG.pickle")
    
    end = time.time()
    time_elapsed = end - start
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))
    print("Best Score: {:.4f}".format(best_jaccard))
    
    # load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, history

In [None]:
model = build_model()
optimizer = optim.Adam(model.parameters(), lr=CFG["lr"], weight_decay=CFG["wd"])
scheduler = fetch_scheduler(optimizer)

In [None]:
if torch.cuda.is_available():
    print("cuda: {}\n".format(torch.cuda.get_device_name()))
    CFG['device'] = torch.cuda.get_device_name()

#config = {k:v for k, v in dict(vars(CFG)).items() if '__' not in k}
#config.update(data_transforms_dict)
CFG.update(data_transforms_dict)

#n_fold = 1 if CFG["debug"] else CFG["n_fold"]
#n_fold = 2 #CFG["n_fold"]

for fold in range(CFG["n_fold"]):
    print(f'#'*15)
    print(f'### Fold: {fold}')
    print(f'#'*15)
    if fold in CFG['learning_fold']:
        run = wandb.init(project='Hacking_the_Human_Body', 
                        group=run_start_time,
                        name=f"fold-{fold}",
                        config=CFG,
                        #anonymous=anonymous,
                        tags=CFG["tags"],
                        )
        train_loader, valid_loader = prepare_loaders(fold=fold, debug=CFG["debug"])
        model     = build_model()
        optimizer = optim.Adam(model.parameters(), lr=CFG["lr"], weight_decay=CFG["wd"])
        scheduler = fetch_scheduler(optimizer)
        model, history = run_training(model, optimizer, scheduler,
                                    device=CFG["device"],
                                    num_epochs=CFG["epochs"])
        #run.finish()
        display(ipd.IFrame(run.url, width=1000, height=720))

In [None]:
def pred_each_situation(df):
    results_for_each_data = []
    for fold in range(CFG["n_fold"]):
        if fold in CFG['learning_fold']:
            test_df = df.query("fold==@fold").reset_index(drop=True)
            test_dataset = BuildDataset(test_df, transforms=data_transforms['valid'])
            test_loader = DataLoader(test_dataset, batch_size=1, num_workers=os.cpu_count(), shuffle=False, pin_memory=True)

            model = load_model(f"/content/OUTPUT/MODEL/best_epoch-{fold:02d}.bin")
            pbar = tqdm(enumerate(test_loader), total=len(test_loader), desc='Test ')
            for step, (images, masks) in pbar:
                image  = images.to(CFG["device"], dtype=torch.float)
                mask   = masks.to(CFG["device"], dtype=torch.float)
                
                y_pred  = model(image)
                #y_pred = nn.Sigmoid()(y_pred)
                y_pred = (nn.Sigmoid()(y_pred)>0.5).double()
                val_dice = dice_coef(mask, y_pred).cpu().detach().numpy()
                val_jaccard = iou_coef(mask, y_pred).cpu().detach().numpy()
                results_for_each_data.append([val_dice, val_jaccard, image.cpu().detach().numpy(), mask.cpu().detach().numpy(), y_pred.cpu().detach().numpy()])
                gc.collect()

    for i in range(len(results_for_each_data)-1):
        for j in range(i+1, len(results_for_each_data)):
            if results_for_each_data[i][0] < results_for_each_data[j][0]:
                tmp = copy.deepcopy(results_for_each_data[i])
                results_for_each_data[i] = copy.deepcopy(results_for_each_data[j])
                results_for_each_data[j] = copy.deepcopy(tmp)
        
    return results_for_each_data

In [None]:
def plot_imgs_msks_preds(results_for_each_data, filename):
    fig = plt.figure(figsize=(5*5, 15*3))

    if len(results_for_each_data)<5:
        image_num = len(results_for_each_data)
    else:
        image_num = 5

    for idx in range(image_num):
        ax1 = fig.add_subplot(5, 3, (idx*3)+1)
        #img = np.transpose(results_for_each_data[idx][2][0], (1, 2, 0))*255.0
        img = np.transpose(results_for_each_data[idx][2][0], (1, 2, 0))
        img = img.astype('uint8')
        ax1.set_title('image')
        #ax1.imshow(img[...,0], cmap='bone')
        ax1.imshow(img)

        ax2 = fig.add_subplot(5, 3, (idx*3)+2)
        #msk = np.transpose(results_for_each_data[idx][3][0], (1, 2, 0))*255.0
        msk = np.transpose(results_for_each_data[idx][3][0], (1, 2, 0))
        msk = msk.astype('uint8')
        ax2.set_title('mask')
        #ax2.legend(handles,labels)
        #ax2.imshow(msk, cmap='bone')
        ax2.imshow(img)
        ax2.imshow(msk[...,0], cmap='coolwarm', alpha=0.5)

        ax3 = fig.add_subplot(5, 3, (idx*3)+3)
        #pred = np.transpose(results_for_each_data[idx][4][0], (1, 2, 0))*255.0
        pred = np.transpose(results_for_each_data[idx][4][0], (1, 2, 0))
        #pred = results_for_each_data[idx][4][0]
        pred = pred.astype('uint8')
        ax3.set_title('pred='+str(results_for_each_data[idx][0]))
        #ax3.legend(handles,labels)
        #ax3.imshow(pred, cmap='bone')
        ax3.imshow(img)
        ax3.imshow(pred[...,0], cmap='coolwarm', alpha=0.5)

    fig.tight_layout()
    fig.show()
    file_pass = "/content/OUTPUT/"+filename
    fig.savefig(file_pass)
    wandb.save(file_pass)

In [None]:
%%time
if not CFG["debug"]:
    prostate = pred_each_situation(df.query('organ=="prostate"').reset_index(drop=True))
    spleen = pred_each_situation(df.query('organ=="spleen"').reset_index(drop=True))
    lung = pred_each_situation(df.query('organ=="lung"').reset_index(drop=True))
    kidney = pred_each_situation(df.query('organ=="kidney"').reset_index(drop=True))
    largeintestine	 = pred_each_situation(df.query('organ=="largeintestine"').reset_index(drop=True))

In [None]:
if not CFG["debug"]:
    plot_imgs_msks_preds(prostate[:5],"prostate_best.png")

In [None]:
if not CFG["debug"]:
    plot_imgs_msks_preds(prostate[-5:],"prostate_worst.png")

In [None]:
if not CFG["debug"]:
    plot_imgs_msks_preds(spleen[:5],"spleen_best.png")

In [None]:
if not CFG["debug"]:
    plot_imgs_msks_preds(spleen[-5:],"spleen_worst.png")

In [None]:
if not CFG["debug"]:
    plot_imgs_msks_preds(lung[:5],"lung_best.png")

In [None]:
if not CFG["debug"]:
    plot_imgs_msks_preds(lung[-5:],"lung_worst.png")

In [None]:
if not CFG["debug"]:
    plot_imgs_msks_preds(kidney[:5],"kidney_best.png")

In [None]:
if not CFG["debug"]:
    plot_imgs_msks_preds(kidney[-5:],"kidney_worst.png")

In [None]:
if not CFG["debug"]:
    plot_imgs_msks_preds(largeintestine[:5],"largeintestine_best.png")

In [None]:
if not CFG["debug"]:
    plot_imgs_msks_preds(largeintestine[-5:],"largeintestine_worst.png")

In [None]:
run.finish()

In [None]:
wandb_url = run.url
with open('/content/OUTPUT/wandb_url.txt', 'w') as f:
    f.write(wandb_url)

In [None]:
LOG_PATH = '/content/drive/MyDrive/Kaggle/202207_Hacking_the_Human_Body/Result/' +  run_start_time
!mkdir $LOG_PATH
#!cp /content/log.log $LOG_PATH
#!cp -r /content/TEST $LOG_PATH
#!cp -r /content/OOF $LOG_PATH
#!cp -r /content/DATA $LOG_PATH
!cp -r /content/OUTPUT $LOG_PATH