In [None]:
from matplotlib import pyplot as plt
from PIL import Image
import numpy as np
import os
import cv2
import torch
from torch import nn
import torch.optim as optim
import torch.nn.functional as F
from torchmetrics.classification import BinaryF1Score
import pickle

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
MASKS_PATH = 'datasets/train_updated_titiles/masks/train_mask_{num:03d}.png'
IMAGES_PATH = 'datasets/train_updated_titiles/images/train_image_{num:03d}.png'
data = []
total_area = 0
for num in range(0,21):
    im_file = IMAGES_PATH.format(num=num)
    mask_file = MASKS_PATH.format(num=num)
    file_name = os.path.basename(im_file)
    im_cv = np.array(Image.open(im_file))
    if im_cv.shape[2] != 3:
        im_cv = im_cv[:,:,:3]
    mask_cv = np.array(Image.open(mask_file))
    print(num, im_cv.shape, mask_cv.shape)
    data.append((im_cv, mask_cv))
    total_area += im_cv.shape[0]*im_cv.shape[1]

for i, (img, mask) in enumerate(data):
    area = img.shape[0]*img.shape[1]
    percent = area / total_area
    data[i] += (percent,)
    print(img.shape[:2], percent)

In [None]:
def random_crop(img, mask, size):
    half_size = size // 2
    d = half_size * 2**0.5
    angle = np.random.randint(360)
    rad =  np.pi * angle / 180
    dist = int(np.ceil(d * max(abs(np.cos(np.pi/4 - rad)), abs(np.sin(np.pi/4 - rad)))))
    angle, rad, dist
    dist = dist + 1
    x = np.random.randint(dist, img.shape[1] - dist)
    y = np.random.randint(dist, img.shape[0] - dist)
    area = img[y-dist:y+dist, x-dist:x+dist, :]
    area_mask = mask[y-dist:y+dist, x-dist:x+dist]
    
    (h, w) = area.shape[:2]
    (cX, cY) = (w // 2, h // 2)
    M = cv2.getRotationMatrix2D((cX, cY), angle, 1.0)
    rotated = cv2.warpAffine(area, M, (w, h))
    rotated_mask = cv2.warpAffine(area_mask, M, (w, h))
    cropped = rotated[cY-half_size:cY+half_size, cX-half_size:cX+half_size,:]
    cropped_mask = rotated_mask[cY-half_size:cY+half_size, cX-half_size:cX+half_size]
    return cropped, cropped_mask


def get_batch(length, size, min_percent = 0):
    res_x = []
    res_y = []
    limit = size * size * min_percent
    for img, mask, percent in data:
        for i in range(int(round(length * percent))):
            while True:
                cropped, cropped_mask = random_crop(img, mask, size)
                if cropped_mask.sum() > limit:
                    break
            res_x.append(cropped)
            res_y.append(cropped_mask)
    res_x = np.array(res_x)
    res_y = np.array(res_y)
    res_y = res_y.reshape(res_y.shape + (1,))
    return res_x, res_y
        

cnt = len(data)
fig, axes = plt.subplots(cnt, 2, figsize=(10,5*cnt))
for (img, mask, percent), (ax_img, ax_mask) in zip([data[4]]*cnt, axes):
    cropped, cropped_mask = random_crop(img, mask, 512)
    ax_img.imshow(cropped)
    ax_mask.imshow(cropped_mask)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.cuda.amp as amp


## Soft Dice Loss for binary segmentation
##
# v1: pytorch autograd
class SoftDiceLossV1(nn.Module):
    '''
    soft-dice loss, useful in binary segmentation
    '''
    def __init__(self,
                 p=1,
                 smooth=1):
        super(SoftDiceLossV1, self).__init__()
        self.p = p
        self.smooth = smooth

    def forward(self, logits, labels):
        '''
        inputs:
            logits: tensor of shape (N, H, W, ...)
            label: tensor of shape(N, H, W, ...)
        output:
            loss: tensor of shape(1, )
        '''
        probs = torch.sigmoid(logits)
        numer = (probs * labels).sum()
        denor = (probs.pow(self.p) + labels.pow(self.p)).sum()
        loss = 1. - (2 * numer + self.smooth) / (denor + self.smooth)
        return loss


##
# v2: self-derived grad formula
class SoftDiceLossV2(nn.Module):
    '''
    soft-dice loss, useful in binary segmentation
    '''
    def __init__(self,
                 p=1,
                 smooth=1):
        super(SoftDiceLossV2, self).__init__()
        self.p = p
        self.smooth = smooth

    def forward(self, logits, labels):
        '''
        inputs:
            logits: tensor of shape (N, H, W, ...)
            label: tensor of shape(N, H, W, ...)
        output:
            loss: tensor of shape(1, )
        '''
        logits = logits.view(1, -1)
        labels = labels.view(1, -1)
        loss = SoftDiceLossV2Func.apply(logits, labels, self.p, self.smooth)
        return loss


class SoftDiceLossV2Func(torch.autograd.Function):
    '''
    compute backward directly for better numeric stability
    '''
    @staticmethod
    @amp.custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, logits, labels, p, smooth):
        '''
        inputs:
            logits: (N, L)
            labels: (N, L)
        outpus:
            loss: (N,)
        '''
        #  logits = logits.float()

        probs = torch.sigmoid(logits)
        numer = 2 * (probs * labels).sum(dim=1) + smooth
        denor = (probs.pow(p) + labels.pow(p)).sum(dim=1) + smooth
        loss = 1. - numer / denor

        ctx.vars = probs, labels, numer, denor, p, smooth
        return loss

    @staticmethod
    @amp.custom_bwd
    def backward(ctx, grad_output):
        '''
        compute gradient of soft-dice loss
        '''
        probs, labels, numer, denor, p, smooth = ctx.vars

        numer, denor = numer.view(-1, 1), denor.view(-1, 1)

        term1 = (1. - probs).mul_(2).mul_(labels).mul_(probs).div_(denor)

        term2 = probs.pow(p).mul_(1. - probs).mul_(numer).mul_(p).div_(denor.pow_(2))

        grads = term2.sub_(term1).mul_(grad_output)

        return grads, None, None, None




In [None]:
class BinaryDiceLoss(nn.Module):
    """Soft Dice loss of binary class
    Args:
        p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2
        predict: A tensor of shape [N, *]
        target: A tensor of shape same with predict
       Returns:
        Loss tensor

    """

    def __init__(self, p=2, epsilon=1e-6):
        super().__init__()
        self.p = p  # pow degree
        self.epsilon = epsilon

    def forward(self, predict, target):
        predict = predict.flatten(1)
        target = target.flatten(1)

        # https://pytorch.org/docs/stable/generated/torch.mul.html
        num = torch.sum(torch.mul(predict, target), dim=1) + self.epsilon
        den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.epsilon
        loss = 1 - 2 * num / den

        return loss.mean()  # over batch

#PyTorch
class IoULoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(IoULoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        #intersection is equivalent to True Positive count
        #union is the mutually inclusive area of all labels & predictions 
        intersection = (inputs * targets).sum()
        total = (inputs + targets).sum()
        union = total - intersection 
        
        IoU = (intersection + smooth)/(union + smooth)
                
        return 1 - IoU


#PyTorch
ALPHA = 0.5
BETA = 0.5

class TverskyLoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(TverskyLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1, alpha=ALPHA, beta=BETA):
        
        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = F.sigmoid(inputs)       
        
        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        #True Positives, False Positives & False Negatives
        TP = (inputs * targets).sum()    
        FP = ((1-targets) * inputs).sum()
        FN = (targets * (1-inputs)).sum()
       
        Tversky = (TP + smooth) / (TP + alpha*FP + beta*FN + smooth)  
        
        return 1 - Tversky

class MiniUnet(nn.Module):
    def __init__(self, first_size=64):
        super().__init__()
        self.block1 = self.block(3, first_size) #64
        self.block2 = self.block(first_size, first_size*2) #128
        self.block3 = self.block(first_size*2, first_size*4) #256
        self.block4 = self.block(first_size*4, first_size*8) #512
        
        self.block5 = self.block(first_size*8, first_size*4)
        self.block6 = self.block(first_size*4, first_size*2)
        self.block7 = self.block(first_size*2, first_size)
        self.pool = nn.MaxPool2d(2)
        
        self.up1 = self.up(first_size*8)
        self.up2 = self.up(first_size*4)
        self.up3 = self.up(first_size*2)
        self.conv = nn.Conv2d(first_size, 1, kernel_size=1, bias=False)
        #self.up1 = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        # Your code here

    def block(self, in_channels, out_channels):
      return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
      
    def up(self, in_channels):
        return nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
      # return nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def cat(self, out_up, out_down):
        return torch.cat((out_up, out_down), dim=1)
    
    def forward(self, x):
        pad = (0, (16 - x.shape[3] % 16) % 16, 0, (16 - x.shape[2] % 16) % 16)
        if sum(pad) > 0:
            print(x.shape)
            print(pad)
            x = F.pad(input=x, pad=pad, mode='constant', value=0.0)
            print(x.shape)

        out1 = self.block1(x) #  ------------------------------>
        out_pool1 = self.pool(out1)
        
        out2 = self.block2(out_pool1)
        out_pool2 = self.pool(out2)
        
        out3 = self.block3(out_pool2)
        out_pool3 = self.pool(out3)
        
        out4 = self.block4(out_pool3)
        # return up
        
        out_up1 = self.up1(out4)
        out_cat1 = self.cat(out_up1, out3)
        out5 = self.block5(out_cat1)
        
        out_up2 = self.up2(out5)
        out_cat2 = self.cat(out_up2, out2)
        out6 = self.block6(out_cat2)
        
        out_up3 = self.up3(out6)
        out_cat3 = self.cat(out_up3, out1) # <-------
        out7 = self.block7(out_cat3)
        
        out = self.conv(out7)
        return out

class MaxiUnet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, first_size=64, layers=4):
        b = 64
        super().__init__()
        forward_blocks = [self.block(in_channels, first_size)]
        backward_blocks = []
        upscale_blocks = []
        size = first_size
        for i in range(1, layers):
            block = self.block(size, size * 2)
            forward_blocks.append(block)
            block = self.block(size * 2, size)
            backward_blocks.append(block)
            size *= 2
            block = self.up(size)
            upscale_blocks.append(block)
        forward_blocks[-1].use_pool = False
        backward_blocks.reverse()
        upscale_blocks.reverse()
        self.forward_blocks = nn.ModuleList(forward_blocks)
        self.backward_blocks = nn.ModuleList(backward_blocks)
        self.upscale_blocks = nn.ModuleList(upscale_blocks)
        self.pool = nn.MaxPool2d(2)
        self.out_block = nn.Conv2d(first_size, out_channels, kernel_size=1, bias=False)

    def block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        block.use_pool = True
        return block
      
    def up(self, in_channels):
        return nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)

    def cat(self, out_up, out_down):
        return torch.cat((out_up, out_down), dim=1)
    
    def forward(self, x):
        outs = []
        for block in self.forward_blocks:
            x = block(x)
            outs.append(x)
            if block.use_pool:
                x = self.pool(x)
        outs.pop()
        outs.reverse()
        for out, block, up in zip(outs, self.backward_blocks, self.upscale_blocks):
            x = up(x)
            cat = self.cat(x, out)
            x = block(cat)
        return self.out_block(x)

model = MaxiUnet(first_size=112, layers=5, out_channels=1)
with open('weights_5_112.pkl', 'rb') as f:
    model.load_state_dict(pickle.load(f))
model.to(device)
model

In [None]:
%%time
optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=1e-4)
criterion = SoftDiceLossV1()
model.to(device)


In [None]:
%%time
model.train()

QUOTA = 1024 * 1024
epochs = 50
metric = BinaryF1Score()
metric.to(device)

batch_count = 10
for size in (128, 256, 512, 1024):
    batch_size = QUOTA // size // size
    
    
    train_x, train_y = get_batch(batch_size * batch_count, size)
    train_x = train_x / 255
    for epoch in range(1, epochs+1):
        ep_loss = 0
        outputs = []
        xs = []
        count = 0
        for batch_num in range(batch_count):
            x = torch.tensor(train_x[batch_num*batch_size:(batch_num+1)*batch_size, :, :, :], dtype=torch.float, device=device)
            y = torch.tensor(train_y[batch_num*batch_size:(batch_num+1)*batch_size, :, :, :], dtype=torch.float, device=device)
            x = x.permute(0, 3,1,2)
            y = y.permute(0, 3,1,2)
            batch_x = x
            batch_y = y
            
            optimizer.zero_grad()
            output = model(batch_x)
            loss = criterion(output.squeeze(1), batch_y.squeeze(1))
            loss.backward()
            
            outputs.append(output.detach())
            ep_loss += loss.item()
            count += 1
            optimizer.step()
        outputs = torch.cat(outputs)
        print(f"Epoch {epoch}: loss={ep_loss/count}, f1score={metric(outputs, torch.tensor(train_y, device=device).permute(0,3,1,2))}")

In [None]:
%%time

QUOTA = 1024 * 1024
epochs = 50
metric = BinaryF1Score()
metric.to(device)
for size in (128, 256):
    batch_size = QUOTA // size // size
    for epoch in range(1, epochs+1):
        train_x, train_y = get_batch(batch_size, size)
        train_x = train_x / 255
        x = torch.tensor(train_x, dtype=torch.float, device=device)
        y = torch.tensor(train_y, dtype=torch.float, device=device)
        x = x.permute(0, 3,1,2)
        y = y.permute(0, 3,1,2)
        batch_x = x
        batch_y = y
        
        optimizer.zero_grad()
        output = model(batch_x)
        loss = criterion(output.squeeze(1), batch_y.squeeze(1))
        loss.backward()
        optimizer.step()
        print(f"Epoch {epoch}: loss={loss.item()}, f1score={metric(output, torch.tensor(train_y, device=device).permute(0,3,1,2))}")

In [None]:
cnt = min(10, x.shape[0])
indeces = np.arange(x.shape[0])
np.random.shuffle(indeces)
indeces = indeces[:cnt]

fig, axes = plt.subplots(cnt, 4, figsize=(20,5 * cnt))
plot_x = train_x
plot_y = train_y
plot_y = plot_y.reshape(plot_y.shape[:3])
plot_p = output.detach().cpu().permute(0,2,3,1).numpy()
plot_p = plot_p.reshape(plot_p.shape[:3])
if plot_x.shape[0] == 1:
    axes = [axes]
for im, mask, pred, (ax_img, ax_mask, ax_pred, ax_round) in zip(plot_x[indeces], plot_y[indeces], plot_p[indeces], axes):
    ax_img.imshow(im)
    ax_mask.imshow(mask)
    ax_pred.imshow(pred)
    ax_round.imshow(np.array(pred > 0.5,dtype=np.uint8))

In [None]:
with open('weights_5_112.pkl', 'wb') as f:
    pickle.dump(model.state_dict(), f)


In [None]:
def gen_mask(lx, cx, rx, size):
    mask = np.ones(size)
    half_size = size // 2
    if rx is not None:
        ls = np.linspace(1, 0, (cx + half_size) - (rx - half_size))
        mask[-ls.shape[0]:] = ls
        
    if lx is not None:
        ls = np.linspace(0, 1, lx + half_size - (cx - half_size))
        mask[:ls.shape[0]] = ls
    return mask

def process_image(model, img, size):
    half_size = size // 2
    lx = half_size
    ty = half_size
    rx = img.shape[1] - half_size
    by = img.shape[0] - half_size
    w = rx - lx
    h = by - ty
    count_x = w // half_size
    count_y = h // half_size
    dx = w / count_x
    dy = h / count_y

    center_x = [lx]
    for i in range(1, count_x):
        x = lx + int(round(i * dx))
        center_x.append(x)
    center_x.append(rx)

    center_y = [ty]
    for i in range(1, count_y):
        y = ty + int(round(i * dy))
        center_y.append(y)
    center_y.append(by)

    tiles = []
    for cy in center_y:
        for cx in center_x:
            tiles.append(img[cy - half_size:cy + half_size, cx - half_size:cx + half_size, :])

    QUOTA = 1024 * 1024
    batch_size = QUOTA // size // size
    
    tiles = np.array(tiles)
    outputs = []
    for i in range(0, tiles.shape[0], batch_size):
        x = torch.tensor(tiles[i:i+batch_size], dtype=torch.float, device=device)/255
        x = x.permute(0, 3,1,2)
        output = model(x)
        outputs.append(output.detach().cpu().numpy())

    outputs = np.vstack(outputs)
    # res = np.zeros(img.shape[:2], dtype=np.float)
    tile = 0
    lines = []
    for line_num, cy in enumerate(center_y):
        line = np.zeros((size, img.shape[1]))
        for col_num, cx in enumerate(center_x):
            prev_col = col_num - 1
            lx = None
            rx = None
            if prev_col >= 0:
                lx = center_x[prev_col]
            next_col = col_num + 1
            if next_col < len(center_x):
                rx = center_x[next_col]
            mask = gen_mask(lx, cx, rx, size)
            mask = mask.reshape((1,size))
            output = outputs[tile] * mask
            line[:, cx - half_size:cx + half_size] = output
            tile += 1
        lines.append(line)

    res = np.zeros(img.shape[:2], dtype=np.float32)
    for line_num, cy in enumerate(center_y):
        line = lines[line_num]
        prev_line = line_num - 1
        ly = None
        ry = None
        if prev_line >= 0:
            ly = center_y[prev_line]
        next_line = line_num + 1
        if next_line < len(center_y):
            ry = center_y[next_line]
        mask = gen_mask(ly, cy, ry, size)
        mask = mask.reshape((size, 1))
        res[cy - half_size:cy + half_size, :] = mask*line
        
    return res


model.eval()

cnt = len(data)
fig, axes = plt.subplots(cnt*3, 1, figsize=(15,10*3*cnt))
for i in range(cnt):
    res = process_image(model, data[i][0], 256)
    mask = np.array(res > 0.5, dtype=np.uint8)
    print(f"train_mask_{i:03d}.png", BinaryF1Score()(torch.tensor(data[i][1]), torch.tensor(mask)).item())
    axes[3*i].imshow(data[i][0])
    axes[3*i + 1].imshow(data[i][1])
    axes[3*i + 2].imshow(np.array(mask>0, dtype=np.uint8))


In [None]:
from glob import glob
test_data = []
files = glob('datasets/images/test_image_*.png')
files.sort()
for im_file in files:
    name = os.path.basename(im_file)
    im_cv = np.array(Image.open(im_file))
    if im_cv.shape[2] != 3:
        im_cv = im_cv[:,:,:3]
    test_data.append((name, im_cv))

cnt = len(test_data)
fig, axes = plt.subplots(cnt, 2, figsize=(15,7*cnt))
for i in range(cnt):
    res = process_image(model, test_data[i][1], 512)
    mask = np.array(res > 0.5, dtype=np.uint8)
    axes[i][0].imshow(test_data[i][1])
    Image.fromarray(mask).save(f'result/test_mask_{i:03d}.png')
    axes[i][1].imshow(np.array(mask>0, dtype=np.uint8))