In [8]:
import os
from typing import Sequence, Tuple

import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from torchmetrics import Dice
from torchmetrics.classification import BinaryJaccardIndex

from segmentation_models_pytorch import DeepLabV3Plus

In [4]:
import cv2

import albumentations as A
from albumentations.pytorch import ToTensorV2

In [3]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda')

In [6]:
MEAN = np.array([0.585, 0.466, 0.445])
STD = np.array([0.224, 0.208, 0.198])

In [7]:
class ImageDatasetForSegmentation(Dataset):
    def __init__(self, path: str, is_train: bool = True) -> None:
        super().__init__()

        self.resize = A.Resize(256, 256)
        self.transform = A.Compose([
            A.Normalize(mean=MEAN, std=STD),
            A.HorizontalFlip(),
            A.VerticalFlip(),
            A.Rotate(45, border_mode=cv2.BORDER_CONSTANT, value=0),
            A.RandomBrightnessContrast(0.2),
            ToTensorV2()
        ]) if is_train else A.Compose([
            A.Normalize(mean=MEAN, std=STD),
            ToTensorV2()
        ])

        self.images, self.masks = [], []
        
        for file in tqdm(os.listdir(f'{path}/images')):
            if not os.path.exists(f'{path}/masks/{file}'):
                continue

            image = cv2.cvtColor(cv2.imread(f'{path}/images/{file}'), cv2.COLOR_BGR2RGB)
            mask = cv2.imread(f'{path}/masks/{file}', cv2.IMREAD_GRAYSCALE) / 255

            resized = self.resize(image=image, mask=mask)

            self.images.append(resized['image']), self.masks.append(resized['mask'])

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        transformed = self.transform(image=self.images[index], mask=self.masks[index])
        return transformed['image'].float(), transformed['mask'].float()

    def __len__(self) -> int:
        return len(self.images)

In [9]:
denormalize = A.Normalize(
    mean=(-1) * MEAN / STD,
    std=1 / STD,
    max_pixel_value=1.0
)


def restore_image(image: torch.Tensor) -> np.ndarray:
    return denormalize(image=image.permute(1, 2, 0).numpy())['image']


def plot_image_grid(images: Sequence[torch.Tensor | np.ndarray]):
    fig, axs = plt.subplots(1, len(images), gridspec_kw={'wspace': 0, 'hspace': 0})

    for ax, image in zip(axs, images):
        ax.imshow(image)
        ax.axis('off')