In [None]:
import os
import cv2
from tqdm import tqdm
import torch
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision

!pip install pytorch_warmup torchview
import pytorch_warmup as warmup
from torchview import draw_graph

In [None]:
WIDTH = 1640 // 5 #1280 // 5 #1640 // 5 #
HEIGHT = 590 // 5 #720 // 5 #590 // 5 #
LEARNING_RATE = 0.001
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
NUM_EPOCHS = 20
NUM_WORKERS = 2
PIN_MEMORY = True
DILATION = np.ones((15, 15), np.uint8)

In [None]:
class CULane(Dataset):
    def __init__(self, image_dir, mask_dir, transform = None):
        self.images = sorted(tuple(Path(image_dir).rglob('*.jpg')))
        self.masks = sorted(tuple(Path(mask_dir).rglob('*.png')))
        self.samples = len(self.images)
        
        self.transform = transform

    def __len__(self):
        return self.samples
    
    def resize_with_pad(self, image, new_shape) -> np.array:
        original_shape = (image.shape[1], image.shape[0])
        ratio = float(max(new_shape))/max(original_shape)
        new_size = tuple([int(x*ratio) for x in original_shape])
        image = cv2.resize(image, new_size)
        delta_w = new_shape[0] - new_size[0]
        delta_h = new_shape[1] - new_size[1]
        top, bottom = delta_h//2, delta_h-(delta_h//2)
        left, right = delta_w//2, delta_w-(delta_w//2)
        image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(0, 0, 0))
        return image

    def __getitem__(self, i):
        image =  cv2.imread(str(self.images[i]))
        mask = cv2.imread(str(self.masks[i]), 0)
        mask = cv2.dilate(mask, DILATION)
        
        image = self.resize_with_pad(image, (WIDTH, HEIGHT))
        mask = self.resize_with_pad(mask, (WIDTH, HEIGHT))
        
        mask = mask.astype(np.float32) / 255.0
        mask[mask > 0] = 1
        
        if self.transform:
            result = self.transform(image=image, mask=mask)
            image = result['image']
            mask = result['mask']

        return image, mask

In [None]:
train_transforms = A.Compose([
    A.Rotate(limit=45, p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.CoarseDropout(max_holes=40, min_holes=25, 
                    p=0.5,
                    max_height=30, 
                    max_width=30, fill_value=1),
    A.Normalize(
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
        max_pixel_value=255.0
    ),
    ToTensorV2(),
])

val_transforms = A.Compose([
    A.Normalize(
        mean=[0.0, 0.0, 0.0],
        std=[1.0, 1.0, 1.0],
        max_pixel_value=255.0,
    ),
    ToTensorV2(),
])

In [None]:
dataset_path = '/kaggle/input/culane/CULane'
full_dataset = CULane(image_dir='/kaggle/input/culane/driver_161_90frame', 
                      mask_dir='/kaggle/input/culane/driver_161_90frame_labels', 
                      transform=None)

train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [0.7, 0.3])
train_dataset.dataset.transform = train_transforms
val_dataset.dataset.transform = val_transforms

loader_args = {
    "batch_size": BATCH_SIZE,
    "num_workers": os.cpu_count(), 
    "pin_memory": PIN_MEMORY,
}

train_loader = DataLoader(train_dataset, shuffle=True, **loader_args)
val_loader = DataLoader(val_dataset, **loader_args)

In [None]:
for x, y in train_loader:
    y = y.unsqueeze(1)

    for i in range(min(BATCH_SIZE, 10)):
        plt.figure(figsize=(10, 10))
        plt.subplot(1, 2, 1)       
        plt.imshow(x[i].permute(1, 2, 0))
        plt.subplot(1, 2, 2)     
        plt.imshow(y[i].squeeze(0), cmap='gray')
    break

In [None]:
def save_checkpoint(state, filename):
    print("=> Saving checkpoint")
    torch.save(state, filename)

def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

def train_fn(loader, model, optimizer, loss_fn, scaler, lr_scheduler, epoch):    
    loop = tqdm(loader)
    
    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.to(device=DEVICE).unsqueeze(1)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        lr_scheduler.step()
        
        # update tqdm loop
        loop.set_postfix(loss=loss.item(), epoch=epoch, lr=optimizer.param_groups[0]["lr"])

def validate_fn(loader, model, loss_fn, device="cuda"):
    dice_score = 0
    loss = 0
    
    model.eval()

    with torch.cuda.amp.autocast():
        with torch.no_grad():
            for x, y in loader:
                x = x.to(device)
                y = y.to(device).unsqueeze(1)
                
                preds = torch.sigmoid(model(x))
                preds[preds >= 0.5] = 1

                dice_score += (2 * (preds * y).sum()) / (
                    (preds + y).sum() + 1e-8
                )
                
                loss += loss_fn(preds, y)

    model.train()
    dice_score = dice_score / len(loader)
    loss = loss / len(loader)
    
    print(f"Dice score\t{dice_score:.3f}")
    print(f"Val loss\t{loss:.3f}")
    
    
    return dice_score, loss

In [None]:
import torch.cuda.amp as amp
import torch.nn.functional as F

##
# version 2: user derived grad computation
class FocalSigmoidLossFunc(torch.autograd.Function):
    '''
    compute backward directly for better numeric stability
    '''
    @staticmethod
    @amp.custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, logits, label, alpha, gamma):
        #  logits = logits.float()

        probs = torch.sigmoid(logits)
        coeff = (label - probs).abs_().pow_(gamma).neg_()
        log_probs = torch.where(logits >= 0,
                F.softplus(logits, -1, 50),
                logits - F.softplus(logits, 1, 50))
        log_1_probs = torch.where(logits >= 0,
                -logits + F.softplus(logits, -1, 50),
                -F.softplus(logits, 1, 50))
        ce_term1 = log_probs.mul_(label).mul_(alpha)
        ce_term2 = log_1_probs.mul_(1. - label).mul_(1. - alpha)
        ce = ce_term1.add_(ce_term2)
        loss = ce * coeff

        ctx.vars = (coeff, probs, ce, label, gamma, alpha)

        return loss

    @staticmethod
    @amp.custom_bwd
    def backward(ctx, grad_output):
        '''
        compute gradient of focal loss
        '''
        (coeff, probs, ce, label, gamma, alpha) = ctx.vars

        d_coeff = (label - probs).abs_().pow_(gamma - 1.).mul_(gamma)
        d_coeff.mul_(probs).mul_(1. - probs)
        d_coeff = torch.where(label < probs, d_coeff.neg(), d_coeff)
        term1 = d_coeff.mul_(ce)

        d_ce = label * alpha
        d_ce.sub_(probs.mul_((label * alpha).mul_(2).add_(1).sub_(label).sub_(alpha)))
        term2 = d_ce.mul(coeff)

        grads = term1.add_(term2)
        grads.mul_(grad_output)

        return grads, None, None, None


class FocalLoss(nn.Module):

    def __init__(self,
                 alpha=0.25,
                 gamma=2,
                 reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, label):
        '''
        Usage is same as nn.BCEWithLogits:
            >>> criteria = FocalLossV2()
            >>> logits = torch.randn(8, 19, 384, 384)
            >>> lbs = torch.randint(0, 2, (8, 19, 384, 384)).float()
            >>> loss = criteria(logits, lbs)
        '''
        loss = FocalSigmoidLossFunc.apply(logits, label, self.alpha, self.gamma)
        if self.reduction == 'mean':
            loss = loss.mean()
        if self.reduction == 'sum':
            loss = loss.sum()
        return loss

In [None]:
class DilatedConvBlock(nn.Module):
    def __init__(self, ch_in, ch_out, dilation=2):
        super(DilatedConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=dilation, dilation=dilation, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.LeakyReLU(inplace=True),
            
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=dilation, dilation=dilation, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.LeakyReLU(inplace=True),
        )
        
    def forward(self, x):
        return self.conv(x)
    
class ConvBlock(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.LeakyReLU(inplace=True),
            
            nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.LeakyReLU(inplace=True),
        )
        
    def forward(self, x):
        return self.conv(x)


class UpConvBlock(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(UpConvBlock, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(ch_out),
            nn.LeakyReLU(inplace=True),
        )

    def forward(self, x):
        x = self.up(x)
        return x


class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.LeakyReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1+x1)
        psi = self.psi(psi)

        return x*psi
    
    
class UpModule(nn.Module):
    def __init__(self, feature):
        super(UpModule, self).__init__()
        self.up = UpConvBlock(feature*2, feature)
        self.attn = AttentionBlock(feature, feature, feature // 2)
        self.conv = DilatedConvBlock(feature*2, feature)
        self.act = nn.LeakyReLU(inplace=True)

    def forward(self, x, skip_connection):
        x = self.up(x)
        skip = skip_connection

        if x.shape != skip_connection.shape:
            x = torchvision.transforms.functional.resize(x, size=skip_connection.shape[2:])

        skip_connection = self.attn(x, skip_connection)

        concat_skip = torch.cat((skip_connection, x), dim=1)
        x = self.conv(concat_skip)
        
        x = torch.add(x, skip)
        x = torch.add(x, skip_connection)
        
        return self.act(x)

class DownModule(nn.Module):
    def __init__(self, in_channels, feature):
        super(DownModule, self).__init__()
        self.down = DilatedConvBlock(in_channels, feature)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        

    def forward(self, x):
        x = self.down(x)
        skip = x
        x = self.pool(x)

        return x, skip


class AttenUNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256]):
        super(AttenUNET, self).__init__()
        
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        
        down_outputs = 0
        
        # Down part of UNET
        for feature in features:
            self.downs.append(DownModule(in_channels, feature))
            in_channels = feature
        
                
        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(UpModule(feature))

        self.bottleneck = nn.Sequential(
            ConvBlock(features[-1], features[-1]*2),
            ConvBlock(features[-1]*2, features[-1]*2),
        )
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)


    def forward(self, x):
        #x = nn.functional.pad(x, (PADDING, PADDING), value=1) # White padding
        skip_connections = []

        for down in self.downs:
            x, skip = down(x)
            skip_connections.append(skip)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        
        
        for i, up in enumerate(self.ups):
            skip_connection = skip_connections[i]
            x = self.ups[i](x, skip_connection)
        
        y = self.final_conv(x)
        #y = nn.functional.interpolate(y, size=(HEIGHT, WIDTH)) # Reverse the padding
        return y

In [None]:
model = AttenUNET(3, 1).to(DEVICE)
#model = torch.compile(model) # Faster

loss_fn = FocalLoss(gamma=2)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
num_steps = (len(train_loader) * NUM_EPOCHS)
ex = lambda x: 0.999 ** x
lr_lambda = lambda x: ex(x) if ex(x) > 1e-2 else 1e-2
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
scaler = torch.cuda.amp.GradScaler()

In [None]:
max_score = 0.0

for epoch in range(NUM_EPOCHS):
    train_fn(train_loader, model, optimizer, loss_fn, scaler, lr_scheduler, epoch)

    # check accuracy
    dice_score, loss = validate_fn(val_loader, model, loss_fn, device=DEVICE)
    
    if dice_score > max_score:
        max_score = dice_score
        
        checkpoint = {
            "state_dict": model.state_dict(),
        }
        save_checkpoint(checkpoint, "attn-unet-best.pth.tar")

In [None]:
model.eval()

for x, y in val_loader:
    with torch.no_grad():
        x = x.to(DEVICE)
        y = y.to(DEVICE).unsqueeze(1)
        preds = torch.sigmoid(model(x))
        preds[preds >= 0.5] = 1
        
        for i in range(BATCH_SIZE):
            plt.figure(figsize=(10, 10))
            plt.subplot(1, 2, 1)
            plt.imshow(x[i].cpu().permute(1, 2, 0))
            plt.imshow(y[i].cpu().squeeze(0), cmap='gray', alpha=0.5)
            
            plt.subplot(1, 2, 2)
            plt.imshow(x[i].cpu().permute(1, 2, 0))
            plt.imshow(preds[i].cpu().squeeze(0), cmap='gray', alpha=0.5)
        break
model.train()