# AgroVision Training and Evaluation


# About this notebook
This notebook guides you through training and evaluating the AgroVision semantic segmentation model using the `agrovision_core` library. It is organized into small, focused sections:

1. Setup: ensure we run from the repository root so paths resolve correctly. ✅
2. Imports: bring in the functions used for training and evaluation.
3. Run training & evaluation: runs model training (can be long; see guidance below).
4. Data inspection: quick checks on the processed mask files and class distribution.
5. Visualization: visualize a single validation tile and the model prediction.

Quick tips:
- If you want to run a fast smoke test, set `cfg['training']['epochs'] = 1` and `training_cfg['num_workers'] = 0` before calling `train(cfg)`.
- Running full training may take many minutes/hours depending on hardware; use the run-instructions cell to customize.

Make sure you are using the project's Python environment (e.g. via `uv run jupyter lab` or the project's virtual env).

In [1]:
# Set the working directory to the repository root so relative paths (config, data, outputs) resolve reliably.
# This is useful when running the notebook from inside the `notebooks/` folder in VS Code.
import os
from pathlib import Path

ROOT = Path.cwd().parent
os.chdir(ROOT)
print("Working directory is now:", Path.cwd())

Working directory is now: d:\trying\AgroVision


In [2]:
# Quick explanation of imports below: these bring training, evaluation, and config helpers from the library.
# - `train` performs model training and returns (model, metrics)
# - `evaluate` runs evaluation and returns a metrics dict
# - `load_config` reads the YAML configuration used to control training/evaluation
from agrovision_core.train.train import train
from agrovision_core.train.evaluate import evaluate
from agrovision_core.utils.io import load_config

In [3]:
# Load configuration: this loads `config/config.yaml` into `cfg` (a Python dict)
# You can inspect `cfg` to see dataset paths, class ids, training hyperparams, etc.
cfg = load_config("config/config.yaml")

# Ensure training defaults are set and safe for quick debugging
training_cfg = cfg.setdefault("training", {})
training_cfg.setdefault("ignore_index", 0)
training_cfg.setdefault("min_labeled_fraction", 0.05)

# Optional: avoid Windows/Jupyter multiprocessing issues by using a single worker
# training_cfg["num_workers"] = 0



0.05

In [4]:
# WARNING: The next line will run model training, which can be long depending on your machine.
# For a quick smoke test, set `cfg['training']['epochs'] = 1` before calling train(cfg).
model, train_metrics = train(cfg)
# Run evaluation on the trained model (optional) — this returns a dict of metrics.
eval_results = evaluate(model, cfg)
train_metrics, eval_results

Using class weights (normalized mean=1): [0.0, 0.017464593052864075, 0.02802138775587082, 0.45504868030548096, 0.036044325679540634, 2.4706315994262695, 0.22541329264640808, 0.41647788882255554, 0.14953896403312683, 0.3745090961456299, 1.9349637031555176, 1.4807058572769165, 5.0264573097229, 0.3847229480743408]


OSError: [Errno 22] Invalid argument

# Guidance before running training
# - Make any configuration changes now (e.g., reduce epochs for a faster run)
# - If you get CUDA errors, set device to 'cpu' in cfg: cfg['model']['device'] = 'cpu'
# - If you only want to evaluate an existing checkpoint, skip the `train` call and call `evaluate(model=None, cfg=cfg, checkpoint_path='path/to/checkpoint.pth')`


In [None]:
cfg = load_config("config/config.yaml")

training_cfg = cfg.setdefault("training", {})
training_cfg.setdefault("ignore_index", 0)
training_cfg.setdefault("min_labeled_fraction", 0.05)

# Optional: avoid Windows/Jupyter multiprocessing issues
# training_cfg["num_workers"] = 0

model, train_metrics = train(cfg)
eval_results = evaluate(model, cfg)
train_metrics, eval_results


In [None]:
# Create a validation dataset and load a small batch for spot-checking
processed_dir = resolve_path(cfg["paths"]["data_processed"])
val_dataset = CropDataset(processed_dir, split="val")
contig_ids = sorted(val_dataset.index_to_raw.keys())
if contig_ids != list(range(len(contig_ids))):
    print("Warning: non-contiguous class_map values:", contig_ids)

device = next(model.parameters()).device
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=0)
batch = next(iter(val_loader))
images = batch["image"].to(device)
gt_masks = batch["mask"].cpu().numpy()

model.eval()
with torch.inference_mode():
    logits = model(images)
pred_masks = torch.argmax(logits, dim=1).cpu().numpy()

# Build color map from cfg (raw ids)
color_map = {}
for raw_id, info in cfg.get("classes", {}).items():
    try:
        color_map[int(raw_id)] = info.get("color", [0, 0, 0])
    except (TypeError, ValueError):
        pass

# What this check does
# - Confirms that the processed validation masks exist and can be opened
# - Shows which raw class ids appear in the masks and checks against the `classes` config
# - Loads a small validation batch and runs a forward pass to produce predictions
# Use these outputs to validate data processing, label mappings, and quick model sanity checks


In [None]:
# Quick visualization of one validation tile
# This cell is intended to show a single image, the ground truth mask, and the model's prediction.
# Requirements: `cfg`, `model` and a prepared `val_dataset` or run the data-check cell above first.

import numpy as np
import matplotlib.pyplot as plt
import torch

from agrovision_core.data.dataset import CropDataset
from agrovision_core.utils.io import resolve_path

if "cfg" not in globals() or "model" not in globals():
    raise RuntimeError("Run the config cell and training cell first.")

processed_dir = resolve_path(cfg["paths"]["data_processed"])
val_dataset = CropDataset(processed_dir, split="val")
device = next(model.parameters()).device
sample = val_dataset[0]
image = sample["image"]
mask = sample["mask"].cpu().numpy()

ignore_index = int(cfg.get("training", {}).get("ignore_index", 0))

# Run the model on one image (no gradient needed)
with torch.inference_mode():
    pred_logits = model(image.unsqueeze(0).to(device))
pred_mask = torch.argmax(pred_logits, dim=1).squeeze(0).cpu().numpy()

# Build color map for raw ids from the config
color_map = {}
for raw_id, info in cfg.get("classes", {}).items():
    try:
        color_map[int(raw_id)] = info.get("color", [0, 0, 0])
    except (TypeError, ValueError):
        pass

index_to_raw = val_dataset.index_to_raw

def _mask_to_rgb_quick(mask_contig):
    h, w = mask_contig.shape
    rgb = np.zeros((h, w, 3), dtype=np.uint8)
    for contig_id, raw_id in index_to_raw.items():
        rgb[mask_contig == contig_id] = color_map.get(int(raw_id), [0, 0, 0])
    rgb[mask_contig == ignore_index] = [255, 255, 255]
    return rgb

band_names = [b.get("name") for b in cfg.get("bands", [])]
band_idx = {name: i for i, name in enumerate(band_names) if name}
rgb_indices = [band_idx[name] for name in ("B04", "B03", "B02") if name in band_idx]
if len(rgb_indices) != 3:
    rgb_indices = [0, 1, 2]

def _to_rgb_quick(image_tensor):
    arr = image_tensor.cpu().numpy()
    rgb = arr[rgb_indices, :, :]
    rgb = np.stack(rgb, axis=-1)
    min_val = rgb.min(axis=(0, 1), keepdims=True)
    max_val = rgb.max(axis=(0, 1), keepdims=True)
    rgb = (rgb - min_val) / np.clip(max_val - min_val, 1e-6, None)
    return rgb

# Plot results
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(_to_rgb_quick(image))
axes[0].set_title("Image")
axes[1].imshow(_mask_to_rgb_quick(mask))
axes[1].set_title("GT")
axes[2].imshow(_mask_to_rgb_quick(pred_mask))
axes[2].set_title("Prediction")
for ax in axes:
    ax.axis("off")
plt.tight_layout()
