In [1]:
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
from pathlib import Path


In [2]:

class AugmentedGlacierDataset(Dataset):
    def __init__(self, base_dir, transform=None):
        self.base_dir = Path(base_dir)
        self.band_dirs = [self.base_dir / f"Band{i}" for i in range(1, 6)]
        self.label_dir = self.base_dir / "labels"
        self.ids = sorted([f.stem for f in self.band_dirs[0].glob("*.tif")])
        self.transform = transform

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        
        # Load
        bands = [cv2.imread(str(self.band_dirs[i] / f"{img_id}.tif"), cv2.IMREAD_UNCHANGED).astype(np.float32) for i in range(5)]
        image = np.stack(bands, axis=-1)
        mask = cv2.imread(str(self.label_dir / f"{img_id}.tif"), cv2.IMREAD_UNCHANGED)
        
        # Normalize
        p02, p98 = np.percentile(image, 2), np.percentile(image, 98)
        image = np.clip(image, p02, p98)
        image = (image - image.min()) / (image.max() - image.min() + 1e-6)

        # Map Labels (0, 85, 170, 255 -> 0, 1, 2, 3)
        mask_mapped = np.zeros_like(mask, dtype=np.int64)
        mask_mapped[mask == 85] = 1
        mask_mapped[mask == 170] = 2
        mask_mapped[mask == 255] = 3

        # Apply Augmentations
        if self.transform:
            # Albumentations expects numpy (H, W, C)
            augmented = self.transform(image=image, mask=mask_mapped)
            image = augmented['image']
            mask_mapped = augmented['mask']
        else:
            # Manual conversion if no transform provided
            image = torch.from_numpy(image.transpose(2, 0, 1)).float()
            mask_mapped = torch.from_numpy(mask_mapped).long()
            
        return image, mask_mapped

# --- Define Transforms ---
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.Transpose(p=0.5),
    # Elastic Transform is great for "melting" shapes like glaciers
    A.OneOf([
        A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0),
        A.GridDistortion(num_steps=5, distort_limit=0.05, p=1.0),
    ], p=0.3),
    ToTensorV2() # Converts to PyTorch Tensor (C, H, W)
])

if __name__ == "__main__":
    ds = AugmentedGlacierDataset(r"D:\GlacierHack_practice\Train", transform=train_transform)
    loader = DataLoader(ds, batch_size=2, shuffle=True)
    
    imgs, masks = next(iter(loader))
    print("✅ Augmentation pipeline ready.")
    print(f"Output Tensor Shape: {imgs.shape}")

  A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0),


✅ Augmentation pipeline ready.
Output Tensor Shape: torch.Size([2, 5, 512, 512])
