In [None]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as transforms
from torchvision.transforms import ToTensor, ToPILImage
from torchvision.transforms.functional import to_tensor, to_pil_image

import albumentations as A
from albumentations import HorizontalFlip, Compose, Resize, Normalize
from albumentations.pytorch import ToTensorV2

from tqdm import tqdm

from PIL import Image, ImageOps, ImageEnhance

import segmentation_models_pytorch as smp

from sklearn.model_selection import train_test_split

In [None]:
from google.colab import drive

drive.mount('/content/drive')

drive_dir = "/datasets/data/data"

In [None]:
train_dir = drive_dir + "/train"

train_img_dir = train_dir + "/imgs"
train_mask_dir = train_dir + "/masks"

train_imgs = list(sorted(os.listdir(train_img_dir)))
train_masks = list(sorted(os.listdir(train_mask_dir)))

train_val_ims = list(sorted(os.listdir(train_img_dir)))
train_imgs, val_imgs = train_test_split(train_val_ims, test_size=0.2) #, random_state=42

In [None]:
test_dir = drive_dir + "/test"

test_img_dir = test_dir + "/imgs"

test_imgs = list(sorted(os.listdir(test_img_dir)))

In [None]:
class SSDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, msk_dir, image_list, preprocessor, mode='train'):
        self.transforms = transforms
        self.imgs = image_list
        self.img_dir, self.msk_dir = img_dir, msk_dir
        self.labels = list(range(0, 13))
        self.mode = mode
        self.preprocessor = preprocessor
        self.augmentation = transforms.Compose([
            transforms.RandomVerticalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2)
        ])

    def __getitem__(self, idx):
        if self.mode == 'train' or self.mode == 'val':
            file_image = f'train_{idx+1:04d}.png'
        else:
            file_image = f'{self.mode}_{idx+1:04d}.png'
        train_img_dir = os.path.join(self.img_dir, file_image)
        image = Image.open(train_img_dir).convert("RGB")

        if self.mode == 'train':
            image = self.augmentation(image)

        image = transforms.ToTensor()(image)
        image = image * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        image = image.clamp(0, 1)

        if self.msk_dir is not None:
            if self.mode == 'train' or self.mode == 'val':
                mask_image = f'train_{idx+1:04d}.png'
            else:
                mask_image = f'{self.mode}_{idx+1:04d}.png'
            mask_path = os.path.join(self.msk_dir, mask_image)
            mask = np.array(Image.open(mask_path).convert("RGB"))[:,:,0]
            input_dict = self.preprocessor.preprocess(images=image, segmentation_maps=mask, return_tensors='pt')
        else:
            input_dict = self.preprocessor.preprocess(images=image, return_tensors='pt')

        for k, v in input_dict.items():
          if isinstance(v, torch.Tensor):
            input_dict[k].squeeze_()

        return input_dict

    def __len__(self):
        return len(self.imgs)

In [None]:
from transformers import OneFormerImageProcessor, OneFormerForUniversalSegmentation

preprocessor = OneFormerImageProcessor.from_pretrained("shi-labs/oneformer_cityscapes_dinat_large", num_text=1)

In [None]:
### Dataset
train_dataset = SSDataset(train_img_dir, train_mask_dir, train_imgs, preprocessor=preprocessor, mode='train')
val_dataset = SSDataset(train_img_dir, train_mask_dir, val_imgs, preprocessor=preprocessor, mode='val')

batch_size = 4

### Dataloader
train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
val_dataloader = DataLoader(val_dataset, batch_size = batch_size, shuffle = False)

In [None]:
def Dice(pred,target):
    target = target.float()
    smooth = 1e-5
    intersect = torch.sum(pred * target)
    y_sum = torch.sum(target * target)
    z_sum = torch.sum(pred * pred)
    loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth)
    return loss

def DiceLoss(pred, target):
    return 1 - Dice(pred, target)

### Criterion & Metric
metric = Dice
criterion = DiceLoss

### Model Definition
num_classes = 13
device = 'cuda'
model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_cityscapes_dinat_large",id2label = {i:i for i in range(num_classes)},
                                                        label2id = {i:i for i in range(num_classes)}, ignore_mismatched_sizes=True).to(device)

### Optim & Sceduler
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)

In [None]:
model.train()
num_train_epochs=12

# Training loop
for epoch in range(num_train_epochs):
    print(f"Epoch {epoch + 1}/{num_train_epochs}")
    loop = tqdm(train_dataloader, leave=True)
    for batch in loop:
        pixel_values = batch["pixel_values"].to(device)

        targets = batch["labels"].to(device) ### target label: 512x512
        targets = targets.unsqueeze(1)

        # one-hot encoding
        one_hot_labels = torch.zeros(targets.size(0), 13,
                                     targets.size(2), targets.size(3),
                                     dtype=torch.float32, device=device)
        one_hot_labels.scatter_(1, targets, 1) ### one_hot_labels output: 13x512x512

        optimizer.zero_grad()
        outputs = model(pixel_values)
        logits = outputs.logits ### Model output: 128x128
        logits = F.interpolate(logits, scale_factor=4, mode='bilinear')

        loss = criterion(logits, one_hot_labels)
        loss.backward()
        optimizer.step()
        loop.set_postfix(loss=loss.item())
    
    model.eval()
    with torch.no_grad():
        total_dice = 0.0
        total_samples = 0
        loop = tqdm(val_dataloader, leave=True)
        for batch in loop:
            pixel_values = batch["pixel_values"].to(device)

            targets = batch["labels"].to(device) ### target label: 512x512
            targets = targets.unsqueeze(1)

            # one-hot encoding
            one_hot_labels = torch.zeros(targets.size(0), 13,
                                     targets.size(2), targets.size(3),
                                     dtype=torch.float32, device=device)
            one_hot_labels.scatter_(1, targets, 1)

            outputs = model(pixel_values)
            logits = outputs.logits ### Model output: 128x128
            logits = F.interpolate(logits, scale_factor=4, mode='bilinear')[0]

            total_dice += metric(logits, one_hot_labels).item()
            total_samples += 1
            avg_dice = total_dice / total_samples

        print(f"Validation Dice Score: {avg_dice:.4f}")

    model.train()
    lr_scheduler.step()

# Compute Metric : DICE

In [None]:
def rle_encode(mask_image):
    pixels = mask_image.flatten()
    pixels[0] = 0
    pixels[-1] = 0
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 2
    runs[1::2] = runs[1::2] - runs[:-1:2]
    return runs

# Inference

In [None]:
###batch_size=1, shuffle=False
test_dataset = SSDataset(test_img_dir, None, test_imgs, preprocessor, "test")
test_dataloader = DataLoader(test_dataset,batch_size = 1, shuffle = False)
test_model = model
test_model.eval()


out_dict = []
output = []
inputs = []

for i,data in tqdm(enumerate(test_dataloader)):
    image = data.pixel_values.to(device)
    prediction = test_model(image)

    prediction = F.interpolate(prediction.logits, size=(600,800), mode='bilinear').argmax(dim=1).unsqueeze(0)
    prediction = prediction.squeeze(0,1)
    ##prediction shape: [1,600,800]

    mask_labels = []
    image_name = test_imgs[i].rsplit('.')[0]

    # RLE encoding
    for j in range(13):
        mask_label = torch.zeros(prediction.shape)
        mask_label[prediction==j]=1

        mask_labels.append(mask_label)

    for j in range(0,13):
        mask_label = mask_labels[j].squeeze().numpy()
        encode = rle_encode(mask_label)
        out_dict.append((f'{image_name}_{j}', ' '.join(str(_) for _ in encode)))

In [None]:
## create csv
import pandas as pdb

df = pdb.DataFrame(out_dict)
df.columns=['ImageId','EncodedPixels']
df=df.set_index('ImageId')

df.to_csv('/kaggle/working/result.csv')