### Pips and Imports

In [None]:
# Pip Installs, run this cell only once per computer
!pip install tqdm
!pip install transformers timm
!pip install --upgrade kagglehub
!pip install -q kaggle

In [None]:
#imports run this cell every time you open this code
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from tqdm import tqdm
from torch.utils.data import Dataset
from PIL import Image
import kagglehub
from tiny_vit import tiny_vit_21m_224
import os
import math
from torch.utils.data import random_split
from torchvision.datasets import ImageFolder
from collections import Counter
from torch.utils.data import WeightedRandomSampler
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [None]:
# Download the dataset from KaggleHub. If you already have it downloaded from when you ran this program earlier, it won't redownload. It will simply save the current directory (in other words run this every time)
path = kagglehub.dataset_download("ubitquitin/geolocation-geoguessr-images-50k")

print("Path to dataset files:", path)

In [None]:
# These are the important file paths
DATA_PATH = path + '/compressed_dataset' #Path for your dataset. By default it is set to the kagglehub dataset path, but if you want to train on your own dataset you can import it here.
MODEL_PATH = '/content/tiny_vit_model_geoguessr_(0.0048 - 92.88%).pth' #where you want to save and load your own weights
PRETRAINED_PATH = "/content/tiny_vit_model_geoguessr_(0.0048 - 92.88%).pth" #These are the weights for the pretrained model. Run off this if its your first time training the AI.

In [None]:
# === Device setup ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") #The 'device' variable defines wether we want to run this on CUDA (more powerful) or CPU (more common)

# === Transform === (Resizes input images to correct format)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  #Aspect ratio for tiny_vit
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# === Dataset ===
dataset = ImageFolder(DATA_PATH, transform=transform) #ImageFolder is a dataset class made by pytorch. It takes image folders and makes bundles them into tensors that work with the model and parallel processing
class_names = dataset.classes #the names of every country
num_classes = len(class_names) #since the AI can't output words, we have it output numbers instead. Each number corresponds to a country.

# === Train/test split === (This section splits our dataset into two subsets: training and testing. Since we don't want to test the model's accuracy on pictures it has seen before, we put aside 20% of the dataset to be used for testing)
total_size = len(dataset)
train_size = int(0.8 * total_size)
test_size = total_size - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size]) #this creates the training and testing datasets

# === Compute class weights for training data only === (This is to normalize the fact that some countries have more images than others. This mean the program can get stuck in a local minimum by simply selecting something like 'United States', because the dataset might be 20% United States. This makes it so that every country is weighed the same, and offsets for repeatedly appearing countries)
train_indices = train_dataset.indices  # indices from original dataset
train_targets = [dataset.targets[i] for i in train_indices]
class_counts = Counter(train_targets)

class_weights = [0.0] * num_classes
for i in range(num_classes):
      class_weights[i] = total_size / (num_classes * class_counts[i]) if class_counts[i] > 0 else 0.0

sample_weights = [class_weights[label] for label in train_targets]
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)

# === DataLoaders === (This turns our datasets into dataloaders. It makes it so that datasets are loaded in batches (for parallel processing), and images are chosen in random order.)
BATCH_SIZE = 64
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler) #for training
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE) #for testing

# === Model === (This is where we define our model.)
class GeoGuessrAI(nn.Module):
    def __init__(self, num_classes=124):
        super(GeoGuessrAI, self).__init__()
        self.base_model = tiny_vit_21m_224(pretrained=True)
        self.base_model.head = nn.Linear(self.base_model.head.in_features, num_classes) #we need to make sure the model is trained on num_classes (number of countries we are testing for, in this case 124). A model trained on a 124 country dataset will not work on a 125 country dataset

    def forward(self, x): #The forward function returns an output for a given image
        return self.base_model(x)

model = GeoGuessrAI(num_classes=num_classes).to(device) #We assign the model to a variable and set the number of classes to 

# === Load Pretrained Weights === (This loads weights from the pretrained model. We can load our own weights later)
state_dict = torch.load(PRETRAINED_PATH)


# === Loss & Optimizer === (These are very important functions to our neural network. We took many iterations to find the best values for these)
criterion = nn.CrossEntropyLoss() # This is the loss function. It is responsible for back propogation, and eveluates the model on how sure it was on its guess. The model will output "I'm 20% sure this is Tanzania" and evaluate it on that
optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0.05) #This is an imported optimizer function. It finds the best optimizations for the model, as well as the learning rate
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True) #it is always best to have the optimizers learning rate be dynaMIC. This scheduler tweaks the learning rate based on what is necessary at any given time


In [None]:
#Run this cell to load your own weights (do this if you have already started training the model and want to pick up from where you left off
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()

In [None]:
EPOCHS = 1000 #How many times we want to iterate over the dataset
best_loss = float('inf') #variable to keep track of which weights work best

for epoch in range(EPOCHS):
    model.train() #puts the model in training mode
    running_loss = 0.0 #tracks loss over a whole epoch

    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}") #this is just code for the loading bar at the bottom
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device) #This makes sure our tensors are being processed in the right device (GPU/CPU)

        # === Standard Training Loop Core ===
        optimizer.zero_grad() #This is necessary to running the optimizer
        outputs = model(images) #runs the model and sees what its current outputs are
        loss = criterion(outputs, labels) # Calculates loss over the batch
        loss.backward() #does backward propogation
        optimizer.step() #necessary for the optimizer

        running_loss += loss.item() #add this batch's loss the the epoch's loss
        loop.set_postfix(loss=loss.item()) #only visual

    avg_loss = running_loss / len(train_loader)
    print(f"\nEpoch {epoch+1} complete. Average loss: {avg_loss:.4f}")

    # === Save these wights only if they are better than the previous one ===
    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"New best model saved with loss {best_loss:.4f}")
    else:
        print(f"Loss did not improve. Keeping previous model with loss {best_loss:.4f}")


    scheduler.step(avg_loss)#Tweak the learning rate of the optimizer with the scheduler

    # === Test Accuracy ===
    model.eval() # Puts the model in evaluation mode
    correct = total = 0 #variables to compare correct guesses with total guesses
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device) #import dataset to device
            outputs = model(images) #run the model on an image
            _, predicted = torch.max(outputs.data, 1) #turns the output tensor into tangible answers/guesses
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total#calculate accuracy
    print(f"Test Accuracy after Epoch {epoch+1}: {accuracy:.2f}%\n")

In [None]:
input_photo_path = 'Enter the image filepath you want'

dataset = ImageFolder(DATA_PATH, transform=transform)
idx_to_class = {v: k for k, v in dataset.class_to_idx.items()} #This turns our output values back into to country names

# === Prediction Function ===
def test_image(image_path, model, transform, device, idx_to_class):
    image = Image.open(image_path) #opens the image in python

    if image.mode == 'RGBA':
        image = image.convert('RGB') #converts to proper format

    image = transform(image).unsqueeze(0).to(device) #formats to proper resolution
    model.eval() #puts the model in evaluation mode

    with torch.no_grad():
        outputs = model(image)
        _, predicted_class = torch.max(outputs, 1)#compares output to predicted value

    return idx_to_class[predicted_class.item()]

# Prediction
predicted_country = test_image(input_photo_path, model, transform, device, idx_to_class)
print(f"Predicted country: {predicted_country}")