In [1]:
import os, sys,shutil, gc

from glob import glob
from tqdm import tqdm

import math
import random
from collections import OrderedDict
import warnings

import albumentations as A
import cv2
from matplotlib import pyplot as plt

import numpy as np
import pandas as pd
import timm
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import AdamW,lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import GradScaler, autocast
from sklearn.metrics import fbeta_score, average_precision_score, roc_auc_score

import multiprocessing

from einops import rearrange, reduce, repeat
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.decoders.unet.decoder import UnetDecoder
from timm.models.resnet import resnet34d, seresnext26t_32x4d

from ink_helpers import (load_image,seed_everything,
                         load_fragment, DiceLoss, FocalLoss, dice_bce_loss)

warnings.simplefilter("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)
print("cpu count:", multiprocessing.cpu_count())

cuda
cpu count: 32


In [2]:
# Config:
random_seed = 42
num_workers = min(12, multiprocessing.cpu_count())

bottom_channel_idx = 29
top_channel_idx = 35
num_fluctuate_channel = 1

num_select_channel = top_channel_idx - bottom_channel_idx

block_size = 256
stride = 128

loss_type = ['bce', 'focal', 'bcedice'][2]
max_lr = 5.0e-5
weight_decay = 1.0e-3
total_epoch = 11
batch_size = 32

valid_id = '2c'

seed_everything(seed=random_seed)

In [3]:
all_frag_ids = ['1', '2a', '2b', '2c', '3']
id2dir = {id:f'./frags/train_{id}' for id in all_frag_ids}
train_id_list = [id for id in all_frag_ids if id != valid_id]
print('Train:', train_id_list)

Train: ['1', '2a', '2b', '3']


In [4]:
id2images,id2frag_mask,id2ink_mask = {},{},{}
for frag_id in tqdm(all_frag_ids):
    images,frag_mask,ink_mask = load_fragment(frag_id)
    id2images[frag_id] = images
    id2frag_mask[frag_id] = frag_mask
    id2ink_mask[frag_id] = ink_mask

100%|█████████████████████████████████████████████| 5/5 [01:14<00:00, 14.96s/it]


In [5]:
class InkDataSet2D(Dataset):
    '''
    image: (D, H, W); mask: (1, H, W)
    '''
    def __init__(self, frag_id_list, block_size, channel_slip=0, transforms=None, has_label=True):
        self.frag_id_list = frag_id_list
        self.block_size = block_size
        self.transforms = transforms
        self.has_label = has_label
        
        # get xy positions
        id_xybt_list = []
        for frag_id in frag_id_list:
            frag_mask = id2frag_mask[frag_id]
            xy_pairs = [
                (min(x,frag_mask.shape[1]-block_size), min(y,frag_mask.shape[0]-block_size))
                for x in range(0, frag_mask.shape[1]-block_size+stride, stride) 
                for y in range(0, frag_mask.shape[0]-block_size+stride, stride) 
                if np.any(frag_mask[y:y+block_size, x:x+block_size] > 0)
            ]
            bt_pairs = [(bottom_channel_idx+f, top_channel_idx+f)
                        for f in range(-channel_slip, channel_slip+1)]
            id_xybt_list += [(frag_id, *xy, *bt) for xy in xy_pairs for bt in bt_pairs]
        self.id_xybt_list = id_xybt_list
        
    def __len__(self):
        return len(self.id_xybt_list)

    def __getitem__(self, idx):
        frag_id,x,y,start_z,end_z = self.id_xybt_list[idx]

        whole_image = id2images[frag_id]
        image = whole_image[start_z:end_z, 
                            y:y+self.block_size, 
                            x:x+self.block_size] # D,H,W
        image = np.moveaxis(image, 0, 2) # H,W,D

        if self.has_label:
            whole_mask = id2ink_mask[frag_id]
            mask = whole_mask[y:y+self.block_size, 
                              x:x+self.block_size] # H,W
            
            if self.transforms:
                transformed = self.transforms(image=image, mask=mask)
                image, mask = transformed['image'], transformed['mask']
                
            image = np.moveaxis(image, 2, 0) # D,H,W
            mask = np.expand_dims(mask, 0) # 1,H,W
            
            return idx, image, mask
        else:
            if self.transforms:
                image = self.transforms(image=image)['image']
            image = np.moveaxis(image, 2, 0) # D,H,W
            return idx,image

In [6]:
train_transform = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=1.0),
#         A.RandomBrightnessContrast(p=0.75),
#         A.ShiftScaleRotate(p=0.75),
#         A.OneOf([
#                 A.GaussNoise(var_limit=[10, 50]),
#                 A.GaussianBlur(),
#                 A.MotionBlur(),
#                 ], p=0.4),
#         A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
        A.CoarseDropout(max_holes=1, max_width=int(block_size * 0.3), max_height=int(block_size * 0.3), 
                        mask_fill_value=0, p=0.5),
        A.Normalize(
            mean=[0]*num_select_channel, 
            std=[1]*num_select_channel
        ),
    ]
)

train_dataset = InkDataSet2D(
    frag_id_list=train_id_list, 
    block_size=block_size, 
    channel_slip=num_fluctuate_channel, 
    transforms=train_transform, 
    has_label=True
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    pin_memory=True,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
    prefetch_factor=1,
)

valid_transform = A.Compose([
        A.Normalize(
            mean=[0]*num_select_channel, 
            std=[1]*num_select_channel
        ),
    ]
)

valid_dataset = InkDataSet2D(
    frag_id_list=[valid_id], 
    block_size=block_size, 
    channel_slip=0, 
    transforms=valid_transform, 
    has_label=False
)

valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=batch_size,
    pin_memory=True,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers,
    prefetch_factor=1,
)

print('Train', len(train_dataloader), ', Valid', len(valid_dataloader), )

Train 749 , Valid 66


# Model

In [7]:
model_name = ['UNet', 'UNetPlusPlus','MAnet'][1]
backbone_name = ['resnet34', 'timm-resnest26d', 'xception'][2]

In [8]:
if model_name == 'UNet':
    model = smp.Unet(
        encoder_name=backbone_name,      
        encoder_weights=None,     
        in_channels=num_select_channel,                  
        classes=1,        
        activation=None,
    );
elif model_name == 'UNetPlusPlus':
    model = smp.UnetPlusPlus(
        encoder_name=backbone_name,      
        encoder_weights=None,     
        in_channels=num_select_channel,                  
        classes=1,        
        activation=None,
    );
elif model_name == 'MAnet':
    model = smp.MAnet(
        encoder_name=backbone_name,      
        encoder_weights=None,     
        in_channels=num_select_channel,                  
        classes=1,        
        activation=None,
    );
    
model.to(device);

In [9]:
# a = torch.from_numpy( np.random.choice(256, (1, 12, 256, 256))).float().cuda()
    
# with torch.no_grad():
#     with torch.cuda.amp.autocast(enabled=True):
#         output = model(a)

In [10]:
if loss_type == 'bce':
    criterion = nn.BCEWithLogitsLoss()
elif loss_type == 'focal':
    criterion = FocalLoss(alpha=1, gamma=2, use_logits=True)
if loss_type == 'bcedice':
    criterion = dice_bce_loss
    
    
optimizer = AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)
scheduler = lr_scheduler.OneCycleLR(
    optimizer,
    epochs=total_epoch,
    steps_per_epoch=len(train_dataloader),
    max_lr=max_lr,
    pct_start=0.1,
    anneal_strategy="cos",
    div_factor=1.0e3,
    final_div_factor=1.0e1,
)
scaler = GradScaler()
Sig = nn.Sigmoid()

loss_list = [1] * 10
for epoch in range(total_epoch):
    
    # training
    gc.collect()
    with tqdm(enumerate(train_dataloader), total=len(train_dataloader)) as pbar:
        for step, (idx, img, target) in pbar:
            
            img, target = img.to(device).float(), target.to(device).float()
            
            optimizer.zero_grad()
            with autocast():
                outputs = model(img).float()
            loss = criterion(outputs, target)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            pbar.set_description(f"Ep{epoch:02d}")
            loss_list = loss_list[1:] + [loss.item()]
            pbar.set_postfix(
                OrderedDict(
                    LR=f"{scheduler.get_last_lr()[0]:.1e}",
                    Loss=f"{sum(loss_list)/10:.3f}",
                )
            )
            
    # validation
    valid_frag_mask = id2frag_mask[valid_id]
    valid_ink_mask = id2ink_mask[valid_id]
    valid_ink_predicts = np.zeros(valid_frag_mask.shape).astype(float)
    valid_ink_count = np.zeros(valid_frag_mask.shape)
    valid_xybt_list = valid_dataset.id_xybt_list
    
    model.eval()
    for idx, img in tqdm(valid_dataloader):
        img = img.to(device).float()
        with torch.no_grad():
            with autocast():
                outputs = Sig(model(img).float())

        for batch_idx,whole_idx in enumerate(idx):
            x,y = map(int, valid_xybt_list[whole_idx][1:3])

            valid_ink_predicts[y:y+block_size, x:x+block_size] += outputs.cpu()[batch_idx][0].numpy()
            valid_ink_count[y:y+block_size, x:x+block_size] += 1
    
    valid_ink_count[np.where(valid_frag_mask==0)] = 1
    valid_ink_predicts = valid_ink_predicts/valid_ink_count
    valid_ink_predicts[np.where(valid_frag_mask==0)] = 0

    valid_ink_predicts_flat = valid_ink_predicts[np.where(valid_frag_mask)].flatten()
    valid_ink_mask_flat = valid_ink_mask[np.where(valid_frag_mask)].flatten()

    map_score = average_precision_score(valid_ink_mask_flat, valid_ink_predicts_flat)
    auc_score = roc_auc_score(valid_ink_mask_flat, valid_ink_predicts_flat)
    fhalf_score = fbeta_score(valid_ink_mask_flat, valid_ink_predicts_flat>0.5, beta=0.5)
    print(f'Valid: mAP {map_score:.3f}, AUC {auc_score:.3f}, F0.5 {fhalf_score:.3f}, ')
    model.train()
    
# save weights
torch.save(
    model.state_dict(), 
    f'./weights/SMP-{model_name}-{backbone_name}-block{block_size}-channel{bottom_channel_idx}-to{top_channel_idx}'
    f'-slip{num_fluctuate_channel}-loss{loss_type}-lr{max_lr}-wd{weight_decay}-bs{batch_size}'
    f'-valid{valid_id}-step{total_epoch*len(train_dataloader)}-seed{random_seed}-epoch{total_epoch}.pth'
)
    
gc.collect()

Ep00: 100%|███████████| 749/749 [04:16<00:00,  2.92it/s, LR=4.9e-05, Loss=0.551]
100%|█████████████████████████████████████| 66/66 [00:09<00:00,  7.05it/s]


Valid: mAP 0.308, AUC 0.661, F0.5 0.337, 


Ep01: 100%|█████| 749/749 [04:15<00:00,  2.94it/s, LR=4.9e-05, Loss=0.502]
100%|█████████████████████████████████████| 66/66 [00:09<00:00,  7.07it/s]


Valid: mAP 0.354, AUC 0.706, F0.5 0.365, 


Ep02: 100%|█████| 749/749 [04:15<00:00,  2.94it/s, LR=4.6e-05, Loss=0.467]
100%|█████████████████████████████████████| 66/66 [00:09<00:00,  7.04it/s]


Valid: mAP 0.404, AUC 0.734, F0.5 0.410, 


Ep03: 100%|█████| 749/749 [04:15<00:00,  2.94it/s, LR=4.0e-05, Loss=0.426]
100%|█████████████████████████████████████| 66/66 [00:09<00:00,  7.09it/s]


Valid: mAP 0.415, AUC 0.757, F0.5 0.414, 


Ep04: 100%|█████| 749/749 [04:15<00:00,  2.94it/s, LR=3.3e-05, Loss=0.382]
100%|█████████████████████████████████████| 66/66 [00:09<00:00,  7.09it/s]


Valid: mAP 0.413, AUC 0.757, F0.5 0.415, 


Ep05: 100%|█████| 749/749 [04:15<00:00,  2.94it/s, LR=2.5e-05, Loss=0.333]
100%|█████████████████████████████████████| 66/66 [00:09<00:00,  7.10it/s]


Valid: mAP 0.442, AUC 0.748, F0.5 0.455, 


Ep06: 100%|█████| 749/749 [04:15<00:00,  2.94it/s, LR=1.8e-05, Loss=0.350]
100%|█████████████████████████████████████| 66/66 [00:09<00:00,  7.05it/s]


Valid: mAP 0.419, AUC 0.755, F0.5 0.434, 


Ep07: 100%|█████| 749/749 [04:15<00:00,  2.94it/s, LR=1.0e-05, Loss=0.298]
100%|█████████████████████████████████████| 66/66 [00:09<00:00,  7.09it/s]


Valid: mAP 0.453, AUC 0.766, F0.5 0.473, 


Ep08: 100%|█████| 749/749 [04:15<00:00,  2.94it/s, LR=4.9e-06, Loss=0.290]
100%|█████████████████████████████████████| 66/66 [00:09<00:00,  7.09it/s]


Valid: mAP 0.451, AUC 0.764, F0.5 0.469, 


Ep09: 100%|█████| 749/749 [04:15<00:00,  2.94it/s, LR=1.2e-06, Loss=0.293]
100%|█████████████████████████████████████| 66/66 [00:09<00:00,  7.13it/s]


Valid: mAP 0.446, AUC 0.760, F0.5 0.465, 


Ep10: 100%|█████| 749/749 [04:15<00:00,  2.94it/s, LR=5.0e-09, Loss=0.278]
100%|█████████████████████████████████████| 66/66 [00:09<00:00,  7.06it/s]


Valid: mAP 0.449, AUC 0.764, F0.5 0.469, 


0

In [11]:
for i in [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]:
    print(i, fbeta_score(valid_ink_mask_flat, valid_ink_predicts_flat>i, beta=0.5))

0.1 0.38526101839112054
0.2 0.4167324410499356
0.3 0.4442678252809358
0.4 0.46323474240634593
0.5 0.46914952980162417
0.6 0.4600835181509361


In [None]:
plt.subplot(2, 1, 1)
plt.imshow(valid_ink_mask)
plt.subplot(2, 1, 2)
plt.imshow(valid_ink_predicts>0.5)