# **1) Setup: mount Drive + install deps + set paths**

In [1]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# >>>> EDIT THIS to match where *your* 'preprocessed_data' lives in MyDrive
ROOT = "/content/drive/MyDrive/Data/preprocessed_data/preprocessed_data"   # e.g. .../preprocessed_data
YEAR = "2020"                                       # as in your Drive
REGIONS = None                                      # or e.g. ["0","1","2","3","4"]
MONTHS = (11,12,1,2,3,4,5,6,7)

# Install required libs (per Colab session)
!pip -q install rasterio torch numpy


Mounted at /content/drive
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.3/22.3 MB[0m [31m86.2 MB/s[0m eta [36m0:00:00[0m
[?25h

# **2) Data loader that matches your layout (year → region → months; separate label tree)**

In [2]:
# ==== Step 2 (REPLACE THIS WHOLE CELL) ====
from pathlib import Path
import glob
import numpy as np
import torch
from torch.utils.data import Dataset
import rasterio
from collections import Counter

class WheatTilesDataset(Dataset):
    """
    data/<YEAR>/<REGION>/<MONTH>/<TILE_ID>.tif   # ~11–13 bands, 64×64 (usually)
    label/<YEAR>/<REGION>/<TILE_ID>.tif          # 2 layers: [0]=valid, [1]=wheat

    LAZY version:
      - Index is built from filenames only (fast).
      - Only a small probe (limit) is used to infer bands & size.
      - Size/band fixes happen at read-time (pad/trim), not during __init__.
    """
    def __init__(self,
        root_preprocessed: str,
        year: str = "2020",
        regions=None,                           # e.g., ["0","1","2","3","4"] or None for all
        month_order=(11,12,1,2,3,4,5,6,7),
        temporal_layout=False,                  # True -> [T,B,64,64]; False -> [C,64,64]
        normalize=True,
        band_stats=None,                        # None or {band:(mean,std)} or {(t,b):(mean,std)}
        require_complete=True,                  # only keep tiles with ALL months present
        # band & size handling
        target_bands: int | None = None,        # None => probe few files to detect modal count
        target_size: tuple[int,int] | None = (64, 64),  # None => probe few files to infer
        size_policy: str = "pad",               # "pad" center pad/crop at read-time (recommended)
        probe_limit: int = 20                   # how many samples to open when probing
    ):
        self.root = Path(root_preprocessed)
        self.year = str(year)
        self.DATA = self.root / "data" / self.year
        self.LABEL = self.root / "label" / self.year

        self.months = tuple(month_order)
        self.temporal_layout = temporal_layout
        self.normalize = normalize
        self.band_stats = band_stats
        self.require_complete = require_complete
        self.size_policy = size_policy

        # Regions (filenames only)
        all_regions = sorted([p.name for p in self.DATA.iterdir() if p.is_dir()])
        self.regions = all_regions if regions is None else [r for r in regions if (self.DATA / r).exists()]

        # Build index from labels (filenames only)
        self.index = self._build_index()
        if not self.index:
            raise RuntimeError("No tiles found. Check ROOT/YEAR/regions structure.")

        # Probe a FEW files to infer bands/size if needed (fast)
        self._probe_bands_size(target_bands, target_size, probe_limit)

        # Sanity: labels should have 2 layers (open ONE label only)
        with rasterio.open(self.index[0]["label_path"]) as dsl:
            if dsl.count != 2:
                raise RuntimeError("Labels must have 2 layers (valid, wheat).")

    # ---------- helpers ----------
    def _build_index(self):
        idx = []
        for region in self.regions:
            label_dir = self.LABEL / region
            if not label_dir.exists():
                print(f"[WARN] missing label dir: {label_dir}"); continue
            for lab_fp in sorted(glob.glob(str(label_dir / "*.tif"))):
                tile_id = Path(lab_fp).stem
                month_paths = {}
                complete = True
                for m in self.months:
                    m_fp = self.DATA / region / str(m) / f"{tile_id}.tif"
                    if m_fp.exists(): month_paths[m] = str(m_fp)
                    else: complete = False
                if self.require_complete and not complete:
                    continue
                if not self.require_complete and len(month_paths) == 0:
                    continue
                idx.append({"region": region, "tile_id": tile_id,
                            "label_path": str(lab_fp), "month_paths": month_paths})
        return idx

    def _probe_bands_size(self, target_bands, target_size, limit):
        # Decide bands
        if target_bands is None:
            counts = Counter()
            seen = 0
            for rec in self.index:
                for m in self.months:
                    p = rec["month_paths"].get(m)
                    if p:
                        with rasterio.open(p) as ds:
                            counts[ds.count] += 1
                        seen += 1
                        break
                if seen >= limit: break
            if not counts:
                raise RuntimeError("Could not detect band counts.")
            self.num_bands = counts.most_common(1)[0][0]
        else:
            self.num_bands = int(target_bands)

        # Decide size
        if target_size is None:
            sizes = Counter()
            seen = 0
            for rec in self.index:
                for m in self.months:
                    p = rec["month_paths"].get(m)
                    if p:
                        with rasterio.open(p) as ds:
                            sizes[(ds.height, ds.width)] += 1
                        seen += 1
                        break
                if seen >= limit: break
            if not sizes:
                raise RuntimeError("Could not infer tile size.")
            self.H, self.W = sizes.most_common(1)[0][0]
        else:
            self.H, self.W = target_size

    def __len__(self): return len(self.index)

    def _fix_band_count(self, arr):
        B,H,W = arr.shape
        tb = self.num_bands
        if B == tb: return arr
        if B > tb:  return arr[:tb]
        pad = np.zeros((tb - B, H, W), dtype=np.float32)
        return np.concatenate([arr, pad], axis=0)

    def _fix_size(self, arr):
        # Always pad/crop CENTER to target (size_policy="pad")
        H,W = arr.shape[1:]
        th, tw = self.H, self.W
        if (H,W) == (th,tw): return arr
        out = np.zeros((arr.shape[0], th, tw), dtype=np.float32)
        h = min(H, th); w = min(W, tw)
        sy = (H - h)//2 if H>h else 0
        sx = (W - w)//2 if W>w else 0
        dy = (th - h)//2 if th>h else 0
        dx = (tw - w)//2 if tw>w else 0
        out[:, dy:dy+h, dx:dx+w] = arr[:, sy:sy+h, sx:sx+w]
        return out

    def _normalize(self, arrTBHW):
        T,B,H,W = arrTBHW.shape
        out = arrTBHW.copy()
        if self.band_stats is None:
            # per-tile min-max per band across time
            for b in range(B):
                band = out[:, b]
                vmin = np.nanmin(band); vmax = np.nanmax(band)
                out[:, b] = 0.0 if vmax <= vmin else (band - vmin)/(vmax - vmin)
            return out
        keyed_tb = any(isinstance(k, tuple) and len(k) == 2 for k in self.band_stats.keys())
        if keyed_tb:
            for t in range(T):
                for b in range(B):
                    mean, std = self.band_stats.get((t,b), (0.0,1.0))
                    if std == 0: std = 1.0
                    out[t,b] = (out[t,b] - mean)/std
        else:
            for b in range(B):
                mean, std = self.band_stats.get(b, (0.0,1.0))
                if std == 0: std = 1.0
                out[:,b] = (out[:,b] - mean)/std
        return out

    def _read_stack(self, month_paths):
        imgs = []
        for m in self.months:
            if m not in month_paths:
                arr = np.zeros((self.num_bands, self.H, self.W), dtype=np.float32)
            else:
                with rasterio.open(month_paths[m]) as ds:
                    arr = ds.read(out_dtype="float32")     # [B,H,W]
                arr = self._fix_band_count(arr)
                arr = self._fix_size(arr)
            imgs.append(arr)
        arrTBHW = np.stack(imgs, axis=0)                    # [T,B,H,W]
        if self.normalize: arrTBHW = self._normalize(arrTBHW)
        if self.temporal_layout: return arrTBHW
        T,B,H,W = arrTBHW.shape
        return arrTBHW.reshape(T*B, H, W)                   # [C,H,W]

    def _read_labels(self, label_path):
        with rasterio.open(label_path) as ds:
            lab = ds.read(out_dtype="float32")              # [2,H,W]
        lab = np.clip(lab, 0, 1)
        lab = self._fix_size(lab)
        return lab[0:1], lab[1:2]

    def __getitem__(self, i):
        rec = self.index[i]
        x = self._read_stack(rec["month_paths"])
        valid, wheat = self._read_labels(rec["label_path"])
        x = np.nan_to_num(x, nan=0.0)
        valid = np.nan_to_num(valid, nan=0.0)
        wheat = np.nan_to_num(wheat, nan=0.0)
        return {
            "x": torch.from_numpy(x),
            "valid_mask": torch.from_numpy(valid),
            "wheat_mask": torch.from_numpy(wheat),
            "tile_id": rec["tile_id"],
            "region": rec["region"]
        }


# **3) Quick audit: regions, months, bands, label layers**

In [3]:
from collections import defaultdict

DATA_ROOT = Path(ROOT) / "data" / YEAR
LABEL_ROOT = Path(ROOT) / "label" / YEAR

regions = REGIONS or sorted([p.name for p in DATA_ROOT.iterdir() if p.is_dir()])
print("Regions found:", regions)

def tiles_in(folder):
    return set(Path(fp).stem for fp in glob.glob(str(folder / "*.tif")))

for r in regions:
    print(f"\n== Region {r} ==")
    # imagery coverage
    month_sets = {}
    for m in MONTHS:
        mdir = DATA_ROOT / r / str(m)
        month_sets[m] = tiles_in(mdir) if mdir.exists() else set()
    inter = set.intersection(*[s for s in month_sets.values()]) if month_sets else set()
    union = set.union(*[s for s in month_sets.values()]) if month_sets else set()
    print(" tiles present in ALL months:", len(inter))
    print(" tiles in ANY month:", len(union))

    # peek one month file for bands/shape
    sample = next(iter(glob.iglob(str(DATA_ROOT / r / str(MONTHS[0]) / "*.tif"))), None)
    if sample:
        with rasterio.open(sample) as ds:
            print(" image bands:", ds.count, "| size:", (ds.height, ds.width))

    # labels are 2-layer?
    label_dir = LABEL_ROOT / r
    two_ok = True
    for lf in list(glob.iglob(str(label_dir / "*.tif")))[:10]:
        with rasterio.open(lf) as dsl:
            if dsl.count != 2:
                print(" !! non-2-layer label:", lf); two_ok = False; break
    print(" labels have 2 layers:", two_ok)


Regions found: ['0', '1', '2', '3', '4']

== Region 0 ==
 tiles present in ALL months: 807
 tiles in ANY month: 807
 image bands: 13 | size: (64, 64)
 labels have 2 layers: True

== Region 1 ==
 tiles present in ALL months: 589
 tiles in ANY month: 591
 image bands: 13 | size: (64, 21)
 labels have 2 layers: True

== Region 2 ==
 tiles present in ALL months: 596
 tiles in ANY month: 596
 image bands: 13 | size: (64, 64)
 labels have 2 layers: True

== Region 3 ==
 tiles present in ALL months: 723
 tiles in ANY month: 723
 image bands: 13 | size: (64, 64)
 labels have 2 layers: True

== Region 4 ==
 tiles present in ALL months: 969
 tiles in ANY month: 993
 image bands: 13 | size: (64, 64)
 labels have 2 layers: True


# **4) Build dataset + DataLoader (works directly from Drive)**

In [None]:
# ==== Fast subset picker (put this ABOVE your build/loader cell) ====
from pathlib import Path
import glob

# how many per region to start with (make it small, you can raise later)
K_PER_REGION = 32

DATA_ROOT = Path(ROOT) / "data" / YEAR
LABEL_ROOT = Path(ROOT) / "label" / YEAR
REGIONS_EFF = REGIONS or sorted([p.name for p in DATA_ROOT.iterdir() if p.is_dir()])

def tiles_in(dirpath: Path):
    return set(Path(fp).stem for fp in glob.glob(str(dirpath / "*.tif")))

def pick_complete_tiles_per_region(regions, months, k_per_region):
    keep = set()   # set of (region, tile_id)
    counts = {}
    for r in regions:
        # label tile_ids define the universe
        label_ids = tiles_in(LABEL_ROOT / r)
        # intersect across all months
        month_sets = []
        for m in months:
            mdir = DATA_ROOT / r / str(m)
            month_sets.append(tiles_in(mdir) if mdir.exists() else set())
        if month_sets:
            complete = set.intersection(label_ids, *month_sets)
        else:
            complete = set()
        chosen = sorted(list(complete))[:k_per_region]
        for t in chosen:
            keep.add((r, t))
        counts[r] = len(chosen)
    return keep, counts

KEEP, COUNTS = pick_complete_tiles_per_region(REGIONS_EFF, MONTHS, K_PER_REGION)
print("Picked per region:", COUNTS, "| total:", sum(COUNTS.values()))


Picked per region: {'0': 32, '1': 32, '2': 32, '3': 32, '4': 32} | total: 160


In [None]:
# ==== Step 4 (temporal layout: keep 9 and 13 as separate dims) ====
from pathlib import Path
import glob
from torch.utils.data import Subset, DataLoader
import time

# 0) Small subset by filenames (unchanged)
K_PER_REGION = 32
DATA_ROOT = Path(ROOT) / "data" / YEAR
LABEL_ROOT = Path(ROOT) / "label" / YEAR
REGIONS_EFF = REGIONS or sorted([p.name for p in DATA_ROOT.iterdir() if p.is_dir()])

def tiles_in(dirpath: Path):
    return set(Path(fp).stem for fp in glob.glob(str(dirpath / "*.tif")))

def pick_complete_tiles_per_region(regions, months, k_per_region):
    keep = set(); counts = {}
    for r in regions:
        label_ids = tiles_in(LABEL_ROOT / r)
        month_sets = []
        for m in months:
            mdir = DATA_ROOT / r / str(m)
            month_sets.append(tiles_in(mdir) if mdir.exists() else set())
        complete = set.intersection(label_ids, *month_sets) if month_sets else set()
        chosen = sorted(list(complete))[:k_per_region]
        counts[r] = len(chosen)
        for t in chosen: keep.add((r,t))
    return keep, counts

KEEP, COUNTS = pick_complete_tiles_per_region(REGIONS_EFF, MONTHS, K_PER_REGION)
print("Picked per region:", COUNTS, "| total:", sum(COUNTS.values()))

# 1) Build dataset with temporal layout
ds_full = WheatTilesDataset(
    root_preprocessed=ROOT,
    year=YEAR,
    regions=REGIONS,
    month_order=MONTHS,
    temporal_layout=True,     # <<<<<< keep [T,B,H,W] = [9,13,64,64]
    normalize=True,
    band_stats=None,
    require_complete=True,
    target_bands=None,
    target_size=(64,64),
    size_policy="pad",
    probe_limit=12
)
print("Full tiles (after indexing):", len(ds_full))

# 2) Subset to the chosen (region, tile_id)s
keep_set = set(KEEP)
keep_idx = [i for i, rec in enumerate(ds_full.index)
            if (rec["region"], rec["tile_id"]) in keep_set]
ds = Subset(ds_full, keep_idx)
print("Subset tiles:", len(ds))

# 3) Inspect one sample — should be [9,13,64,64]
t0=time.time(); s = ds[0]; t1=time.time()
print("One sample:", round(t1-t0,3),"sec | x:", s["x"].shape,
      "| valid:", s["valid_mask"].shape, "| wheat:", s["wheat_mask"].shape)

# 4) DataLoader — batches will be [N, 9, 13, 64, 64]
loader = DataLoader(ds, batch_size=8, shuffle=True, num_workers=0, pin_memory=False)

t2=time.time(); b = next(iter(loader)); t3=time.time()
print("First batch:", round(t3-t2,3),"sec")
print("Batch x:", b["x"].shape)                 # -> torch.Size([8, 9, 13, 64, 64])
print("Batch valid:", b["valid_mask"].shape)    # -> torch.Size([8, 1, 64, 64])
print("Batch wheat:", b["wheat_mask"].shape)    # -> torch.Size([8, 1, 64, 64])


Picked per region: {'0': 32, '1': 32, '2': 32, '3': 32, '4': 32} | total: 160
Full tiles (after indexing): 3684
Subset tiles: 160
One sample: 0.27 sec | x: torch.Size([9, 13, 64, 64]) | valid: torch.Size([1, 64, 64]) | wheat: torch.Size([1, 64, 64])
First batch: 23.681 sec
Batch x: torch.Size([8, 9, 13, 64, 64])
Batch valid: torch.Size([8, 1, 64, 64])
Batch wheat: torch.Size([8, 1, 64, 64])


# **5) (Optional) Compute & use global band stats**

In [None]:
stats = compute_band_stats(ROOT, year=YEAR, regions=REGIONS, month_order=MONTHS, sample_limit=500)
print("Example:", list(stats.items())[:3])

# Save/load if you want (not required)
import json, os
with open('/content/band_stats.json', 'w') as f: json.dump(stats, f)

with open('/content/band_stats.json') as f:
    band_stats = {int(k): tuple(v) for k,v in json.load(f).items()}

ds_norm = WheatTilesDataset(
    root_preprocessed=ROOT,
    year=YEAR,
    regions=REGIONS,
    month_order=MONTHS,
    temporal_layout=False,
    normalize=True,
    band_stats=band_stats,
    require_complete=True
)
print("Norm’d sample:", ds_norm[0]["x"].shape)


NameError: name 'compute_band_stats' is not defined