In [34]:
import glob
from typing import Tuple

import pandas as pd

In [50]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [35]:
import cv2

import albumentations as A
from albumentations.pytorch import ToTensorV2

In [52]:
from unet import UNet

In [45]:
class ImageDatasetForSegmentation(Dataset):
    def __init__(self, data: pd.DataFrame, path: str, is_train: bool = True) -> None:
        super().__init__()
        
        self.transform = A.Compose([
            A.HorizontalFlip(),
            A.VerticalFlip(),
            A.Rotate(45, border_mode=cv2.BORDER_CONSTANT, value=0),
            ToTensorV2()
        ]) if is_train else ToTensorV2()

        self.images, self.masks = [], []

        for patient in data['Patient']:
            self.masks.extend(masks := glob.glob(f'{path}/{patient}*/*_mask.tif'))
            self.images.extend(map(lambda file: file.replace('_mask', ''), masks))

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image = cv2.cvtColor(cv2.imread(self.images[index]), cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks[index], cv2.IMREAD_GRAYSCALE) / 255
        transformed = self.transform(image=image, mask=mask)
        return transformed['image'].float(), transformed['mask'].float()

    def __len__(self) -> int:
        return len(self.images)

In [46]:
train_data = pd.read_csv('data/train.csv')
val_data = pd.read_csv('data/val.csv')
test_data = pd.read_csv('data/test.csv')

In [47]:
train_dataset = ImageDatasetForSegmentation(train_data, 'data')
val_dataset = ImageDatasetForSegmentation(val_data, 'data', is_train=False)
test_dataset = ImageDatasetForSegmentation(test_data, 'data', is_train=False)

In [51]:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16)

In [None]:
model = UNet().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
loss_func = nn.BCELoss()