In [None]:

import os
import pandas as pd
import numpy as np
from PIL import Image
import cv2
import time
import imageio
import matplotlib.pyplot as plt
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torch import Tensor
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.transforms import Resize, PILToTensor, ToPILImage, Compose, InterpolationMode
from collections import OrderedDict
import wandb
import torchvision.transforms as transforms
import random
import albumentations as A
import timm
from albumentations.pytorch import ToTensorV2

In [None]:
!nvidia-smi -L

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
num_classes = 3

# Number of epoch
epochs = 40

# Hyperparameters for training 
learning_rate = 0.0001
batch_size = 4
display_step = 50

# Model path
checkpoint_path = '/kaggle/working/resunet_model.pth'
pretrained_path = "/kaggle/input/abcdef/resunet_model.pth"
# Initialize lists to keep track of loss and accuracy
loss_epoch_array = []
train_accuracy = []
test_accuracy = []
valid_accuracy = []

In [None]:
transform = A.Compose([
    A.Resize(height = 800, width = 1120),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Blur(blur_limit=(2, 6), p=0.2),
    ToTensorV2()
    
])
test_transform = A.Compose([
    A.Resize(height = 800, width = 1120),
    ToTensorV2()
    
])

label_transform = A.Compose([
    A.Resize(height = 800, width = 1120),
    ToTensorV2()
])


In [None]:
class UNetDataClass(Dataset):
    def __init__(self, images_path, masks_path, transform):
        super(UNetDataClass, self).__init__()
        
        images_list = os.listdir(images_path)
        masks_list = os.listdir(masks_path)
        
        images_list = [images_path + image_name for image_name in images_list]
        masks_list = [masks_path + mask_name for mask_name in masks_list]
        
        
        self.images_list = images_list
        self.masks_list = masks_list
        self.transform = transform
    
        
    def __getitem__(self, index):
        img_path = self.images_list[index]
        mask_path = self.masks_list[index]
        # Open image and mask
        data = Image.open(img_path)
        label = Image.open(mask_path)
        
        data = np.array(data)
        label = np.array(label)
        # Normalize
        
        augmented = self.transform(image = data, mask = label)
        
        data= augmented["image"]
        label = augmented["mask"]
        data = data / 255
        label = label / 255
        
        
        
        label = torch.where(label>0.1, 1.0, 0.0)
        label[:, :, 2] = 0.0001
        label = torch.argmax(label, 2).type(torch.int64)
        return data, label
    
    def __len__(self):
        return len(self.images_list)

In [None]:
images_path = "/kaggle/input/bkai-igh-neopolyp/train/train/"
masks_path =  "/kaggle/input/bkai-igh-neopolyp/train_gt/train_gt/"

In [None]:
unet_dataset = UNetDataClass(images_path, masks_path, transform)

In [None]:
img, mask = unet_dataset[21]

In [None]:
plt.subplot(1, 2, 1)
plt.imshow(img.permute(1, 2, 0))
plt.subplot(1, 2, 2)
plt.imshow(mask)
plt.show()

In [None]:
train_size = 0.85
valid_size = 0.15

In [None]:
train_set, valid_set = random_split(unet_dataset, 
                                    [int(train_size * len(unet_dataset)) , 
                                     int(valid_size * len(unet_dataset))])

In [None]:
len(train_set)

In [None]:
train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_set, batch_size=batch_size, shuffle=False)

In [None]:
def unet_block(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, 1, 1),
        nn.ReLU(),
        nn.Conv2d(out_channels, out_channels, 3, 1, 1),
        nn.ReLU()
    )

class ResUnet(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.n_classes = n_classes
        self.backbone = timm.create_model("resnet50", pretrained=True, features_only=True)
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear")
        self.block_neck = unet_block(2048, 1024)
        self.block_up1 = unet_block(1024+1024, 512)
        self.block_up2 = unet_block(512+512, 256)
        self.block_up3 = unet_block(256+256, 128)
        self.block_up4 = unet_block(128+64, 64)
        self.conv_cls = nn.Conv2d(64, self.n_classes, 1)

    def forward(self, x):
        x1, x2, x3, x4, x5 = self.backbone(x)
        x = self.block_neck(x5) # x (B, 1024, 8, 8)
        x = torch.cat([x4, self.upsample(x)], dim=1)
        x = self.block_up1(x)
        x = torch.cat([x3, self.upsample(x)], dim=1)
        x = self.block_up2(x)
        x = torch.cat([x2, self.upsample(x)], dim=1)
        x = self.block_up3(x)
        x = torch.cat([x1, self.upsample(x)], dim=1)
        x = self.block_up4(x)
        x = self.conv_cls(x) #size/2
        x = self.upsample(x)
        return x

In [None]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def accuracy_function(preds, targets):
    preds_flat = preds.flatten()
    targets_flat = targets.flatten()
    acc = torch.sum(preds_flat == targets_flat)
    return acc/targets_flat.shape[0]


In [None]:
def save_model(model, optimizer, path):
    checkpoint = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, path)

def load_model(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint["model"])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer

In [None]:
torch.cuda.empty_cache()

In [None]:
import torchmetrics
model = ResUnet(3).to(device)
checkpoint_path = "/kaggle/input/resunetsss/modelUNet_ep_30.pth"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)


#optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)
n_eps = 30
learing_rate_scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7)


#metrics
dice_fn = torchmetrics.Dice(num_classes=3, average="macro").to(device)
iou_fn = torchmetrics.JaccardIndex(num_classes=3, task="multiclass", average="macro").to(device)

#meter
acc_meter = AverageMeter()
train_loss_meter = AverageMeter()
dice_meter = AverageMeter()
iou_meter = AverageMeter()

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, num_classes, average="macro"):
        super(DiceLoss, self).__init__()
        self.dice_fn = torchmetrics.Dice(num_classes=num_classes, average=average)

    def forward(self, y_pred, y_true):
        # Calculate the Dice coefficient using the provided function
        dice_coefficient = self.dice_fn(y_pred, y_true)

        # Convert the Dice coefficient to a loss (1 - Dice)
        loss = 1.0 - dice_coefficient

        return loss

In [None]:
class Loss_fn(nn.Module):
    def __init__(self, num_classes, average = "macro"):
        super(Loss_fn, self).__init__()
        self.dice_fn = DiceLoss(num_classes=num_classes, average="macro").to(device)
        self.cep = nn.CrossEntropyLoss().to(device)
        
    def forward(self, y_pred, y_target):
        return self.cep(y_pred, y_target) + 2*self.dice_fn(y_pred, y_target)
criterion = Loss_fn(3)

In [None]:
wandb.login(
    # set the wandb project where this run will be logged
#     project= "PolypSegment", 
    key = "67bdd4eaac7f84a7e588a2ffa2bd6c1d0176d640",
)

In [None]:
from tqdm import tqdm

wandb.init(
    project = "PolypSegment"
)
for ep in range(1, 1+n_eps):
    acc_meter.reset()
    train_loss_meter.reset()
    dice_meter.reset()
    iou_meter.reset()
    model.train()
    test_loss_epoch = 0;

    for batch_id, (x, y) in enumerate(tqdm(train_dataloader), start=1):
        optimizer.zero_grad()
        n = x.shape[0]
        x = x.to(device).float()
        y = y.to(device).long()
        y_hat = model(x) #(B, C, H, W)
        loss = criterion(y_hat, y) #(B, C, H, W) >< (B, H, W)
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            y_hat_mask = y_hat.argmax(dim=1).squeeze() # (B, C, H, W) -> (B, 1, H, W) -> (B, H, W)
            dice_score = dice_fn(y_hat_mask, y.long())
            iou_score = iou_fn(y_hat_mask, y.long())
            accuracy = accuracy_function(y_hat_mask, y.long())

            train_loss_meter.update(loss.item(), n)
            iou_meter.update(iou_score.item(), n)
            dice_meter.update(dice_score.item(), n)
            acc_meter.update(accuracy.item(), n)
        
        
    #Valid set
    with torch.no_grad():
        for (x,y) in valid_dataloader:
            x,y = x.to(device), y.to(device)
            test_output = model(x)
            test_loss = criterion(test_output, y)
            test_loss_epoch += test_loss.item()    
                
#     save_model(model, optimizer, checkpoint_path)
    
    print("EP {}, learning rate = {}, train loss = {}, accuracy = {}, IoU = {}, dice = {}".format(
        ep,learing_rate_scheduler.get_last_lr(), train_loss_meter.avg, acc_meter.avg, iou_meter.avg, dice_meter.avg
    ))
    wandb.log({"Train loss": train_loss_meter.avg, "Valid loss": test_loss_epoch})
#     wandb.log({"Train loss": train_loss_meter.avg})
    if ep >= 25:
        torch.save(model.state_dict(), "modelUNet_ep_{}.pth".format(ep))
    
    learing_rate_scheduler.step()

# Create submission

In [None]:
transform = Compose([Resize((800, 1120), interpolation=InterpolationMode.BILINEAR),
                     PILToTensor()])

In [None]:
class UNetTestDataClass(Dataset):
    def __init__(self, images_path, transform):
        super(UNetTestDataClass, self).__init__()
        
        images_list = os.listdir(images_path)
        images_list = [images_path+i for i in images_list]
        
        self.images_list = images_list
        self.transform = transform
        
    def __getitem__(self, index):
        img_path = self.images_list[index]
        data = Image.open(img_path)
        h = data.size[1]
        w = data.size[0]
        data = self.transform(data) / 255        
        return data, img_path, h, w
    
    def __len__(self):
        return len(self.images_list)

In [None]:
path = '/kaggle/input/bkai-igh-neopolyp/test/test/'
unet_test_dataset = UNetTestDataClass(path, transform)
test_dataloader = DataLoader(unet_test_dataset, batch_size=4, shuffle=True)

In [None]:
for i, (data, path, h, w) in enumerate(test_dataloader):
    img = data
    break

In [None]:
# fig, arr = plt.subplots(5, 2, figsize=(16, 12))
# arr[0][0].set_title('Image');
# arr[0][1].set_title('Predict');
# model.to("cpu")
# model.eval()
# with torch.no_grad():
#     predict = model(img)

# for i in range(5):
#     arr[i][0].imshow(img[i].permute(1, 2, 0));
#     arr[i][1].imshow(F.one_hot(torch.argmax(predict[i], 0).cpu()).float())

In [None]:
model.to(device)
model.eval()
if not os.path.isdir("/kaggle/working/predicted_masks"):
    os.mkdir("/kaggle/working/predicted_masks")
for _, (img, path, H, W) in enumerate(test_dataloader):
    print(_)
    a = path
    b = img
    h = H
    w = W
    b = b.to(device)
    with torch.no_grad():
        predicted_mask = model(b)
    for i in range(len(a)):
        image_id = a[i].split('/')[-1].split('.')[0]
        filename = image_id + ".png"
        mask2img = Resize((h[i].item(), w[i].item()), interpolation=InterpolationMode.NEAREST)(ToPILImage()(F.one_hot(torch.argmax(predicted_mask[i], 0)).permute(2, 0, 1).float()))
        mask2img.save(os.path.join("/kaggle/working/predicted_masks/", filename))

In [None]:
def rle_to_string(runs):
    return ' '.join(str(x) for x in runs)

def rle_encode_one_mask(mask):
    pixels = mask.flatten()
    pixels[pixels > 0] = 255
    use_padding = False
    if pixels[0] or pixels[-1]:
        use_padding = True
        pixel_padded = np.zeros([len(pixels) + 2], dtype=pixels.dtype)
        pixel_padded[1:-1] = pixels
        pixels = pixel_padded
    
    rle = np.where(pixels[1:] != pixels[:-1])[0] + 2
    if use_padding:
        rle = rle - 1
    rle[1::2] = rle[1::2] - rle[:-1:2]
    return rle_to_string(rle)

def mask2string(dir):
    ## mask --> string
    strings = []
    ids = []
    ws, hs = [[] for i in range(2)]
    for image_id in os.listdir(dir):
        id = image_id.split('.')[0]
        path = os.path.join(dir, image_id)
        print(path)
        img = cv2.imread(path)[:,:,::-1]
        h, w = img.shape[0], img.shape[1]
        for channel in range(2):
            ws.append(w)
            hs.append(h)
            ids.append(f'{id}_{channel}')
            string = rle_encode_one_mask(img[:,:,channel])
            strings.append(string)
    r = {
        'ids': ids,
        'strings': strings,
    }
    return r


MASK_DIR_PATH = '/kaggle/working/predicted_masks' # change this to the path to your output mask folder
dir = MASK_DIR_PATH
res = mask2string(dir)
df = pd.DataFrame(columns=['Id', 'Expected'])
df['Id'] = res['ids']
df['Expected'] = res['strings']
df.to_csv(r'output.csv', index=False)