# 🌍 GeoViT: A Convolutional-Transformer Model for Geolocation Estimation

Welcome to the GeoViT project notebook!

This notebook presents the training, evaluation, and experimentation pipeline for **GeoViT**, a neural network model designed to **predict geographic locations from Google Street View images**. The model takes inspiration from the popular game *Geoguessr* and is trained using the [OpenStreetView-5M dataset](https://huggingface.co/datasets/osv5m/osv5m).

🖊️ Authors: Alan Tran and Caleb Wolf

---

## 📌 Project Goals

1. **Train** a hybrid convolutional-transformer model that can learn geospatial patterns from street-level imagery.
2. **Evaluate** the model using geodesic distance-based metrics.
3. **Experiment** with:
   - Vision Transformer ablations (layers & attention heads)
   - Robustness to reduced image context (square vs 3:2 aspect ratio)

---

## 🧠 Model Overview

- **Convolutional Frontend:** Captures local texture and object-level features.
- **Vision Transformer (ViT):** Captures global spatial dependencies.
- **Output:** Regressed GPS coordinates (Latitude, Longitude)

---

## 🧪 Experiments

### ✅ Experiment 1: ViT Ablation
- Reduce number of transformer layers and attention heads
- Assess contribution of transformer structure to geolocation performance

### ✅ Experiment 2: Robustness to Cropped Context
- Evaluate model on square images (less context)
- Compare against standard aspect ratio input

---

In [None]:
# Import libraries
import torch
from torch import nn
from torchvision import datasets, transforms
import numpy as np

%matplotlib inline
import matplotlib.pyplot as plt

seed = 42

np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [None]:
# Import dataset
from huggingface_hub import snapshot_download
snapshot_download(repo_id="osv5m/osv5m", local_dir="datasets/osv5m", repo_type='dataset')
import os
import zipfile
for root, dirs, files in os.walk("datasets/osv5m"):
    for file in files:
        if file.endswith(".zip"):
            with zipfile.ZipFile(os.path.join(root, file), 'r') as zip_ref:
                zip_ref.extractall(root)
                os.remove(os.path.join(root, file))
from datasets import load_dataset
dataset = load_dataset('osv5m/osv5m', full=False)
