In [None]:
!pip install torchsummary
!pip install torchgeometry
!pip install torchvision
from torchsummary import summary
from torchgeometry.losses import one_hot

In [None]:
import os
import pandas as pd
import numpy as np
from PIL import Image
import cv2
import time
import imageio
import matplotlib.pyplot as plt
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Resize, PILToTensor, ToPILImage, Compose, InterpolationMode, Normalize
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from collections import OrderedDict
import wandb
import random
import copy
import torchvision

In [None]:
!pip install segmentation-models-pytorch
import segmentation_models_pytorch as smp

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
batch_size = 16
learning_rate = 1e-04
display_step = 50
checkpoint_path = "/kaggle/working/unet_model.pth"

In [None]:
image_transform = Compose([    
    Resize((224,224),
    interpolation = InterpolationMode.BICUBIC, 
    antialias = True),
])
mask_transform = Compose([
    Resize((224,224),
    interpolation = InterpolationMode.NEAREST_EXACT)  
])

In [None]:
A_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Transpose(p=0.5),
    A.OneOf([
      A.RandomGamma(gamma_limit = (50, 120), p=1),  
      A.CLAHE(clip_limit = 2),
      A.Sharpen(p=1),
      A.Equalize(p=1),  
      A.ColorJitter(brightness = 0.5, contrast = 0.5, saturation = 2, hue = 0.05, p=1),  
    ], p = 0.5),   
    A.OneOf([
      A.Perspective(keep_size = True, fit_output = False, p=1),
      A.Rotate(limit=15, p=1, border_mode=cv2.BORDER_CONSTANT, value = 0, mask_value = 0)
    ], p=0.5),  
    A.ElasticTransform(p=0.5),
    A.RandomSunFlare(num_flare_circles_lower=0, num_flare_circles_upper = 1, src_radius = 10, p = 0.3),
])       

In [None]:
class DataClass(Dataset):
    def __init__(self, images_list, masks_list, image_transform, mask_transform, train=False):
        super(DataClass, self).__init__()
        self.images_list = images_list
        self.masks_list = masks_list
        self.image_transform = image_transform
        self.mask_transform = mask_transform
        self.train = train
    def __getitem__(self, index):
        img_path = self.images_list[index]
        mask_path = self.masks_list[index]

        data = Image.open(img_path)
        label = Image.open(mask_path)

        data = self.image_transform(data)
        label = self.mask_transform(label)

        data = np.array(data)
        label = np.array(label)

        if self.train == True: 
            A_transformed = A_transform(image = data, mask = label)
            data = A_transformed['image']
            label = A_transformed['mask']

        data = Image.fromarray(data)
        label = Image.fromarray(label)

        data = PILToTensor()(data).type(torch.float32)
        label = PILToTensor()(label)  / 255        
        
        data = Normalize(mean = (0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))(data)  

        label = torch.where(label>0.65, 1.0, 0.0)  
        label[2,:,:] = 0.0001 
        label = torch.argmax(label, 0).type(torch.int64)  

        return data, label
    
    def __len__(self):
        return len(self.images_list)    

In [None]:
train_size = 0.8
valid_size = 0.2

In [None]:
ENCODER = 'resnet152'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = [0,1,2]
ACTIVATION = 'softmax2d'

model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
)

In [None]:
images_path = "/kaggle/input/bkai-igh-neopolyp/train/train/"  
masks_path = "/kaggle/input/bkai-igh-neopolyp/train_gt/train_gt/"

images_list = os.listdir(images_path)
masks_list = os.listdir(masks_path)

images_list = [images_path + image_name for image_name in images_list]
masks_list = [masks_path + mask_name for mask_name in masks_list]

combined_list = list(zip(images_list, masks_list))
random.shuffle(combined_list)

train_images_list = images_list[:int(train_size*len(images_list))]
valid_images_list = images_list[int(train_size*len(images_list)):]
train_masks_list = masks_list[:int(train_size*len(masks_list))]
valid_masks_list = masks_list[int(train_size*len(masks_list)):]

train_set = DataClass(copy.deepcopy(train_images_list), copy.deepcopy(train_masks_list), image_transform, mask_transform, True)
valid_set = DataClass(copy.deepcopy(valid_images_list), copy.deepcopy(valid_masks_list), image_transform, mask_transform, False)

train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_set, batch_size=batch_size, shuffle=True)

In [None]:
def dice_score(input, target, smooth, weights):
        
    input_soft = F.softmax(input, dim=1)
    # create the labels one hot tensor
    target_one_hot = one_hot(target, num_classes=input.shape[1],
                             device=input.device, dtype=input.dtype)

    # compute the actual dice score
    dims = (2, 3)
    intersection = torch.sum(input_soft * target_one_hot, dims)
    cardinality = torch.sum(input_soft + target_one_hot, dims)

    dice_score = (2.*intersection)/(cardinality + smooth)

    dice_score = torch.sum(dice_score * weights, dim = 1)

    return torch.mean(dice_score) 

class CEDiceLoss(nn.Module):
    def __init__(self, weights) -> None:
        super(CEDiceLoss, self).__init__()
        self.eps: float = 1e-6
        self.weights: torch.Tensor = weights

    def forward(
            self,
            input: torch.Tensor,
            target: torch.Tensor) -> torch.Tensor:
        
        if not torch.is_tensor(input):
            raise TypeError("Input type is not a torch.Tensor. Got {}"
                            .format(type(input)))
        if not len(input.shape) == 4:
            raise ValueError("Invalid input shape, we expect BxNxHxW. Got: {}"
                             .format(input.shape))
        if not input.shape[-2:] == target.shape[-2:]:
            raise ValueError("input and target shapes must be the same. Got: {}"
                             .format(input.shape, input.shape))
        if not input.device == target.device:
            raise ValueError(
                "input and target must be in the same device. Got: {}" .format(
                    input.device, target.device))
        if not self.weights.shape[1] == input.shape[1]:
            raise ValueError("The number of weights must equal the number of classes")
        if not torch.sum(self.weights).item() == 1:
            raise ValueError("The sum of all weights must equal 1")
            
        # cross entropy loss
        celoss = nn.CrossEntropyLoss(self.weights)(input, target)
        
        dicescore = dice_score(input, target, self.eps, self.weights)
        
        return 1 - dicescore + celoss    

In [None]:
def save_model(model, optimizer, path):
    checkpoint = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, path)
    print('model saved')

def load_model(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint["model"])
    optimizer.load_state_dict(checkpoint['optimizer'])
    print('model loaded')
    return model, optimizer

In [None]:
def train(train_dataloader, 
          valid_dataloader,
          epoch, display_step):
    
    start_time = time.time()
    train_loss_epoch = 0
    test_loss_epoch = 0
    train_dice_score_epoch = 0
    test_dice_score_epoch = 0
    
    last_loss = 999999999
    model.train()
    for i, (data,targets) in enumerate(train_dataloader):
        
        data, targets = data.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(data)
        
        loss = loss_function(outputs, targets.long())
        loss.backward()
        
        optimizer.step()
        
        score = dice_score(outputs, targets.long(), 1e-6, weights)
        train_loss_epoch += loss.item()
        train_dice_score_epoch += score.item()
        if (i+1) % display_step == 0:
            print('Train Epoch: {} [{}/{} ({}%)]\tLoss: {:.4f}\tScore: {:.4f}'.format(
                epoch + 1, (i+1) * len(data), len(train_dataloader.dataset), 100 * (i+1) * len(data) / len(train_dataloader.dataset), loss.item(), score.item()
            ))     
                  
    train_loss_epoch/= (i + 1)
    train_dice_score_epoch /= (i + 1) 
    print(f"Done epoch #{epoch+1}, time for this epoch: {time.time()-start_time}s")
                  
    model.eval()
    with torch.no_grad():
        for data, target in valid_dataloader:
            data, target = data.to(device), target.to(device)
            test_output = model(data)
            test_loss = loss_function(test_output, target)
            test_score = dice_score(test_output, target, 1e-6, weights)      
            test_loss_epoch += test_loss.item()
            test_dice_score_epoch += test_score.item()
            print(test_loss.item(), test_score.item())
              
    test_loss_epoch/= (i+1)
    test_dice_score_epoch/= (i+1)
    
    return train_loss_epoch , test_loss_epoch, train_dice_score_epoch, test_dice_score_epoch    

In [None]:
model = nn.DataParallel(model)
model.to(device)
summary(model, (3, 224, 224))

In [None]:
weights = torch.Tensor([[0.35, 0.58, 0.07]]).cuda()
loss_function = CEDiceLoss(weights)
optimizer = optim.Adam(params=model.parameters(), lr=learning_rate)

In [None]:
wandb.login(
    key = ""
)

wandb.init(
    project = "PolypSegment_2",
)

In [None]:
last_score = 0
epochs = 50
for epoch in range(epochs):
    train_loss_epoch = 0
    test_loss_epoch = 0
    train_dice_score_epoch = 0
    test_dice_score_epoch = 0
    (train_loss_epoch, test_loss_epoch, train_dice_score_epoch, test_dice_score_epoch) = train(train_dataloader, valid_dataloader, epoch, display_step)

    if test_dice_score_epoch > last_score:    
        save_model(model, optimizer, checkpoint_path)
        last_score = test_dice_score_epoch    
        
    wandb.log({"Train loss": train_loss_epoch, "Valid loss": test_loss_epoch, "Train score": train_dice_score_epoch, "Valid score": test_dice_score_epoch})
    

In [None]:
for i, (data, label) in enumerate(train_dataloader):
    img = data
    mask = label
    break
    
fig, arr = plt.subplots(4, 3, figsize=(16, 12))
arr[0][0].set_title('Image')
arr[0][1].set_title('Segmentation')
arr[0][2].set_title('Predict')

model.eval()
with torch.no_grad():
    predict = model(img)
    
for i in range(4):
    arr[i][0].imshow(img[i].permute(1, 2, 0));
    
    arr[i][1].imshow(F.one_hot(mask[i]).float())
    
    arr[i][2].imshow(F.one_hot(torch.argmax(predict[i], 0).cpu()).float())

In [None]:
transform = Compose([
    Resize((224,224),
        interpolation = InterpolationMode.BICUBIC, 
        antialias = True),
    PILToTensor(),
])        

class UNetTestDataClass(Dataset):
    def __init__(self, images_path, transform):
        super(UNetTestDataClass, self).__init__()
        
        images_list = os.listdir(images_path)
        images_list = [images_path+i for i in images_list]
        
        self.images_list = images_list
        self.transform = transform   
    
    def __getitem__(self, index):
        img_path = self.images_list[index]
        data = Image.open(img_path)
        h = data.size[1]
        w = data.size[0]
        
        data = self.transform(data)
        data = data.type(torch.float32)
        data = Normalize(mean = (0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))(data)        

        return data, img_path, h, w
    
    def __len__(self):
        return len(self.images_list)

In [None]:
path = '/kaggle/input/bkai-igh-neopolyp/test/test/'
unet_test_dataset = UNetTestDataClass(path, transform)
test_dataloader = DataLoader(unet_test_dataset, batch_size=8, shuffle=True)

for i, (data, path, h, w) in enumerate(test_dataloader):
    img = data
    break
    
fig, arr = plt.subplots(5, 2, figsize=(16, 12))
arr[0][0].set_title('Image');
arr[0][1].set_title('Predict');

model.eval()
with torch.no_grad():
    predict = model(img)

for i in range(5):
    arr[i][0].imshow(img[i].permute(1, 2, 0));
    arr[i][1].imshow(F.one_hot(torch.argmax(predict[i], 0).cpu()).float())

In [None]:
model.eval()
if not os.path.isdir("/kaggle/working/predicted_masks"):
    os.mkdir("/kaggle/working/predicted_masks")
for _, (img, path, H, W) in enumerate(test_dataloader):
    a = path
    b = img
    h = H
    w = W
    
    with torch.no_grad():
        predicted_mask = model(b)
    for i in range(len(a)):
        image_id = a[i].split('/')[-1].split('.')[0]
        filename = image_id + ".png"
        mask2img = Resize((h[i].item(), w[i].item()), interpolation=InterpolationMode.NEAREST_EXACT)(ToPILImage()(F.one_hot(torch.argmax(predicted_mask[i], 0)).permute(2, 0, 1).float()))
        mask2img.save(os.path.join("/kaggle/working/predicted_masks/", filename))

In [None]:
def rle_to_string(runs):
    return ' '.join(str(x) for x in runs)

def rle_encode_one_mask(mask):
    pixels = mask.flatten()
    pixels[pixels > 0] = 255
    use_padding = False
    if pixels[0] or pixels[-1]:
        use_padding = True
        pixel_padded = np.zeros([len(pixels) + 2], dtype=pixels.dtype)
        pixel_padded[1:-1] = pixels
        pixels = pixel_padded
    
    rle = np.where(pixels[1:] != pixels[:-1])[0] + 2
    if use_padding:
        rle = rle - 1
    rle[1::2] = rle[1::2] - rle[:-1:2]
    return rle_to_string(rle)

def mask2string(dir):
    ## mask --> string
    strings = []
    ids = []
    ws, hs = [[] for i in range(2)]
    for image_id in os.listdir(dir):
        id = image_id.split('.')[0]
        path = os.path.join(dir, image_id)
        print(path)
        img = cv2.imread(path)[:,:,::-1]
        h, w = img.shape[0], img.shape[1]
        for channel in range(2):
            ws.append(w)
            hs.append(h)
            ids.append(f'{id}_{channel}')
            string = rle_encode_one_mask(img[:,:,channel])
            strings.append(string)
    r = {
        'ids': ids,
        'strings': strings,
    }
    return r

MASK_DIR_PATH = '/kaggle/working/predicted_masks' # change this to the path to your output mask folder
dir = MASK_DIR_PATH
res = mask2string(dir)
df = pd.DataFrame(columns=['Id', 'Expected'])
df['Id'] = res['ids']
df['Expected'] = res['strings']
df.to_csv(r'output.csv', index=False)