In [1]:
import os, sys,shutil, gc

from glob import glob
from tqdm import tqdm

import math
import random
from collections import OrderedDict
import warnings

import albumentations as A
import cv2
from matplotlib import pyplot as plt

import numpy as np
import pandas as pd
import timm
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import AdamW,lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import GradScaler, autocast
from sklearn.metrics import fbeta_score, average_precision_score, roc_auc_score

import multiprocessing

from ink_helpers import (load_image,seed_everything,
                         load_fragment,)

warnings.simplefilter("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)
print("cpu count:", multiprocessing.cpu_count())

cuda
cpu count: 32


In [2]:
# Config:
random_seed = 42
num_workers = min(12, multiprocessing.cpu_count())

bottom_channel_idx = 29
top_channel_idx = 35
num_fluctuate_channel = 1

num_select_channel = top_channel_idx - bottom_channel_idx

block_size = 512
stride = block_size // 4

loss_type = ['bce', 'focal'][0]
max_lr = 1.0e-5
weight_decay = 1.0e-3
total_epoch = 12
batch_size = 24

valid_id = '2c'


seed_everything(seed=random_seed)

In [3]:
all_frag_ids = ['1', '2a', '2b', '2c', '3']
id2dir = {id:f'./frags/train_{id}' for id in all_frag_ids}
train_id_list = [id for id in all_frag_ids if id != valid_id]
print('Train:', train_id_list)

Train: ['1', '2a', '2b', '3']


In [5]:
id2images,id2frag_mask,id2ink_mask = {},{},{}
for frag_id in tqdm(all_frag_ids):
    images,frag_mask,ink_mask = load_fragment(frag_id)
    id2images[frag_id] = images
    id2frag_mask[frag_id] = frag_mask
    id2ink_mask[frag_id] = ink_mask

100%|█████████████████████████████████████████| 5/5 [01:13<00:00, 14.71s/it]


In [6]:
class InkDataSet2D(Dataset):
    '''
    image: (D, H, W); mask: (1, H, W)
    '''
    def __init__(self, frag_id_list, block_size, channel_slip=0, transforms=None, has_label=True):
        self.frag_id_list = frag_id_list
        self.block_size = block_size
        self.transforms = transforms
        self.has_label = has_label
        
        # get xy positions
        id_xybt_list = []
        for frag_id in frag_id_list:
            frag_mask = id2frag_mask[frag_id]
            xy_pairs = [
                (min(x,frag_mask.shape[1]-block_size), min(y,frag_mask.shape[0]-block_size))
                for x in range(0, frag_mask.shape[1]-block_size+stride, stride) 
                for y in range(0, frag_mask.shape[0]-block_size+stride, stride) 
                if np.any(frag_mask[y:y+block_size, x:x+block_size] > 0)
            ]
            bt_pairs = [(bottom_channel_idx+f, top_channel_idx+f)
                        for f in range(-channel_slip, channel_slip+1)]
            id_xybt_list += [(frag_id, *xy, *bt) for xy in xy_pairs for bt in bt_pairs]
        self.id_xybt_list = id_xybt_list
        
    def __len__(self):
        return len(self.id_xybt_list)

    def __getitem__(self, idx):
        frag_id,x,y,start_z,end_z = self.id_xybt_list[idx]

        whole_image = id2images[frag_id]
        image = whole_image[start_z:end_z, 
                            y:y+self.block_size, 
                            x:x+self.block_size] # D,H,W
        image = np.moveaxis(image, 0, 2) # H,W,D

        if self.has_label:
            whole_mask = id2ink_mask[frag_id]
            mask = whole_mask[y:y+self.block_size, 
                              x:x+self.block_size] # H,W
            
            if self.transforms:
                transformed = self.transforms(image=image, mask=mask)
                image, mask = transformed['image'], transformed['mask']
                
            image = np.moveaxis(image, 2, 0) # D,H,W
            mask = np.expand_dims(mask, 0) # 1,H,W
            
            return idx, image, mask
        else:
            if self.transforms:
                image = self.transforms(image=image)['image']
            image = np.moveaxis(image, 2, 0) # D,H,W
            return idx,image

In [7]:
train_transform = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=1.0),
        A.RandomBrightnessContrast(p=0.75),
        A.ShiftScaleRotate(p=0.75),
        A.OneOf([
                A.GaussNoise(var_limit=[10, 50]),
                A.GaussianBlur(),
                A.MotionBlur(),
                ], p=0.4),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
        A.CoarseDropout(max_holes=1, max_width=int(block_size * 0.3), max_height=int(block_size * 0.3), 
                        mask_fill_value=0, p=0.5),
        A.Normalize(
            mean=[0]*num_select_channel, 
            std=[1]*num_select_channel
        ),
    ]
)

train_dataset = InkDataSet2D(
    frag_id_list=train_id_list, 
    block_size=block_size, 
    channel_slip=num_fluctuate_channel, 
    transforms=train_transform, 
    has_label=True
)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    pin_memory=True,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
    prefetch_factor=1,
)

valid_transform = A.Compose([
        A.Normalize(
            mean=[0]*num_select_channel, 
            std=[1]*num_select_channel
        ),
    ]
)

valid_dataset = InkDataSet2D(
    frag_id_list=[valid_id], 
    block_size=block_size, 
    channel_slip=0, 
    transforms=valid_transform, 
    has_label=False
)

valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=batch_size,
    pin_memory=True,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers,
    prefetch_factor=1,
)

print('Train', len(train_dataloader), ', Valid', len(valid_dataloader), )

Train 1004 , Valid 85


In [8]:
class UNet2D(nn.Module):
    def __init__(self, num_channels, num_classes):
        super(UNet2D, self).__init__()
        self.num_classes = num_classes
        self.contracting_11 = self.conv_block(in_channels=num_channels, out_channels=64)
        self.contracting_12 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.contracting_21 = self.conv_block(in_channels=64, out_channels=128)
        self.contracting_22 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.contracting_31 = self.conv_block(in_channels=128, out_channels=256)
        self.contracting_32 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.contracting_41 = self.conv_block(in_channels=256, out_channels=512)
        self.contracting_42 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.middle = self.conv_block(in_channels=512, out_channels=1024)
        self.expansive_11 = nn.ConvTranspose2d(
            in_channels=1024,
            out_channels=512,
            kernel_size=3,
            stride=2,
            padding=1,
            output_padding=1,
        )
        self.expansive_12 = self.conv_block(in_channels=1024, out_channels=512)
        self.expansive_21 = nn.ConvTranspose2d(
            in_channels=512,
            out_channels=256,
            kernel_size=3,
            stride=2,
            padding=1,
            output_padding=1,
        )
        self.expansive_22 = self.conv_block(in_channels=512, out_channels=256)
        self.expansive_31 = nn.ConvTranspose2d(
            in_channels=256,
            out_channels=128,
            kernel_size=3,
            stride=2,
            padding=1,
            output_padding=1,
        )
        self.expansive_32 = self.conv_block(in_channels=256, out_channels=128)
        self.expansive_41 = nn.ConvTranspose2d(
            in_channels=128,
            out_channels=64,
            kernel_size=3,
            stride=2,
            padding=1,
            output_padding=1,
        )
        self.expansive_42 = self.conv_block(in_channels=128, out_channels=64)
        self.output = nn.Conv2d(
            in_channels=64, out_channels=num_classes, kernel_size=3, stride=1, padding=1
        )

    def conv_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3,
                stride=1,
                padding=1,
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=out_channels),
            nn.Conv2d(
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=3,
                stride=1,
                padding=1,
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=out_channels),
        )
        return block

    def forward(self, X):
        contracting_11_out = self.contracting_11(X)  # [-1, 64, 256, 256]
        contracting_12_out = self.contracting_12(
            contracting_11_out
        )  # [-1, 64, 128, 128]
        contracting_21_out = self.contracting_21(
            contracting_12_out
        )  # [-1, 128, 128, 128]
        contracting_22_out = self.contracting_22(
            contracting_21_out
        )  # [-1, 128, 64, 64]
        contracting_31_out = self.contracting_31(
            contracting_22_out
        )  # [-1, 256, 64, 64]
        contracting_32_out = self.contracting_32(
            contracting_31_out
        )  # [-1, 256, 32, 32]
        contracting_41_out = self.contracting_41(
            contracting_32_out
        )  # [-1, 512, 32, 32]
        contracting_42_out = self.contracting_42(
            contracting_41_out
        )  # [-1, 512, 16, 16]
        middle_out = self.middle(contracting_42_out)  # [-1, 1024, 16, 16]
        expansive_11_out = self.expansive_11(middle_out)  # [-1, 512, 32, 32]
        expansive_12_out = self.expansive_12(
            torch.cat((expansive_11_out, contracting_41_out), dim=1)
        )  # [-1, 1024, 32, 32] -> [-1, 512, 32, 32]
        expansive_21_out = self.expansive_21(expansive_12_out)  # [-1, 256, 64, 64]
        expansive_22_out = self.expansive_22(
            torch.cat((expansive_21_out, contracting_31_out), dim=1)
        )  # [-1, 512, 64, 64] -> [-1, 256, 64, 64]
        expansive_31_out = self.expansive_31(expansive_22_out)  # [-1, 128, 128, 128]
        expansive_32_out = self.expansive_32(
            torch.cat((expansive_31_out, contracting_21_out), dim=1)
        )  # [-1, 256, 128, 128] -> [-1, 128, 128, 128]
        expansive_41_out = self.expansive_41(expansive_32_out)  # [-1, 64, 256, 256]
        expansive_42_out = self.expansive_42(
            torch.cat((expansive_41_out, contracting_11_out), dim=1)
        )  # [-1, 128, 256, 256] -> [-1, 64, 256, 256]
        output_out = self.output(expansive_42_out)  # [-1, num_classes, 256, 256]
        return output_out

In [9]:
model = UNet2D(num_channels=num_select_channel, num_classes=1);
model.to(device);

In [10]:
if loss_type == 'bce':
    criterion = nn.BCEWithLogitsLoss()
elif loss_type == 'focal':
    criterion = FocalLoss(alpha=1, gamma=2, use_logits=True)
    
optimizer = AdamW(model.parameters(), lr=max_lr, weight_decay=weight_decay)
scheduler = lr_scheduler.OneCycleLR(
    optimizer,
    epochs=total_epoch,
    steps_per_epoch=len(train_dataloader),
    max_lr=max_lr,
    pct_start=0.1,
    anneal_strategy="cos",
    div_factor=1.0e3,
    final_div_factor=1.0e1,
)
scaler = GradScaler()
Sig = nn.Sigmoid()

loss_list = [1] * 10
for epoch in range(total_epoch):
    
    # training
    gc.collect()
    with tqdm(enumerate(train_dataloader), total=len(train_dataloader)) as pbar:
        for step, (idx, img, target) in pbar:
            
            img, target = img.to(device).float(), target.to(device).float()
            
            optimizer.zero_grad()
            with autocast():
                outputs = model(img).float()
            loss = criterion(outputs, target)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            pbar.set_description(f"Epoch {epoch:02d}")
            loss_list = loss_list[1:] + [loss.item()]
            pbar.set_postfix(
                OrderedDict(
                    LR=f"{scheduler.get_last_lr()[0]:.2e}",
                    Loss=f"{sum(loss_list)/10:.4f}",
                )
            )
            
    # validation
    valid_frag_mask = id2frag_mask[valid_id]
    valid_ink_mask = id2ink_mask[valid_id]
    valid_ink_predicts = np.zeros(valid_frag_mask.shape).astype(float)
    valid_ink_count = np.zeros(valid_frag_mask.shape)
    valid_xybt_list = valid_dataset.id_xybt_list
    
    model.eval()
    for idx, img in tqdm(valid_dataloader):
        img = img.to(device).float()
        with torch.no_grad():
            with autocast():
                outputs = Sig(model(img).float())

        for batch_idx,whole_idx in enumerate(idx):
            x,y = map(int, valid_xybt_list[whole_idx][1:3])

            valid_ink_predicts[y:y+block_size, x:x+block_size] += outputs.cpu()[batch_idx][0].numpy()
            valid_ink_count[y:y+block_size, x:x+block_size] += 1

    valid_ink_count[np.where(valid_frag_mask==0)] = 1
    valid_ink_predicts = valid_ink_predicts/valid_ink_count
    valid_ink_predicts[np.where(valid_frag_mask==0)] = 0

    valid_ink_predicts_flat = valid_ink_predicts[np.where(valid_frag_mask)].flatten()
    valid_ink_mask_flat = valid_ink_mask[np.where(valid_frag_mask)].flatten()

    map_score = average_precision_score(valid_ink_mask_flat, valid_ink_predicts_flat)
    auc_score = roc_auc_score(valid_ink_mask_flat, valid_ink_predicts_flat)
    fhalf_score = fbeta_score(valid_ink_mask_flat, valid_ink_predicts_flat>0.5, beta=0.5)
    print(f'Valid: mAP {map_score:.3f}, AUC {auc_score:.3f}, F0.5 {fhalf_score:.3f}, ')
    model.train()
        
# save weights
torch.save(
    model.state_dict(), 
    f'./weights/2DUNet-block{block_size}-channel{bottom_channel_idx}-to{top_channel_idx}'
    f'-slip{num_fluctuate_channel}-loss{loss_type}-lr{max_lr}-wd{weight_decay}-bs{batch_size}'
    f'-valid{valid_id}-step{total_epoch*len(train_dataloader)}-seed{random_seed}-epoch{total_epoch}.pth'
)
    
gc.collect()

Epoch 00: 100%|█| 1004/1004 [18:53<00:00,  1.13s/it, LR=9.34e-06, Loss=0.664
100%|███████████████████████████████████████| 85/85 [00:41<00:00,  2.02it/s]


Valid: mAP 0.287, AUC 0.650, F0.5 0.272, 


Epoch 01: 100%|█| 1004/1004 [18:52<00:00,  1.13s/it, LR=9.86e-06, Loss=0.597
100%|███████████████████████████████████████| 85/85 [00:42<00:00,  2.02it/s]


Valid: mAP 0.340, AUC 0.689, F0.5 0.348, 


Epoch 02: 100%|█| 1004/1004 [18:52<00:00,  1.13s/it, LR=9.33e-06, Loss=0.509
100%|███████████████████████████████████████| 85/85 [00:42<00:00,  2.02it/s]


Valid: mAP 0.363, AUC 0.699, F0.5 0.389, 


Epoch 03: 100%|█| 1004/1004 [18:52<00:00,  1.13s/it, LR=8.43e-06, Loss=0.466
100%|███████████████████████████████████████| 85/85 [00:41<00:00,  2.03it/s]


Valid: mAP 0.365, AUC 0.705, F0.5 0.388, 


Epoch 04: 100%|█| 1004/1004 [18:48<00:00,  1.12s/it, LR=7.24e-06, Loss=0.397
100%|█████████████████████████████████| 85/85 [00:41<00:00,  2.03it/s]


Valid: mAP 0.380, AUC 0.715, F0.5 0.394, 


Epoch 05: 100%|█| 1004/1004 [18:49<00:00,  1.12s/it, LR=5.87e-06, Loss
100%|█████████████████████████████████| 85/85 [00:42<00:00,  2.02it/s]


Valid: mAP 0.369, AUC 0.705, F0.5 0.395, 


Epoch 06: 100%|█| 1004/1004 [18:47<00:00,  1.12s/it, LR=4.42e-06, Loss
100%|█████████████████████████████████| 85/85 [00:41<00:00,  2.04it/s]


Valid: mAP 0.377, AUC 0.716, F0.5 0.398, 


Epoch 07: 100%|█| 1004/1004 [18:50<00:00,  1.13s/it, LR=3.02e-06, Loss
100%|███████████████████████████████████| 85/85 [00:41<00:00,  2.02it/s]


Valid: mAP 0.407, AUC 0.735, F0.5 0.418, 


Epoch 08: 100%|█| 1004/1004 [18:47<00:00,  1.12s/it, LR=1.79e-06, Loss=0
100%|███████████████████████████████████| 85/85 [00:41<00:00,  2.04it/s]


Valid: mAP 0.407, AUC 0.729, F0.5 0.430, 


Epoch 09: 100%|█| 1004/1004 [18:47<00:00,  1.12s/it, LR=8.22e-07, Loss=0
100%|███████████████████████████████████| 85/85 [00:41<00:00,  2.04it/s]


Valid: mAP 0.406, AUC 0.727, F0.5 0.421, 


Epoch 10:   5%| | 47/1004 [00:55<18:40,  1.17s/it, LR=7.85e-07, Loss=0.3


KeyboardInterrupt: 

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, use_logits=False, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.use_logits = use_logits
        self.reduce = reduce

    def forward(self, inputs, targets):
        if self.use_logits:
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
        else:
            BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

In [None]:
plt.subplot(2, 1, 1)
plt.imshow(valid_ink_mask)
plt.subplot(2, 1, 2)
plt.imshow(valid_ink_predicts>0.5)