**Put and run this notebook in the directory which contains TorchSpatial, because TorchSpatial will be used as a package. Relative imports are used within the package**

**Model Training and Testing**

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam

from sklearn.model_selection import train_test_split

In [2]:
import TorchSpatial.modules.trainer as trainer
from TorchSpatial.modules.encoder_selector import get_loc_encoder
import TorchSpatial.modules.model as premade_models

In [3]:
import importlib
print(trainer.__file__)


/Users/bolongtang/Downloads/TorchSpatial/modules/trainer.py


In [4]:
# For easy reloading, bypassing cache in case trainer.py gets edited
importlib.reload(trainer) 

importlib.reload(premade_models)

<module 'TorchSpatial.modules.model' from '/Users/bolongtang/Downloads/TorchSpatial/modules/model.py'>

**Import Data**

In [5]:
import TorchSpatial.utils.datasets as data_import
import importlib

In [6]:
importlib.reload(data_import)

<module 'TorchSpatial.utils.datasets' from '/Users/bolongtang/Downloads/TorchSpatial/utils/datasets.py'>

In [7]:
params = {"dataset": "birdsnap", "meta_type": "orig_meta", "regress_dataset": []}
train_remove_invalid = True
eval_remove_invalid = True

In [8]:
train_data = data_import.load_dataset(params = params,
    eval_split = "train",
    train_remove_invalid = train_remove_invalid,
    eval_remove_invalid = eval_remove_invalid,
    load_cnn_predictions=True,
    load_cnn_features=True,
    load_cnn_features_train=True)

Loading birdsnap_with_loc_2019.json - train
   using meta data: orig_meta
	 46386 total entries
	 43426 entries with images
	 19133 entries with meta data
Loading birdsnap_with_loc_2019.json - train
   using meta data: orig_meta
	 46386 total entries
	 43426 entries with images
	 19133 entries with meta data


In [9]:
test_data = data_import.load_dataset(params = params,
    eval_split = "test",
    train_remove_invalid = train_remove_invalid,
    eval_remove_invalid = eval_remove_invalid,
    load_cnn_predictions=True,
    load_cnn_features=True,
    load_cnn_features_train=True)

Loading birdsnap_with_loc_2019.json - train
   using meta data: orig_meta
	 46386 total entries
	 43426 entries with images
	 19133 entries with meta data
Loading birdsnap_with_loc_2019.json - test
   using meta data: orig_meta
	 2443 total entries
	 2262 entries with images
	 816 entries with meta data


In [13]:
# - birdsnap dataset
dataset = "birdsnap"
task = "Classification"
N = 19133
device = "cpu"
num_classes = 500 # birdsnap class count
img_dim = loc_dim = embed_dim = 2048 # birdsnap embedding count
coord_dim = 2 #lonlat

img_tr = train_data["val_feats"] # shape=(19133, 2048)
loc_tr = train_data["val_locs"] # shape=(19133, 2)
y_tr = train_data["val_preds"] # shape=(19133, 500)

img_te = test_data["val_feats"] # shape=(816, 2048)
loc_te = test_data["val_locs"] # shape=(816, 2)
y_te = test_data["val_preds"] # shape=(816, 500)

In [14]:
print(img_tr.shape, loc_tr.shape, y_tr.shape)

(19133, 2048) (19133, 2) (19133, 500)


In [15]:
print(img_te.shape, loc_te.shape, y_te.shape)

(816, 2048) (816, 2) (816, 500)


In [16]:
train_data = list(zip(train_data["val_feats"], train_data["val_locs"], train_data["val_preds"]))
test_data  = list(zip(test_data["val_feats"], test_data["val_locs"], test_data["val_preds"]))

In [17]:
# - Dataloader (loads image embeddings)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
train_loader2 = DataLoader(train_data, batch_size=32, shuffle=True) # For demonstration only; used below to not iterate through the actual train_loader, which is an iterator and each row only be accessed once per refill
test_loader  = DataLoader(test_data, batch_size=32, shuffle=False)

In [18]:
for img_b, loc_b, y_b in train_loader2:
    print(img_b.shape, type(img_b), img_b) # Should be random floats
    print(loc_b.shape, type(loc_b), loc_b) # lat [-90, 90] and lon [0, 360]
    print(y_b.shape, type(y_b), y_b) # random ints [0, 500]
    break

torch.Size([32, 2048]) <class 'torch.Tensor'> tensor([[0.1851, 0.0532, 0.2775,  ..., 0.1338, 1.2624, 0.4498],
        [0.3632, 0.5042, 0.0191,  ..., 0.6604, 0.0580, 0.7077],
        [0.1710, 0.3779, 0.0968,  ..., 0.2376, 0.7921, 1.1272],
        ...,
        [0.0679, 0.0562, 0.7301,  ..., 0.2016, 0.1728, 0.0412],
        [0.2140, 0.3429, 0.0867,  ..., 0.0068, 0.5626, 0.0616],
        [0.6310, 0.1920, 0.0062,  ..., 0.0628, 0.4519, 2.1073]])
torch.Size([32, 2]) <class 'torch.Tensor'> tensor([[-122.4751,   37.8095],
        [  17.6211,   59.8419],
        [ -90.7098,   -0.3993],
        [   9.5993,   52.8801],
        [-121.3746,   38.3951],
        [ -76.3279,   38.0480],
        [ -83.9242,   39.3878],
        [ -83.4506,   41.5058],
        [ -99.1151,   26.5146],
        [-110.8801,   31.7251],
        [ -97.7631,   30.4610],
        [ -92.6603,   47.1678],
        [ -87.8515,   18.2137],
        [-118.3077,   33.7495],
        [ -75.3359,   37.9095],
        [ -74.0052,   45.4860],
 

In [19]:
# - location encoder
# Allowed: Space2Vec-grid, Space2Vec-theory, xyz, NeRF, Sphere2Vec-sphereC, Sphere2Vec-sphereC+, Sphere2Vec-sphereM, Sphere2Vec-sphereM+, Sphere2Vec-dfs, rbf, rff, wrap, wrap_ffn, tile_ffn
# overrides is a dictionary that allows overriding specific params. 
# ex. loc_encoder = get_loc_encoder(name = "Space2Vec-grid", overrides = {"max_radius":7800, "min_radius":15, "spa_embed_dim":2048, "device": device})
loc_encoder = get_loc_encoder(name = "Space2Vec-grid", overrides = {"coord_dim": coord_dim, "spa_embed_dim": loc_dim, "device": device}) # "device": device is needed if you defined device = 'cpu' above and don't have cuda setup to prevent "AssertionError: Torch not compiled with CUDA enabled", because the default is device="cuda"

In [20]:
# - model
decoder = premade_models.ThreeLayerMLP(input_dim = embed_dim, hidden_dim = 1024, category_count = num_classes).to(device)

In [21]:
# - Criterion
criterion = nn.CrossEntropyLoss()
# - Optimizer
optimizer = Adam(params = list(loc_encoder.ffn.parameters()) + list(decoder.parameters()), lr = 1e-3)


In [22]:
# - train() 
epochs = 10
trainer.train(task=task,
        epochs = epochs, 
        batch_count_print_avg_loss = 10,
        loc_encoder = loc_encoder,
        dataloader = train_loader,
        decoder = decoder,
        criterion = criterion,
        optimizer = optimizer,
        device = device)

[epoch 1, batch    10] loss: 6.214
[epoch 1, batch    20] loss: 6.210
[epoch 1, batch    30] loss: 6.195
[epoch 1, batch    40] loss: 6.158
[epoch 1, batch    50] loss: 6.124
[epoch 1, batch    60] loss: 5.997
[epoch 1, batch    70] loss: 5.800
[epoch 1, batch    80] loss: 5.478
[epoch 1, batch    90] loss: 5.078
[epoch 1, batch   100] loss: 4.840
[epoch 1, batch   110] loss: 4.548
[epoch 1, batch   120] loss: 4.007
[epoch 1, batch   130] loss: 3.493
[epoch 1, batch   140] loss: 3.290
[epoch 1, batch   150] loss: 2.820
[epoch 1, batch   160] loss: 2.601
[epoch 1, batch   170] loss: 2.402
[epoch 1, batch   180] loss: 2.285
[epoch 1, batch   190] loss: 2.209
[epoch 1, batch   200] loss: 1.978
[epoch 1, batch   210] loss: 1.930
[epoch 1, batch   220] loss: 1.844
[epoch 1, batch   230] loss: 1.638
[epoch 1, batch   240] loss: 1.657
[epoch 1, batch   250] loss: 1.730
[epoch 1, batch   260] loss: 1.584
[epoch 1, batch   270] loss: 1.519
[epoch 1, batch   280] loss: 1.263
[epoch 1, batch   29

In [24]:
# - test
loc_encoder.eval()
decoder.eval()

total = 0
correct_top1 = 0
correct_top3 = 0
mrr_sum = 0

with torch.no_grad():
    for img_b, loc_b, y_b in test_loader:
        img_b, loc_b, y_b = img_b.to(device), loc_b.to(device), y_b.to(device)

        img_embedding = img_b
        loc_embedding = trainer.forward_with_np_array(batch_data=loc_b, model=loc_encoder)

        loc_img_interaction_embedding = torch.mul(loc_embedding, img_embedding)
        logits = decoder(loc_img_interaction_embedding)

        # Top-1
        pred = logits.argmax(dim=1)
        y_b = y_b.argmax(dim=1)

        # Top-3 accuracy
        top3_idx = logits.topk(3, dim=1).indices                    # [B, 3]
        correct_top3 += (top3_idx == y_b.unsqueeze(1)).any(dim=1).sum().item()

        # MRR (full ranking over all classes)
        ranking = logits.argsort(dim=1, descending=True)             # [B, C]
        positions = ranking.argsort(dim=1)                           # [B, C] where positions[b, c] = rank index (0-based)
        true_pos0 = positions.gather(1, y_b.view(-1, 1)).squeeze(1)  # [B]
        mrr_sum += (1.0 / (true_pos0.float() + 1.0)).sum().item()

        total += y_b.size(0)
        correct_top1 += (pred == y_b).sum().item()

top1_acc = 100.0 * correct_top1 / total if total else 0.0
top3_acc = 100.0 * correct_top3 / total if total else 0.0
mrr = mrr_sum / total if total else 0.0

print(f"Top-1 Accuracy on {total} test images: {top1_acc:.2f}%")
print(f"Top-3 Accuracy on {total} test images: {top3_acc:.2f}%")
print(f"MRR on {total} test images: {mrr:.4f}")

Top-1 Accuracy on 816 test images: 83.09%
Top-3 Accuracy on 816 test images: 96.81%
MRR on 816 test images: 0.8994


**Model saving**

In [25]:
from pathlib import Path
def save_model(loc_encoder, decoder, optimizer, epoch, path):
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)

    torch.save({
        "epoch": epoch,
        "loc_encoder": loc_encoder.state_dict(),
        "decoder": decoder.state_dict(),
        "optimizer": optimizer.state_dict(),
    }, path)

In [26]:
save_model(
    loc_encoder=loc_encoder,
    decoder=decoder,
    optimizer=optimizer,
    epoch=epochs,
    path="TorchSpatial/checkpoints/final.pt"
)

**Use Saved Models**

In [27]:
loc_encoder = get_loc_encoder(name = "Space2Vec-grid", overrides = {"coord_dim": coord_dim, "spa_embed_dim": loc_dim, "device": device}) # "device": device is needed if you defined device = 'cpu' above and don't have cuda setup to prevent "AssertionError: Torch not compiled with CUDA enabled", because the default is device="cuda"
decoder = premade_models.ThreeLayerMLP(input_dim = embed_dim, hidden_dim = 1024, category_count = num_classes).to(device)
optimizer = Adam(params = list(loc_encoder.ffn.parameters()) + list(decoder.parameters()), lr = 1e-3)

In [28]:
ckpt = torch.load("TorchSpatial/checkpoints/final.pt", map_location=device)

In [29]:
loc_encoder.load_state_dict(ckpt["loc_encoder"])
decoder.load_state_dict(ckpt["decoder"])

<All keys matched successfully>

In [30]:
optimizer.load_state_dict(ckpt["optimizer"])
start_epoch = ckpt["epoch"] + 1

In [31]:
# - test
loc_encoder.eval()
decoder.eval()

total = 0
correct_top1 = 0
correct_top3 = 0
mrr_sum = 0

with torch.no_grad():
    for img_b, loc_b, y_b in test_loader:
        img_b, loc_b, y_b = img_b.to(device), loc_b.to(device), y_b.to(device)

        img_embedding = img_b
        loc_embedding = trainer.forward_with_np_array(batch_data=loc_b, model=loc_encoder)

        loc_img_interaction_embedding = torch.mul(loc_embedding, img_embedding)
        logits = decoder(loc_img_interaction_embedding)

        # Top-1
        pred = logits.argmax(dim=1)
        y_b = y_b.argmax(dim=1)

        # Top-3 accuracy
        top3_idx = logits.topk(3, dim=1).indices                    # [B, 3]
        correct_top3 += (top3_idx == y_b.unsqueeze(1)).any(dim=1).sum().item()

        # MRR (full ranking over all classes)
        ranking = logits.argsort(dim=1, descending=True)             # [B, C]
        positions = ranking.argsort(dim=1)                           # [B, C] where positions[b, c] = rank index (0-based)
        true_pos0 = positions.gather(1, y_b.view(-1, 1)).squeeze(1)  # [B]
        mrr_sum += (1.0 / (true_pos0.float() + 1.0)).sum().item()

        total += y_b.size(0)
        correct_top1 += (pred == y_b).sum().item()

top1_acc = 100.0 * correct_top1 / total if total else 0.0
top3_acc = 100.0 * correct_top3 / total if total else 0.0
mrr = mrr_sum / total if total else 0.0

print(f"Top-1 Accuracy on {total} test images: {top1_acc:.2f}%")
print(f"Top-3 Accuracy on {total} test images: {top3_acc:.2f}%")
print(f"MRR on {total} test images: {mrr:.4f}")

# Results below match the final model pre-saving

Top-1 Accuracy on 816 test images: 83.09%
Top-3 Accuracy on 816 test images: 96.81%
MRR on 816 test images: 0.8994
