# AgroVision Training and Evaluation

This notebook imports the ML library from `backend/src` and runs training/evaluation.
Do not execute `train.py` as a script from notebooks.


In [None]:
import os
from pathlib import Path

os.chdir(Path.cwd().parent)
print("Working directory is now:", Path.cwd())

In [None]:

from src.train.train import train
from src.train.evaluate import evaluate
from src.utils.io import load_config


In [None]:
cfg = load_config("config/config.yaml")

training_cfg = cfg.setdefault("training", {})
training_cfg.setdefault("ignore_index", 0)
training_cfg.setdefault("min_labeled_fraction", 0.05)

# Optional: avoid Windows/Jupyter multiprocessing issues
# training_cfg["num_workers"] = 0

model, train_metrics = train(cfg)
eval_results = evaluate(model, cfg)
train_metrics, eval_results


In [None]:
# Checks: class distribution, label mapping, and predictions
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader

from src.data.dataset import CropDataset
from src.utils.io import resolve_path

if "cfg" not in globals():
    raise RuntimeError("Run the config cell first.")
if "model" not in globals():
    raise RuntimeError("Run the training cell first.")

processed_dir = resolve_path(cfg["paths"]["data_processed"])
masks_path = processed_dir / "val_masks.npy"
if not masks_path.exists():
    raise FileNotFoundError(f"Missing {masks_path}")

masks = np.load(masks_path, mmap_mode="r")
print("val_masks.npy shape:", masks.shape, "dtype:", masks.dtype)

# Class distribution from raw mask ids
counts = {}
chunk_size = 32
for start in range(0, masks.shape[0], chunk_size):
    chunk = masks[start:start + chunk_size]
    vals, cts = np.unique(chunk, return_counts=True)
    for v, c in zip(vals.tolist(), cts.tolist()):
        counts[int(v)] = counts.get(int(v), 0) + int(c)

total_pixels = sum(counts.values())
sorted_counts = sorted(counts.items(), key=lambda x: x[0])
print("Unique raw class ids in masks:", [k for k, _ in sorted_counts])
print("Background pixel ratio:", counts.get(0, 0) / total_pixels if total_pixels else 0.0)

cfg_class_ids = sorted(int(k) for k in cfg.get("classes", {}).keys())
extra = [cid for cid, _ in sorted_counts if cid not in cfg_class_ids]
missing = [cid for cid in cfg_class_ids if cid not in counts]
print("Extra ids not in cfg:", extra)
print("Missing ids not in masks:", missing)

print("Top classes by pixel count:")
for cid, cnt in sorted(counts.items(), key=lambda x: x[1], reverse=True)[:10]:
    name = cfg.get("classes", {}).get(cid, {}).get("name", "Unknown")
    print(f"  {cid} ({name}): {cnt} pixels ({cnt/total_pixels:.4%})")

val_dataset = CropDataset("val", cfg)
contig_ids = sorted(val_dataset.class_map.values())
if contig_ids != list(range(len(contig_ids))):
    print("Warning: non-contiguous class_map values:", contig_ids)

device = next(model.parameters()).device
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=0)
batch = next(iter(val_loader))
images = batch["image"].to(device)
gt_masks = batch["mask"].cpu().numpy()

model.eval()
with torch.inference_mode():
    logits = model(images)
pred_masks = torch.argmax(logits, dim=1).cpu().numpy()

# Build color map from cfg (raw ids)
color_map = {}
for raw_id, info in cfg.get("classes", {}).items():
    try:
        color_map[int(raw_id)] = info.get("color", [0, 0, 0])
    except (TypeError, ValueError):
        pass

index_to_raw = val_dataset.index_to_raw
ignore_index = int(cfg.get("training", {}).get("ignore_index", 0))

def _mask_to_rgb(mask_contig):
    h, w = mask_contig.shape
    rgb = np.zeros((h, w, 3), dtype=np.uint8)
    for contig_id, raw_id in index_to_raw.items():
        color = color_map.get(int(raw_id), [0, 0, 0])
        rgb[mask_contig == contig_id] = color
    rgb[mask_contig == ignore_index] = [255, 255, 255]
    return rgb

band_names = [b.get("name") for b in cfg.get("bands", [])]
band_idx = {name: i for i, name in enumerate(band_names) if name}
rgb_indices = [band_idx[name] for name in ("B04", "B03", "B02") if name in band_idx]
if len(rgb_indices) != 3:
    rgb_indices = [0, 1, 2]

def _to_rgb(image_tensor):
    arr = image_tensor.detach().cpu().numpy()
    rgb = arr[rgb_indices, :, :]
    rgb = np.stack(rgb, axis=-1)
    min_val = rgb.min(axis=(0, 1), keepdims=True)
    max_val = rgb.max(axis=(0, 1), keepdims=True)
    rgb = (rgb - min_val) / np.clip(max_val - min_val, 1e-6, None)
    return rgb

n = min(4, images.shape[0])
fig, axes = plt.subplots(n, 3, figsize=(12, 4 * n))
if n == 1:
    axes = np.expand_dims(axes, axis=0)

for i in range(n):
    axes[i, 0].imshow(_to_rgb(images[i]))
    axes[i, 0].set_title("Image")
    axes[i, 1].imshow(_mask_to_rgb(gt_masks[i]))
    axes[i, 1].set_title("GT mask")
    axes[i, 2].imshow(_mask_to_rgb(pred_masks[i]))
    axes[i, 2].set_title("Pred mask")
    for j in range(3):
        axes[i, j].axis("off")

plt.tight_layout()


In [None]:
# Quick visualization of one validation tile
import numpy as np
import matplotlib.pyplot as plt
import torch

from src.data.dataset import CropDataset

if "cfg" not in globals() or "model" not in globals():
    raise RuntimeError("Run the training cell first.")

val_dataset = CropDataset("val", cfg)
device = next(model.parameters()).device
sample = val_dataset[0]
image = sample["image"]
mask = sample["mask"].cpu().numpy()

ignore_index = int(cfg.get("training", {}).get("ignore_index", 0))

with torch.inference_mode():
    pred_logits = model(image.unsqueeze(0).to(device))
pred_mask = torch.argmax(pred_logits, dim=1).squeeze(0).cpu().numpy()

color_map = {}
for raw_id, info in cfg.get("classes", {}).items():
    try:
        color_map[int(raw_id)] = info.get("color", [0, 0, 0])
    except (TypeError, ValueError):
        pass

index_to_raw = val_dataset.index_to_raw

def _mask_to_rgb_quick(mask_contig):
    h, w = mask_contig.shape
    rgb = np.zeros((h, w, 3), dtype=np.uint8)
    for contig_id, raw_id in index_to_raw.items():
        rgb[mask_contig == contig_id] = color_map.get(int(raw_id), [0, 0, 0])
    rgb[mask_contig == ignore_index] = [255, 255, 255]
    return rgb

band_names = [b.get("name") for b in cfg.get("bands", [])]
band_idx = {name: i for i, name in enumerate(band_names) if name}
rgb_indices = [band_idx[name] for name in ("B04", "B03", "B02") if name in band_idx]
if len(rgb_indices) != 3:
    rgb_indices = [0, 1, 2]

def _to_rgb_quick(image_tensor):
    arr = image_tensor.cpu().numpy()
    rgb = arr[rgb_indices, :, :]
    rgb = np.stack(rgb, axis=-1)
    min_val = rgb.min(axis=(0, 1), keepdims=True)
    max_val = rgb.max(axis=(0, 1), keepdims=True)
    rgb = (rgb - min_val) / np.clip(max_val - min_val, 1e-6, None)
    return rgb

fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(_to_rgb_quick(image))
axes[0].set_title("Image")
axes[1].imshow(_mask_to_rgb_quick(mask))
axes[1].set_title("GT")
axes[2].imshow(_mask_to_rgb_quick(pred_mask))
axes[2].set_title("Prediction")
for ax in axes:
    ax.axis("off")
plt.tight_layout()
