In [None]:
import os
import win32com.client
def resolve_shortcut(path):
    shell = win32com.client.Dispatch("WScript.Shell")
    shortcut = shell.CreateShortCut(path)
    return shortcut.Targetpath

data_path = resolve_shortcut(r"G:\\My Drive\\dc4data.lnk")
benthic_path = resolve_shortcut(data_path+r"\\benthic_datasets.lnk")
coralbleaching_path = resolve_shortcut(data_path+r"\\coral_bleaching.lnk")
if not os.path.exists(r"G:\.shortcut-targets-by-id\1v4g4qOrbisBvrpqOxLrYn96nd_gPG_Ge\dc4data\coralscapes"):
     coralscapes_path = resolve_shortcut(data_path+r"\\coralscapes.lnk")
else:
        coralscapes_path = r"G:\.shortcut-targets-by-id\1v4g4qOrbisBvrpqOxLrYn96nd_gPG_Ge\dc4data\coralscapes"
for p in [data_path, benthic_path, coralbleaching_path, coralscapes_path]:
    if os.path.exists(p):
        print(f"Path exists: {p}")
    if not os.path.exists(p):
        raise FileNotFoundError(f"Path does not exist: {p}")
    

Path exists: G:\.shortcut-targets-by-id\1v4g4qOrbisBvrpqOxLrYn96nd_gPG_Ge\dc4data
Path exists: G:\.shortcut-targets-by-id\1mx2OJcVKp1mRbTbjezqWucDXpbGrd_OA\benthic_datasets
Path exists: G:\.shortcut-targets-by-id\1jGkNA1n0znoxKnQBHTJZuPgvkiu_OBM8\coral_bleaching
Path exists: G:\.shortcut-targets-by-id\1v4g4qOrbisBvrpqOxLrYn96nd_gPG_Ge\dc4data\coralscapes


## Benthic Datset

##

In [6]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch
import os
from datasets import load_dataset, Dataset as HFDataset
from torch.utils.data import Dataset
from PIL import Image
import pyarrow.parquet as pq
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, Union, Sequence, List

benthic_paths = [r"G:\.shortcut-targets-by-id\1mx2OJcVKp1mRbTbjezqWucDXpbGrd_OA\benthic_datasets\mask_labels\reef_support\SEAFLOWER_BOLIVAR",
r"G:\.shortcut-targets-by-id\1mx2OJcVKp1mRbTbjezqWucDXpbGrd_OA\benthic_datasets\mask_labels\reef_support\SEAFLOWER_COURTOWN",
r"G:\.shortcut-targets-by-id\1mx2OJcVKp1mRbTbjezqWucDXpbGrd_OA\benthic_datasets\mask_labels\reef_support\SEAVIEW_PAC_USA",
r"G:\.shortcut-targets-by-id\1mx2OJcVKp1mRbTbjezqWucDXpbGrd_OA\benthic_datasets\mask_labels\reef_support\SEAVIEW_IDN_PHL",
r"G:\.shortcut-targets-by-id\1mx2OJcVKp1mRbTbjezqWucDXpbGrd_OA\benthic_datasets\mask_labels\reef_support\SEAVIEW_PAC_AUS",
r"G:\.shortcut-targets-by-id\1mx2OJcVKp1mRbTbjezqWucDXpbGrd_OA\benthic_datasets\mask_labels\reef_support\TETES_PROVIDENCIA",
r"G:\.shortcut-targets-by-id\1mx2OJcVKp1mRbTbjezqWucDXpbGrd_OA\benthic_datasets\mask_labels\reef_support\SEAVIEW_ATL",
r"G:\.shortcut-targets-by-id\1mx2OJcVKp1mRbTbjezqWucDXpbGrd_OA\benthic_datasets\mask_labels\reef_support\UNAL_BLEACHING_TAYRONA",]

class SegmentationDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform

        # keep only typical image files; sorted for reproducibility
        exts = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")
        self.images = sorted([f for f in os.listdir(img_dir) if f.lower().endswith(exts)])

        if not self.images:
            raise FileNotFoundError(f"No images found in {img_dir}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.img_dir, img_name)

        # Map "name.ext" -> "name_mask.png" (as you requested)
        stem = Path(img_name).stem
        mask_name = f"{stem}_mask.png"
        mask_path = os.path.join(self.mask_dir, mask_name)

        if not os.path.exists(mask_path):
            raise FileNotFoundError(
                f"Mask not found for {img_name}. Expected: {mask_path} "
                "(pattern '<image_stem>_mask.png')."
            )

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path)

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask


def get_mask(benthic_folder):
    mask_path = os.path.join(benthic_folder, 'masks_stitched')
    return mask_path

def get_image(benthic_folder):
    image_path = os.path.join(benthic_folder, 'images')
    return image_path


#DATASETS
SEAFLOWER_BOLIVAR = SegmentationDataset(get_image(benthic_paths[0]), get_mask(benthic_paths[0]))
SEAFLOWER_COURTOWN = SegmentationDataset(get_image(benthic_paths[1]), get_mask(benthic_paths[1]))
SEAVIEW_PAC_USA = SegmentationDataset(get_image(benthic_paths[2]), get_mask(benthic_paths[2]))
SEAVIEW_IDN_PHL = SegmentationDataset(get_image(benthic_paths[3]), get_mask(benthic_paths[3]))
SEAVIEW_PAC_AUS = SegmentationDataset(get_image(benthic_paths[4]), get_mask(benthic_paths[4]))
TETES_PROVIDENCIA = SegmentationDataset(get_image(benthic_paths[5]), get_mask(benthic_paths[5]))
SEAVIEW_ATL = SegmentationDataset(get_image(benthic_paths[6]), get_mask(benthic_paths[6]))
UNAL_BLEACHING_TAYRONA = SegmentationDataset(get_image(benthic_paths[7]), get_mask(benthic_paths[7]))

## Coral Scapes

In [2]:
# --- helper: load masks from parquet parts by dataset index ---
class _ParquetMasksByIndex:
    def __init__(self, parquet_dir_or_paths: Union[str, Path, Sequence[Union[str, Path]]],
                 column_png: str = "label_health_rgb_png"):
        # normalize to list of parquet files
        if isinstance(parquet_dir_or_paths, (str, Path)):
            p = Path(parquet_dir_or_paths)
            if p.is_dir():
                paths = sorted(p.glob("*.parquet"))
                if not paths:
                    raise FileNotFoundError(f"No parquet files in directory: {p}")
            else:
                if not p.exists():
                    raise FileNotFoundError(f"Parquet file not found: {p}")
                paths = [p]
        else:
            paths = [Path(x) for x in parquet_dir_or_paths]
            for p in paths:
                if not p.exists():
                    raise FileNotFoundError(f"Parquet file not found: {p}")

        self._tables = [pq.read_table(p) for p in paths]
        for t in self._tables:
            if "index" not in t.column_names or column_png not in t.column_names:
                raise ValueError(f"Parquet must have 'index' and '{column_png}'. Got: {t.column_names}")
        self._colname = column_png

        # build index -> (table_id, row_id)
        self._map = {}
        for tid, t in enumerate(self._tables):
            idxs = t["index"].to_pylist()
            for rid, ds_idx in enumerate(idxs):
                self._map[int(ds_idx)] = (tid, rid)

    def get_mask_pil(self, ds_index: int) -> Image.Image:
        tid, rid = self._map[ds_index]
        cell = self._tables[tid][self._colname][rid].as_py()
        if isinstance(cell, memoryview):
            cell = cell.tobytes()
        elif isinstance(cell, bytearray):
            cell = bytes(cell)
        return Image.open(BytesIO(cell)).convert("RGB")
    
class CoralScapesImagesMasks(Dataset):
    """
    Images:
      - Either from HF split (set split="train"/"validation"/"test")
      - Or from local Arrow shards (set arrow_paths=[...])
    Masks:
      - From your Parquet export (dir or list), matched by dataset index.
    """
    def __init__(self,
                 parquet_dir_or_paths: Union[str, Path, Sequence[Union[str, Path]]],
                 split: Optional[str] = None,
                 arrow_paths: Optional[Sequence[Union[str, Path]]] = None,
                 img_transform: Optional[Callable] = None,
                 mask_transform: Optional[Callable] = None):
        if (split is None) == (arrow_paths is None):
            raise ValueError("Specify exactly one image source: either `split` (HF) OR `arrow_paths` (local).")

        # Image source
        if split is not None:
            ds_all = load_dataset("EPFL-ECEO/coralscapes")
            if split not in ds_all:
                raise ValueError(f"Split '{split}' not found. Available: {list(ds_all.keys())}")
            self.img_ds: HFDataset = ds_all[split]
        else:
            paths = [Path(p) for p in arrow_paths]
            for p in paths:
                if not p.exists():
                    raise FileNotFoundError(f"Arrow file not found: {p}")
            shards = [HFDataset.from_file(p.as_posix()) for p in paths]
            self.img_ds = shards[0] if len(shards) == 1 else HFDataset.concatenate_datasets(shards)

        # Masks
        self.masks = _ParquetMasksByIndex(parquet_dir_or_paths)

        self.img_tf = img_transform
        self.mask_tf = mask_transform

    def __len__(self):
        return len(self.img_ds)

    def __getitem__(self, idx: int):
        rec = self.img_ds[idx]
        img: Image.Image = rec["image"].convert("RGB")
        mask: Image.Image = self.masks.get_mask_pil(idx)  # keyed by dataset index

        if self.img_tf is not None:
            img = self.img_tf(img)
        if self.mask_tf is not None:
            mask = self.mask_tf(mask)
        return img, mask


# --- your instantiation lines (unchanged) ---
TRAIN_PARQUET_DIR = r"data_preprocessing\coralscapes_export\parquet\train"
VAL_PARQUET_DIR   = r"data_preprocessing\coralscapes_export\parquet\validation"
TEST_PARQUET_DIR  = r"data_preprocessing\coralscapes_export\parquet\test"

CORALSCAPES_train = CoralScapesImagesMasks(split="train", parquet_dir_or_paths=TRAIN_PARQUET_DIR)
CORALSCAPES_val   = CoralScapesImagesMasks(split="validation", parquet_dir_or_paths=VAL_PARQUET_DIR)
CORALSCAPES_test  = CoralScapesImagesMasks(split="test", parquet_dir_or_paths=TEST_PARQUET_DIR)


## Coral Bleaching

In [3]:
coral_bleaching_images = r"g:\.shortcut-targets-by-id\1jGkNA1n0znoxKnQBHTJZuPgvkiu_OBM8\coral_bleaching\reef_support\UNAL_BLEACHING_TAYRONA\images"
coral_bleaching_combined_masks = r"data_preprocessing/coralbleaching/combined_masks"
coral_bleaching_single_masks = r"data_preprocessing/coralbleaching/single_masks"

In [4]:
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

def pil_to_tensor(img):
    a = np.asarray(img.convert("RGB"), dtype=np.uint8)  # (H,W,3)
    return torch.from_numpy(a).permute(2,0,1).float()/255.0  # (3,H,W)

class CoralBleachingDataset(Dataset):
    def __init__(self, images_dir, combined_dir, single_dir):
        self.images_dir = Path(images_dir)
        self.combined_dir = Path(combined_dir)
        self.single_bleached = Path(single_dir) / "bleached_blue"
        self.single_non = Path(single_dir) / "non_bleached_red"

        imgs = []
        for e in ("*.png","*.jpg","*.jpeg"):
            imgs += list(self.images_dir.glob(e))
        self.images = sorted(imgs)

        self.pairs = self._match_pairs()

    def _match_pairs(self):
        def index_dir(d):
            out={}
            for e in ("*.png","*.jpg","*.jpeg"):
                for p in d.glob(e): out[p.stem.lower()] = p
            return out
        cmb = index_dir(self.combined_dir)
        ble = index_dir(self.single_bleached)
        non = index_dir(self.single_non)

        pairs=[]
        for img in self.images:
            key = img.stem.lower()
            k_cmb = f"{key}_combined"
            if k_cmb in cmb: pairs.append((img, cmb[k_cmb])); continue
            cand = [p for k,p in ble.items() if k.startswith(key) or key in k]
            if cand: pairs.append((img, cand[0])); continue
            cand = [p for k,p in non.items() if k.startswith(key) or key in k]
            if cand: pairs.append((img, cand[0]))
        return pairs

    def __len__(self): return len(self.pairs)
    def __getitem__(self, i):
        ip, mp = self.pairs[i]
        x = pil_to_tensor(Image.open(ip))
        y = pil_to_tensor(Image.open(mp))
        return x, y  # (3,H,W), (3,H,W)

def pad_collate(batch):
    # batch: list of (img, mask) with varying H,W
    imgs, masks = zip(*batch)
    C = imgs[0].shape[0]
    H = max(t.shape[1] for t in imgs)
    W = max(t.shape[2] for t in imgs)
    xb = torch.zeros(len(imgs), C, H, W, dtype=imgs[0].dtype)
    yb = torch.zeros(len(masks), C, H, W, dtype=masks[0].dtype)
    for i, (x, y) in enumerate(zip(imgs, masks)):
        h, w = x.shape[1], x.shape[2]
        xb[i, :, :h, :w] = x
        yb[i, :, :h, :w] = y
    return xb, yb

# ---- use it ----
dataset = CoralBleachingDataset(
    images_dir=r"g:\.shortcut-targets-by-id\1jGkNA1n0znoxKnQBHTJZuPgvkiu_OBM8\coral_bleaching\reef_support\UNAL_BLEACHING_TAYRONA\images",
    combined_dir=r"data_preprocessing/coralbleaching/combined_masks",
    single_dir=r"data_preprocessing/coralbleaching/single_masks"
)
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=pad_collate)

xb, yb = next(iter(loader))
print(xb.shape, yb.shape)  # -> (B,3,H_max,W_max) (B,3,H_max,W_max)


  return torch.from_numpy(a).permute(2,0,1).float()/255.0  # (3,H,W)


torch.Size([8, 3, 3221, 4296]) torch.Size([8, 3, 3221, 4296])


# Combine ALL into ONE

In [7]:
# --- 1) Common PIL->tensor transform for both image and mask ---
import numpy as np
import torch
from PIL import Image

def pil_to_tensor_rgb(img):
    if img is None:
        raise ValueError("Received None instead of a PIL.Image. Check your dataset/__getitem__.")
    if isinstance(img, torch.Tensor):
        return img  # already a tensor
    a = np.asarray(img.convert("RGB"), dtype=np.uint8)
    return torch.from_numpy(a).permute(2,0,1).float() / 255.0


class ToTensorPair:
    """Apply the same PIL->tensor conversion to (image, mask) pairs."""
    def __call__(self, img: Image.Image, mask: Image.Image) -> tuple[torch.Tensor, torch.Tensor]:
        return pil_to_tensor_rgb(img), pil_to_tensor_rgb(mask)

# --- 2) A tiny wrapper to enforce a uniform transform across heterogeneous datasets ---
from torch.utils.data import Dataset

class PairTransformWrapper(Dataset):
    """
    Wraps any (image, mask) dataset and applies (img_tf, mask_tf) before returning.
    If the underlying dataset already returns tensors, you can pass identity lambdas.
    """
    def __init__(self, base_ds: Dataset, img_tf=None, mask_tf=None):
        self.base = base_ds
        self.img_tf = img_tf
        self.mask_tf = mask_tf

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx: int):
        img, mask = self.base[idx]
        if self.img_tf is not None:
            img = self.img_tf(img)
        if self.mask_tf is not None:
            mask = self.mask_tf(mask)
        return img, mask

# --- 4) Standardize outputs: wrap PIL-returning datasets so *all* output tensors ---
to_tensor = ToTensorPair()

# Benthic sets likely return PIL unless you set transform; wrap them:
BOLIVAR_t     = PairTransformWrapper(SEAFLOWER_BOLIVAR,     img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)
COURTOWN_t    = PairTransformWrapper(SEAFLOWER_COURTOWN,    img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)
PAC_USA_t     = PairTransformWrapper(SEAVIEW_PAC_USA,       img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)
IDN_PHL_t     = PairTransformWrapper(SEAVIEW_IDN_PHL,       img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)
PAC_AUS_t     = PairTransformWrapper(SEAVIEW_PAC_AUS,       img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)
TETES_t       = PairTransformWrapper(TETES_PROVIDENCIA,     img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)
ATL_t         = PairTransformWrapper(SEAVIEW_ATL,           img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)
TAYRONA_t     = PairTransformWrapper(UNAL_BLEACHING_TAYRONA, img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)

# Coralscapes (PIL) → wrap as tensors too:
CS_train_t    = PairTransformWrapper(CORALSCAPES_train,     img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)
CS_val_t      = PairTransformWrapper(CORALSCAPES_val,       img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)
CS_test_t     = PairTransformWrapper(CORALSCAPES_test,      img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)

# CoralBleachingDataset already returns tensors (3,H,W). If so, no wrap needed:
BLEACH_all_t  = dataset  # keep as-is
# If you prefer symmetry, you can still wrap with identity:
# BLEACH_all_t = PairTransformWrapper(dataset, img_tf=lambda x:x, mask_tf=lambda x:x)

# --- 5) Concatenate EVERYTHING into one mega dataset ---
from torch.utils.data import ConcatDataset

ALL_DATA = ConcatDataset([
    BOLIVAR_t, COURTOWN_t, PAC_USA_t, IDN_PHL_t, PAC_AUS_t, TETES_t, ATL_t, TAYRONA_t,
    CS_train_t, CS_val_t, CS_test_t,
    BLEACH_all_t,
])

# --- 6) Use your existing padded collate (handles variable sizes) ---
from torch.utils.data import DataLoader

loader_all = DataLoader(
    ALL_DATA,
    batch_size=8,
    shuffle=True,
    num_workers=0,
    collate_fn=pad_collate,   # you already defined pad_collate earlier
)

xb, yb = next(iter(loader_all))
print(xb.shape, yb.shape)  # (B,3,Hmax,Wmax) (B,3,Hmax,Wmax)


torch.Size([8, 3, 2441, 3281]) torch.Size([8, 3, 2441, 3281])
