# 🌍 GeoViT: A Convolutional-Transformer Model for Geolocation Estimation

Welcome to the GeoViT project notebook!

This notebook presents the training, evaluation, and experimentation pipeline for **GeoViT**, a neural network model designed to **predict geographic locations from Google Street View images**. The model takes inspiration from the popular game *Geoguessr* and is trained using the [OpenStreetView-5M dataset](https://huggingface.co/datasets/osv5m/osv5m).

🖊️ Authors: Alan Tran and Caleb Wolf

---

## 📌 Project Goals

1. **Train** a hybrid convolutional-transformer model that can learn geospatial patterns from street-level imagery.
2. **Evaluate** the model using geodesic distance-based metrics.
3. **Experiment** with:
   - Vision Transformer ablations (layers & attention heads)
   - Robustness to reduced image context (square vs 3:2 aspect ratio)

---

## 🧠 Model Overview

- **Convolutional Frontend:** Captures local texture and object-level features.
- **Vision Transformer (ViT):** Captures global spatial dependencies.
- **Output:** Regressed GPS coordinates (Latitude, Longitude)

---

## 🧪 Experiments

### ✅ Experiment 1: ViT Ablation
- Reduce number of transformer layers and attention heads
- Assess contribution of transformer structure to geolocation performance

### ✅ Experiment 2: Robustness to Cropped Context
- Evaluate model on square images (less context)
- Compare against standard aspect ratio input

---

In [None]:
import os
import glob
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
import torchvision.models as models
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
from tqdm import tqdm
import timm
from s2sphere import LatLng, CellId
import heapq

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Import data
DATA_ROOT = './osv5m/'
TRAIN_CSV = os.path.join(DATA_ROOT, 'train_mini.csv')
TEST_CSV = os.path.join(DATA_ROOT, 'test_mini.csv')
TRAIN_IMG_DIR = os.path.join(DATA_ROOT, 'train_images')
TEST_IMG_DIR = os.path.join(DATA_ROOT, 'test_images')

# Set global parameters
EPOCHS = 10
BATCH_SIZE = 64
L = 10
LABEL_COL = 's2_cell'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# New column maps images to an s2 grid cell
def add_s2_cell_column(
    df: pd.DataFrame,
    lat_col: str = 'latitude',
    lon_col: str = 'longitude',
    level: int = 10,
    new_col: str = 's2_cell'
) -> pd.DataFrame:
    """
    Return a copy of `df` with a new column `new_col` containing the S2 cell
    token (hex) at the specified `level` for each (lat_col, lon_col) pair.
    """
    df = df.copy()
    # compute tokens with a list comprehension for speed
    tokens = [
        CellId.from_lat_lng(LatLng.from_degrees(lat, lon))
              .parent(level)
              .to_token()
        for lat, lon in zip(df[lat_col], df[lon_col])
    ]
    df[new_col] = tokens
    return df

# Split dataset into train and test sets
train_df = add_s2_cell_column(pd.read_csv(TRAIN_CSV), level=L, new_col=LABEL_COL)
test_df = add_s2_cell_column(pd.read_csv(TEST_CSV), level=L, new_col=LABEL_COL)

# Build a global mapping from cell → index using the training CSV
cells, classes = pd.factorize(train_df[LABEL_COL])
class_to_idx = {cell: idx for idx, cell in enumerate(classes)}

In [None]:
# Define CNN + ViT hybrid model for geospatial classification
class CNN_ViT_Hybrid(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        # Conv feature extractor (ResNet50)
        self.cnn = timm.create_model("resnet50", pretrained=True, features_only=True)
        cnn_out_channels = self.cnn.feature_info[-1]['num_chs']

        # ViT block (tiny patch-based attention)
        self.vit = timm.create_model("vit_small_patch16_224", pretrained=True)
        self.vit.head = nn.Identity()  # remove classifier

        # Fusion + Classifier
        self.pool = nn.AdaptiveAvgPool2d((14, 14))
        self.proj = nn.Linear(cnn_out_channels, self.vit.embed_dim)

        self.dropout = nn.Dropout(p=0.2) # Dropout Regularization
        self.classifier = nn.Linear(self.vit.embed_dim, num_classes)

    def forward(self, x):
        # Get last feature map from CNN
        x = self.cnn(x)[-1]  # shape (B, C, H, W)

        # Pool to fixed 14 x 14 size
        x = self.pool(x)  # shape (B, C, 14, 14)

        # Flatten and transpose to patch seq format that matches ViT input
        x = x.flatten(2).transpose(1, 2)  # (B, C, H*W) -> (B, H*W, C)

        # Project to ViT embedding dim
        x = self.proj(x)  # shape (B, 196, D)

        # Feed through ViT encoder blocks
        x = self.vit.blocks(x)
        x = x.mean(dim=1)  # Global average pooling

        return self.classifier(x)
    
# Define the geospatial dataset class
class GeoDataset(Dataset):
    def __init__(self, csv_path, images_root, class_to_idx, transforms=None):
        # load annotations
        self.df = pd.read_csv(csv_path)

        # map 'cell' to the consistent label index; drop any rows not seen in training
        self.df['label'] = self.df['cell'].map(class_to_idx)
        self.df = self.df[self.df['label'].notna()].reset_index(drop=True)
        self.df['label'] = self.df['label'].astype(int)

        # share the same classes list
        self.classes = classes

        # build a map from image‐ID → full path
        all_files = glob.glob(os.path.join(images_root, '*', '*.jpg'))
        self.id2path = {
            os.path.splitext(os.path.basename(p))[0]: p
            for p in all_files
        }
        self.transforms = transforms

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row    = self.df.iloc[idx]
        img_id = str(row['id'])
        label  = int(row['label'])
        img    = Image.open(self.id2path[img_id]).convert('RGB')
        if self.transforms:
            img = self.transforms(img)
        return img, label

In [None]:
# Transformations for training data augmentation (better generalization)
train_transforms = T.Compose([
    T.Resize((224, 224)),
    T.RandomHorizontalFlip(),
    T.RandomApply([
        T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02)
    ], p=0.5)  # apply 50% of the time variations in color (simulates lighting changes)
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Transformations for test data (no augmentation)
test_transforms = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Instantiate the dataset and dataloaders
train_ds = GeoDataset(TRAIN_CSV, TRAIN_IMG_DIR, class_to_idx, train_transforms)
test_ds  = GeoDataset(TEST_CSV,  TEST_IMG_DIR,  class_to_idx, test_transforms)

num_val = int(0.1 * len(train_ds)) # 90% training set, 10% testing set
num_train = len(train_ds) - num_val
train_subset, val_subset = random_split(train_ds, [num_train, num_val], generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_subset,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_ds,      batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# Model
model = CNN_ViT_Hybrid(num_classes=len(train_ds.classes)).to(device)
criterion = nn.CrossEntropyLoss()

# Freeze ResNet
for param in model.cnn.parameters():
    param.requires_grad = False

# Freeze ViT
for param in model.vit.parameters():
    param.requires_grad = False

# Create optimizer only for trainable layers (ViT projection + classifier)
optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()), 
    lr=1e-4, 
    weight_decay=1e-4
)

In [None]:
# Get dimensions of training batch
for images, labels in train_loader:
    print(images.shape, labels.shape)
    break

In [None]:
# Training and evaluation functions
def train_one_epoch(model, loader):
    model.train()
    running_loss, correct, total = 0, 0, 0
    for images, labels in tqdm(loader):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        correct += outputs.argmax(1).eq(labels).sum().item()
        total += labels.size(0)
    return running_loss / total, correct / total

def evaluate(model, loader):
    model.eval()
    running_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * images.size(0)
            correct += outputs.argmax(1).eq(labels).sum().item()
            total += labels.size(0)
    return running_loss / total, correct / total

In [None]:
best_val_acc = 0.0
best_val_loss = float('inf')
# --- Training Loop ---
for epoch in range(4, EPOCHS):
    print(f"\n🌍 Epoch {epoch+1}/{EPOCHS}")

    # Unfreeze layers after 3 epochs
    if epoch == 3:
        print("🔓 Unfreezing layers...")
        for param in model.cnn.parameters():
            param.requires_grad = True
        for param in model.vit.parameters():
            param.requires_grad = True
        optimizer = optim.Adam(
            filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5
        )

    train_loss, train_acc = train_one_epoch(model, train_loader)
    val_loss, val_acc     = evaluate(model, val_loader)

    print(f"Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
    print(f"Val   Loss: {val_loss:.4f}, Acc: {val_acc:.4f}")

    if val_loss < best_val_loss or val_acc > best_val_acc:
        best_val_loss = val_loss
        best_val_acc = val_acc
        torch.save(model.state_dict(), f"hybrid_best_model_epoch{epoch+1}.pth")
        print("✅ Saved best model.")

In [None]:
# --- Final Test ---
test_loss, test_acc = evaluate(model, test_loader)
print(f"\n✅ Final Test Loss: {test_loss:.4f}, Accuracy: {test_acc:.4f}")