## Preprocess


In [None]:
# preprocess.py
import os
import numpy as np
import pandas as pd

# Backend CSV contract columns (Project A) :contentReference[oaicite:5]{index=5}
_PATH_COLS = ["image_path", "filepath", "image", "path", "file_name"]
_LAT_COLS = ["Latitude", "latitude", "lat"]
_LON_COLS = ["Longitude", "longitude", "lon"]


def _find_col(df_cols, candidates):
    cols = list(df_cols)
    cols_set = set(cols)
    for c in candidates:
        if c in cols_set:
            return c
    # fallback: case-insensitive
    lower = {c.lower(): c for c in cols}
    for c in candidates:
        if c.lower() in lower:
            return lower[c.lower()]
    return None


def _prepare_from_csv(csv_path: str):
    df = pd.read_csv(csv_path)

    path_col = _find_col(df.columns, _PATH_COLS)
    lat_col = _find_col(df.columns, _LAT_COLS)
    lon_col = _find_col(df.columns, _LON_COLS)

    if path_col is None or lat_col is None or lon_col is None:
        raise KeyError(
            f"CSV missing required columns. Found: {list(df.columns)}. "
            f"Need path in {_PATH_COLS}, lat in {_LAT_COLS}, lon in {_LON_COLS}."
        )

    base_dir = os.path.dirname(os.path.abspath(csv_path))

    paths = []
    for p in df[path_col].astype(str).tolist():
        p = p.strip()
        if not os.path.isabs(p):
            p = os.path.join(base_dir, p)
        paths.append(p)

    y = df[[lat_col, lon_col]].astype(np.float32).values  # raw degrees
    return paths, y


def _parse_hf_spec(spec: str):
    """
    Supported:
      hf://dataset_name
      hf://dataset_name:split
    """
    s = spec.strip()
    if s.startswith("hf://"):
        s = s[len("hf://") :]
    elif s.startswith("hf:"):
        s = s[len("hf:") :]

    split = "train"
    if ":" in s:
        name, split = s.split(":", 1)
        name, split = name.strip(), split.strip()
    else:
        name = s.strip()

    if not name:
        raise ValueError(f"Bad HuggingFace spec: {spec}")
    return name, split


def _prepare_from_hf(spec: str):
    # Lazy import: backend env may not include `datasets` :contentReference[oaicite:6]{index=6}
    try:
        from datasets import load_dataset
    except ImportError as e:
        raise ImportError(
            "HF loading requested but `datasets` is not installed. "
            "Install locally via `pip install datasets`. "
            "For submission, pass a real CSV path."
        ) from e

    ds_name, split = _parse_hf_spec(spec)
    ds = load_dataset(ds_name, split=split)

    # Your dataset schema is: image, latitude, longitude :contentReference[oaicite:7]{index=7}
    if not all(k in ds.column_names for k in ["image", "latitude", "longitude"]):
        raise KeyError(f"Unexpected columns: {ds.column_names}. Expected image/latitude/longitude.")

    X = ds["image"]  # typically PIL Images via HF datasets Image feature
    y = np.stack(
        [
            np.asarray(ds["latitude"], dtype=np.float32),
            np.asarray(ds["longitude"], dtype=np.float32),
        ],
        axis=1,
    )
    return X, y


def prepare_data(csv_path: str):
    """
    Submission contract: csv_path is a CSV on disk :contentReference[oaicite:8]{index=8}
    Local convenience: supports HF spec hf://...:split
    """
    if os.path.isfile(csv_path) and csv_path.lower().endswith(".csv"):
        return _prepare_from_csv(csv_path)

    if csv_path.startswith("hf://") or csv_path.startswith("hf:"):
        return _prepare_from_hf(csv_path)

    # default: treat as CSV path
    return _prepare_from_csv(csv_path)


## Model

In [None]:
# model.py
import cv2
import math
from typing import Tuple, Any
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.models import MobileNet_V3_Small_Weights

# -----------------------------
# HARD-CODED BOUNDS
# -----------------------------
LAT_MIN = 39.95009994506836
LAT_MAX = 39.9530029296875
LON_MIN = -75.1928939819336
LON_MAX = -75.18990325927734

# Recommended grid for this campus extent:
# ~480 classes; ~13.5m x ~12.8m per cell (roughly)
GRID_H = 24
GRID_W = 20

SOFTMAX_TEMP = 0.85

INPUT_SIZE = 224
RESIZE_SHORT = 256

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

_MEAN_T = torch.tensor(IMAGENET_MEAN, dtype=torch.float32).view(3, 1, 1)
_STD_T = torch.tensor(IMAGENET_STD, dtype=torch.float32).view(3, 1, 1)


class GeM(nn.Module):
    def __init__(self, p: float = 3.0, eps: float = 1e-6, learn_p: bool = True):
        super().__init__()
        if learn_p:
            self.p = nn.Parameter(torch.tensor([p], dtype=torch.float32))
        else:
            self.register_buffer("p", torch.tensor([p], dtype=torch.float32))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # DEBUG: fix NaN
        # Clamp p to reasonable range
        p = self.p.clamp(min=1.0, max=6.0)

        # FIX: Use ReLU + eps to ensure positive values before pow
        # This is the key fix - negative values from batch norm cause NaN
        x = F.relu(x) + self.eps

        # Apply generalized mean
        x = x.pow(p)
        x = F.adaptive_avg_pool2d(x, output_size=1)
        x = x.pow(1.0 / p)

        return x


def _build_mobilenet_v3_small():
    return torchvision.models.mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT)


def _deg2rad(x: torch.Tensor) -> torch.Tensor:
    return x * (math.pi / 180.0)


def haversine_m(pred_deg: torch.Tensor, tgt_deg: torch.Tensor) -> torch.Tensor:
    R = 6_371_000.0

    lat1 = _deg2rad(pred_deg[:, 0])
    lon1 = _deg2rad(pred_deg[:, 1])
    lat2 = _deg2rad(tgt_deg[:, 0])
    lon2 = _deg2rad(tgt_deg[:, 1])

    dlat = lat2 - lat1
    dlon = lon2 - lon1

    a = torch.sin(dlat * 0.5) ** 2 + torch.cos(lat1) * torch.cos(lat2) * (torch.sin(dlon * 0.5) ** 2)
    a = a.clamp(0.0, 1.0)

    c = 2.0 * torch.atan2(torch.sqrt(a), torch.sqrt((1.0 - a).clamp(0.0, 1.0)))
    return R * c


class Model(nn.Module):
    """
    Submission entry-point requirements:
      - instantiable without args
      - backend calls predict(batch) if available else forward(batch)
      - outputs [lat, lon] in raw degrees (not normalized) :contentReference[oaicite:10]{index=10}
    """

    def __init__(self):
        super().__init__()
        self.lat_min = float(LAT_MIN)
        self.lat_max = float(LAT_MAX)
        self.lon_min = float(LON_MIN)
        self.lon_max = float(LON_MAX)

        self.grid_h = int(GRID_H)
        self.grid_w = int(GRID_W)
        self.num_cells = self.grid_h * self.grid_w

        dlat = (self.lat_max - self.lat_min) / float(self.grid_h)
        dlon = (self.lon_max - self.lon_min) / float(self.grid_w)

        lat_edges = torch.linspace(self.lat_min, self.lat_max, self.grid_h + 1, dtype=torch.float32)
        lon_edges = torch.linspace(self.lon_min, self.lon_max, self.grid_w + 1, dtype=torch.float32)
        lat_centers = (lat_edges[:-1] + lat_edges[1:]) * 0.5
        lon_centers = (lon_edges[:-1] + lon_edges[1:]) * 0.5

        yy, xx = torch.meshgrid(lat_centers, lon_centers, indexing="ij")
        centers = torch.stack([yy.reshape(-1), xx.reshape(-1)], dim=-1)  # [K,2]

        self.register_buffer("cell_centers_deg", centers)
        self.register_buffer("cell_size_deg", torch.tensor([dlat, dlon], dtype=torch.float32))

        base = _build_mobilenet_v3_small()
        self.backbone_features = base.features
        self.gem = GeM(p=3.0, learn_p=True)

        with torch.no_grad():
            dummy = torch.zeros(1, 3, INPUT_SIZE, INPUT_SIZE)
            feat = self.backbone_features(dummy)
            pooled = self.gem(feat).flatten(1)
            feat_dim = pooled.shape[1]

        self.dropout = nn.Dropout(p=0.15)
        self.cell_head = nn.Linear(feat_dim, self.num_cells)
        self.offset_head = nn.Linear(feat_dim, self.num_cells * 2)
        self.softmax_temp = float(SOFTMAX_TEMP)

    # -------- robust image decoding (paths, PIL, HF dicts with bytes/path) --------
    def _normalize_np_image(self, arr: np.ndarray) -> np.ndarray:
        if arr.ndim == 2:
            arr = np.stack([arr, arr, arr], axis=-1)
        if arr.ndim == 3 and arr.shape[2] == 4:
            arr = arr[:, :, :3]
        if arr.ndim != 3 or arr.shape[2] != 3:
            return np.zeros((INPUT_SIZE, INPUT_SIZE, 3), dtype=np.uint8)

        if arr.dtype != np.uint8:
            amax = float(np.max(arr)) if arr.size else 1.0
            if amax <= 1.5:
                arr = np.clip(arr * 255.0, 0, 255).astype(np.uint8)
            else:
                arr = np.clip(arr, 0, 255).astype(np.uint8)
        return arr

    def _to_rgb_uint8(self, item: Any) -> np.ndarray:
        if isinstance(item, str):
            img = cv2.imread(item, cv2.IMREAD_COLOR)
            if img is None:
                return np.zeros((INPUT_SIZE, INPUT_SIZE, 3), dtype=np.uint8)
            return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        if isinstance(item, dict):
            if "path" in item and item["path"]:
                img = cv2.imread(str(item["path"]), cv2.IMREAD_COLOR)
                if img is None:
                    return np.zeros((INPUT_SIZE, INPUT_SIZE, 3), dtype=np.uint8)
                return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

            if "bytes" in item and item["bytes"] is not None:
                buf = np.frombuffer(item["bytes"], dtype=np.uint8)
                img = cv2.imdecode(buf, cv2.IMREAD_COLOR)
                if img is None:
                    return np.zeros((INPUT_SIZE, INPUT_SIZE, 3), dtype=np.uint8)
                return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

            if "array" in item and item["array"] is not None:
                return self._normalize_np_image(np.asarray(item["array"]))

        try:
            arr = np.asarray(item)  # PIL image -> np array
            return self._normalize_np_image(arr)
        except Exception:
            return np.zeros((INPUT_SIZE, INPUT_SIZE, 3), dtype=np.uint8)

    def _preprocess_rgb_uint8(self, img: np.ndarray) -> torch.Tensor:
        h, w = img.shape[:2]
        scale = RESIZE_SHORT / float(min(h, w) + 1e-6)
        new_w = int(round(w * scale))
        new_h = int(round(h * scale))
        img = cv2.resize(
            img, (new_w, new_h),
            interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LINEAR
        )

        h2, w2 = img.shape[:2]
        top = max(0, (h2 - INPUT_SIZE) // 2)
        left = max(0, (w2 - INPUT_SIZE) // 2)
        img = img[top:top + INPUT_SIZE, left:left + INPUT_SIZE, :]

        if img.shape[0] != INPUT_SIZE or img.shape[1] != INPUT_SIZE:
            pad_h = INPUT_SIZE - img.shape[0]
            pad_w = INPUT_SIZE - img.shape[1]
            img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, borderType=cv2.BORDER_CONSTANT, value=0)

        x = torch.from_numpy(img).to(torch.float32) / 255.0
        x = x.permute(2, 0, 1).contiguous()
        x = (x - _MEAN_T) / _STD_T
        return x

    def _load_and_preprocess_one(self, item: Any) -> torch.Tensor:
        img = self._to_rgb_uint8(item)
        return self._preprocess_rgb_uint8(img)

    def _batch_to_tensor(self, batch) -> torch.Tensor:
        if torch.is_tensor(batch):
            return batch

        if isinstance(batch, np.ndarray):
            if batch.ndim == 4 and batch.shape[-1] == 3:
                return torch.from_numpy(batch).permute(0, 3, 1, 2).contiguous()
            if batch.dtype == object or batch.dtype.type is np.str_:
                batch = batch.tolist()

        if isinstance(batch, (list, tuple)):
            if len(batch) == 0:
                return torch.empty(0, 3, INPUT_SIZE, INPUT_SIZE)
            if torch.is_tensor(batch[0]):
                return torch.stack(list(batch), dim=0)
            xs = [self._load_and_preprocess_one(b) for b in batch]
            return torch.stack(xs, dim=0)

        return torch.stack([self._load_and_preprocess_one(batch)], dim=0)

    # -------- model core --------
    def forward_raw(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        feat = self.backbone_features(x)
        feat = self.gem(feat).flatten(1)
        feat = self.dropout(feat)
        logits = self.cell_head(feat)
        offsets = self.offset_head(feat)
        return logits, offsets

    def decode_latlon(self, logits: torch.Tensor, offsets_raw: torch.Tensor) -> torch.Tensor:
        B, K = logits.shape
        p = F.softmax(logits / self.softmax_temp, dim=-1)

        off = torch.tanh(offsets_raw).view(B, K, 2) * 0.5  # cell units in [-0.5,0.5]
        cand = self.cell_centers_deg.view(1, K, 2) + off * self.cell_size_deg.view(1, 1, 2)
        pred = torch.sum(p.unsqueeze(-1) * cand, dim=1)

        pred_lat = pred[:, 0].clamp(self.lat_min, self.lat_max)
        pred_lon = pred[:, 1].clamp(self.lon_min, self.lon_max)
        return torch.stack([pred_lat, pred_lon], dim=-1)

    def forward(self, batch):
        x = self._batch_to_tensor(batch)
        device = next(self.parameters()).device
        x = x.to(device, non_blocking=True)
        logits, offsets = self.forward_raw(x)
        return self.decode_latlon(logits, offsets)

    def predict(self, batch) -> np.ndarray:
        self.eval()
        with torch.inference_mode():
            pred = self.forward(batch)
        return pred.detach().cpu().numpy().astype(np.float32)

    # -------- training helpers --------
    def encode_cell_and_offset(self, tgt_latlon_deg: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        lat = tgt_latlon_deg[:, 0]
        lon = tgt_latlon_deg[:, 1]
        dlat = self.cell_size_deg[0]
        dlon = self.cell_size_deg[1]

        i = torch.floor((lat - self.lat_min) / dlat).to(torch.long).clamp(0, self.grid_h - 1)
        j = torch.floor((lon - self.lon_min) / dlon).to(torch.long).clamp(0, self.grid_w - 1)
        cell_id = i * self.grid_w + j

        centers = self.cell_centers_deg[cell_id]
        off_lat = (lat - centers[:, 0]) / (dlat + 1e-12)
        off_lon = (lon - centers[:, 1]) / (dlon + 1e-12)
        off = torch.stack([off_lat, off_lon], dim=-1).clamp(-0.5, 0.5)
        return cell_id, off

    def compute_loss(
        self,
        logits: torch.Tensor,
        offsets_raw: torch.Tensor,
        tgt_latlon_deg: torch.Tensor,
        w_ce: float = 1.0,
        w_off: float = 2.0,
        w_geo: float = 1.0,
        label_smoothing: float = 0.05,
        huber_beta_m: float = 10.0,
        add_mixture_geo: bool = True,
        w_mix_geo: float = 0.3,
    ):
        B, K = logits.shape
        device = logits.device

        cell_id, off_tgt = self.encode_cell_and_offset(tgt_latlon_deg)

        ce = F.cross_entropy(logits, cell_id, label_smoothing=label_smoothing)

        off_all = torch.tanh(offsets_raw).view(B, K, 2) * 0.5
        idx = torch.arange(B, device=device)
        off_pred = off_all[idx, cell_id]
        off = F.smooth_l1_loss(off_pred, off_tgt, beta=0.05)

        # Always used by off-loss, so keep this in graph
        pred_true = self.cell_centers_deg[cell_id] + off_pred * self.cell_size_deg.view(1, 2)

        loss = (w_ce * ce) + (w_off * off)

        # ---- GEO term: only compute WITH grad if weight != 0 ----
        geo = torch.zeros((), device=device)
        if w_geo != 0.0:
            dist_m = haversine_m(pred_true, tgt_latlon_deg)
            geo = F.smooth_l1_loss(dist_m, torch.zeros_like(dist_m), beta=huber_beta_m)
            loss = loss + (w_geo * geo)
            dist_for_stats = dist_m
        else:
            with torch.no_grad():
                dist_for_stats = haversine_m(pred_true, tgt_latlon_deg)

        # ---- MIX term: only compute WITH grad if weight != 0 AND enabled ----
        mix_geo = torch.zeros((), device=device)
        if add_mixture_geo and (w_mix_geo != 0.0):
            pred_mix = self.decode_latlon(logits, offsets_raw)
            dist_mix = haversine_m(pred_mix, tgt_latlon_deg)
            mix_geo = F.smooth_l1_loss(dist_mix, torch.zeros_like(dist_mix), beta=huber_beta_m)
            loss = loss + (w_mix_geo * mix_geo)

        stats = {
            "loss": float(loss.detach().cpu()),
            "ce": float(ce.detach().cpu()),
            "off": float(off.detach().cpu()),
            "geo": float(geo.detach().cpu()),
            "mix_geo": float(mix_geo.detach().cpu()),
            "dist_m_mean_truecell": float(dist_for_stats.mean().detach().cpu()),
        }
        return loss, stats


def get_model():
    return Model()

sanity test

In [None]:
p = torch.tensor([[39.95, -75.19]], dtype=torch.float32)
print(haversine_m(p, p))  # should be tensor([0.])

tensor([0.])


# Training section

In [14]:
# Training Cell for HuggingFace Dataset - FIXED VERSION
# Dataset: aaron-jiang/penncampus_image2gps_merged
#
# FIXES:
# 1. Validation loop now unpacks 3 values (batch_x, batch_y, batch_idx) to match Dataset
# 2. Added numerical stability to prevent NaN issues
# 4. Added anomaly detection for debugging

import copy
import torch
import random
import numpy as np
import torch.nn as nn
from tqdm.auto import tqdm
import torch.optim as optim
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader

# ============== Configuration ==============
DATASET_NAME = "tianyi-in-the-bush/penncampus_image2gps"
BATCH_SIZE = 64
NUM_EPOCHS = 5
LEARNING_RATE = 1e-4  # Reduced from 3e-4 for stability
WEIGHT_DECAY = 1e-4
VAL_SPLIT = 0.1
SEED = 42
NUM_WORKERS = 0
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
USE_AMP = True # Use Mixed Precision

# Loss weights
W_CE = 1.0
W_OFF = 2.0
W_GEO = 0.0
W_MIX_GEO = 0.0
LABEL_SMOOTHING = 0.05

print(f"Using device: {DEVICE}")

# ============== Set Seeds ==============
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(SEED)

# ============== Dataset Class ==============
class GeoDataset(Dataset):
    def __init__(self, hf_dataset, model):
        self.dataset = hf_dataset
        self.model = model

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]

        # Get image and preprocess using model's method
        img = item["image"]
        x = self.model._load_and_preprocess_one(img)

        # Get coordinates
        lat = float(item["latitude"])
        lon = float(item["longitude"])
        y = torch.tensor([lat, lon], dtype=torch.float32)

        return x, y, idx  # Returns 3 values!


# ============== Load Dataset ==============
print(f"Loading dataset: {DATASET_NAME}")
train_hf = load_dataset(DATASET_NAME, split="train")
val_hf = load_dataset(DATASET_NAME, split="test")
print(f"Train samples: {len(train_hf)}, Val samples: {len(val_hf)}")

# ============== Initialize Model ==============
model = get_model()
model = model.to(DEVICE)

# Create datasets
train_dataset = GeoDataset(train_hf, model)
val_dataset = GeoDataset(val_hf, model)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    drop_last=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

# ============== Optimizer & Scheduler ==============
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

# Add warmup scheduler for more stable training
def get_lr_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps):
    """Linear warmup then cosine decay"""
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + np.cos(np.pi * progress)))
    return optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

num_training_steps = NUM_EPOCHS * len(train_loader)
num_warmup_steps = len(train_loader)  # 1 epoch warmup
scheduler = get_lr_scheduler_with_warmup(optimizer, num_warmup_steps, num_training_steps)


# ============== Training Functions ==============
def train_one_epoch(model, loader, optimizer, scheduler, device):
    model.train()
    total_loss = 0.0
    total_dist = 0.0
    num_batches = 0

    pbar = tqdm(loader, desc="Training")

    for batch_x, batch_y, batch_idx in pbar:  # Unpack 3 values
        # Check for non-finite targets
        if not torch.isfinite(batch_y).all():
            print("Non-finite targets at:", batch_idx[~torch.isfinite(batch_y).all(dim=1)])
            print(batch_y[~torch.isfinite(batch_y).all(dim=1)])
            continue  # Skip this batch instead of breaking

        batch_x = batch_x.to(device, non_blocking=True)
        batch_y = batch_y.to(device, non_blocking=True)

        optimizer.zero_grad()

        # Forward pass
        logits, offsets = model.forward_raw(batch_x)

        # Check for non-finite outputs
        if not torch.isfinite(logits).all() or not torch.isfinite(offsets).all():
            print(f"Non-finite model outputs at batch {num_batches}")
            print(f"  Logits has NaN: {torch.isnan(logits).any()}, Inf: {torch.isinf(logits).any()}")
            print(f"  Offsets has NaN: {torch.isnan(offsets).any()}, Inf: {torch.isinf(offsets).any()}")
            continue  # Skip this batch

        # Compute loss
        loss, stats = model.compute_loss(
            logits, offsets, batch_y,
            w_ce=W_CE, w_off=W_OFF, w_geo=W_GEO,
            label_smoothing=LABEL_SMOOTHING,
            add_mixture_geo=(W_MIX_GEO != 0.0),  # DEBUG: fix NaN
            w_mix_geo=W_MIX_GEO
        )

        # Check for non-finite loss
        if not torch.isfinite(loss):
            print(f"Non-finite loss at batch {num_batches}")
            print(stats)
            continue  # Skip this batch

        # Backward pass
        loss.backward()
        # DEBUG: fix NaN
        bad = False
        for n, p in model.named_parameters():
            if p.grad is not None and not torch.isfinite(p.grad).all():
                g = p.grad
                bad = True
                mask = ~torch.isfinite(g)
                print("Non-finite grad:", n,
                    "count:", int(mask.sum().item()), "/", g.numel(),
                    "example:", g[mask][:5])
        if bad:
            raise RuntimeError("Stopping: non-finite gradients")

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        scheduler.step()

        total_loss += stats["loss"]
        total_dist += stats["dist_m_mean_truecell"]
        num_batches += 1

        pbar.set_postfix({
            "loss": f"{stats['loss']:.4f}",
            "dist_m": f"{stats['dist_m_mean_truecell']:.2f}",
            "lr": f"{scheduler.get_last_lr()[0]:.2e}"
        })

    if num_batches == 0:
        return float('inf'), float('inf')
    return total_loss / num_batches, total_dist / num_batches


@torch.no_grad()
def validate(model, loader, device):
    model.eval()
    total_loss = 0.0
    total_dist = 0.0
    all_dists = []
    num_batches = 0

    pbar = tqdm(loader, desc="Validation")

    # FIX: Unpack 3 values to match Dataset.__getitem__ return
    for batch_x, batch_y, batch_idx in pbar:
        batch_x = batch_x.to(device, non_blocking=True)
        batch_y = batch_y.to(device, non_blocking=True)

        logits, offsets = model.forward_raw(batch_x)

        # Skip if non-finite outputs
        if not torch.isfinite(logits).all() or not torch.isfinite(offsets).all():
            continue

        loss, stats = model.compute_loss(
            logits, offsets, batch_y,
            w_ce=W_CE, w_off=W_OFF, w_geo=W_GEO,
            label_smoothing=LABEL_SMOOTHING,
            add_mixture_geo=True, w_mix_geo=W_MIX_GEO
        )

        # Compute actual prediction distances
        pred = model.decode_latlon(logits, offsets)
        dist = haversine_m(pred, batch_y)
        all_dists.append(dist.cpu())

        total_loss += stats["loss"]
        total_dist += stats["dist_m_mean_truecell"]
        num_batches += 1

    if num_batches == 0:
        return float('inf'), float('inf'), float('inf')

    all_dists = torch.cat(all_dists)
    median_dist = float(all_dists.median())
    mean_dist = float(all_dists.mean())

    return total_loss / num_batches, mean_dist, median_dist


# ============== Training Loop ==============
best_val_dist = float("inf")
best_model_state = None

print("\n" + "="*50)
print("Starting Training")
print("="*50)

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 30)

    # Train
    train_loss, train_dist = train_one_epoch(model, train_loader, optimizer, scheduler, DEVICE)

    # Validate
    val_loss, val_mean_dist, val_median_dist = validate(model, val_loader, DEVICE)

    current_lr = scheduler.get_last_lr()[0]

    print(f"Train Loss: {train_loss:.4f} | Train Dist: {train_dist:.2f}m")
    print(f"Val Loss: {val_loss:.4f} | Val Mean Dist: {val_mean_dist:.2f}m | Val Median Dist: {val_median_dist:.2f}m")
    print(f"Learning Rate: {current_lr:.6f}")

    # Save best model
    if val_mean_dist < best_val_dist:
        best_val_dist = val_mean_dist
        best_model_state = copy.deepcopy(model.state_dict())
        print(f"*** New best model! Val Mean Dist: {best_val_dist:.2f}m ***")

# ============== Load Best Model & Save ==============
print("\n" + "="*50)
print(f"Training Complete! Best Val Mean Distance: {best_val_dist:.2f}m")
print("="*50)

# Load best weights
if best_model_state is not None:
    model.load_state_dict(best_model_state)

# Save model
SAVE_PATH = "mobilenet_v3_grid_best.pth"
torch.save({
    "model_state_dict": model.state_dict(),
    "best_val_dist": best_val_dist,
    "config": {
        "grid_h": GRID_H,
        "grid_w": GRID_W,
        "lat_min": LAT_MIN,
        "lat_max": LAT_MAX,
        "lon_min": LON_MIN,
        "lon_max": LON_MAX,
    }
}, SAVE_PATH)
print(f"Model saved to {SAVE_PATH}")

Using device: cuda
Loading dataset: tianyi-in-the-bush/penncampus_image2gps


README.md:   0%|          | 0.00/497 [00:00<?, ?B/s]

data/train-00000-of-00002.parquet:   0%|          | 0.00/255M [00:00<?, ?B/s]

data/train-00001-of-00002.parquet:   0%|          | 0.00/250M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/128M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1013 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/254 [00:00<?, ? examples/s]

Train samples: 1013, Val samples: 254

Starting Training

Epoch 1/5
------------------------------


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/4 [00:00<?, ?it/s]

Train Loss: 6.9012 | Train Dist: 6.68m
Val Loss: 5.9378 | Val Mean Dist: 109.66m | Val Median Dist: 116.71m
Learning Rate: 0.000100
*** New best model! Val Mean Dist: 109.66m ***

Epoch 2/5
------------------------------


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/4 [00:00<?, ?it/s]

Train Loss: 5.0143 | Train Dist: 4.31m
Val Loss: 4.3588 | Val Mean Dist: 99.53m | Val Median Dist: 102.09m
Learning Rate: 0.000085
*** New best model! Val Mean Dist: 99.53m ***

Epoch 3/5
------------------------------


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/4 [00:00<?, ?it/s]

Train Loss: 3.8982 | Train Dist: 3.31m
Val Loss: 3.8720 | Val Mean Dist: 94.42m | Val Median Dist: 93.39m
Learning Rate: 0.000050
*** New best model! Val Mean Dist: 94.42m ***

Epoch 4/5
------------------------------


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/4 [00:00<?, ?it/s]

Train Loss: 3.5260 | Train Dist: 2.97m
Val Loss: 3.7055 | Val Mean Dist: 92.52m | Val Median Dist: 92.17m
Learning Rate: 0.000015
*** New best model! Val Mean Dist: 92.52m ***

Epoch 5/5
------------------------------


Training:   0%|          | 0/15 [00:00<?, ?it/s]

Validation:   0%|          | 0/4 [00:00<?, ?it/s]

Train Loss: 3.4054 | Train Dist: 2.97m
Val Loss: 3.6438 | Val Mean Dist: 91.71m | Val Median Dist: 90.91m
Learning Rate: 0.000000
*** New best model! Val Mean Dist: 91.71m ***

Training Complete! Best Val Mean Distance: 91.71m
Model saved to mobilenet_v3_grid_best.pth


# Testing section