# Import

In [1]:
import torch 
import torch.nn as nn  
import numpy as np
from tqdm import tqdm
import os,sys,cv2
from torch.cuda.amp import autocast
import matplotlib.pyplot as plt
import albumentations as A
import segmentation_models_pytorch as smp
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from torch.nn.parallel import DataParallel
import torch.nn.functional as F
from glob import glob
import ssl

ssl._create_default_https_context = ssl._create_unverified_context

# config

In [2]:
class CFG:
    # ============== pred target =============
    target_size = 1

    # ============== model CFG =============
    model_name = 'Unet'
    backbone = 'resnext101_32x8d' #'se_resnext101_32x4d' #'densenet161' #'se_resnext50_32x4d' #'timm-resnest50d_4s2x40d'

    in_chans = 5 # 65
    # ============== training CFG =============
    image_size = 256 #512
    input_size = 256
    drop_egde_pixel = 0
    tile_size = image_size
    stride = tile_size // 2
    assert stride>drop_egde_pixel

    train_batch_size = 16 # 32
    valid_batch_size = train_batch_size * 2

    epochs = 50
    lr = 6e-5

    # ============== fold =============
    valid_id = 1
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



    # ============== augmentation =============
    train_aug_list = [
        A.RandomResizedCrop(
            input_size, input_size, scale=(0.8,1.25)),
        A.ShiftScaleRotate(p=0.75),
        #A.HorizontalFlip(p=0.3),
        #A.VerticalFlip(p=0.3),
        A.OneOf([
                A.GaussNoise(var_limit=[10, 50]),
                A.GaussianBlur(),
                A.MotionBlur(),
                ], p=0.4),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
        ToTensorV2(transpose_mask=True),
    ]
    train_aug = A.Compose(train_aug_list)
    valid_aug_list = [
        ToTensorV2(transpose_mask=True),
    ]
    valid_aug = A.Compose(valid_aug_list)

# Model

In [3]:
class CustomModel(nn.Module):
    def __init__(self, CFG, weight=None):
        super().__init__()
        self.CFG = CFG
        self.encoder = smp.Unet(
            encoder_name=CFG.backbone, 
            encoder_weights=weight,
            in_channels=CFG.in_chans,
            classes=CFG.target_size,
            activation=None,
        )

    def forward(self, image):
        output = self.encoder(image)
        # output = output.squeeze(-1)
        return output[:,0]#.sigmoid()


def build_model(weight="imagenet"):
    from dotenv import load_dotenv
    load_dotenv()

    print('model_name', CFG.model_name)
    print('backbone', CFG.backbone)

    model = CustomModel(CFG, weight)

    return model.cuda()

# Functions

In [4]:
def min_max_normalization(x:torch.Tensor)->torch.Tensor:
    """input.shape=(batch,f1,...)"""
    shape=x.shape
    if x.ndim>2:
        x=x.reshape(x.shape[0],-1)
    
    min_=x.min(dim=-1,keepdim=True)[0]
    max_=x.max(dim=-1,keepdim=True)[0]
    if min_.mean()==0 and max_.mean()==1:
        return x.reshape(shape)
    
    x=(x-min_)/(max_-min_+1e-9)
    return x.reshape(shape)

class Data_loader(Dataset):
    def __init__(self,path,s="/images/"):
        self.paths=glob(path+f"{s}*.tif")
        self.paths.sort()
        self.bool=s=="/labels/"
    
    def __len__(self):
        return len(self.paths)
    
    def __getitem__(self,index):
        img=cv2.imread(self.paths[index],cv2.IMREAD_GRAYSCALE)
        img=torch.from_numpy(img)
        if self.bool:
            img=img.to(torch.bool)
        else:
            img=img.to(torch.uint8)
        return img

def load_data(path,s):
    data_loader=Data_loader(path,s)
    data_loader=DataLoader(data_loader, batch_size=16, num_workers=2)
    data=[]
    for x in tqdm(data_loader):
        data.append(x)
    return torch.cat(data,dim=0)


#https://www.kaggle.com/code/kashiwaba/sennet-hoa-train-unet-simple-baseline
def dice_coef(pred:torch.Tensor,target:torch.Tensor,TH=0.5,epsilon=1e-5):
    if torch.any(pred<0) or torch.any(pred>1):
        pred=pred.sigmoid()
    target = target.unsqueeze(1).to(torch.float32)
    pred = (pred>TH).to(torch.float32)
    inter = (target*pred).sum()
    den = target.sum() + pred.sum()
    dice = ((2*inter+epsilon)/(den+epsilon)).mean()
    return dice

    
class Kaggld_Dataset(Dataset):
    def __init__(self,x:list,y:list,arg=False):
        super(Dataset,self).__init__()
        self.x=x#list[(C,H,W),...]
        self.y=y#list[(C,H,W),...]
        self.image_size=CFG.image_size
        self.in_chans=CFG.in_chans
        self.arg=arg
        if arg:
            self.transform=CFG.train_aug
        else: 
            self.transform=CFG.valid_aug

    def __len__(self) -> int:
        return sum([y.shape[0]-self.in_chans for y in self.y])
    
    def __getitem__(self,index):
        i=0
        for x in self.x:
            if index>x.shape[0]-self.in_chans:
                index-=x.shape[0]-self.in_chans
                i+=1
            else:
                break
        x=self.x[i]
        y=self.y[i]
        
        x_index=np.random.randint(0,x.shape[1]-self.image_size)
        y_index=np.random.randint(0,x.shape[2]-self.image_size)

        x=x[index:index+self.in_chans,x_index:x_index+self.image_size,y_index:y_index+self.image_size].to(torch.float32)
        y=y[index+self.in_chans//2,x_index:x_index+self.image_size,y_index:y_index+self.image_size].to(torch.float32)

        data = self.transform(image=x.numpy().transpose(1,2,0), mask=y.numpy())
        x = data['image']
        y = data['mask']
        if self.arg:
            i=np.random.randint(4)
            x=x.rot90(i,dims=(1,2))
            y=y.rot90(i,dims=(0,1))
            for i in range(3):
                if np.random.randint(2):
                    x=x.flip(dims=(i,))
                    if i>=1:
                        y=y.flip(dims=(i-1,))
        return x,y


# Load data 

In [5]:
def load_data(path,s):
    data_loader=Data_loader(path,s)
    data_loader=DataLoader(data_loader, batch_size=16, num_workers=2)
    data=[]
    for x in tqdm(data_loader):
        data.append(x)
    return torch.cat(data,dim=0)

In [7]:
train_x=[]
train_y=[]

root_path="../data/blood-vessel-segmentation/"
paths=glob(root_path+"train/*")
paths.sort()
for i,path in enumerate(paths[1:]):
    if path=="../data/blood-vessel-segmentation/train/kidney_3_dense":
        continue
    x=load_data(path,"/images/")
    print(x.shape)
    y=load_data(path,"/labels/")
    train_x.append(x)
    train_y.append(y)

    #(C,H,W)

    #aug
    train_x.append(x.permute(1,2,0))
    train_y.append(y.permute(1,2,0))
    train_x.append(x.permute(2,0,1))
    train_y.append(y.permute(2,0,1))

val_x=load_data(paths[0],"/images/")
val_y=load_data(paths[0],"/labels/")


  0%|          | 0/88 [00:00<?, ?it/s]

100%|██████████| 88/88 [01:43<00:00,  1.18s/it]


torch.Size([1397, 1928, 1928])


100%|██████████| 88/88 [01:02<00:00,  1.41it/s]
100%|██████████| 139/139 [01:11<00:00,  1.94it/s]


torch.Size([2217, 1041, 1511])


100%|██████████| 139/139 [00:39<00:00,  3.55it/s]
100%|██████████| 65/65 [00:54<00:00,  1.19it/s]


torch.Size([1035, 1706, 1510])


100%|██████████| 65/65 [00:32<00:00,  1.99it/s]
100%|██████████| 143/143 [00:07<00:00, 18.40it/s]
100%|██████████| 143/143 [00:07<00:00, 19.29it/s]


# Training

In [8]:
class FocalLoss(nn.modules.loss._WeightedLoss):

    def __init__(self, gamma=0, size_average=None, ignore_index=-100,
                 reduce=None, balance_param=1.0):
        super(FocalLoss, self).__init__(size_average)
        self.gamma = gamma
        self.size_average = size_average
        self.ignore_index = ignore_index
        self.balance_param = balance_param

    def forward(self, input, target):
        
        assert len(input.shape) == len(target.shape)
        assert input.size(0) == target.size(0)
        assert input.size(1) == target.size(1)

        logpt = - F.binary_cross_entropy_with_logits(input, target)
        pt = torch.exp(logpt)

        focal_loss = -((1 - pt) ** self.gamma) * logpt
        balanced_focal_loss = self.balance_param * focal_loss
        return balanced_focal_loss

In [9]:
train_dataset=Kaggld_Dataset(train_x,train_y,arg=True)
train_dataset = DataLoader(train_dataset, batch_size=16, num_workers=2, shuffle=True, pin_memory=True)
val_dataset=Kaggld_Dataset([val_x],[val_y])
val_dataset = DataLoader(val_dataset, batch_size=16, num_workers=2, shuffle=False, pin_memory=True)

model=build_model()
#model=DataParallel(model)

loss_fn= smp.losses.DiceLoss(mode='binary') #smp.losses.DiceLoss(mode='binary')#FocalLoss(gamma=2) #nn.BCEWithLogitsLoss()
optimizer=torch.optim.AdamW(model.parameters(),lr=CFG.lr)
scaler=torch.cuda.amp.GradScaler()
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=CFG.lr,
                                                steps_per_epoch=len(train_dataset), epochs=CFG.epochs+1,
                                                pct_start=0.1,)

model_name Unet
backbone resnext101_32x8d


In [10]:
model

CustomModel(
  (encoder): Unet(
    (encoder): ResNetEncoder(
      (conv1): Conv2d(5, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      

In [56]:
for epoch in range(CFG.epochs):
    time=tqdm(range(len(train_dataset)))
    losss=0
    scores=0
    for i,(x,y) in enumerate(train_dataset):
        x=x.cuda()
        y=y.cuda()
        x=min_max_normalization(x)

        with autocast():
            pred=model(x)
            loss=loss_fn(pred,y)

            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        scheduler.step()
        score=dice_coef(pred.detach(),y)
        losss=(losss*i+loss.item())/(i+1)
        scores=(scores*i+score)/(i+1)
        time.set_description(f"epoch:{epoch},loss:{losss:.4f},score:{scores:.4f},lr{optimizer.param_groups[0]['lr']:.4e}")
        time.update()
        del loss,pred
    time.close()
    val_losss=0
    val_scores=0
    time=tqdm(range(len(val_dataset)))
    for i,(x,y) in enumerate(val_dataset):
        x=x.cuda()
        y=y.cuda()
        x=min_max_normalization(x)

        with autocast():
            with torch.no_grad():
                pred=model(x)
                loss=loss_fn(pred,y)
        score=dice_coef(pred.detach(),y)
        val_losss=(val_losss*i+loss.item())/(i+1)
        val_scores=(val_scores*i+score)/(i+1)
        time.set_description(f"val-->loss:{val_losss:.4f},score:{val_scores:.4f}")
        time.update()

    time.close()
    torch.save(model.state_dict(),f"./checkpoints/ver11_{CFG.backbone}_{epoch}_loss{losss:.2f}_score{scores:.2f}_val_loss{val_losss:.2f}_val_score{val_scores:.2f}.pt")

time.close()

epoch:0,loss:0.9853,score:0.2713,lr2.4057e-06:   2%|▏         | 23/1170 [00:08<07:13,  2.64it/s]


epoch:0,loss:0.9673,score:0.2864,lr1.6463e-05: 100%|██████████| 1170/1170 [03:49<00:00,  5.09it/s]
val-->loss:0.9085,score:0.4234: 100%|██████████| 143/143 [00:09<00:00, 14.36it/s]
epoch:1,loss:0.7848,score:0.5867,lr4.4410e-05: 100%|██████████| 1170/1170 [03:49<00:00,  5.09it/s]
val-->loss:0.4875,score:0.7311: 100%|██████████| 143/143 [00:09<00:00, 14.74it/s]
epoch:2,loss:0.3296,score:0.8005,lr5.9907e-05: 100%|██████████| 1170/1170 [03:48<00:00,  5.13it/s]
val-->loss:0.2857,score:0.8397: 100%|██████████| 143/143 [00:09<00:00, 14.52it/s]
epoch:3,loss:0.2694,score:0.8130,lr5.9839e-05: 100%|██████████| 1170/1170 [03:47<00:00,  5.13it/s]
val-->loss:0.2378,score:0.8737: 100%|██████████| 143/143 [00:09<00:00, 14.82it/s]
epoch:4,loss:0.2443,score:0.8285,lr5.9301e-05: 100%|██████████| 1170/1170 [03:48<00:00,  5.12it/s]
val-->loss:0.2115,score:0.8823: 100%|██████████| 143/143 [00:09<00:00, 15.35it/s]
epoch:5,loss:0.2318,score:0.8410,lr5.8392e-05: 100%|██████████| 1170/1170 [03:47<00:00,  5.13it