#### Various Custom Image Dataset

In [None]:
########## Normal Image Dataset ##########
class ImageDataset(Dataset):
    def __init__(self, df, transform, dir_path, special_label=None,  mode="train"):
        self.df = df.reset_index()
        self.transform = transform
        self.dir_path = dir_path
        self.special_label = special_label

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = self.df.loc[idx, 'filename']
        label = self.df.loc[idx, 'genus']
        image = Image.open(os.path.join(self.dir_path, img_path)).convert("RGB")
        if self.special_label is not None :
            if self.transform and label not in self.special_label:
                image = self.transform[1](image)
            else:
                image = self.transform[0](image)
        else:
                image = self.transform[0](image)
        return image, label

In [None]:
########## TIFF Image Dataset ##########
class TifImageDataset(Dataset):
    def __init__(self, df, img_dir, label_cols=None, transform=None, to_rgb=True, max_channels=3):
        """
        Custom PyTorch dataset to handle .tif/.tiff and standard images.

        Args:
            df (pd.DataFrame): dataframe berisi nama file dan (opsional) label.
            img_dir (str): path ke folder berisi gambar.
            label_cols (list): kolom label di df. Jika None → hanya image.
            transform (Compose): augmentasi/transforms torchvision.
            to_rgb (bool): jika True, konversi ke 3 channel (RGB).
            max_channels (int): jika multi-band, ambil n channel pertama.
        """
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.label_cols = label_cols
        self.transform = transform
        self.to_rgb = to_rgb
        self.max_channels = max_channels

    def __len__(self):
        return len(self.df)

    def _load_tif_image(self, path):
        """Load image from TIF/TIFF with tifffile"""
        img = tifffile.imread(path)  # np.ndarray

        # Handle grayscale
        if img.ndim == 2:
            img = np.stack([img] * 3, axis=-1)

        # Handle multiband (keep first N channels)
        elif img.ndim == 3 and img.shape[2] > self.max_channels:
            img = img[:, :, :self.max_channels]

        # Convert to 8-bit if needed (normalize)
        if img.dtype != np.uint8:
            img = np.clip(img / img.max(), 0, 1)
            img = (img * 255).astype(np.uint8)

        return Image.fromarray(img)

    def __getitem__(self, idx):
        img_name = self.df.loc[idx, 'name']
        img_path = os.path.join(self.img_dir, img_name)
        ext = os.path.splitext(img_name)[1].lower()

        # ========== Handle multiple formats (.tif, .png, .jpg) ==========
        if ext in ['.tif', '.tiff']:
            image = self._load_tif_image(img_path)
        else:
            image = Image.open(img_path).convert("RGB")
        # ================================================================

        if self.transform:
            image = self.transform(image)

        if self.label_cols is not None:
            labels = torch.tensor(self.df.loc[idx, self.label_cols].values.astype('float32'))
            return image, labels

        return image

#### Dataset Usage

In [None]:
train_data = ImageDataset(train_set, transform=[train_transform, minor_transform], dir_path = train_img, special_label = minor_class)
val_data = ImageDataset(test_set, transform=[train_transform, minor_transform], dir_path = train_img)

#### DataLoader

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=True)