# GrainSet: notebook quickstart

This notebook runs the repository code (config → datalist → dataset → model → loss/optimizer) and performs a single forward/backward step.

If you see missing-module errors (e.g. `albumentations`, `timm`, `opencv-python`), install dependencies first (see `requirements.txt`).

In [None]:
from __future__ import annotations

import os
import sys
from pathlib import Path

# Make imports work no matter where the notebook is opened from.
repo_root = Path.cwd()
if not (repo_root / "src").exists() and (repo_root.parent / "src").exists():
    repo_root = repo_root.parent
os.chdir(repo_root)

sys.path.insert(0, str(repo_root / "src"))

print("repo_root:", repo_root)


## Prepare datalist (train/val/test splits)

The training code expects split files under `runs/datalist/` (it also tolerates `runs/datalist/datalist/` if you extracted the zip as-is).

In [None]:
import zipfile

runs_datalist = repo_root / "runs" / "datalist"
runs_datalist.mkdir(parents=True, exist_ok=True)

# Extract datasets/datalist.zip if split files are missing.
has_any_split = any(runs_datalist.rglob("*_train.txt"))
if not has_any_split:
    zip_path = repo_root / "datasets" / "datalist.zip"
    if not zip_path.exists():
        raise FileNotFoundError(f"Missing {zip_path}.")
    with zipfile.ZipFile(zip_path) as z:
        z.extractall(runs_datalist)

print("datalist files:", len(list(runs_datalist.rglob("*.txt"))))


## Load config + pick dataset path

This repo includes a sample wheat dataset under `datasets/wheat/wheat/`.

In [None]:
from config import cfg

cfg.update_from_file(str(repo_root / "configs" / "wheat.yaml"))
cfg.PHASE = "train"

# Auto-detect whether the dataset lives at datasets/wheat/ or datasets/wheat/wheat/.
data_root = repo_root / "datasets" / "wheat"
if not (data_root / "train").exists() and (data_root / "wheat" / "train").exists():
    data_root = data_root / "wheat"

cfg.DATASET.PATH = str(data_root)

# Notebook-friendly defaults (Windows/Jupyter + smaller batches).
cfg.DATASET.WORKERS = 0
cfg.TRAIN.BATCH = 8
cfg.TEST.BATCH = 8

print("cfg.DATASET.PATH:", cfg.DATASET.PATH)
print("cfg.MODEL.NAME:", cfg.MODEL.NAME)
print("cfg.DATASET.CLASS_NUMS:", cfg.DATASET.CLASS_NUMS)


## Build dataloaders (using repo dataset + augmentations)

In [None]:
import torch

from datasets.reader import get_imglists
from datasets.GrainDataset import GrainDataset
from datasets.augment import get_transforms

train_imgs = get_imglists(root=cfg.DATASET.PATH, split="train", phase="train")
val_imgs = get_imglists(root=cfg.DATASET.PATH, split="val", phase="train")

# Keep this quick: sample a small subset.
train_imgs_small = train_imgs.sample(n=min(256, len(train_imgs)), random_state=0)
val_imgs_small = val_imgs.sample(n=min(128, len(val_imgs)), random_state=0)

train_transform, val_transform = get_transforms()
train_ds = GrainDataset(train_imgs_small, mode="train", transforms=train_transform)
val_ds = GrainDataset(val_imgs_small, mode="val", transforms=val_transform)

train_loader = torch.utils.data.DataLoader(
    train_ds,
    batch_size=cfg.TRAIN.BATCH,
    shuffle=True,
    num_workers=int(cfg.DATASET.WORKERS),
    pin_memory=torch.cuda.is_available(),
    drop_last=True,
)

images, labels, filenames = next(iter(train_loader))
print("batch images:", images.shape, images.dtype)
print("batch labels:", labels.shape, labels[:8].tolist())
print("example file:", filenames[0])


## Instantiate a model and run a forward/backward step

In [None]:
from models.model import get_model
from solver import get_loss, get_optimizer
from utils.misc import accuracy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

model = get_model(cfg, num_classes=cfg.DATASET.CLASS_NUMS).to(device)
criterion = get_loss(loss_name=cfg.OPTIM.LOSS, device=device)
optimizer = get_optimizer(model, optim_name=cfg.OPTIM.NAME, learn_rate=cfg.OPTIM.INIT_LR)

model.train()
images, labels, _ = next(iter(train_loader))
images, labels = images.to(device), labels.to(device)

logits = model(images)
loss = criterion(logits, labels)
top1, top2 = accuracy(logits, labels, topk=(1, 2))

optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()

print("loss:", float(loss.detach().cpu()))
print("top1:", float(top1.detach().cpu()), "top2:", float(top2.detach().cpu()))
