# Importing Libraries

In [None]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import tifffile as tiff
import cv2
import torch.nn as nn
import albumentations as A
import numpy as np
import os
import time
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

In [None]:
!nvidia-smi

# Sample Original Image & Label

In [None]:
base_path = '/kaggle/input/blood-vessel-segmentation/train'
dataset = 'kidney_1_dense'

images_path = os.path.join(base_path, dataset, 'images')
labels_path = os.path.join(base_path, dataset, 'labels')

image_files = sorted([os.path.join(images_path, f) for f in os.listdir(images_path) if f.endswith('.tif')])
label_files = sorted([os.path.join(labels_path, f) for f in os.listdir(labels_path) if f.endswith('.tif')])

def show_images(images, titles = None, cmap = 'gray'):
    n = len(images)
    fig, axes = plt.subplots(1, n, figsize = (20, 10))
    if not isinstance(axes, np.ndarray):
        axes = [axes]
    for idx, ax in enumerate(axes):
        ax.imshow(images[idx], cmap = cmap)
        if titles:
            ax.set_title(titles[idx])
        ax.axis('off')
    plt.tight_layout()
    plt.show()

first_image = tiff.imread(image_files[981])
first_label = tiff.imread(label_files[981])

show_images([first_image, first_label], titles = ['First Image', 'First Label'])

# Dataset

In [None]:
class CustomDataset(Dataset):
    def __init__(self, image_files, mask_files, input_size = (256, 256), augmentation_transforms = None):
        self.image_files = image_files
        self.mask_files = mask_files
        self.input_size = input_size
        self.augmentation_transforms = augmentation_transforms
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        image_path = self.image_files[idx]
        mask_path = self.mask_files[idx]
        
        image = preprocess_image(image_path)
        mask = preprocess_mask(mask_path)
        
        if self.augmentation_transforms:
            image, mask = self.augmentation_transformsaug(image, mask)
            
        return image, mask
        

# Preprocessing of Images

In [None]:
def preprocess_image(path):
    img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    img = np.tile(img[...,None], [1, 1, 3])
    img = img.astype('float32')
    mx = np.max(img)
    if mx:
        img /= mx
    
    img = np.transpose(img, (2, 0, 1))
    img_ten = torch.tensor(img)
    return img_ten

# Preprocessing of Masks

In [None]:
def preprocess_mask(path):
    msk = cv2.imread(path, cv2.IMREAD_UNCHANGED)
    msk = np.tile(img[..., None], [1, 1, 3])
    msk /= 255.0
    msk_ten = torch.tensor(msk)
    return msk_ten

# Augmentation

In [None]:
def augment_image(image, mask):
    image_np = image.permute(1, 2, 0).numpy()
    mask_np = mask.numpy()
    
    transform = A.Compose([
        A.Resize(256, 256, interpolation = cv2.INTER_NEAREST),
        A.HorizontalFlip(p = 0.5),
        A.VerticalFlip(p = 0.5),
        A.ShiftScaleRotate(scale_limit = 0.5, rotate_limit = 0, shift_limit = 0.1, p = 1, border_mode = 0),
        A.RandomCrop(height = 256, width = 256, always_apply = True),
        A.RandomBrightness(p = 1),
        A.OneOf(
            [
                A.Blur(blur_limit = 3, p = 1),
                A.MotionBlur(blur_limit = 3, p = 1),
            ],
            p = 0.9,
        ),
    ])
    
    augmented = transform(image = image_np, mask = mask_np)
    augmented_image, augmented_mask = augmented['image'], augmented['mask']
    
    augmented_image = torch.tensor(augmented_image, dtype = torch.float32).permute(2, 0, 1)
    augmented_mask = torch.tensor(augmented_mask, dtype = torch.float32)
    
    return augmented_image, augmented_mask

# Splitting the Dataset

In [None]:
train_image_files, val_image_files, train_mask_files, val_mask_files = train_test_split(image_files, label_files, test_size = 0.2, random_state = 42)

train_dataset = CustomDataset(train_image_files, train_mask_files, augmentation_transforms = augment_image)
val_dataset = CustomDataset(val_image_files, val_mask_files, augmentation_transforms = augment_image)

train_dataloader = DataLoader(train_dataset, batch_size = 8, shuffle = True)
val_dataloader = DataLoader(val_dataset, batch_size = 8, shuffle = False)