In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
from pathlib import Path

In [2]:

class GlacierDataset(Dataset):
    def __init__(self, base_dir):
        self.base_dir = Path(base_dir)
        # Define paths for all 5 bands
        self.band_dirs = [self.base_dir / f"Band{i}" for i in range(1, 6)]
        self.label_dir = self.base_dir / "labels"
        
        # Get list of IDs
        self.ids = sorted([f.stem for f in self.band_dirs[0].glob("*.tif")])

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        
        # 1. Load Bands
        bands = []
        for b_dir in self.band_dirs:
            path = b_dir / f"{img_id}.tif"
            img = cv2.imread(str(path), cv2.IMREAD_UNCHANGED).astype(np.float32)
            bands.append(img)
        
        image = np.stack(bands, axis=-1) # (H, W, 5)

        # 2. Load Label
        label_path = self.label_dir / f"{img_id}.tif"
        mask = cv2.imread(str(label_path), cv2.IMREAD_UNCHANGED)
        
        # 3. Normalization (Percentile Clipping)
        p02 = np.percentile(image, 2)
        p98 = np.percentile(image, 98)
        image = np.clip(image, p02, p98)
        
        # Min-Max Scale to 0-1
        image = (image - image.min()) / (image.max() - image.min() + 1e-6)

        # 4. Convert to Tensor (Channels First)
        # Numpy is (H, W, C) -> PyTorch needs (C, H, W)
        image = torch.from_numpy(image.transpose(2, 0, 1)).float()
        
        # Mask needs to be Long (Integers) for CrossEntropy
        # Map 85, 170, 255 -> 1, 2, 3
        mask_tensor = torch.zeros_like(torch.from_numpy(mask)).long()
        mask_tensor[mask == 85] = 1
        mask_tensor[mask == 170] = 2
        mask_tensor[mask == 255] = 3
        
        return image, mask_tensor

# --- Test the Class ---
if __name__ == "__main__":
    ds = GlacierDataset(r"D:\GlacierHack_practice\Train")
    loader = DataLoader(ds, batch_size=4, shuffle=True)
    
    imgs, masks = next(iter(loader))
    print(f"Batch Image Shape: {imgs.shape}") # Should be [4, 5, 512, 512]
    print(f"Batch Mask Shape:  {masks.shape}") # Should be [4, 512, 512]

Batch Image Shape: torch.Size([4, 5, 512, 512])
Batch Mask Shape:  torch.Size([4, 512, 512])
