In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [81]:
import os
import glob
import pandas as pd
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torchvision.transforms as transforms
import timm
from sklearn.cluster import KMeans
from sklearn.metrics import mean_squared_error, pairwise_distances
from tqdm import tqdm # Use standard tqdm
import gc
import time
import collections
import warnings

# --- Configuration ---
BASE_DIR = "/kaggle/input/"
TRAIN_CSV_PATH = os.path.join(BASE_DIR, "filetered/filtered_labels_train.csv")
VAL_CSV_PATH = os.path.join(BASE_DIR, "iiith-campus/labels_val.csv")
TEST_IMG_DIR = "/kaggle/input/images-test/images_test" # Default test image path for Kaggle
TRAIN_IMG_DIR = os.path.join(BASE_DIR, "iiith-campus/images_train/images_train")
VAL_IMG_DIR = os.path.join(BASE_DIR, "iiith-campus/images_val/images_val")
REGION_CLASSIFIER_PATH = os.path.join(BASE_DIR, "region_classifier/pytorch/default/1/final_campus_region_model.pth")
CACHE_DIR = "/kaggle/working/"
TRAIN_CACHE_FILE = os.path.join(CACHE_DIR, "train_features_cache.npz")
VAL_CACHE_FILE = os.path.join(CACHE_DIR, "val_features_cache.npz")
CONVNEXT_MODEL_NAME = 'convnextv2_tiny.fcmae_ft_in1k'
VIT_MODEL_NAME = 'vit_large_patch16_224.augreg_in21k_ft_in1k'
NUM_REGIONS = 15
IMG_SIZE = 224
BATCH_SIZE = 64
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
N_ANCHORS_PER_REGION = 10
MLP_HIDDEN_DIM = 256
MLP_EPOCHS_PER_REGION = 500
MLP_LEARNING_RATE = 5e-4
MLP_WEIGHT_DECAY = 1e-5
KMEANS_N_INIT = 5
VAL_IGNORE_INDICES = {95, 145, 146, 158, 159, 160, 161}
SEED = 42

# --- Seeding and Print Config ---
print(f"Using device: {DEVICE}")
print(f"Number of Anchor Points per Region: {N_ANCHORS_PER_REGION}")
print(f"Ignoring validation indices for loss calculation: {sorted(list(VAL_IGNORE_INDICES))}")
torch.manual_seed(SEED); np.random.seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed(SEED); torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False

# --- Data Loading, Transforms, Collate ---
class CampusDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform=None, target_size=(IMG_SIZE, IMG_SIZE)):
        try: self.df = pd.read_csv(csv_path); self.df['original_index'] = self.df.index
        except FileNotFoundError: print(f"FATAL ERROR: CSV file not found at {csv_path}"); raise
        self.img_dir = img_dir; self.transform = transform; self.target_size = target_size
        if 'Region_ID' not in self.df.columns: print(f"FATAL ERROR: 'Region_ID' column not found in {csv_path}"); raise KeyError
        self.df['region_id_0idx'] = self.df['Region_ID'] - 1
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        if idx >= len(self.df): raise IndexError("Index out of bounds")
        row = self.df.iloc[idx]; img_name = row['filename']; original_idx = row['original_index']
        img_path = os.path.join(self.img_dir, img_name)
        try:
            if not os.path.exists(img_path): return None
            image = Image.open(img_path).convert('RGB')
            if image is None: return None
        except Exception: return None
        if self.transform:
            try: image = self.transform(image); # Basic check removed for brevity assert image.shape[1:] == self.target_size
            except Exception: return None
        latitude = row['latitude']; longitude = row['longitude']; region_id_0idx = row['region_id_0idx']
        return image, latitude, longitude, region_id_0idx, img_name, original_idx

def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch));
    if not batch: return None
    images, lats, lons, rids, fnames, oidxs = zip(*batch)
    return torch.stack(images, 0), torch.tensor(lats, dtype=torch.float32), torch.tensor(lons, dtype=torch.float32), \
           torch.tensor(rids, dtype=torch.long), fnames, torch.tensor(oidxs, dtype=torch.long)

try: # Transform definition
    dummy_convnext = timm.create_model(CONVNEXT_MODEL_NAME, pretrained=False)
    data_config = timm.data.resolve_model_data_config(dummy_convnext)
    data_config['input_size'] = (3, IMG_SIZE, IMG_SIZE)
    transform = timm.data.create_transform(**data_config, is_training=False)
    del dummy_convnext
except Exception as e:
    print(f"Error creating transforms: {e}. Using basic fallback.")
    transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

train_dataset = CampusDataset(TRAIN_CSV_PATH, TRAIN_IMG_DIR, transform=transform)
val_dataset = CampusDataset(VAL_CSV_PATH, VAL_IMG_DIR, transform=transform)
train_img_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, collate_fn=collate_fn)
val_img_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True, collate_fn=collate_fn)

# --- Load Base Models ---
print(f"\nLoading Base Models...")
try:
    region_classifier = timm.create_model(CONVNEXT_MODEL_NAME, pretrained=False, num_classes=NUM_REGIONS);
    checkpoint = torch.load(REGION_CLASSIFIER_PATH, map_location='cpu')
    if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict = checkpoint['state_dict']
    elif isinstance(checkpoint, dict) and 'model' in checkpoint: state_dict = checkpoint['model']
    else: state_dict = checkpoint
    state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    try: region_classifier.load_state_dict(state_dict, strict=True)
    except RuntimeError: region_classifier.load_state_dict(state_dict, strict=False)
    region_classifier = region_classifier.to(DEVICE); region_classifier.eval()

    if hasattr(region_classifier, 'head') and hasattr(region_classifier.head, 'fc'): convnext_feature_extractor = nn.Sequential(*list(region_classifier.children())[:-1])
    elif hasattr(region_classifier, 'fc'): convnext_feature_extractor = nn.Sequential(*list(region_classifier.children())[:-1])
    else: convnext_feature_extractor = nn.Sequential(*list(region_classifier.children())[:-1])
    convnext_feature_extractor = convnext_feature_extractor.to(DEVICE); convnext_feature_extractor.eval()

    vit_feature_extractor = timm.create_model(VIT_MODEL_NAME, pretrained=True, num_classes=0)
    vit_feature_extractor = vit_feature_extractor.to(DEVICE); vit_feature_extractor.eval()
    print("Base models loaded.")
except Exception as e: print(f"Error loading base models: {e}"); raise

# --- Feature Extraction (with Caching) ---
@torch.no_grad()
def extract_features(dataloader, convnext_model, vit_model, region_classifier_model, device, dataset_name=""):
    all_convnext_feats, all_vit_feats = [], []; all_latitudes, all_longitudes = [], []
    all_true_regions, all_pred_regions = [], []; all_fnames = []; all_original_indices = []
    for batch in tqdm(dataloader, desc=f"Extracting features [{dataset_name}]"):
        if batch is None: continue
        images, latitudes, longitudes, true_regions, fnames, original_indices = batch
        if images is None or len(images) == 0: continue
        images = images.to(device)
        try:
            convnext_feats_raw = convnext_model(images)
            if convnext_feats_raw.dim() > 2: pool = nn.AdaptiveAvgPool2d((1, 1)); convnext_feats = torch.flatten(pool(convnext_feats_raw), 1)
            else: convnext_feats = convnext_feats_raw
            vit_feats = vit_model(images)
            if vit_feats.dim() == 3: vit_feats = vit_feats[:, 0]
            region_logits = region_classifier_model(images); pred_regions = torch.argmax(region_logits, dim=1)
            all_convnext_feats.append(convnext_feats.cpu()); all_vit_feats.append(vit_feats.cpu())
            all_latitudes.append(latitudes.cpu()); all_longitudes.append(longitudes.cpu())
            all_true_regions.append(true_regions.cpu()); all_pred_regions.append(pred_regions.cpu())
            all_original_indices.append(original_indices.cpu()); all_fnames.extend(fnames)
            del images, convnext_feats_raw, convnext_feats, vit_feats, region_logits, pred_regions, latitudes, longitudes, true_regions, original_indices
            gc.collect(); torch.cuda.empty_cache()
        except Exception as e: print(f"\nError extracting batch features [{dataset_name}]: {e}"); continue
    if not all_convnext_feats: print(f"FATAL: No features extracted for {dataset_name}."); return [np.array([])]*7
    all_convnext_feats = torch.cat(all_convnext_feats, dim=0).numpy(); all_vit_feats = torch.cat(all_vit_feats, dim=0).numpy()
    all_latitudes = torch.cat(all_latitudes, dim=0).numpy(); all_longitudes = torch.cat(all_longitudes, dim=0).numpy()
    all_true_regions = torch.cat(all_true_regions, dim=0).numpy(); all_pred_regions = torch.cat(all_pred_regions, dim=0).numpy()
    all_original_indices = torch.cat(all_original_indices, dim=0).numpy()
    fused_features = np.concatenate((all_convnext_feats, all_vit_feats), axis=1)
    return fused_features, all_latitudes, all_longitudes, all_true_regions, all_pred_regions, all_fnames, all_original_indices

# -- Load or Extract Features (Using Cache) --
if os.path.exists(TRAIN_CACHE_FILE):
    print(f"\nLoading training features from cache: {TRAIN_CACHE_FILE}")
    cache_data = np.load(TRAIN_CACHE_FILE); train_features = cache_data['features']; train_lat = cache_data['lat']; train_lon = cache_data['lon']; train_true_region = cache_data['true_region']; train_pred_region = cache_data['pred_region']; train_original_indices = cache_data['original_indices']
else: # Extract and cache
    print("\nExtracting features for Training Set..."); start_time=time.time()
    train_features, train_lat, train_lon, train_true_region, train_pred_region, _, train_original_indices = extract_features(train_img_loader, convnext_feature_extractor, vit_feature_extractor, region_classifier, DEVICE, "Train")
    print(f" Extracted train features in {time.time()-start_time:.2f}s")
    if train_features.size > 0: print(f"Saving training features to cache: {TRAIN_CACHE_FILE}"); os.makedirs(CACHE_DIR, exist_ok=True); np.savez_compressed(TRAIN_CACHE_FILE, features=train_features, lat=train_lat, lon=train_lon, true_region=train_true_region, pred_region=train_pred_region, original_indices=train_original_indices)

if os.path.exists(VAL_CACHE_FILE):
    print(f"\nLoading validation features from cache: {VAL_CACHE_FILE}")
    cache_data = np.load(VAL_CACHE_FILE); val_features = cache_data['features']; val_lat = cache_data['lat']; val_lon = cache_data['lon']; val_true_region = cache_data['true_region']; val_pred_region = cache_data['pred_region']; val_original_indices = cache_data['original_indices']
else: # Extract and cache
    print("\nExtracting features for Validation Set..."); start_time=time.time()
    val_features, val_lat, val_lon, val_true_region, val_pred_region, _, val_original_indices = extract_features(val_img_loader, convnext_feature_extractor, vit_feature_extractor, region_classifier, DEVICE, "Validation")
    print(f" Extracted val features in {time.time()-start_time:.2f}s")
    if val_features.size > 0: print(f"Saving validation features to cache: {VAL_CACHE_FILE}"); os.makedirs(CACHE_DIR, exist_ok=True); np.savez_compressed(VAL_CACHE_FILE, features=val_features, lat=val_lat, lon=val_lon, true_region=val_true_region, pred_region=val_pred_region, original_indices=val_original_indices)

# --- Check Feature Availability & Get Dimension ---
if 'train_features' not in locals() or train_features.size == 0: raise RuntimeError("Training features failed.")
if 'val_features' not in locals() or val_features.size == 0: raise RuntimeError("Validation features failed.")
fused_feature_dim = train_features.shape[1]; print(f"\nCorrect Fused feature dimension: {fused_feature_dim}")

# --- Generate Anchor Points ---
def generate_anchor_points(latitudes, longitudes, true_region_ids, n_regions=NUM_REGIONS, n_anchors=N_ANCHORS_PER_REGION, n_init=KMEANS_N_INIT):
    anchor_points_by_region = {}
    print(f"\n--- Generating {n_anchors} Anchor Points per Region using K-Means ---")
    coords = np.stack((latitudes, longitudes), axis=1)
    for region_id in range(n_regions):
        region_mask = (true_region_ids == region_id); region_coords = coords[region_mask]; num_samples = len(region_coords)
        if num_samples == 0: anchor_points_by_region[region_id] = None; continue
        unique_coords = np.unique(region_coords, axis=0); n_clusters = min(n_anchors, len(unique_coords))
        if n_clusters < 1: anchor_points_by_region[region_id] = None; continue
        if n_clusters < n_anchors: print(f"  Region {region_id + 1}: Using n_clusters={n_clusters}")
        if n_clusters == 1: anchor_points = unique_coords
        else:
             try: kmeans = KMeans(n_clusters=n_clusters, random_state=SEED, n_init=n_init, verbose=0); kmeans.fit(region_coords); anchor_points = kmeans.cluster_centers_
             except Exception as e: print(f"  Region {region_id + 1}: K-Means failed ({e}). Using unique points."); anchor_points = unique_coords[:n_anchors]
        anchor_points_by_region[region_id] = anchor_points
    all_anchors_list = [ap for ap in anchor_points_by_region.values() if ap is not None]
    if all_anchors_list: global_anchor_mean = np.mean(np.concatenate(all_anchors_list, axis=0), axis=0)
    else: global_anchor_mean = np.mean(coords, axis=0) if len(coords)>0 else np.array([0.0, 0.0])
    for region_id in range(n_regions):
        if anchor_points_by_region[region_id] is None: anchor_points_by_region[region_id] = np.tile(global_anchor_mean, (n_anchors, 1)); print(f"  Region {region_id+1}: Assigning fallback anchor mean.")
    print("--- Finished Generating Anchor Points ---")
    return anchor_points_by_region
anchor_points_dict = generate_anchor_points(train_lat, train_lon, train_true_region)

# --- Assign Training Labels ---
def assign_closest_anchor_labels(latitudes, longitudes, true_region_ids, anchor_points_by_region):
    print("\n--- Assigning Training Labels (Closest Anchor Index) ---")
    num_samples = len(latitudes); anchor_labels = np.full(num_samples, -1, dtype=int); coords = np.stack((latitudes, longitudes), axis=1)
    for i in tqdm(range(num_samples), desc="Assigning Labels"):
        true_region_id = true_region_ids[i]; region_anchors = anchor_points_by_region.get(true_region_id)
        if region_anchors is not None and len(region_anchors) > 0:
            distances = pairwise_distances(coords[i:i+1], region_anchors, metric='euclidean').flatten(); closest_anchor_index = np.argmin(distances); anchor_labels[i] = closest_anchor_index
    unassigned_count = np.sum(anchor_labels == -1);
    if unassigned_count > 0: print(f"Warning: {unassigned_count} samples could not be assigned an anchor label.")
    return anchor_labels
train_anchor_labels = assign_closest_anchor_labels(train_lat, train_lon, train_true_region, anchor_points_dict)

# --- Define Region-Specific MLP Classifier ---
class RegionAnchorClassifierMLP(nn.Module):
    def __init__(self, input_dim=fused_feature_dim, num_anchors=N_ANCHORS_PER_REGION, hidden_dim=MLP_HIDDEN_DIM):
        super().__init__(); self.network = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(0.4), nn.Linear(hidden_dim, hidden_dim // 2), nn.BatchNorm1d(hidden_dim // 2), nn.ReLU(), nn.Dropout(0.3), nn.Linear(hidden_dim // 2, num_anchors))
    def forward(self, x): return self.network(x)
print(f"\nDefined RegionAnchorClassifierMLP architecture (Input: {fused_feature_dim}, Output: {N_ANCHORS_PER_REGION}).")

# --- Train Region-Specific Anchor Classifiers ---
def train_region_anchor_classifiers(features, anchor_labels, true_region_ids, n_regions=NUM_REGIONS, feature_dim=fused_feature_dim, epochs=MLP_EPOCHS_PER_REGION, lr=MLP_LEARNING_RATE, wd=MLP_WEIGHT_DECAY):
    anchor_classifiers = {}
    print(f"\n--- Training {n_regions} Region-Specific Anchor Classifiers ---")
    if features.shape[0] == 0: print("No features to train models."); return {}
    for region_id in range(n_regions):
        print(f"\n-- Training Classifier for Region {region_id + 1}/{n_regions} --")
        region_mask = (true_region_ids == region_id) & (anchor_labels != -1); region_indices = np.where(region_mask)[0]; min_samples_required = 10
        if len(region_indices) < min_samples_required: print(f"Skipping Region {region_id + 1}: Only {len(region_indices)} valid samples found."); anchor_classifiers[region_id] = None; continue
        region_features_np = features[region_indices]; region_labels_np = anchor_labels[region_indices]; unique_classes, _ = np.unique(region_labels_np, return_counts=True); num_unique_classes = len(unique_classes)
        if num_unique_classes < 2: print(f"  Skipping Region {region_id + 1}: Fewer than 2 unique anchor labels found."); anchor_classifiers[region_id] = None; continue
        region_features_tensor = torch.tensor(region_features_np, dtype=torch.float32); region_labels_tensor = torch.tensor(region_labels_np, dtype=torch.long)
        region_dataset = TensorDataset(region_features_tensor, region_labels_tensor); region_loader = DataLoader(region_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
        if len(region_loader) == 0: print(f"  Skipping Region {region_id + 1}: No full batches."); anchor_classifiers[region_id] = None; continue
        model = RegionAnchorClassifierMLP(input_dim=feature_dim, num_anchors=N_ANCHORS_PER_REGION).to(DEVICE); optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd); criterion = nn.CrossEntropyLoss(); model.train()
        for epoch in range(epochs):
            epoch_loss = 0.0; epoch_samples = 0
            for inputs, labels in region_loader:
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE); optimizer.zero_grad(); outputs = model(inputs); loss = criterion(outputs, labels)
                loss.backward(); optimizer.step(); epoch_loss += loss.item() * inputs.size(0); epoch_samples += inputs.size(0)
            avg_epoch_loss = epoch_loss / epoch_samples if epoch_samples > 0 else 0
            if (epoch + 1) % 100 == 0 or epoch == 0 or epoch == epochs -1: print(f"  Region {region_id+1} - Epoch {epoch+1}/{epochs} - Avg Loss: {avg_epoch_loss:.4f}")
        model.eval(); anchor_classifiers[region_id] = model; print(f"  Finished training classifier for Region {region_id + 1}."); gc.collect(); torch.cuda.empty_cache()
    print("\n--- Finished Training All Region Anchor Classifiers ---")
    return anchor_classifiers
trained_anchor_classifiers = train_region_anchor_classifiers(train_features, train_anchor_labels, train_true_region)

# --- Evaluate Anchor Classifier Approach ---
def evaluate_anchor_classifier_performance(features, true_lat, true_lon, predicted_regions, original_indices, anchor_classifiers, anchor_points_by_region, dataset_name="Dataset"):
    final_pred_lat_list = []; final_pred_lon_list = []; indices_for_loss = []
    successful_predictions = 0; missing_classifier_count = 0; missing_anchor_count = 0
    print(f"\n--- Evaluating Anchor Classifier Performance on {dataset_name} ---")
    if features.shape[0] == 0: print(f"Cannot evaluate {dataset_name}: No features."); return np.nan, np.nan, np.nan, 0, [], [] # Return empty preds
    num_samples = features.shape[0]
    with torch.no_grad():
        for i in tqdm(range(num_samples), desc=f"Predicting anchors on {dataset_name}"):
            current_original_index = original_indices[i]; ignore_this_sample = (dataset_name == "Validation Set" and current_original_index in VAL_IGNORE_INDICES)
            pred_region_id = predicted_regions[i]; classifier = anchor_classifiers.get(pred_region_id); region_anchors = anchor_points_by_region.get(pred_region_id)
            pred_lat, pred_lon = np.nan, np.nan; prediction_successful = False
            if classifier is not None and region_anchors is not None and len(region_anchors) > 0:
                try:
                    feature_vector_np = features[i:i+1]; feature_tensor = torch.tensor(feature_vector_np, dtype=torch.float32).to(DEVICE)
                    classifier.eval(); logits = classifier(feature_tensor); predicted_anchor_index = torch.argmax(logits, dim=1).item()
                    if 0 <= predicted_anchor_index < len(region_anchors):
                         predicted_coords = region_anchors[predicted_anchor_index]; pred_lat = predicted_coords[0]; pred_lon = predicted_coords[1]
                         if np.isfinite(pred_lat) and np.isfinite(pred_lon): prediction_successful = True; successful_predictions += 1
                         else: pred_lat, pred_lon = np.nan, np.nan
                    else: missing_anchor_count += 1
                except Exception as e: pass
            elif classifier is None: missing_classifier_count += 1
            else: missing_anchor_count += 1
            final_pred_lat_list.append(pred_lat); final_pred_lon_list.append(pred_lon) # Append raw predictions
            if prediction_successful and not ignore_this_sample: indices_for_loss.append(i)
    pred_lat_arr = np.array(final_pred_lat_list); pred_lon_arr = np.array(final_pred_lon_list); num_valid_for_loss = len(indices_for_loss)
    ignored_valid_preds = successful_predictions - num_valid_for_loss if dataset_name == "Validation Set" else 0
    print(f"{dataset_name}: Processed {num_samples} samples.")
    print(f"  Successfully predicted finite coords via anchor classification: {successful_predictions} samples.")
    if dataset_name == "Validation Set": print(f"  Ignored {ignored_valid_preds} successfully predicted samples based on VAL_IGNORE_INDICES.")
    print(f"  Samples used for MSE calculation: {num_valid_for_loss}")
    print(f"  Samples where classifier model was missing: {missing_classifier_count}")
    print(f"  Samples where anchors were missing/invalid: {missing_anchor_count}")
    if num_valid_for_loss == 0: print(f"No valid & included predictions available for {dataset_name} MSE calculation."); return np.nan, np.nan, np.nan, 0, final_pred_lat_list, final_pred_lon_list # Return raw preds
    valid_true_lat = true_lat[indices_for_loss]; valid_pred_lat = pred_lat_arr[indices_for_loss]; valid_true_lon = true_lon[indices_for_loss]; valid_pred_lon = pred_lon_arr[indices_for_loss]
    if len(valid_true_lat) == 0: print("Error: No data for MSE calc."); return np.nan, np.nan, np.nan, 0, final_pred_lat_list, final_pred_lon_list
    mse_lat = mean_squared_error(valid_true_lat, valid_pred_lat); mse_lon = mean_squared_error(valid_true_lon, valid_pred_lon); avg_mse = 0.5 * (mse_lat + mse_lon)
    print(f"{dataset_name} Anchor Classification Results (based on {num_valid_for_loss} valid & included predictions):")
    print(f"  Latitude MSE  : {mse_lat:.4f}"); print(f"  Longitude MSE : {mse_lon:.4f}"); print(f"  Average MSE   : {avg_mse:.4f}")
    # Return raw predictions along with metrics
    return mse_lat, mse_lon, avg_mse, num_valid_for_loss, final_pred_lat_list, final_pred_lon_list

# --- Main Execution ---
if ('train_features' not in locals() or train_features.size == 0 or
    'val_features' not in locals() or val_features.size == 0 or
    not trained_anchor_classifiers):
     print("\nHalting execution - features missing or anchor classifiers not trained.")
else:
    print("\n--- Evaluating Final Region Anchor Classifiers ---")
    # Evaluate Training set (don't need predictions back)
    train_mse_lat, train_mse_lon, train_avg_mse, train_valid_count, _, _ = evaluate_anchor_classifier_performance(train_features, train_lat, train_lon, train_pred_region, train_original_indices, trained_anchor_classifiers, anchor_points_dict, dataset_name="Training Set")

    # Evaluate Validation set AND get the raw predictions back
    val_mse_lat, val_mse_lon, val_avg_mse, val_valid_count, val_pred_lat_raw, val_pred_lon_raw = evaluate_anchor_classifier_performance(val_features, val_lat, val_lon, val_pred_region, val_original_indices, trained_anchor_classifiers, anchor_points_dict, dataset_name="Validation Set")

    print("\n--- Final Summary (Anchor Classification) ---")
    train_total = len(train_features); val_total = len(val_features)
    print(f"Training Set   | Avg MSE: {train_avg_mse:.4f} | Lat MSE: {train_mse_lat:.4f} | Lon MSE: {train_mse_lon:.4f} | Samples used for MSE: {train_valid_count}/{train_total}")
    print(f"Validation Set | Avg MSE: {val_avg_mse:.4f} | Lat MSE: {val_mse_lat:.4f} | Lon MSE: {val_mse_lon:.4f} | Samples used for MSE: {val_valid_count}/{val_total}")


    # ==============================================================================
    # >> Submission CSV Generation <<
    # ==============================================================================
    import glob # Ensure glob is imported

    # --- Define Test Dataset (Copy from previous submission block) ---
    class SubmissionTestDataset(Dataset):
        def __init__(self, img_dir, transform, img_size=IMG_SIZE):
            self.img_dir = img_dir; self.transform = transform; self.img_size = img_size
            try:
                self.image_files = glob.glob(os.path.join(self.img_dir, 'img_*.*'))
                if not self.image_files: print(f"Warning: No test images found in {self.img_dir}")
                self.image_files.sort(key=lambda f: int(os.path.splitext(os.path.basename(f))[0].split('_')[-1]))
                print(f"Found {len(self.image_files)} test images.")
            except Exception as e: print(f"Error finding/sorting test images: {e}"); self.image_files = []
        def __len__(self): return len(self.image_files)
        def __getitem__(self, idx):
            img_path = self.image_files[idx]
            try: image = Image.open(img_path).convert('RGB');
            except Exception: return torch.zeros((3, self.img_size, self.img_size)), img_path, False
            valid_load = True
            try:
                if self.transform: image = self.transform(image)
            except Exception: valid_load = False; image = torch.zeros((3, self.img_size, self.img_size))
            return image, img_path, valid_load

    # --- Prediction Function (Copy from previous submission block) ---
    # Make sure this function matches the models and features used
    @torch.no_grad()
    def predict_submission_coords(batch_images,
                                  convnext_feat_extractor, vit_feat_extractor, region_classifier_model,
                                  anchor_classifiers_dict, anchor_points_dict, device):
        pred_coords_batch = []; fallback_coord = np.array([0.0, 0.0])
        if batch_images is None or batch_images.numel() == 0: return []
        batch_images = batch_images.to(device)
        try:
            features_convnext = convnext_feat_extractor(batch_images);
            if features_convnext.dim()==4: features_convnext = nn.functional.adaptive_avg_pool2d(features_convnext, (1,1))
            features_convnext = features_convnext.reshape(features_convnext.size(0), -1)
            features_vit = vit_feat_extractor(batch_images);
            if features_vit.dim()==3: features_vit = features_vit[:, 0]
            features = torch.cat((features_convnext, features_vit), dim=1).cpu().numpy() # FUSED Features
            region_logits = region_classifier_model(batch_images); predicted_region_ids = torch.argmax(region_logits, dim=1).cpu().numpy()
        except Exception as e: return [(fallback_coord[0], fallback_coord[1])] * batch_images.shape[0]
        for i in range(features.shape[0]):
            pred_region_id = predicted_region_ids[i]; classifier = anchor_classifiers_dict.get(pred_region_id); region_anchors = anchor_points_dict.get(pred_region_id)
            pred_lat, pred_lon = fallback_coord
            if classifier is not None and region_anchors is not None and len(region_anchors) > 0:
                try:
                    feature_vector_np = features[i:i+1]; feature_tensor = torch.tensor(feature_vector_np, dtype=torch.float32).to(device); classifier.eval()
                    logits = classifier(feature_tensor); predicted_anchor_index = torch.argmax(logits, dim=1).item()
                    if 0 <= predicted_anchor_index < len(region_anchors):
                        coords = region_anchors[predicted_anchor_index]
                        if not (np.isnan(coords[0]) or np.isnan(coords[1])): pred_lat, pred_lon = coords[0], coords[1]
                except Exception: pass
            pred_coords_batch.append((pred_lat, pred_lon))
        return pred_coords_batch

    # --- Generate Submission File ---
    print("\n--- Generating Final Submission CSV (Anchor Classification) ---")
    submission_data = [] # List to hold (id, Latitude, Longitude) tuples

    # 1. Add Validation Predictions (using the raw predictions returned by evaluate_*)
    print("Adding Validation Predictions...")
    df_val = pd.read_csv(VAL_CSV_PATH) # Load val df for count
    num_val_samples = len(df_val)
    if len(val_pred_lat_raw) == num_val_samples and len(val_pred_lon_raw) == num_val_samples:
        for i in range(num_val_samples):
            submission_data.append((i, val_pred_lat_raw[i], val_pred_lon_raw[i]))
        print(f"Added {num_val_samples} validation predictions.")
    else:
        print(f"Warning: Length mismatch between val predictions ({len(val_pred_lat_raw)}) and val CSV ({num_val_samples}). Submission might be incorrect.")
        # Fallback: Try filling with zeros or re-running prediction if needed
        for i in range(num_val_samples): submission_data.append((i, 0.0, 0.0))


    # 2. Process Test Set
    print("\nProcessing Test Set for submission...")
    TEST_IMG_DIR = "/kaggle/input/images-test/images_test" # Default test path for Kaggle
    test_submission_dataset = SubmissionTestDataset(TEST_IMG_DIR, transform) # Use correct transform
    test_submission_loader = DataLoader(test_submission_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
    num_test_images = len(test_submission_dataset)
    test_start_id = num_val_samples
    test_predictions_temp = {} # Store by index 0 to N_test-1

    if num_test_images > 0:
        with torch.no_grad():
            current_test_idx = 0
            for images, img_paths, valid_loads in tqdm(test_submission_loader, desc="Test Submission Pred"):
                num_in_batch = len(img_paths)
                indices_in_batch = list(range(current_test_idx, current_test_idx + num_in_batch))
                images_valid = images[valid_loads]
                indices_valid = [idx for idx, v in zip(indices_in_batch, valid_loads) if v]
                indices_failed = [idx for idx, v in zip(indices_in_batch, valid_loads) if not v]
                for failed_idx in indices_failed: test_predictions_temp[failed_idx] = (0.0, 0.0)

                if images_valid.numel() > 0:
                    # Pass the correct feature extractors needed by predict_submission_coords
                    predicted_coords_list = predict_submission_coords(
                        images_valid, convnext_feature_extractor, vit_feature_extractor,
                        region_classifier, trained_anchor_classifiers, anchor_points_dict, DEVICE
                    )
                    for i in range(len(predicted_coords_list)):
                        test_idx = indices_valid[i]; pred_lat, pred_lon = predicted_coords_list[i]
                        test_predictions_temp[test_idx] = (pred_lat, pred_lon)
                current_test_idx += num_in_batch

        # Add test predictions to submission_data from temp dict
        for test_idx in range(num_test_images):
             pred_lat, pred_lon = test_predictions_temp.get(test_idx, (0.0, 0.0))
             submission_id = test_start_id + test_idx
             submission_data.append((submission_id, pred_lat, pred_lon))
    print(f"Collected predictions for {len(test_predictions_temp)} test samples (out of {num_test_images}).")

    # 3. Finalize and Save CSV
    print(f"\nTotal entries collected for submission: {len(submission_data)}")
    submission_data.sort(key=lambda x: x[0]) # Sort by ID

    # --- Handle NaNs ---
    temp_df = pd.DataFrame(submission_data, columns=['id','lat','lon'])
    nan_count_lat = temp_df['lat'].isna().sum(); nan_count_lon = temp_df['lon'].isna().sum()
    if nan_count_lat > 0 or nan_count_lon > 0:
        print(f"Found {nan_count_lat} NaN lat, {nan_count_lon} NaN lon predictions.")
        fallback_lat, fallback_lon = 0.0, 0.0
        try:
            if 'train_lat' in locals() and train_lat.size > 0: fallback_lat = np.mean(train_lat)
            if 'train_lon' in locals() and train_lon.size > 0: fallback_lon = np.mean(train_lon)
        except NameError: pass
        print(f"Filling NaNs with Lat={fallback_lat:.4f}, Lon={fallback_lon:.4f}")
        filled_submission_data = []
        for sub_id, p_lat, p_lon in submission_data:
             final_lat = fallback_lat if pd.isna(p_lat) else p_lat; final_lon = fallback_lon if pd.isna(p_lon) else p_lon
             filled_submission_data.append((sub_id, final_lat, final_lon))
        submission_data = filled_submission_data

    # Create DataFrame
    final_submission_df = pd.DataFrame(submission_data, columns=["id", "Latitude", "Longitude"])

    # --- Save Final Submission ---
    submission_filename = "submission.csv"
    try:
        final_submission_df.to_csv(submission_filename, index=False, float_format='%.5f')
        print(f"\nCombined submission file saved successfully to {submission_filename}")
        print(f"Total entries: {len(final_submission_df)}")
        if not final_submission_df.empty: print("Submission Head:\n", final_submission_df.head()); print("\nSubmission Tail:\n", final_submission_df.tail())
    except Exception as e: print(f"\nERROR saving final submission file: {e}")

    # --- Final Cleanup ---
    gc.collect()
    if DEVICE == 'cuda': torch.cuda.empty_cache()
    print("\nSubmission generation finished.")

Using device: cuda
Number of Anchor Points per Region: 10
Ignoring validation indices for loss calculation: [95, 145, 146, 158, 159, 160, 161]

Loading Base Models...


  checkpoint = torch.load(REGION_CLASSIFIER_PATH, map_location='cpu')


Base models loaded.

Loading training features from cache: /kaggle/working/train_features_cache.npz

Loading validation features from cache: /kaggle/working/val_features_cache.npz

Correct Fused feature dimension: 1792

--- Generating 10 Anchor Points per Region using K-Means ---
--- Finished Generating Anchor Points ---

--- Assigning Training Labels (Closest Anchor Index) ---


Assigning Labels: 100%|██████████| 6467/6467 [00:00<00:00, 8152.25it/s]



Defined RegionAnchorClassifierMLP architecture (Input: 1792, Output: 10).

--- Training 15 Region-Specific Anchor Classifiers ---

-- Training Classifier for Region 1/15 --
  Region 1 - Epoch 1/500 - Avg Loss: 2.3176
  Region 1 - Epoch 100/500 - Avg Loss: 0.0147
  Region 1 - Epoch 200/500 - Avg Loss: 0.0038
  Region 1 - Epoch 300/500 - Avg Loss: 0.0025
  Region 1 - Epoch 400/500 - Avg Loss: 0.0013
  Region 1 - Epoch 500/500 - Avg Loss: 0.0006
  Finished training classifier for Region 1.

-- Training Classifier for Region 2/15 --
  Region 2 - Epoch 1/500 - Avg Loss: 2.3125
  Region 2 - Epoch 100/500 - Avg Loss: 0.0087
  Region 2 - Epoch 200/500 - Avg Loss: 0.0025
  Region 2 - Epoch 300/500 - Avg Loss: 0.0009
  Region 2 - Epoch 400/500 - Avg Loss: 0.0004
  Region 2 - Epoch 500/500 - Avg Loss: 0.0050
  Finished training classifier for Region 2.

-- Training Classifier for Region 3/15 --
  Region 3 - Epoch 1/500 - Avg Loss: 2.3853
  Region 3 - Epoch 100/500 - Avg Loss: 0.0103
  Region 3 -

Predicting anchors on Training Set: 100%|██████████| 6467/6467 [00:03<00:00, 2095.71it/s]


Training Set: Processed 6467 samples.
  Successfully predicted finite coords via anchor classification: 6467 samples.
  Samples used for MSE calculation: 6467
  Samples where classifier model was missing: 0
  Samples where anchors were missing/invalid: 0
Training Set Anchor Classification Results (based on 6467 valid & included predictions):
  Latitude MSE  : 5439.5435
  Longitude MSE : 6268.1694
  Average MSE   : 5853.8564

--- Evaluating Anchor Classifier Performance on Validation Set ---


Predicting anchors on Validation Set: 100%|██████████| 369/369 [00:00<00:00, 2007.40it/s]


Validation Set: Processed 369 samples.
  Successfully predicted finite coords via anchor classification: 369 samples.
  Ignored 7 successfully predicted samples based on VAL_IGNORE_INDICES.
  Samples used for MSE calculation: 362
  Samples where classifier model was missing: 0
  Samples where anchors were missing/invalid: 0
Validation Set Anchor Classification Results (based on 362 valid & included predictions):
  Latitude MSE  : 45289.1016
  Longitude MSE : 60724.7188
  Average MSE   : 53006.9102

--- Final Summary (Anchor Classification) ---
Training Set   | Avg MSE: 5853.8564 | Lat MSE: 5439.5435 | Lon MSE: 6268.1694 | Samples used for MSE: 6467/6467
Validation Set | Avg MSE: 53006.9102 | Lat MSE: 45289.1016 | Lon MSE: 60724.7188 | Samples used for MSE: 362/369

--- Generating Final Submission CSV (Anchor Classification) ---
Adding Validation Predictions...
Added 369 validation predictions.

Processing Test Set for submission...
Found 369 test images.


Test Submission Pred: 100%|██████████| 6/6 [00:18<00:00,  3.10s/it]


Collected predictions for 369 test samples (out of 369).

Total entries collected for submission: 738

Combined submission file saved successfully to submission.csv
Total entries: 738
Submission Head:
    id       Latitude     Longitude
0   0  219686.125000  144801.12500
1   1  220190.093750  144217.21875
2   2  220190.093750  144217.21875
3   3  220070.796875  141976.50000
4   4  220314.093750  142191.28125

Submission Tail:
       id       Latitude      Longitude
733  733  220348.343750  144014.609375
734  734  218899.406250  144052.203125
735  735  218738.296875  144269.062500
736  736  218541.968750  144499.765625
737  737  218738.296875  144269.062500

Submission generation finished.
