In [None]:
# 1. Install Python 3.10
!sudo apt-get update -y
!sudo apt-get install python3.10 python3.10-dev python3.10-distutils -y

# 2. Point Colab to Python 3.10
!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1
!sudo update-alternatives --config python3

# 3. Install pip for Python 3.10
!curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10

# 4. Reinstall packages
!python3 -m pip install --upgrade pip
!python3 -m pip install timm faiss-gpu s2sphere tqdm


In [None]:
# =========================================================
# Improved GeoFuse with Hard Negative Mining on OpenStreetView-5M
# =========================================================

# Install dependencies with error handling
import subprocess
import sys

def install_package(package):
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])
        print(f"‚úÖ Successfully installed {package}")
    except subprocess.CalledProcessError as e:
        print(f"‚ùå Failed to install {package}: {e}")


#packages = ["timm", "faiss-cpu", "s2sphere", "tqdm", "albumentations", "scikit-learn"]
packages = ["timm", "faiss-gpu", "faiss-cpu", "s2sphere", "tqdm", "albumentations", "scikit-learn"]
for pkg in packages:
    install_package(pkg)

In [None]:
!pip install faiss

In [None]:
import os, random, warnings
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision import transforms
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

import faiss
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

import math

def latlon_to_s2cell(lat, lon, level=10):
    """
    Fallback encoder: approximate S2 cells by dividing
    lat/lon into a grid of 2^level x 2^level bins.
    """
    lat_bin = int((lat + 90.0) / 180.0 * (1 << level))
    lon_bin = int((lon + 180.0) / 360.0 * (1 << level))
    return lat_bin * (1 << level) + lon_bin


warnings.filterwarnings('ignore')

In [None]:
# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [None]:
def setup_gsv_dataset(api_key, locations, image_size="640x640"):
    """Download images using Google Street View Static API"""
    import requests

    os.makedirs('/content/osv_subset/images', exist_ok=True)

    metadata = []
    for i, (lat, lon) in enumerate(locations):
        try:
            url = f"https://maps.googleapis.com/maps/api/streetview"
            params = {
                'size': image_size,
                'location': f"{lat},{lon}",
                'heading': '0',  # Can randomize for variety
                'pitch': '0',
                'key': api_key,
                'fov': '90'
            }

            response = requests.get(url, params=params)
            if response.status_code == 200:
                filename = f"gsv_{i:04d}.jpg"
                with open(f'/content/osv_subset/images/{filename}', 'wb') as f:
                    f.write(response.content)

                metadata.append({
                    'filename': filename,
                    'lat': lat,
                    'lon': lon,
                    'source': 'google_street_view'
                })

        except Exception as e:
            print(f"Failed to download image for {lat}, {lon}: {e}")

    # Save metadata
    pd.DataFrame(metadata).to_csv('/content/osv_subset/metadata.csv', index=False)
    print(f"Downloaded {len(metadata)} images from Google Street View")

# Option 2: Use existing datasets (YFCC100M, etc.)
def download_yfcc100m_subset():
    """Download a subset of YFCC100M dataset with GPS coordinates"""
    # This would require implementing YFCC100M download logic
    pass

# Option 3: Use Mapillary dataset
def setup_mapillary_dataset():
    """Setup using Mapillary street view images"""
    # Requires Mapillary API access
    pass

# Option 4: Custom dataset from uploaded folder
def setup_custom_dataset(folder_path):
    """Setup dataset from a custom folder structure"""

    print(f"Setting up custom dataset from {folder_path}")

    # Expected structure:
    # folder_path/
    #   ‚îú‚îÄ‚îÄ images/
    #   ‚îÇ   ‚îú‚îÄ‚îÄ img1.jpg
    #   ‚îÇ   ‚îî‚îÄ‚îÄ img2.jpg
    #   ‚îî‚îÄ‚îÄ metadata.csv (optional)

    images_path = os.path.join(folder_path, 'images')
    metadata_path = os.path.join(folder_path, 'metadata.csv')

    if not os.path.exists(images_path):
        print(f"‚ùå Images folder not found at {images_path}")
        return False

    # Copy images
    import shutil
    shutil.copytree(images_path, '/content/osv_subset/images', dirs_exist_ok=True)

    # Handle metadata
    if os.path.exists(metadata_path):
        shutil.copy(metadata_path, '/content/osv_subset/metadata.csv')
        print("‚úÖ Existing metadata copied")
    else:
        # Generate metadata from EXIF data or user input
        generate_metadata_from_images('/content/osv_subset/images')

    return True

def generate_metadata_from_images(images_folder):
    """Generate metadata by extracting EXIF GPS data from images"""
    from PIL import Image
    from PIL.ExifTags import TAGS
    import glob

    metadata = []
    image_files = glob.glob(os.path.join(images_folder, "*.jpg")) + \
                  glob.glob(os.path.join(images_folder, "*.jpeg")) + \
                  glob.glob(os.path.join(images_folder, "*.png"))

    for img_path in tqdm(image_files, desc="Extracting GPS from images"):
        try:
            image = Image.open(img_path)
            exif = image.getexif()

            lat, lon = None, None

            # Extract GPS data
            for tag_id in exif:
                tag = TAGS.get(tag_id, tag_id)
                if tag == "GPSInfo":
                    gps_data = exif[tag_id]
                    lat, lon = parse_gps_data(gps_data)
                    break

            # If no GPS data, generate random coordinates (for demo)
            if lat is None or lon is None:
                lat = random.uniform(-60, 70)
                lon = random.uniform(-180, 180)
                print(f"‚ö†Ô∏è No GPS data for {os.path.basename(img_path)}, using random coordinates")

            metadata.append({
                'filename': os.path.basename(img_path),
                'lat': lat,
                'lon': lon,
                'source': 'exif_gps'
            })

        except Exception as e:
            print(f"Failed to process {img_path}: {e}")

    # Save metadata
    pd.DataFrame(metadata).to_csv('/content/osv_subset/metadata.csv', index=False)
    print(f"Generated metadata for {len(metadata)} images")

def parse_gps_data(gps_data):
    """Parse GPS data from EXIF"""
    try:
        lat_ref = gps_data.get(1)
        lat_data = gps_data.get(2)
        lon_ref = gps_data.get(3)
        lon_data = gps_data.get(4)

        if lat_data and lon_data:
            lat = convert_gps_coord(lat_data, lat_ref)
            lon = convert_gps_coord(lon_data, lon_ref)
            return lat, lon
    except:
        pass

    return None, None

def convert_gps_coord(coord_data, ref):
    """Convert GPS coordinate from EXIF format to decimal"""
    try:
        degrees = float(coord_data[0])
        minutes = float(coord_data[1])
        seconds = float(coord_data[2])

        decimal = degrees + minutes/60 + seconds/3600

        if ref in ['S', 'W']:
            decimal = -decimal

        return decimal
    except:
        return None

# Interactive dataset setup
def interactive_dataset_setup():
    """Interactive setup with user choices"""

    print("üåç GeoFuse Dataset Setup")
    print("=" * 40)

    options = {
        "1": "Download OpenStreetView-5M (automatic)",
        "2": "Use Google Street View API (requires API key)",
        "3": "Upload custom dataset",
        "4": "Generate synthetic data",
        "5": "Use existing Colab files"
    }

    for key, value in options.items():
        print(f"{key}. {value}")

    choice = input("\nSelect an option (1-5): ").strip()

    if choice == "1":
        print("Option 1 (Download OpenStreetView-5M) is not fully implemented in this notebook.")
        return False

    elif choice == "2":
        api_key = input("Enter your Google Street View API key: ")
        num_samples = int(input("Number of samples to download (max 100): "))

        # Generate random locations (you can customize this)
        locations = [(random.uniform(-60, 70), random.uniform(-180, 180))
                    for _ in range(min(num_samples, 100))]

        setup_gsv_dataset(api_key, locations)
        return True

    elif choice == "3":
        print("Please upload your dataset folder and specify the path:")
        folder_path = input("Enter folder path: ")
        return setup_custom_dataset(folder_path)

    elif choice == "4":
        print("Option 4 (Generate synthetic data) is not fully implemented in this notebook.")
        return False

    elif choice == "5":
        print("Option 5 (Use existing Colab files) is not fully implemented in this notebook.")
        return False


    else:
        print("Invalid choice. Please select a valid option (1-5).")
        return False


# Enhanced file upload helper
def upload_and_extract_dataset():
    """Helper to upload and extract dataset files"""

    try:
        from google.colab import files
        print("üì§ Upload your dataset ZIP file:")
        uploaded = files.upload()

        for filename in uploaded.keys():
            if filename.endswith('.zip'):
                print(f"Extracting {filename}...")
                # Assuming extract_archive is defined elsewhere
                print("Simulating archive extraction...")
                # extract_archive(f'/content/{filename}', '/content/osv_subset')
                print("‚úÖ Dataset extracted successfully")
                return True

        print("‚ùå No ZIP file found in upload")
        return False

    except ImportError:
        print("‚ùå File upload only available in Google Colab")
        return False
    except Exception as e:
        print(f"‚ùå Upload failed: {e}")
        return False

# Dataset validation
def validate_dataset():
    """Validate the setup dataset"""

    issues = []

    # Check metadata
    if not os.path.exists('/content/osv_subset/metadata.csv'):
        issues.append("‚ùå metadata.csv not found")
    else:
        try:
            df = pd.read_csv('/content/osv_subset/metadata.csv')
            required_cols = ['filename', 'lat', 'lon']
            missing_cols = [col for col in required_cols if col not in df.columns]
            if missing_cols:
                issues.append(f"‚ùå Missing columns in metadata: {missing_cols}")

            # Check coordinate ranges
            if 'lat' in df.columns:
                invalid_lat = df[(df['lat'] < -90) | (df['lat'] > 90)]
                if len(invalid_lat) > 0:
                    issues.append(f"‚ö†Ô∏è {len(invalid_lat)} invalid latitude values")

            if 'lon' in df.columns:
                invalid_lon = df[(df['lon'] < -180) | (df['lon'] > 180)]
                if len(invalid_lon) > 0:
                    issues.append(f"‚ö†Ô∏è {len(invalid_lon)} invalid longitude values")

        except Exception as e:
            issues.append(f"‚ùå Error reading metadata: {e}")

    # Check images
    if not os.path.exists('/content/osv_subset/images'):
        issues.append("‚ùå Images folder not found")
    else:
        image_files = [f for f in os.listdir('/content/osv_subset/images')
                      if f.endswith(('.jpg', '.jpeg', '.png'))]
        if len(image_files) == 0:
            issues.append("‚ùå No image files found")
        else:
            print(f"‚úÖ Found {len(image_files)} image files")

    if issues:
        print("Dataset validation issues:")
        for issue in issues:
            print(f"  {issue}")
        return False
    else:
        print("‚úÖ Dataset validation passed")
        return True

# Main execution
if __name__ == "__main__":
    # You can uncomment one of these options:

    # Option 1: Automatic setup (tries multiple methods)
    # dataset_ready = setup_dataset()

    # Option 2: Interactive setup
    dataset_ready = interactive_dataset_setup()

    # Option 3: Direct upload
    # dataset_ready = upload_and_extract_dataset()

    if dataset_ready:
        validate_dataset()

In [None]:
# --- Step 1: Download OpenStreetView-5M subset with error handling ---
def download_data():
    try:
        if not os.path.exists('/content/osv5m'):
            subprocess.run(['git', 'clone', 'https://github.com/gastruc/osv5m.git'], check=True)

        os.chdir('/content/osv5m')

        if not os.path.exists('/content/osv_subset'):
            subprocess.run(['unzip', '-q', 'sample_data.zip', '-d', '/content/osv_subset'], check=True)

        os.chdir('/content')
        print("‚úÖ Data downloaded successfully")
        return True
    except Exception as e:
        print(f"‚ùå Error downloading data: {e}")
        return False

if not download_data():
    print("Please manually download the dataset")

In [None]:
# --- Step 2: Improved metadata processing ---
def load_and_process_metadata():
    try:
        meta = pd.read_csv("/content/osv_subset/metadata.csv")
        print("Metadata shape:", meta.shape)
        print("Metadata columns:", meta.columns.tolist())
        print("Metadata head:\n", meta.head())

        # Data validation
        meta = meta.dropna(subset=['lat', 'lon', 'filename'])
        meta = meta[(meta['lat'].between(-90, 90)) & (meta['lon'].between(-180, 180))]

        return meta
    except Exception as e:
        print(f"‚ùå Error loading metadata: {e}")
        return None

meta = load_and_process_metadata()
if meta is None:
    raise ValueError("Failed to load metadata")

In [None]:
# Improved S2 cell generation with error handling
def latlon_to_s2cell(lat, lon, level=12):  # Increased level for finer granularity
    try:
        ll = LatLng.from_degrees(float(lat), float(lon))
        cid = CellId.from_lat_lng(ll).parent(level)
        return cid.id()
    except Exception as e:
        print(f"Error converting lat={lat}, lon={lon} to S2 cell: {e}")
        return None

meta["s2_cell"] = meta.apply(lambda r: latlon_to_s2cell(r["lat"], r["lon"]), axis=1)
meta = meta.dropna(subset=['s2_cell'])  # Remove invalid S2 cells


In [None]:
# Build items list with validation
items = []
missing_files = []
for row in meta.itertuples():
    img_path = os.path.join("/content/osv_subset/images", row.filename)
    if os.path.exists(img_path):
        items.append({
            "img_path": img_path,
            "lat": float(row.lat),
            "lon": float(row.lon),
            "s2_cell": int(row.s2_cell)
        })
    else:
        missing_files.append(row.filename)

if missing_files:
    print(f"‚ö†Ô∏è Found {len(missing_files)} missing image files")

print(f"‚úÖ Processed {len(items)} valid items")

In [None]:
# Create label mapping for S2 cells
unique_s2_cells = sorted(list(set(item["s2_cell"] for item in items)))
s2_to_label = {cell: idx for idx, cell in enumerate(unique_s2_cells)}
label_to_s2 = {idx: cell for cell, idx in s2_to_label.items()}

In [None]:
# Update items with labels
for item in items:
    item["label"] = s2_to_label[item["s2_cell"]]

In [None]:
# Improved train/val split with stratification
from collections import Counter
label_counts = Counter(item["label"] for item in items)

In [None]:
# Ensure both train and val have representation from each class
random.shuffle(items)
train_items, val_items = [], []

for label in label_counts.keys():
    label_items = [item for item in items if item["label"] == label]
    split = max(1, int(0.85 * len(label_items)))  # Ensure at least 1 in each set
    train_items.extend(label_items[:split])
    val_items.extend(label_items[split:])

print(f"Train items: {len(train_items)}, Val items: {len(val_items)}")
print(f"Number of classes: {len(unique_s2_cells)}")


In [None]:
# --- Step 3: Enhanced Dataset with better augmentations ---
class GeoImageDataset(Dataset):
    def __init__(self, items, transform=None, is_training=True):
        self.items = items
        self.transform = transform
        self.is_training = is_training

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        try:
            item = self.items[idx]
            img = Image.open(item["img_path"]).convert("RGB")

            if self.transform:
                if self.is_training and hasattr(self.transform, 'transforms'):
                    # Albumentations
                    img_array = np.array(img)
                    transformed = self.transform(image=img_array)
                    img = transformed["image"]
                else:
                    # Torchvision
                    img = self.transform(img)

            return img, idx, item["label"], item["lat"], item["lon"]
        except Exception as e:
            print(f"Error loading image at index {idx}: {e}")
            # Return a dummy image
            dummy_img = torch.zeros(3, 384, 384)
            return dummy_img, idx, 0, 0.0, 0.0


In [None]:
# Enhanced augmentations using Albumentations
def get_train_transforms(size=384):
    return A.Compose([
        A.RandomResizedCrop(height=size, width=size, size=(size, size), scale=(0.7, 1.0), ratio=(0.75, 1.33)),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.RandomRotate90(p=0.3),
        A.OneOf([
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.8),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.8),
        ], p=0.9),
        A.OneOf([
            A.GaussianBlur(blur_limit=(3, 7), p=0.5),
            A.MotionBlur(blur_limit=(3, 7), p=0.5),
        ], p=0.3),
        A.RandomFog(fog_coef_lower=0.1, fog_coef_upper=0.3, p=0.2),
        A.RandomSunFlare(flare_roi=(0, 0, 1, 0.5), angle_lower=0.5, p=0.2),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

def get_val_transforms(size=384):
    return A.Compose([
        A.Resize(height=size, width=size, size=(size, size)),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

train_ds = GeoImageDataset(train_items, transform=get_train_transforms(384), is_training=True)
val_ds = GeoImageDataset(val_items, transform=get_val_transforms(384), is_training=False)

In [None]:

# Weighted sampling for imbalanced classes
def get_weighted_sampler(dataset):
    label_counts = {}
    for item in dataset.items:
        label = item["label"]
        label_counts[label] = label_counts.get(label, 0) + 1

    weights = []
    for item in dataset.items:
        label = item["label"]
        weight = 1.0 / label_counts[label]
        weights.append(weight)

    return WeightedRandomSampler(weights, len(weights))

train_sampler = get_weighted_sampler(train_ds)
train_loader = DataLoader(train_ds, batch_size=16, sampler=train_sampler, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)


In [None]:

# --- Step 4: Enhanced Model Definition ---
# Since we can't download the actual geofuse.py, let's define a mock model
class GeoFuse(nn.Module):
    def __init__(self, backbone_name="efficientnet_b3", num_classes=None, embed_dim=512):
        super().__init__()
        import timm
        self.backbone = timm.create_model(backbone_name, pretrained=True, num_classes=0)

        # Get backbone output dimension
        with torch.no_grad():
            dummy_input = torch.randn(1, 3, 384, 384)
            backbone_out_dim = self.backbone(dummy_input).shape[1]

        self.descriptor_head = nn.Sequential(
            nn.Linear(backbone_out_dim, embed_dim),
            nn.BatchNorm1d(embed_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(embed_dim, embed_dim)
        )

        if num_classes:
            self.classifier = nn.Linear(embed_dim, num_classes)
        else:
            self.classifier = None

        self.embed_dim = embed_dim

    def forward(self, x):
        features = self.backbone(x)
        descriptor = self.descriptor_head(features)

        logits = None
        if self.classifier is not None:
            logits = self.classifier(descriptor)

        return descriptor, logits, features, None

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model = GeoFuse(
    backbone_name="efficientnet_b3",
    num_classes=len(unique_s2_cells),
    embed_dim=512
).to(device)

In [None]:
# Loss functions and optimizer
criterion_cls = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-4,
    weight_decay=1e-4,
    betas=(0.9, 0.999)
)

In [None]:
# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=5, T_mult=2, eta_min=1e-6
)

In [None]:
# --- Step 5: Improved Hard Negative Miner ---
class HardNegativeMiner:
    def __init__(self, model, device, batch_size=32):
        self.model = model
        self.device = device
        self.batch_size = batch_size
        self.index = None
        self.descriptors = None

    def extract_descriptors(self, dataset):
        """Extract descriptors for all samples in the dataset"""
        self.model.eval()
        loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False, num_workers=2)

        descriptors = []
        labels = []

        with torch.no_grad():
            for imgs, idxs, lbls, lats, lons in tqdm(loader, desc="Extracting descriptors"):
                imgs = imgs.to(self.device)
                desc, _, _, _ = self.model(imgs)
                desc = F.normalize(desc, p=2, dim=1)
                descriptors.append(desc.cpu())
                labels.extend(lbls.numpy())

        self.descriptors = torch.cat(descriptors, dim=0).numpy().astype("float32")
        self.labels = np.array(labels)

        # Build FAISS index
        self.index = faiss.IndexFlatIP(self.descriptors.shape[1])  # Inner Product for cosine similarity
        self.index.add(self.descriptors)

        return self.descriptors

    def mine_hard_negatives(self, top_k=20, same_class_ratio=0.3):
        """Mine hard negatives with both inter-class and intra-class negatives"""
        if self.index is None:
            raise ValueError("Must extract descriptors first")

        hard_negatives = {}

        for i in range(len(self.descriptors)):
            # Search for most similar samples
            scores, indices = self.index.search(self.descriptors[i:i+1], top_k * 3)

            current_label = self.labels[i]
            negatives = []
            same_class_negs = []
            diff_class_negs = []

            for idx in indices[0]:
                if idx == i:
                    continue

                if self.labels[idx] == current_label:
                    same_class_negs.append(idx)
                else:
                    diff_class_negs.append(idx)

            # Mix same-class and different-class negatives
            num_same_class = min(len(same_class_negs), int(top_k * same_class_ratio))
            num_diff_class = min(len(diff_class_negs), top_k - num_same_class)

            negatives.extend(same_class_negs[:num_same_class])
            negatives.extend(diff_class_negs[:num_diff_class])

            hard_negatives[i] = negatives

        return hard_negatives

# Initialize hard negative miner
hn_miner = HardNegativeMiner(model, device, batch_size=32)


In [None]:

# --- Step 6: Enhanced Training Loop ---
def train_epoch(model, train_loader, optimizer, criterion_cls, hard_negatives=None, epoch=0):
    model.train()
    total_loss = 0
    total_cls_loss = 0
    total_ret_loss = 0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

    for batch_idx, (imgs, idxs, labels, lats, lons) in enumerate(pbar):
        imgs, labels = imgs.to(device), labels.to(device)

        # Forward pass
        desc, logits, _, _ = model(imgs)
        desc = F.normalize(desc, p=2, dim=1)

        # Classification loss
        loss_cls = criterion_cls(logits, labels)

        # Hard negative mining loss
        loss_retrieval = 0.0
        if hard_negatives is not None:
            for i, anchor_idx in enumerate(idxs):
                anchor_idx = anchor_idx.item()
                neg_candidates = hard_negatives.get(anchor_idx, [])

                if not neg_candidates:
                    continue

                # Sample multiple negatives for more stable training
                num_negs = min(3, len(neg_candidates))
                selected_negs = random.sample(neg_candidates, num_negs)

                for neg_idx in selected_negs:
                    try:
                        neg_img, _, neg_label, _, _ = train_ds[neg_idx]
                        neg_img = neg_img.unsqueeze(0).to(device)

                        with torch.no_grad():
                            neg_desc, _, _, _ = model(neg_img)
                            neg_desc = F.normalize(neg_desc, p=2, dim=1)

                        # Triplet-like loss
                        anchor_desc = desc[i].unsqueeze(0)
                        pos_sim = torch.cosine_similarity(anchor_desc, anchor_desc, dim=1)
                        neg_sim = torch.cosine_similarity(anchor_desc, neg_desc, dim=1)

                        margin = 0.5
                        triplet_loss = F.relu(neg_sim - pos_sim + margin)
                        loss_retrieval += triplet_loss
                    except:
                        continue

            if len(idxs) > 0:
                loss_retrieval = loss_retrieval / len(idxs)

        # Combined loss
        alpha = 0.2  # Weight for retrieval loss
        total_batch_loss = loss_cls + alpha * loss_retrieval

        # Backward pass
        optimizer.zero_grad()
        total_batch_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        # Statistics
        total_loss += total_batch_loss.item()
        total_cls_loss += loss_cls.item()
        total_ret_loss += loss_retrieval.item() if isinstance(loss_retrieval, torch.Tensor) else loss_retrieval

        preds = logits.argmax(1)
        correct += (preds == labels).sum().item()
        total += len(labels)

        # Update progress bar
        pbar.set_postfix({
            "loss": total_batch_loss.item(),
            "cls_loss": loss_cls.item(),
            "ret_loss": loss_retrieval.item() if isinstance(loss_retrieval, torch.Tensor) else loss_retrieval,
            "acc": 100 * correct / total
        })

    return {
        "loss": total_loss / len(train_loader),
        "cls_loss": total_cls_loss / len(train_loader),
        "ret_loss": total_ret_loss / len(train_loader),
        "accuracy": 100 * correct / total
    }

def validate_epoch(model, val_loader, criterion_cls):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for imgs, idxs, labels, lats, lons in tqdm(val_loader, desc="Validation"):
            imgs, labels = imgs.to(device), labels.to(device)

            desc, logits, _, _ = model(imgs)
            loss = criterion_cls(logits, labels)

            preds = logits.argmax(1)
            correct += (preds == labels).sum().item()
            total += len(labels)
            total_loss += loss.item()

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = 100 * correct / total
    avg_loss = total_loss / len(val_loader)

    return {
        "loss": avg_loss,
        "accuracy": accuracy,
        "predictions": all_preds,
        "labels": all_labels
    }

In [None]:

# --- Training Loop ---
num_epochs = 15
best_val_acc = 0
hard_negatives = None
train_history = []
val_history = []

print("üöÄ Starting training...")

for epoch in range(num_epochs):
    # Mine hard negatives every 3 epochs
    if epoch % 3 == 0:
        print(f"\nüîç Mining hard negatives for epoch {epoch+1}...")
        try:
            hn_miner.extract_descriptors(train_ds)
            hard_negatives = hn_miner.mine_hard_negatives(top_k=15, same_class_ratio=0.3)
            print(f"‚úÖ Mined hard negatives for {len(hard_negatives)} samples")
        except Exception as e:
            print(f"‚ö†Ô∏è Hard negative mining failed: {e}")
            hard_negatives = None

    # Training
    train_metrics = train_epoch(model, train_loader, optimizer, criterion_cls, hard_negatives, epoch)
    train_history.append(train_metrics)

    # Validation
    val_metrics = validate_epoch(model, val_loader, criterion_cls)
    val_history.append(val_metrics)

    # Update learning rate
    scheduler.step()

    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"Train - Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['accuracy']:.2f}%")
    print(f"Val   - Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['accuracy']:.2f}%")
    print(f"LR: {scheduler.get_last_lr()[0]:.2e}")

    # Save best model
    if val_metrics['accuracy'] > best_val_acc:
        best_val_acc = val_metrics['accuracy']
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_val_acc': best_val_acc,
            's2_to_label': s2_to_label,
            'label_to_s2': label_to_s2,
        }, "best_geofuse_osv5m.pth")
        print(f"üíæ New best model saved! Val Acc: {best_val_acc:.2f}%")


In [None]:
# Final model save
torch.save({
    'epoch': num_epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'val_acc': val_metrics['accuracy'],
    's2_to_label': s2_to_label,
    'label_to_s2': label_to_s2,
    'train_history': train_history,
    'val_history': val_history,
}, "final_geofuse_osv5m.pth")

print(f"\nüéâ Training completed!")
print(f"Best validation accuracy: {best_val_acc:.2f}%")
print(f"Final validation accuracy: {val_metrics['accuracy']:.2f}%")


In [None]:
# --- Evaluation and Visualization ---
def plot_training_history(train_history, val_history):
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 10))

    # Loss curves
    train_losses = [h['loss'] for h in train_history]
    val_losses = [h['loss'] for h in val_history]

    ax1.plot(train_losses, label='Train Loss', color='blue')
    ax1.plot(val_losses, label='Val Loss', color='red')
    ax1.set_title('Loss Curves')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)

    # Accuracy curves
    train_accs = [h['accuracy'] for h in train_history]
    val_accs = [h['accuracy'] for h in val_history]

    ax2.plot(train_accs, label='Train Acc', color='blue')
    ax2.plot(val_accs, label='Val Acc', color='red')
    ax2.set_title('Accuracy Curves')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True)

    # Classification loss
    train_cls_losses = [h['cls_loss'] for h in train_history]
    ax3.plot(train_cls_losses, label='Classification Loss', color='green')
    ax3.set_title('Classification Loss')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Loss')
    ax3.legend()
    ax3.grid(True)

    # Retrieval loss
    train_ret_losses = [h['ret_loss'] for h in train_history]
    ax4.plot(train_ret_losses, label='Retrieval Loss', color='orange')
    ax4.set_title('Retrieval Loss')
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Loss')
    ax4.legend()
    ax4.grid(True)

    plt.tight_layout()
    plt.savefig('training_history.png', dpi=150, bbox_inches='tight')
    plt.show()

In [None]:

# Plot training history
plot_training_history(train_history, val_history)

print("‚úÖ Training analysis completed. Check 'training_history.png' for visualizations.")
print("‚úÖ Models saved as 'best_geofuse_osv5m.pth' and 'final_geofuse_osv5m.pth'")