**Put and run this notebook in the directory which contains TorchSpatial, because TorchSpatial will be used as a package. Relative imports are used within the package**

**Model Training and Testing**

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam

from sklearn.model_selection import train_test_split

In [3]:
import TorchSpatial.modules.trainer as trainer
from TorchSpatial.modules.encoder_selector import get_loc_encoder
import TorchSpatial.modules.model as premade_models

In [4]:
import importlib
print(trainer.__file__)


/Users/bolongtang/Downloads/TorchSpatial/modules/trainer.py


In [5]:
# For easy reloading, bypassing cache in case trainer.py gets edited
importlib.reload(trainer) 

importlib.reload(premade_models)

<module 'TorchSpatial.modules.model' from '/Users/bolongtang/Downloads/TorchSpatial/modules/model.py'>

In [6]:
device = "cpu"
num_classes = 500 # birdsnap class count
img_dim = loc_dim = embed_dim = 784 # Assumed, can change.
coord_dim = 2

**Random Data With Patterns**

In [7]:

# - fake dataset: each row = (img_emb[784], latlon[2], class_index)
N = 2048
img = torch.randn(N, img_dim) * 10                         # [N,784]

lat = torch.rand(N)*180 - 90 # - 90 to 90
lon = torch.rand(N)*360 # 0 to 360
loc = torch.stack([lat, lon], dim=1) 

y = torch.randint(0, num_classes, (N,), dtype=torch.long)

In [8]:
# Structured relationship: class depends on latitude and longitude bands
num_lat_bands = 10
num_lon_bands = 10
lat_band = ((lat + 90) // (180 / num_lat_bands)).long()
lon_band = (lon // (360 / num_lon_bands)).long()
y = (lat_band * num_lon_bands + lon_band) % num_classes

In [9]:
print(img.shape, loc.shape, y.shape)

torch.Size([2048, 784]) torch.Size([2048, 2]) torch.Size([2048])


In [10]:
Ximg_tr, Ximg_te, Xloc_tr, Xloc_te, y_tr, y_te = train_test_split(
img, loc, y, test_size=0.2, random_state=42, shuffle=True
)


In [11]:
train_data = list(zip(Ximg_tr, Xloc_tr, y_tr))
test_data  = list(zip(Ximg_te, Xloc_te, y_te))

In [12]:
# - Dataloader (loads image embeddings)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
train_loader2 = DataLoader(train_data, batch_size=32, shuffle=True) # For demonstration only; used below to not iterate through the actual train_loader, which is an iterator and each row only be accessed once per refill
test_loader  = DataLoader(test_data, batch_size=32, shuffle=False)

In [13]:
for img_b, loc_b, y_b in train_loader2:
    print(img_b.shape, type(img_b), img_b) # Should be random floats
    print(loc_b.shape, type(loc_b), loc_b) # lat [-90, 90] and lon [0, 360]
    print(y_b.shape, type(y_b), y_b) # random ints [0, 500]
    break

torch.Size([32, 784]) <class 'torch.Tensor'> tensor([[ -9.9439,  18.0293,  10.1675,  ...,   0.2642, -10.9086,  -1.5463],
        [ 10.6064,  -0.1239,   4.8373,  ...,  -3.8627,   1.7590,   8.2846],
        [ -6.2455,   8.5587,   0.7948,  ...,  -7.0476,  -1.3425,  -7.0909],
        ...,
        [-14.2805,  11.9501,   7.2585,  ..., -11.6572,   5.7310,  -2.8456],
        [ -2.3232,  -3.4536,  15.2739,  ...,  -9.6967,  -6.8765,  -1.6661],
        [ 23.0095,  -6.4897,  10.7217,  ...,  10.3794, -13.2433,  -1.1273]])
torch.Size([32, 2]) <class 'torch.Tensor'> tensor([[-48.8777, 261.4194],
        [ 54.4361, 348.2578],
        [ 82.2791, 282.2256],
        [-84.3648, 115.8486],
        [-34.4906, 222.1841],
        [  6.9828, 116.2994],
        [ -7.0862, 169.4892],
        [ 43.2347,  39.8236],
        [-10.1921,  32.9706],
        [ 14.0709, 115.6781],
        [ 30.7118, 156.0988],
        [-21.1795, 130.9596],
        [ 41.5294,   0.5229],
        [-61.1150, 219.3945],
        [ 59.7967, 210

In [14]:
# - location encoder
# Allowed: Space2Vec-grid, Space2Vec-theory, xyz, NeRF, Sphere2Vec-sphereC, Sphere2Vec-sphereC+, Sphere2Vec-sphereM, Sphere2Vec-sphereM+, Sphere2Vec-dfs, rbf, rff, wrap, wrap_ffn, tile_ffn
# overrides is a dictionary that allows overriding specific params. 
# ex. loc_encoder = get_loc_encoder(name = "Space2Vec-grid", overrides = {"max_radius":7800, "min_radius":15, "spa_embed_dim":784, "device": device})
loc_encoder = get_loc_encoder(name = "Space2Vec-grid", overrides = {"coord_dim": coord_dim, "spa_embed_dim": loc_dim, "device": device}) # "device": device is needed if you defined device = 'cpu' above and don't have cuda setup to prevent "AssertionError: Torch not compiled with CUDA enabled", because the default is device="cuda"

In [15]:
# - model
decoder = premade_models.ThreeLayerMLP(input_dim = embed_dim, hidden_dim = 1024, category_count = num_classes).to(device)

In [16]:
# - Criterion
criterion = nn.CrossEntropyLoss()
# - Optimizer
optimizer = Adam(params = list(loc_encoder.ffn.parameters()) + list(decoder.parameters()), lr = 1e-3)


In [17]:
# - train() 
epochs = 50
trainer.train(epochs = epochs, 
        batch_count_print_avg_loss = 30,
        loc_encoder = loc_encoder,
        dataloader = train_loader,
        decoder = decoder,
        criterion = criterion,
        optimizer = optimizer,
        device = device)

[epoch 1, batch    30] loss: 5.290
[epoch 2, batch    30] loss: 4.560
[epoch 3, batch    30] loss: 4.474
[epoch 4, batch    30] loss: 4.276
[epoch 5, batch    30] loss: 4.035
[epoch 6, batch    30] loss: 3.652
[epoch 7, batch    30] loss: 3.340
[epoch 8, batch    30] loss: 2.854
[epoch 9, batch    30] loss: 2.567
[epoch 10, batch    30] loss: 2.375
[epoch 11, batch    30] loss: 2.112
[epoch 12, batch    30] loss: 1.883
[epoch 13, batch    30] loss: 1.639
[epoch 14, batch    30] loss: 1.552
[epoch 15, batch    30] loss: 1.433
[epoch 16, batch    30] loss: 1.283
[epoch 17, batch    30] loss: 1.224
[epoch 18, batch    30] loss: 1.189
[epoch 19, batch    30] loss: 1.020
[epoch 20, batch    30] loss: 1.141
[epoch 21, batch    30] loss: 0.920
[epoch 22, batch    30] loss: 1.014
[epoch 23, batch    30] loss: 0.930
[epoch 24, batch    30] loss: 0.959
[epoch 25, batch    30] loss: 0.978
[epoch 26, batch    30] loss: 0.815
[epoch 27, batch    30] loss: 0.869
[epoch 28, batch    30] loss: 0.896
[

In [18]:
# - test
loc_encoder.eval()
decoder.eval()

total = 0
correct_top1 = 0
correct_top3 = 0
mrr_sum = 0

with torch.no_grad():
    for img_b, loc_b, y_b in test_loader:
        img_b, loc_b, y_b = img_b.to(device), loc_b.to(device), y_b.to(device)

        img_embedding = img_b
        loc_embedding = trainer.forward_with_np_array(batch_data=loc_b, model=loc_encoder)

        loc_img_interaction_embedding = torch.mul(loc_embedding, img_embedding)
        logits = decoder(loc_img_interaction_embedding)

        # Top-1
        pred = logits.argmax(dim=1)

        # Top-3 accuracy
        top3_idx = logits.topk(3, dim=1).indices                    # [B, 3]
        correct_top3 += (top3_idx == y_b.unsqueeze(1)).any(dim=1).sum().item()

        # MRR (full ranking over all classes)
        ranking = logits.argsort(dim=1, descending=True)             # [B, C]
        positions = ranking.argsort(dim=1)                           # [B, C] where positions[b, c] = rank index (0-based)
        true_pos0 = positions.gather(1, y_b.view(-1, 1)).squeeze(1)  # [B]
        mrr_sum += (1.0 / (true_pos0.float() + 1.0)).sum().item()

        total += y_b.size(0)
        correct_top1 += (pred == y_b).sum().item()

top1_acc = 100.0 * correct_top1 / total if total else 0.0
top3_acc = 100.0 * correct_top3 / total if total else 0.0
mrr = mrr_sum / total if total else 0.0

print(f"Top-1 Accuracy on {total} test images: {top1_acc:.2f}%")
print(f"Top-3 Accuracy on {total} test images: {top3_acc:.2f}%")
print(f"MRR on {total} test images: {mrr:.4f}")

Top-1 Accuracy on 410 test images: 6.10%
Top-3 Accuracy on 410 test images: 15.37%
MRR on 410 test images: 0.1521


**Model saving**

In [19]:
from pathlib import Path
def save_model(loc_encoder, decoder, optimizer, epoch, path):
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)

    torch.save({
        "epoch": epoch,
        "loc_encoder": loc_encoder.state_dict(),
        "decoder": decoder.state_dict(),
        "optimizer": optimizer.state_dict(),
    }, path)

In [20]:
save_model(
    loc_encoder=loc_encoder,
    decoder=decoder,
    optimizer=optimizer,
    epoch=epochs,
    path="TorchSpatial/checkpoints/final.pt"
)

**Use Saved Models**

In [21]:
loc_encoder = get_loc_encoder(name = "Space2Vec-grid", overrides = {"coord_dim": coord_dim, "spa_embed_dim": loc_dim, "device": device}) # "device": device is needed if you defined device = 'cpu' above and don't have cuda setup to prevent "AssertionError: Torch not compiled with CUDA enabled", because the default is device="cuda"
decoder = premade_models.ThreeLayerMLP(input_dim = embed_dim, hidden_dim = 1024, category_count = num_classes).to(device)
optimizer = Adam(params = list(loc_encoder.ffn.parameters()) + list(decoder.parameters()), lr = 1e-3)

In [22]:
ckpt = torch.load("TorchSpatial/checkpoints/final.pt", map_location=device)

In [23]:
loc_encoder.load_state_dict(ckpt["loc_encoder"])
decoder.load_state_dict(ckpt["decoder"])

<All keys matched successfully>

In [24]:
optimizer.load_state_dict(ckpt["optimizer"])
start_epoch = ckpt["epoch"] + 1

In [25]:
# - test
loc_encoder.eval()
decoder.eval()

total = 0
correct_top1 = 0
correct_top3 = 0
mrr_sum = 0

with torch.no_grad():
    for img_b, loc_b, y_b in test_loader:
        img_b, loc_b, y_b = img_b.to(device), loc_b.to(device), y_b.to(device)

        img_embedding = img_b
        loc_embedding = trainer.forward_with_np_array(batch_data=loc_b, model=loc_encoder)

        loc_img_interaction_embedding = torch.mul(loc_embedding, img_embedding)
        logits = decoder(loc_img_interaction_embedding)

        # Top-1
        pred = logits.argmax(dim=1)

        # Top-3 accuracy
        top3_idx = logits.topk(3, dim=1).indices                    # [B, 3]
        correct_top3 += (top3_idx == y_b.unsqueeze(1)).any(dim=1).sum().item()

        # MRR (full ranking over all classes)
        ranking = logits.argsort(dim=1, descending=True)             # [B, C]
        positions = ranking.argsort(dim=1)                           # [B, C] where positions[b, c] = rank index (0-based)
        true_pos0 = positions.gather(1, y_b.view(-1, 1)).squeeze(1)  # [B]
        mrr_sum += (1.0 / (true_pos0.float() + 1.0)).sum().item()

        total += y_b.size(0)
        correct_top1 += (pred == y_b).sum().item()

top1_acc = 100.0 * correct_top1 / total if total else 0.0
top3_acc = 100.0 * correct_top3 / total if total else 0.0
mrr = mrr_sum / total if total else 0.0

print(f"Top-1 Accuracy on {total} test images: {top1_acc:.2f}%")
print(f"Top-3 Accuracy on {total} test images: {top3_acc:.2f}%")
print(f"MRR on {total} test images: {mrr:.4f}")

# Results below match the final model pre-saving

Top-1 Accuracy on 410 test images: 6.10%
Top-3 Accuracy on 410 test images: 15.37%
MRR on 410 test images: 0.1521
