In [6]:
%%writefile disaster-ai/src/cv_dataset.py
from pathlib import Path
from PIL import Image
import json, random
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class XBDDamageDataset(Dataset):
    def __init__(self, root, split="train", img_size=224, seed=42):
        self.root = Path(root)
        # Find all images with "post" in the filename (after-disaster tiles)
        self.items = sorted(self.root.rglob("*post*.png"))
        random.seed(seed); random.shuffle(self.items)

        # 80% training, 20% validation
        n = int(0.8*len(self.items))
        self.items = self.items[:n] if split=="train" else self.items[n:]

        # Transform images (resize + normalize)
        self.tf = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
        ])

    def _label_from_geojson(self, img_path: Path):
        # Each image has a corresponding JSON with building damage info
        label_json = img_path.with_suffix("").with_name(img_path.stem.replace(".png","") + "_label.json")
        if not label_json.exists():
            return 0  # default: no-damage if label missing
        data = json.loads(label_json.read_text())
        mapping = {"no-damage":0, "minor-damage":1, "major-damage":2, "destroyed":3}
        counts = [0,0,0,0]
        for f in data.get("features", []):
            props = f.get("properties", {})
            dmg = props.get("subtype") or props.get("damage")
            if dmg in mapping: counts[mapping[dmg]] += 1
        return int(counts.index(max(counts))) if sum(counts)>0 else 0

    def __len__(self): 
        return len(self.items)

    def __getitem__(self, idx):
        p = self.items[idx]
        img = Image.open(p).convert("RGB")
        y = self._label_from_geojson(p)
        return self.tf(img), torch.tensor(y, dtype=torch.long)

Writing disaster-ai/src/cv_dataset.py
