In [None]:
# tabs intentional (your preference)
import os, csv, math, json, random
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve, f1_score

# -------------------
# Paths
# -------------------
BASE_DIR = Path("/content/drive/MyDrive")  # <- change if needed
RESEARCH = BASE_DIR / "research"
GRID_DIR = RESEARCH / "keypoints" / "train"
LABEL_CSV = RESEARCH / "dataset" / "train_labels.csv"
OUT_DIR = RESEARCH / "runs" / "grid_cnn"
OUT_DIR.mkdir(parents=True, exist_ok=True)

SEED = 1337
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# -------------------
# Data utils
# -------------------
def read_labels(csv_fp):
	labels = {}
	with open(csv_fp, newline="", encoding="utf-8") as f:
		reader = csv.DictReader(f)
		for r in reader:
			tid = str(r["task_id"]).strip()
			lb = int(r["label"])
			labels[tid] = lb
	return labels

def find_samples(grid_root, labels):
	samples = []
	for tid, y in labels.items():
		fp = grid_root / tid / "grid_G32.npy"
		if fp.exists():
			samples.append((tid, fp, y))
	return samples

class GridDataset(Dataset):
	def __init__(self, samples, mean=None, std=None, add_coord=True):
		self.samples = samples
		self.add_coord = add_coord
		self._mean = mean
		self._std = std

	def __len__(self): return len(self.samples)

	def set_norm(self, mean, std):
		self._mean = mean; self._std = std

	def _load_grid(self, fp):
		arr = np.load(fp)
		# expect [C,32,32] or [32,32,C]
		if arr.ndim == 2:
			arr = arr[None, ...]  # [1,H,W]
		if arr.shape[0] == 32 and arr.shape[1] == 32:
			arr = np.transpose(arr, (2,0,1))  # -> [C,32,32]
		return arr.astype(np.float32)

	def __getitem__(self, i):
		tid, fp, y = self.samples[i]
		x = self._load_grid(fp)  # [C,32,32]
		C, H, W = x.shape

		if self.add_coord:
			ii = np.linspace(0, 1, H, dtype=np.float32)
			jj = np.linspace(0, 1, W, dtype=np.float32)
			yy, xx = np.meshgrid(ii, jj, indexing="ij")
			coord = np.stack([yy, xx], axis=0)  # [2,H,W]
			x = np.concatenate([x, coord], axis=0)
			C += 2

		if self._mean is not None and self._std is not None:
			# per-channel normalize
			x = (x - self._mean[:C, None, None]) / (self._std[:C, None, None] + 1e-6)

		return torch.from_numpy(x), torch.tensor([y], dtype=torch.float32), tid

def compute_channel_stats(dataset):
	# pass with no normalization
	sum_c = None; sumsq_c = None; n_pix = 0
	for k in range(len(dataset)):
		x, _, _ = dataset[k]
		x = x.numpy()
		C, H, W = x.shape
		if sum_c is None:
			sum_c = np.zeros(C, dtype=np.float64)
			sumsq_c = np.zeros(C, dtype=np.float64)
		sum_c += x.reshape(C, -1).sum(axis=1)
		sumsq_c += (x.reshape(C, -1) ** 2).sum(axis=1)
		n_pix += H * W
	mean = sum_c / n_pix
	var = np.maximum(sumsq_c / n_pix - mean**2, 0.0)
	std = np.sqrt(var)
	return mean.astype(np.float32), std.astype(np.float32)

# -------------------
# Model
# -------------------
class ConvBlock(nn.Module):
	def __init__(self, c_in, c_out, k=3, s=1, p=1):
		super().__init__()
		self.conv = nn.Conv2d(c_in, c_out, k, s, p, bias=False)
		self.bn = nn.BatchNorm2d(c_out)
		self.act = nn.ReLU(inplace=True)
	def forward(self, x): return self.act(self.bn(self.conv(x)))

class ResidBlock(nn.Module):
	def __init__(self, c, widen=1):
		super().__init__()
		mid = c * widen
		self.b1 = ConvBlock(c, mid)
		self.b2 = ConvBlock(mid, c)
	def forward(self, x): return x + self.b2(self.b1(x))

class JointGridCNN(nn.Module):
	def __init__(self, in_ch, base=48, p_drop=0.2):
		super().__init__()
		self.stem = nn.Sequential(
			ConvBlock(in_ch, base, k=5, s=2, p=2),   # 32->16
			ResidBlock(base),
			ConvBlock(base, base*2, k=3, s=2, p=1), # 16->8
			ResidBlock(base*2),
			ResidBlock(base*2)
		)
		self.se = nn.Sequential(
			nn.AdaptiveAvgPool2d(1),
			nn.Conv2d(base*2, base*2//4, 1), nn.ReLU(True),
			nn.Conv2d(base*2//4, base*2, 1), nn.Sigmoid()
		)
		self.head = nn.Sequential(
			nn.Dropout(p_drop),
			nn.Linear(base*2, 128), nn.ReLU(True),
			nn.Dropout(p_drop),
			nn.Linear(128, 1)  # logit
		)
	def forward(self, x):
		h = self.stem(x)
		h = h * self.se(h)
		h = F.adaptive_avg_pool2d(h, 1).flatten(1)
		return self.head(h)

# -------------------
# Train / Eval
# -------------------
def eval_metrics(y_true, y_score, threshold=None):
	roc = roc_auc_score(y_true, y_score)
	pr = average_precision_score(y_true, y_score)
	if threshold is None:
		# find best-F1 threshold on PR curve
		p, r, t = precision_recall_curve(y_true, y_score)
		f1s = (2 * p * r) / np.maximum(p + r, 1e-9)
		best_idx = int(np.nanargmax(f1s))
		best_thr = 0.5 if best_idx >= len(t) else t[best_idx]
	else:
		best_thr = threshold
	y_pred = (y_score >= best_thr).astype(np.int32)
	f1 = f1_score(y_true, y_pred)
	return {"roc_auc": roc, "pr_auc": pr, "f1": f1, "thr": float(best_thr)}

def train_one_epoch(model, loader, opt, loss_fn, dev):
	model.train()
	total = 0.0
	for x, y, _ in loader:
		x = x.to(dev, non_blocking=True)
		y = y.to(dev, non_blocking=True)
		opt.zero_grad(set_to_none=True)
		logit = model(x)
		loss = loss_fn(logit, y)
		loss.backward()
		opt.step()
		total += float(loss.item()) * x.size(0)
	return total / len(loader.dataset)

@torch.no_grad()
def predict(model, loader, dev):
	model.eval()
	all_y, all_s = [], []
	for x, y, _ in loader:
		x = x.to(dev, non_blocking=True)
		logit = model(x)
		all_s.append(torch.sigmoid(logit).cpu().numpy().ravel())
		all_y.append(y.cpu().numpy().ravel())
	return np.concatenate(all_y), np.concatenate(all_s)

# -------------------
# Orchestration
# -------------------
labels = read_labels(LABEL_CSV)
samples = find_samples(GRID_DIR, labels)
assert len(samples) > 0, "No matching grid files found."

# stratified split
tids = [s[0] for s in samples]
ys = [s[2] for s in samples]
train_ids, val_ids = train_test_split(
	np.arange(len(samples)), test_size=0.2, random_state=SEED, stratify=ys
)

train_samples = [samples[i] for i in train_ids]
val_samples   = [samples[i] for i in val_ids]

# Optional: split val into tune/eval if you want to pick threshold on a separate split
USE_TWO_VALS = False
if USE_TWO_VALS:
	val_tune_idx, val_hold_idx = train_test_split(
		np.arange(len(val_samples)), test_size=0.5, random_state=SEED, stratify=[s[2] for s in val_samples]
	)
	val_tune = [val_samples[i] for i in val_tune_idx]
	val_hold = [val_samples[i] for i in val_hold_idx]
else:
	val_tune, val_hold = val_samples, None

# build datasets
ds_train = GridDataset(train_samples, mean=None, std=None, add_coord=True)
mean, std = compute_channel_stats(ds_train)
ds_train.set_norm(mean, std)
ds_val_tune = GridDataset(val_tune, mean=mean, std=std, add_coord=True)
dl_train = DataLoader(ds_train, batch_size=128, shuffle=True, num_workers=2, pin_memory=True, persistent_workers=True)
dl_val_tune = DataLoader(ds_val_tune, batch_size=256, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=True)

if val_hold is not None:
	ds_val_hold = GridDataset(val_hold, mean=mean, std=std, add_coord=True)
	dl_val_hold = DataLoader(ds_val_hold, batch_size=256, shuffle=False, num_workers=2, pin_memory=True, persistent_workers=True)

# model
# infer input channels from one sample
x0, _, _ = ds_train[0]
in_ch = x0.shape[0]
dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = JointGridCNN(in_ch=in_ch, base=48, p_drop=0.2).to(dev)

# class imbalance handling
pos_ratio = float(np.mean([s[2] for s in train_samples]))
pos_weight = torch.tensor([(1.0 - pos_ratio) / max(pos_ratio, 1e-6)], device=dev)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)

best_pr = -1.0
best_state = None
EPOCHS = 10

for epoch in range(1, EPOCHS+1):
	train_loss = train_one_epoch(model, dl_train, opt, loss_fn, dev)
	yv, sv = predict(model, dl_val_tune, dev)
	metrics = eval_metrics(yv, sv)  # tunes threshold on val_tune
	sched.step()

	print(f"epoch {epoch:02d} | loss {train_loss:.4f} | roc {metrics['roc_auc']:.4f} | pr {metrics['pr_auc']:.4f} | f1 {metrics['f1']:.4f} | thr {metrics['thr']:.3f}")

	if metrics["pr_auc"] > best_pr:
		best_pr = metrics["pr_auc"]
		best_state = {
			"epoch": epoch,
			"model": model.state_dict(),
			"mean": mean, "std": std,
			"threshold": metrics["thr"],
			"in_ch": in_ch
		}

# final eval
if val_hold is not None:
	model.load_state_dict(best_state["model"])
	yh, sh = predict(model, dl_val_hold, dev)
	final_m = eval_metrics(yh, sh, threshold=best_state["threshold"])
	print(f"[HOLDOUT] roc {final_m['roc_auc']:.4f} | pr {final_m['pr_auc']:.4f} | f1 {final_m['f1']:.4f} | thr {final_m['thr']:.3f}")

# save checkpoint
ckpt_fp = OUT_DIR / "jointgridcnn_best.pth"
torch.save(best_state, ckpt_fp)
with open(OUT_DIR / "meta.json", "w") as f:
	json.dump({
		"best_epoch": best_state["epoch"],
		"threshold": best_state["threshold"],
		"train_size": len(train_samples),
		"val_size": len(val_samples),
		"use_two_vals": USE_TWO_VALS
	}, f, indent=2)

print(f"Saved: {ckpt_fp}")