In [52]:
!pip install torchgeometry
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import cv2
import pandas as pd
import torch
from torchvision.transforms import functional as T
import torch.nn.functional as F
import torch.nn as nn
from torchvision.transforms import Resize, PILToTensor, ToPILImage, Compose, InterpolationMode
from torchgeometry.losses import one_hot
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import os
from PIL import Image



In [53]:
torch.backends.cudnn.benchmark = True
!pip install wandb
import wandb



In [54]:
# Check compute device
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [55]:
epochs = 100
train_split = 0.8

In [56]:
class RandomGamma:
    def __init__(self, gamma_range=(0.7, 1.3), p=0.2):
        self.gamma_range = gamma_range
        self.p = p

    def __call__(self, img):
        if torch.rand(1).item() < self.p:
            gamma = torch.empty(1).uniform_(*self.gamma_range).item()
            return T.adjust_gamma(img, gamma, gain=1)
        return img

In [57]:
class CustomDataset(Dataset):
    def __init__(self, images_path, masks_path, transform):
        super(CustomDataset, self).__init__()
        images_list = os.listdir(images_path)
        masks_list = os.listdir(masks_path)
        
        images_list = [images_path + image_name for image_name in images_list]
        self.images_list = images_list
        masks_list = [masks_path + mask_name for mask_name in masks_list]
        self.masks_list = masks_list
        self.transform = transform

    def __getitem__(self, index):
        img_path = self.images_list[index]
        mask_path = self.masks_list[index]
        
        # Open image and mask
        data = Image.open(img_path)
        label = Image.open(mask_path)

        # Apply transformations
        data = self.transform(data)
        label = self.transform(label)
        
        # Normalize the data (if not already done in the transform)
        data = data / 255.0  # Normalize image to [0, 1] range if transform doesn't handle it
        
        # Threshold label to binary mask (or multi-class if needed)
        label = torch.where(label > 0.65, 1.0, 0.0)  # Apply thresholding
        
        # Set the third channel to a small value if you need to manipulate it specifically
        if label.shape[0] > 2:  # Check if the label has more than 2 channels
            label[2, :, :] = 0.0001
            
        # Convert the label to class indices (if label is one-hot encoded)
        label = torch.argmax(label, dim=0).type(torch.int64)  # Get class indices

        return data, label

    def __len__(self):
        return len(self.images_list)

In [58]:
transforms = transforms.Compose([transforms.Resize((224, 224)), 
                                transforms.RandomHorizontalFlip(p=0.5),
                                transforms.RandomVerticalFlip(p=0.5),
                                RandomGamma(gamma_range=(0.7, 1.3), p=0.2),
                                transforms.ToTensor()])
dataset = CustomDataset('/kaggle/input/bkai-igh-neopolyp/train/train/', '/kaggle/input/bkai-igh-neopolyp/train_gt/train_gt/', transforms)
train_dataset, val_dataset = random_split(dataset, 
                                    [int(train_split * len(dataset)) , 
                                     len(dataset) - int(train_split * len(dataset))])
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)


In [59]:
!pip install segmentation-models-pytorch
import segmentation_models_pytorch as smp

model = smp.UnetPlusPlus(
    encoder_name="efficientnet-b7",
    encoder_weights="imagenet",
    in_channels=3,
    classes=3
)
model.to(device)
#print(model)




UnetPlusPlus(
  (encoder): EfficientNetEncoder(
    (_conv_stem): Conv2dStaticSamePadding(
      3, 64, kernel_size=(3, 3), stride=(2, 2), bias=False
      (static_padding): ZeroPad2d((0, 1, 0, 1))
    )
    (_bn0): BatchNorm2d(64, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_depthwise_conv): Conv2dStaticSamePadding(
          64, 64, kernel_size=(3, 3), stride=[1, 1], groups=64, bias=False
          (static_padding): ZeroPad2d((1, 1, 1, 1))
        )
        (_bn1): BatchNorm2d(64, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          64, 16, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          16, 64, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticS

In [60]:
class DiceLoss(nn.Module):
    def __init__(self, weights):
        super(DiceLoss, self).__init__()
        self.eps: float = 1e-6
        self.weights: torch.Tensor = weights
    def forward(self, input: torch.Tensor, target: torch.Tensor):
        # cross entropy loss
        celoss = nn.CrossEntropyLoss(self.weights)(input, target)
        
        # compute softmax over the classes axis
        input_soft = F.softmax(input, dim=1)

        # create the labels one hot tensor
        target_one_hot = one_hot(target, num_classes=input.shape[1],
                                 device=input.device, dtype=input.dtype)

        # compute the actual dice score
        dims = (2, 3)
        intersection = torch.sum(input_soft * target_one_hot, dims)
        cardinality = torch.sum(input_soft + target_one_hot, dims)

        dice_score = 2. * intersection / (cardinality + self.eps)
        
        dice_score = torch.sum(dice_score * self.weights, dim=1)
        
        return torch.mean(1. - dice_score) + celoss


In [None]:
weights = torch.Tensor([[0.4, 0.55, 0.05]]).cuda()
criterion = DiceLoss(weights)
optimizer = optim.Adam(model.parameters(), lr=0.001)
train_loss_array = []
test_loss_array = []
best_val_loss = 9999999
wandb.login(
    key = "a999625da52ea7e053c244463d7cee7050b12839",
)
wandb.init(
    project = "BKAI_graph"
)
# Training loop
for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    for images, masks in train_loader:  # images, masks are (B, C, H, W)
        images, masks = images.to(device), masks.to(device)
        outputs = model(images)
        
        loss = criterion(outputs, masks.long())  # Use updated DiceLoss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    #
    model.eval()
    test_loss = 0.0
    correct = 0
    total_samples = 0
    with torch.no_grad():
        for i, (data, targets) in enumerate(val_loader):
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            _, pred = torch.max(outputs, 1)
            
            loss = criterion(outputs, targets.long())
            test_loss += loss.item()
    if test_loss < best_val_loss:
        best_val_loss = test_loss
        checkpoint = { 
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'train_loss':train_loss,
            'val_loss': val_loss,
        }
        save_path = f'model.pth'
        torch.save(checkpoint, save_path)
    train_loss_array.append(train_loss/len(train_loader))
    test_loss_array.append(test_loss/len(val_loader))

    print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Validation Loss: {test_loss/len(val_loader):.4f}")
    wandb.log({
        "Epoch": epoch + 1,
        "Train Loss": train_loss / len(train_loader),
        "Validation Loss": test_loss / len(val_loader),
    })



VBox(children=(Label(value='0.024 MB of 0.024 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
Epoch,▁▂▃▃▄▅▆▆▇█
Train Loss,█▄▄▄▃▃▁▁▁▁
Validation Loss,█▇▆▃▃▃▂▁▂▁

0,1
Epoch,10.0
Train Loss,1.37456
Validation Loss,1.33847


In [None]:
for i, (data, label) in enumerate(val_loader):
     img = data
     mask = label
     break

In [None]:
fig, arr = plt.subplots(4, 3, figsize=(16, 12))
arr[0][0].set_title('Image')
arr[0][1].set_title('Segmentation')
arr[0][2].set_title('Predict')

model.eval()
with torch.no_grad():
     predict = model(img.to(device))

for i in range(4):

     arr[i][0].imshow((img*255)[i].cpu().numpy().transpose(1, 2, 0));
    
     arr[i][1].imshow(F.one_hot(mask[i]).float())
    
     arr[i][2].imshow(F.one_hot(torch.argmax(predict[i], dim = 0).cpu()).float())

**Submission**

In [None]:
class TestDataset(Dataset):
    def __init__(self, images_path, transform):
        super(TestDataset, self).__init__()
        
        images_list = os.listdir(images_path)
        images_list = [images_path+i for i in images_list]
        
        self.images_list = images_list
        self.transform = transform
        
    def __getitem__(self, index):
        img_path = self.images_list[index]
        data = Image.open(img_path)
        h = data.size[1]
        w = data.size[0]
        data = self.transform(data) / 255        
        return data, img_path, h, w
    
    def __len__(self):
        return len(self.images_list)

In [None]:
path = '/kaggle/input/bkai-igh-neopolyp/test/test/'
test_dataset = TestDataset(path, transform)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True)

In [None]:
for i, (data, path, h, w) in enumerate(test_loader):
    img = data
    break

In [None]:
fig, arr = plt.subplots(4, 2, figsize=(16, 12))
arr[0][0].set_title('Image');
arr[0][1].set_title('Predict');

model.eval()
with torch.no_grad():
    predict = model(img.to(device))

for i in range(4):
    arr[i][0].imshow((img*255)[i].cpu().numpy().transpose(1, 2, 0));
    arr[i][1].imshow(F.one_hot(torch.argmax(predict[i], 0).cpu()).float())

In [None]:
model.eval()
if not os.path.isdir("/kaggle/working/predicted_masks"):
    os.mkdir("/kaggle/working/predicted_masks")
for _, (img, path, H, W) in enumerate(test_loader):
    
    with torch.no_grad():
        predicted_mask = model(img.to(device))
    for i in range(len(a)):
        image_id = path[i].split('/')[-1].split('.')[0]
        filename = image_id + ".png"
        mask2img = Resize((H[i].item(), W[i].item()), interpolation=InterpolationMode.NEAREST)(ToPILImage()(F.one_hot(torch.argmax(predicted_mask[i], 0)).permute(2, 0, 1).float()))
        mask2img.save(os.path.join("/kaggle/working/predicted_masks/", filename))

In [None]:
def rle_to_string(runs):
    return ' '.join(str(x) for x in runs)

def rle_encode_one_mask(mask):
    pixels = mask.flatten()
    pixels[pixels > 0] = 255
    use_padding = False
    if pixels[0] or pixels[-1]:
        use_padding = True
        pixel_padded = np.zeros([len(pixels) + 2], dtype=pixels.dtype)
        pixel_padded[1:-1] = pixels
        pixels = pixel_padded
    
    rle = np.where(pixels[1:] != pixels[:-1])[0] + 2
    if use_padding:
        rle = rle - 1
    rle[1::2] = rle[1::2] - rle[:-1:2]
    return rle_to_string(rle)

def mask2string(dir):
    strings = []
    ids = []
    ws, hs = [[] for i in range(2)]
    for image_id in os.listdir(dir):
        id = image_id.split('.')[0]
        path = os.path.join(dir, image_id)
        print(path)
        img = cv2.imread(path)[:,:,::-1]
        h, w = img.shape[0], img.shape[1]
        for channel in range(2):
            ws.append(w)
            hs.append(h)
            ids.append(f'{id}_{channel}')
            string = rle_encode_one_mask(img[:,:,channel])
            strings.append(string)
    r = {
        'ids': ids,
        'strings': strings,
    }
    return r


MASK_DIR_PATH = '/kaggle/working/predicted_masks' # change this to the path to your output mask folder
dir = MASK_DIR_PATH
res = mask2string(dir)
df = pd.DataFrame(columns=['Id', 'Expected'])
df['Id'] = res['ids']
df['Expected'] = res['strings']
df.to_csv(r'output.csv', index=False)