# Data Loader for Mac with Google Drive

Install Google Drive Desktop app and sync your folders, then update the paths below to match your Google Drive location.

Typical Mac Google Drive path: `~/Library/CloudStorage/GoogleDrive-<your-email>/My Drive/`

In [None]:
import os
from pathlib import Path

# UPDATE THIS: Your Google Drive path on Mac
GOOGLE_DRIVE_ROOT = Path.home() / "Library/CloudStorage/GoogleDrive-jokas.jasas@gmail.com/My Drive"

# Data paths (adjust folder names as needed)
data_path = GOOGLE_DRIVE_ROOT / "dc4data"
benthic_path = data_path / "benthic_datasets"
coralbleaching_path = data_path / "coral_bleaching"
coralscapes_path = data_path / "coralscapes"

# HuggingFace cache directory (use Google Drive to save disk space)
HF_CACHE_DIR = GOOGLE_DRIVE_ROOT / ".hf_cache"
HF_CACHE_DIR.mkdir(exist_ok=True)

print(f"HuggingFace cache: {HF_CACHE_DIR}")

# Verify paths exist
for p in [data_path, benthic_path, coralbleaching_path, coralscapes_path]:
    if os.path.exists(p):
        print(f"✓ Path exists: {p}")
    else:
        print(f"✗ Path NOT found: {p}")
        print(f"  Please check if folder is synced in Google Drive Desktop")

## Benthic Dataset

In [2]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch
import os
from pathlib import Path

# Update these paths to match your Google Drive structure
benthic_base = benthic_path / "mask_labels/reef_support"

benthic_folders = [
    "SEAFLOWER_BOLIVAR",
    "SEAFLOWER_COURTOWN",
    "SEAVIEW_PAC_USA",
    "SEAVIEW_IDN_PHL",
    "SEAVIEW_PAC_AUS",
    "TETES_PROVIDENCIA",
    "SEAVIEW_ATL",
    "UNAL_BLEACHING_TAYRONA",
]

benthic_paths = [benthic_base / folder for folder in benthic_folders]

class SegmentationDataset(Dataset):
    def __init__(self, img_dir, mask_dir, transform=None):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.transform = transform

        # keep only typical image files; sorted for reproducibility
        exts = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")
        self.images = sorted([f for f in os.listdir(img_dir) if f.lower().endswith(exts)])

        if not self.images:
            raise FileNotFoundError(f"No images found in {img_dir}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        img_path = os.path.join(self.img_dir, img_name)

        # Map "name.ext" -> "name_mask.png"
        stem = Path(img_name).stem
        mask_name = f"{stem}_mask.png"
        mask_path = os.path.join(self.mask_dir, mask_name)

        if not os.path.exists(mask_path):
            raise FileNotFoundError(
                f"Mask not found for {img_name}. Expected: {mask_path} "
                "(pattern '<image_stem>_mask.png')."
            )

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path)

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask


def get_mask(benthic_folder):
    mask_path = os.path.join(benthic_folder, 'masks_stitched')
    return mask_path

def get_image(benthic_folder):
    image_path = os.path.join(benthic_folder, 'images')
    return image_path


# Create datasets
SEAFLOWER_BOLIVAR = SegmentationDataset(get_image(benthic_paths[0]), get_mask(benthic_paths[0]))
SEAFLOWER_COURTOWN = SegmentationDataset(get_image(benthic_paths[1]), get_mask(benthic_paths[1]))
SEAVIEW_PAC_USA = SegmentationDataset(get_image(benthic_paths[2]), get_mask(benthic_paths[2]))
SEAVIEW_IDN_PHL = SegmentationDataset(get_image(benthic_paths[3]), get_mask(benthic_paths[3]))
SEAVIEW_PAC_AUS = SegmentationDataset(get_image(benthic_paths[4]), get_mask(benthic_paths[4]))
TETES_PROVIDENCIA = SegmentationDataset(get_image(benthic_paths[5]), get_mask(benthic_paths[5]))
SEAVIEW_ATL = SegmentationDataset(get_image(benthic_paths[6]), get_mask(benthic_paths[6]))
UNAL_BLEACHING_TAYRONA = SegmentationDataset(get_image(benthic_paths[7]), get_mask(benthic_paths[7]))

print(f"✓ Loaded {len(benthic_folders)} benthic datasets")

✓ Loaded 8 benthic datasets


## Coral Scapes

In [None]:
from datasets import Dataset as HFDataset, concatenate_datasets
from torch.utils.data import Dataset
from PIL import Image
import pyarrow.parquet as pq
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, Union, Sequence

class _ParquetMasksByIndex:
    def __init__(self, parquet_dir_or_paths: Union[str, Path, Sequence[Union[str, Path]]],
                 column_png: str = "label_health_rgb_png"):
        # normalize to list of parquet files
        if isinstance(parquet_dir_or_paths, (str, Path)):
            p = Path(parquet_dir_or_paths)
            if p.is_dir():
                paths = sorted(p.glob("*.parquet"))
                if not paths:
                    raise FileNotFoundError(f"No parquet files in directory: {p}")
            else:
                if not p.exists():
                    raise FileNotFoundError(f"Parquet file not found: {p}")
                paths = [p]
        else:
            paths = [Path(x) for x in parquet_dir_or_paths]
            for p in paths:
                if not p.exists():
                    raise FileNotFoundError(f"Parquet file not found: {p}")

        self._tables = [pq.read_table(p) for p in paths]
        for t in self._tables:
            if "index" not in t.column_names or column_png not in t.column_names:
                raise ValueError(f"Parquet must have 'index' and '{column_png}'. Got: {t.column_names}")
        self._colname = column_png

        # build index -> (table_id, row_id)
        self._map = {}
        for tid, t in enumerate(self._tables):
            idxs = t["index"].to_pylist()
            for rid, ds_idx in enumerate(idxs):
                self._map[int(ds_idx)] = (tid, rid)

    def get_mask_pil(self, ds_index: int) -> Image.Image:
        tid, rid = self._map[ds_index]
        cell = self._tables[tid][self._colname][rid].as_py()
        if isinstance(cell, memoryview):
            cell = cell.tobytes()
        elif isinstance(cell, bytearray):
            cell = bytes(cell)
        return Image.open(BytesIO(cell)).convert("RGB")
    
class CoralScapesImagesMasks(Dataset):
    """
    Load CoralScapes from local Arrow files (no HuggingFace download needed!)
    
    Images: From local Arrow files on Google Drive
    Masks: From local Parquet files
    """
    def __init__(self,
                 parquet_dir_or_paths: Union[str, Path, Sequence[Union[str, Path]]],
                 arrow_dir: Union[str, Path],
                 img_transform: Optional[Callable] = None,
                 mask_transform: Optional[Callable] = None):
        
        # Load images from local Arrow files
        arrow_path = Path(arrow_dir)
        if not arrow_path.exists():
            raise FileNotFoundError(f"Arrow directory not found: {arrow_path}")
        
        arrow_files = sorted(arrow_path.glob("*.arrow"))
        if not arrow_files:
            raise FileNotFoundError(f"No arrow files found in: {arrow_path}")
        
        print(f"Loading from {len(arrow_files)} Arrow files in {arrow_path}")
        shards = [HFDataset.from_file(str(p)) for p in arrow_files]
        self.img_ds = concatenate_datasets(shards) if len(shards) > 1 else shards[0]
        print(f"✓ Loaded {len(self.img_ds)} images")

        # Masks from parquet
        self.masks = _ParquetMasksByIndex(parquet_dir_or_paths)

        self.img_tf = img_transform
        self.mask_tf = mask_transform

    def __len__(self):
        return len(self.img_ds)

    def __getitem__(self, idx: int):
        rec = self.img_ds[idx]
        img: Image.Image = rec["image"].convert("RGB")
        mask: Image.Image = self.masks.get_mask_pil(idx)

        if self.img_tf is not None:
            img = self.img_tf(img)
        if self.mask_tf is not None:
            mask = self.mask_tf(mask)
        return img, mask


# Paths to LOCAL data (no downloads needed!)
TRAIN_PARQUET_DIR = "data_preprocessing/coralscapes_export/parquet/train"
VAL_PARQUET_DIR   = "data_preprocessing/coralscapes_export/parquet/validation"
TEST_PARQUET_DIR  = "data_preprocessing/coralscapes_export/parquet/test"

TRAIN_ARROW_DIR = coralscapes_path / "train"
VAL_ARROW_DIR = coralscapes_path / "validation"
TEST_ARROW_DIR = coralscapes_path / "test"

# Load from local files - NO HUGGINGFACE DOWNLOAD!
CORALSCAPES_train = CoralScapesImagesMasks(parquet_dir_or_paths=TRAIN_PARQUET_DIR, arrow_dir=TRAIN_ARROW_DIR)
CORALSCAPES_val   = CoralScapesImagesMasks(parquet_dir_or_paths=VAL_PARQUET_DIR, arrow_dir=VAL_ARROW_DIR)
CORALSCAPES_test  = CoralScapesImagesMasks(parquet_dir_or_paths=TEST_PARQUET_DIR, arrow_dir=TEST_ARROW_DIR)

print(f"✓ All CoralScapes datasets loaded from local files!")

## Coral Bleaching

In [None]:
# UPDATE: Point to your Google Drive coral bleaching data
coral_bleaching_images = coralbleaching_path / "reef_support/UNAL_BLEACHING_TAYRONA/images"
coral_bleaching_combined_masks = "data_preprocessing/coralbleaching/combined_masks"
coral_bleaching_single_masks = "data_preprocessing/coralbleaching/single_masks"

In [None]:
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from torchvision.transforms import Resize

resize = Resize((640, 640))

def pil_to_tensor(img):
    """Converts PIL image to normalized torch tensor and resizes to 640x640."""
    if img is None:
        raise ValueError("Received None instead of a PIL.Image.")
    if isinstance(img, torch.Tensor):
        return img
    img = resize(img)
    a = np.asarray(img.convert("RGB"), dtype=np.uint8).copy()
    return torch.from_numpy(a).permute(2, 0, 1).float() / 255.0


class CoralBleachingDataset(Dataset):
    def __init__(self, images_dir, combined_dir, single_dir):
        self.images_dir = Path(images_dir)
        self.combined_dir = Path(combined_dir)
        self.single_bleached = Path(single_dir) / "bleached_blue"
        self.single_non = Path(single_dir) / "non_bleached_red"

        imgs = []
        for e in ("*.png","*.jpg","*.jpeg"):
            imgs += list(self.images_dir.glob(e))
        self.images = sorted(imgs)

        self.pairs = self._match_pairs()
        print(f"Found {len(self.pairs)} image-mask pairs")

    def _match_pairs(self):
        def index_dir(d):
            out={}
            for e in ("*.png","*.jpg","*.jpeg"):
                for p in d.glob(e): 
                    out[p.stem.lower()] = p
            return out
        
        cmb = index_dir(self.combined_dir)
        ble = index_dir(self.single_bleached) if self.single_bleached.exists() else {}
        non = index_dir(self.single_non) if self.single_non.exists() else {}

        pairs=[]
        for img in self.images:
            # Remove _corr suffix if present
            key = img.stem.lower().replace("_corr", "")
            
            # Try combined mask first
            k_cmb = f"{key}_combined"
            if k_cmb in cmb: 
                pairs.append((img, cmb[k_cmb]))
                continue
            
            # Try bleached mask
            cand = [p for k,p in ble.items() if key in k or k in key]
            if cand: 
                pairs.append((img, cand[0]))
                continue
            
            # Try non-bleached mask
            cand = [p for k,p in non.items() if key in k or k in key]
            if cand: 
                pairs.append((img, cand[0]))
                
        return pairs

    def __len__(self): 
        return len(self.pairs)
    
    def __getitem__(self, i):
        ip, mp = self.pairs[i]
        x = pil_to_tensor(Image.open(ip))
        y = pil_to_tensor(Image.open(mp))
        return x, y

def pad_collate(batch):
    imgs, masks = zip(*batch)
    C = imgs[0].shape[0]
    H = max(t.shape[1] for t in imgs)
    W = max(t.shape[2] for t in imgs)
    xb = torch.zeros(len(imgs), C, H, W, dtype=imgs[0].dtype)
    yb = torch.zeros(len(masks), C, H, W, dtype=masks[0].dtype)
    for i, (x, y) in enumerate(zip(imgs, masks)):
        h, w = x.shape[1], x.shape[2]
        xb[i, :, :h, :w] = x
        yb[i, :, :h, :w] = y
    return xb, yb

dataset = CoralBleachingDataset(
    images_dir=coral_bleaching_images,
    combined_dir=coral_bleaching_combined_masks,
    single_dir=coral_bleaching_single_masks
)

if len(dataset) > 0:
    loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0, collate_fn=pad_collate)
    xb, yb = next(iter(loader))
    print(f"✓ Coral Bleaching dataset: {xb.shape}, {yb.shape}")
else:
    print("⚠ No coral bleaching pairs found - skipping this dataset")

# Combine ALL into ONE

In [None]:
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset, ConcatDataset, DataLoader

def pil_to_tensor_rgb(img):
    if img is None:
        raise ValueError("Received None instead of a PIL.Image.")
    if isinstance(img, torch.Tensor):
        return img
    a = np.asarray(img.convert("RGB"), dtype=np.uint8)
    return torch.from_numpy(a).permute(2,0,1).float() / 255.0


class ToTensorPair:
    def __call__(self, img: Image.Image, mask: Image.Image):
        return pil_to_tensor_rgb(img), pil_to_tensor_rgb(mask)


class PairTransformWrapper(Dataset):
    def __init__(self, base_ds: Dataset, img_tf=None, mask_tf=None):
        self.base = base_ds
        self.img_tf = img_tf
        self.mask_tf = mask_tf

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx: int):
        img, mask = self.base[idx]
        if self.img_tf is not None:
            img = self.img_tf(img)
        if self.mask_tf is not None:
            mask = self.mask_tf(mask)
        return img, mask


# Wrap benthic datasets
BOLIVAR_t     = PairTransformWrapper(SEAFLOWER_BOLIVAR,     img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)
COURTOWN_t    = PairTransformWrapper(SEAFLOWER_COURTOWN,    img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)
PAC_USA_t     = PairTransformWrapper(SEAVIEW_PAC_USA,       img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)
IDN_PHL_t     = PairTransformWrapper(SEAVIEW_IDN_PHL,       img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)
PAC_AUS_t     = PairTransformWrapper(SEAVIEW_PAC_AUS,       img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)
TETES_t       = PairTransformWrapper(TETES_PROVIDENCIA,     img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)
ATL_t         = PairTransformWrapper(SEAVIEW_ATL,           img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)
TAYRONA_t     = PairTransformWrapper(UNAL_BLEACHING_TAYRONA, img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)

# Wrap CoralScapes
CS_train_t    = PairTransformWrapper(CORALSCAPES_train,     img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)
CS_val_t      = PairTransformWrapper(CORALSCAPES_val,       img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)
CS_test_t     = PairTransformWrapper(CORALSCAPES_test,      img_tf=pil_to_tensor_rgb, mask_tf=pil_to_tensor_rgb)

# Bleaching dataset already returns tensors
BLEACH_all_t  = dataset

# Concatenate all datasets
ALL_DATA = ConcatDataset([
    BOLIVAR_t, COURTOWN_t, PAC_USA_t, IDN_PHL_t, PAC_AUS_t, TETES_t, ATL_t, TAYRONA_t,
    CS_train_t, CS_val_t, CS_test_t,
    BLEACH_all_t,
])

loader_all = DataLoader(
    ALL_DATA,
    batch_size=8,
    shuffle=True,
    num_workers=0,
    collate_fn=pad_collate,
)

xb, yb = next(iter(loader_all))
print(f"✓ Combined dataset: {xb.shape}, {yb.shape}")

# YOLO Pipeline

In [None]:
from torch.utils.data import Dataset

class CoralTrainWrapper(Dataset):
    """
    Wraps the combined ALL_DATA dataset to produce:
    image: (3,H,W)
    mask: (1,H,W)
    label: scalar (0=healthy, 1=bleached)
    """
    def __init__(self, base_ds):
        self.base = base_ds

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        img, mask = self.base[idx]
        img = img.float()
        mask = mask.float()

        # Create binary coral presence mask (1 = coral area, 0 = background)
        coral_mask = (mask.sum(dim=0, keepdim=True) > 0).float()

        # Determine bleaching label from color mask (red>blue → bleached)
        red_pixels = mask[0] > mask[2]
        bleached = (red_pixels.float().mean() > 0.01)
        label = torch.tensor(int(bleached), dtype=torch.long)

        return {"image": img, "mask": coral_mask, "label": label}

wrapped_data = CoralTrainWrapper(ALL_DATA)

loader_all = DataLoader(
    wrapped_data,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

print(f"✓ YOLO wrapper ready with {len(wrapped_data)} samples")

In [None]:
import torch
from coral_yolo.models.coral_classifier import CoralClassifier
from coral_yolo.losses.classification_loss import CoralClassificationLoss
from coral_yolo.engine.metrics import ClsPRF1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = CoralClassifier(
    yolo_weights="yolo11s.pt",
    num_classes=2,
    freeze_backbone=True
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=5e-4)
criterion = CoralClassificationLoss()
metric = ClsPRF1()

print("✓ Model initialized")

# ResNet Coral Health Classification (3 Classes)

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import numpy as np

class CoralHealthDataset(Dataset):
    """
    2-class coral health classification dataset.
    
    Classes:
        0: Healthy (red channel high in mask)
        1: Unhealthy (blue channel high in mask)
    """
    def __init__(self, base_ds, transform=None):
        self.base = base_ds
        self.transform = transform or transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        img, mask = self.base[idx]
        
        # img is already (3, H, W) tensor normalized to [0, 1]
        # mask is (3, H, W) RGB tensor
        
        # Apply transforms
        img = self.transform(img)
        
        # Determine class from mask colors
        # Red channel high = healthy (class 0)
        # Blue channel high = unhealthy (class 1)
        
        red_mean = mask[0].mean().item()
        blue_mean = mask[2].mean().item()
        
        # Classification logic
        if red_mean > blue_mean:
            label = 0  # Healthy (red)
        else:
            label = 1  # Unhealthy (blue)
        
        return img, label


# Create ResNet dataset from combined data
resnet_dataset = CoralHealthDataset(ALL_DATA)

# Train/val split (80/20) - same split as YOLO
train_size = int(0.8 * len(resnet_dataset))
val_size = len(resnet_dataset) - train_size
resnet_train, resnet_val = random_split(resnet_dataset, [train_size, val_size])

# Create dataloaders
resnet_train_loader = DataLoader(resnet_train, batch_size=32, shuffle=True, num_workers=0)
resnet_val_loader = DataLoader(resnet_val, batch_size=32, shuffle=False, num_workers=0)

print(f"✓ ResNet dataset prepared:")
print(f"  Train samples: {len(resnet_train)}")
print(f"  Val samples: {len(resnet_val)}")
print(f"  Classes: 0=Healthy, 1=Unhealthy")

In [None]:
import sys
sys.path.append('/Users/jokubas/Desktop/Coral-reefs-DBL4')

from resnet import CoralResNet, CoralTrainer
import torch.nn as nn
import torch.optim as optim

# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

resnet_model = CoralResNet(
    num_classes=2,
    pretrained=True,
    freeze_backbone=True
)

# Setup training
optimizer = optim.AdamW(resnet_model.parameters(), lr=1e-3, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.5)

# Initialize trainer
trainer = CoralTrainer(
    model=resnet_model,
    device=device,
    class_names=["Healthy", "Unhealthy"]
)

print("✓ ResNet model initialized")
print(f"  Total parameters: {sum(p.numel() for p in resnet_model.parameters()):,}")
print(f"  Trainable parameters: {sum(p.numel() for p in resnet_model.parameters() if p.requires_grad):,}")

In [None]:
# Train ResNet model
print("Starting ResNet training...")
print("="*60)

history = trainer.fit(
    train_loader=resnet_train_loader,
    val_loader=resnet_val_loader,
    epochs=10,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler
)

print("\n✓ Training complete!")
print(f"Best model saved to: resnet/best_model.pth")