<a href="https://colab.research.google.com/github/SolutionLr/DL/blob/main/Unet_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# import necessary modules 
from google.colab import drive
drive.mount('/content/drive')
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import torchvision
import os
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
import torch
import cv2
from glob import glob
torch.manual_seed(0)
np.random.seed(0)

mask_data = pd.read_csv('/content/drive/MyDrive/DL_1/U-Net/unet_dataset/class_dict.csv')
mask_data = dict(
        [(i, [x, y, z]) for i, x, y, z in zip(mask_data['name'], mask_data['r'], mask_data['g'], mask_data['b'])])


class Data(Dataset):
    def __init__(self, img_dir, mask_dir, transform):
        super(Data, self).__init__()
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.images = os.listdir(img_dir)
        self.masks = os.listdir(mask_dir)
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.img_dir, self.images[index])
        mask_path = os.path.join(self.mask_dir, self.masks[index])
        image = cv2.imread(img_path, cv2.IMREAD_COLOR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_COLOR)
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
        if self.transform is not None:
            transformed_image = self.transform['train'](image=image)
            transformed_mask = self.transform['val'](image=mask)
            image = transformed_image['image']
            mask = transformed_mask['image']
            mask = process_mask(mask)
            mask = torch.from_numpy(mask)
            image = image.to(torch.float32)
            mask = mask.to(torch.int64)
        return image, mask

    def load_data(self, batch_size):
        loader = torch.utils.data.DataLoader(self, batch_size=batch_size, shuffle=True)
        return loader


def process_mask(rgb_mask, colormap=mask_data.values()):
    output_mask = []

    for i, color in enumerate(colormap):
        cmap = np.all(np.equal(rgb_mask, color), axis=-1)
        output_mask.append(cmap)

    output_mask = np.stack(output_mask, axis=-1)
    return np.argmax(output_mask, axis=-1)

