In [1]:
!pip install segmentation_models_pytorch



In [2]:
import numpy as np
import pandas as pd

import torch, cv2
import torch.nn as nn

from glob import glob

import albumentations as A
from albumentations.pytorch import ToTensorV2

from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp

pd.set_option('display.max_colwidth', 200)

In [3]:
class CFG:
    target_size = 1
    model_name = 'Unet'
    backbone = 'resnext50_32x4d'
    in_chans = 3
    grid_size = [512,512]
    epochs = 50
    lr = 1e-5

    train_aug = A.Compose([
        A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
        A.RandomBrightnessContrast(p=0.1),
        A.GaussianBlur(p=0.1),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2() 
    ]) 
    
    val_aug = A.Compose([
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])

In [4]:
mask_path = sorted(glob(f'/kaggle/input/blood-vessel-segmentation/train/*/labels/*tif'))
df = pd.DataFrame({'mask_path': mask_path})

df['dataset'] = df.mask_path.map(lambda x: x.split('/')[-3])
df['slice'] = df.mask_path.map(lambda x: x.split('/')[-1].replace('.tif', ''))

df = df[~df.dataset.str.contains('kidney_3_sparse')]
df['image_path'] = df.mask_path.str.replace('label', 'image')
df['image_path'] = df.image_path.str.replace('kidney_3_dense', 'kidney_3_sparse')

train_df = df[~df.dataset.str.contains('kidney_3')]
val_df = df[df.dataset.str.contains('kidney_3')]

In [5]:
class Load_Data(Dataset):
    def __init__(self, tmp_df, mode, transform=None):
        super().__init__()
        
        data = []
        for _, row in tmp_df.iterrows():
            data_tmp = []
            image = cv2.imread(row['image_path'], cv2.IMREAD_GRAYSCALE)
            mask = cv2.imread(row['mask_path'], cv2.IMREAD_GRAYSCALE)
            tmp = np.stack([image, mask]).transpose((1,2,0)).astype(np.uint8)

            grid_size = CFG.grid_size

            num_i = image.shape[0] // grid_size[0]
            num_j = image.shape[1] // grid_size[1]

            for i in range(num_i):
                for j in range(num_j):
                    data_tmp.append(tmp[i*grid_size[0]:(i+1)*grid_size[0], j*grid_size[1]:(j+1)*grid_size[1], :])

            for i in range(num_i):
                data_tmp.append(tmp[i*grid_size[0]:(i+1)*grid_size[0], tmp.shape[1]-grid_size[1]:, :])

            for j in range(num_j):
                data_tmp.append(tmp[tmp.shape[0]-grid_size[0]:, j*grid_size[1]:(j+1)*grid_size[1], :])

            data_tmp.append(tmp[tmp.shape[0]-grid_size[0]:, tmp.shape[1]-grid_size[1]:, :])

            data_tmp = np.array(data_tmp)

            if mode == 'train':
                data.append(data_tmp[np.random.randint(data_tmp.shape[0], size=2)])
            else:
                data.append(data_tmp)

        data = np.stack(data).reshape(-1, *grid_size, 2)
        self.images, self.masks = data[:, :, :, 0], data[:, :, :, 1] 
        self.masks = (self.masks>127)
        self.transform = transform
        
    def __len__(self):
        return self.images.shape[0]
    
    def __getitem__(self, idx):
        
        image = np.expand_dims(self.images[idx], axis=-1)
        image = np.repeat(image, 3, axis=-1)
            
        transformed = self.transform(image=image, mask=self.masks[idx])
        image, mask = transformed['image'], transformed['mask']
        return image, mask

In [6]:
class CustomModel(nn.Module):
    def __init__(self, CFG, weight=None):
        super(CustomModel, self).__init__()
        self.model = smp.Unet(
            encoder_name = CFG.backbone,
            encoder_weights = weight,
            in_channels = CFG.in_chans, 
            classes = CFG.target_size, 
            activation = None
        )
    def forward(self, image):
        output = self.model(image)
        return output

In [7]:
model = CustomModel(CFG, 'imagenet')

dice_loss = smp.losses.DiceLoss(mode='binary')
optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

CustomModel(
  (model): Unet(
    (encoder): ResNetEncoder(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        

In [8]:
# val_load = Load_Data(val_df, 'val', transform=CFG.val_aug)
# val_loader = DataLoader(val_load, batch_size=8, shuffle=False)

for epoch in range(CFG.epochs):
    
    print('epochs:', epoch+1)
    
    train_load = Load_Data(train_df, 'train', transform=CFG.train_aug)
    train_loader = DataLoader(train_load, batch_size=8, shuffle=True)
    
    model.train()

    train_loss = []
    for i, (images, masks) in enumerate(train_loader):
        
        optimizer.zero_grad()
        masks_pred = model(images.to(device))
        loss = dice_loss(masks_pred.to(device), masks.to(device)).to(device)
        train_loss.append(loss.detach().cpu().numpy())
        loss.backward()
        optimizer.step()  
        
    print('train_loss_avg:', np.mean(np.array(train_loss)))
     
#     model.eval()
    
#     val_loss = []
#     with torch.no_grad():
#         for i, (images, masks) in enumerate(val_loader):

#             masks_pred = model(images.to(device))
#             loss = dice_loss(masks_pred.to(device), masks.to(device)).to(device)
#             val_loss.append(loss.detach().cpu().numpy())
            
#     print('val_loss_avg:', np.mean(np.array(val_loss)))
        
    if (epoch+1)%2 == 0:
        torch.save(model.state_dict(), f'/kaggle/working/{CFG.backbone}_{epoch+1}.pt')     

epochs: 1
train_loss_avg: 0.93506503
epochs: 2
train_loss_avg: 0.80515456
epochs: 3
train_loss_avg: 0.56566954
epochs: 4
train_loss_avg: 0.33602047
epochs: 5
train_loss_avg: 0.20133957
epochs: 6
train_loss_avg: 0.13337101
epochs: 7
train_loss_avg: 0.094398074
epochs: 8
train_loss_avg: 0.07475319
epochs: 9
train_loss_avg: 0.06417417
epochs: 10
train_loss_avg: 0.05951814
epochs: 11
train_loss_avg: 0.053899456
epochs: 12
train_loss_avg: 0.050104883
epochs: 13
train_loss_avg: 0.04854886
epochs: 14
train_loss_avg: 0.046506647
epochs: 15


KeyboardInterrupt: 