**Put and run this notebook in the directory which contains TorchSpatial, because TorchSpatial will be used as a package. Relative imports are used within the package**

**Model Training and Testing**

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam

from sklearn.model_selection import train_test_split

In [2]:
import TorchSpatial.modules.trainer as trainer
from TorchSpatial.modules.encoder_selector import get_loc_encoder
import TorchSpatial.modules.model as premade_models

In [3]:
import importlib
print(trainer.__file__)


/Users/bolongtang/Downloads/TorchSpatial/modules/trainer.py


In [4]:
# For easy reloading, bypassing cache in case trainer.py gets edited
importlib.reload(trainer) 

importlib.reload(premade_models)

<module 'TorchSpatial.modules.model' from '/Users/bolongtang/Downloads/TorchSpatial/modules/model.py'>

**Import Data**

In [5]:
import TorchSpatial.utils.datasets as data_import
import importlib

In [6]:
importlib.reload(data_import)

<module 'TorchSpatial.utils.datasets' from '/Users/bolongtang/Downloads/TorchSpatial/utils/datasets.py'>

In [34]:
# - elevation dataset
dataset = "mosaiks_elevation" 
task = "Regression" # "Classification" or "Regression"
N = 19924
device = "cpu"
num_classes = 1 # 1 for regression
img_dim = loc_dim = embed_dim = 2048 # birdsnap embedding count
coord_dim = 2 #lonlat

In [21]:
params = {"dataset": dataset, "regress_dataset": [dataset]}
train_remove_invalid = False
eval_remove_invalid = False

In [22]:
train_data = data_import.load_dataset(params = params,
    eval_split = "train",
    train_remove_invalid = train_remove_invalid,
    eval_remove_invalid = eval_remove_invalid,
    load_cnn_predictions=True,
    load_cnn_features=True,
    load_cnn_features_train=True)

In [24]:
test_data = data_import.load_dataset(params = params,
    eval_split = "test",
    train_remove_invalid = train_remove_invalid,
    eval_remove_invalid = eval_remove_invalid,
    load_cnn_predictions=True,
    load_cnn_features=True,
    load_cnn_features_train=True)

In [23]:
train_data

{'train_locs': array([[  41.77756 ,   39.75818 ],
        [  28.294346,   12.246856],
        [ 145.11217 ,  -41.215046],
        ...,
        [-102.722855,   60.976524],
        [ -10.351114,   21.973127],
        [-123.078514,   53.51189 ]], shape=(19924, 2), dtype=float32),
 'val_locs': array([[ 45.923065,  33.360394],
        [ 78.96314 ,  37.159798],
        [124.10919 ,  40.032883],
        ...,
        [ 43.719475,  19.045477],
        [-62.08039 , -11.030687],
        [ 80.90506 ,  55.102493]], shape=(19924, 2), dtype=float32),
 'train_labels': array([2120.29427785,  541.21146495,  244.59355366, ...,  420.89562191,
         369.87917377,  739.34824927], shape=(19924,)),
 'val_labels': array([ 114.9841604 , 1769.09281109,   72.95645193, ..., 1372.25071236,
         167.19592743,  137.47612588], shape=(19924,)),
 'train_feats': array([[7.09883804e+01, 1.05888466e+02, 1.13701385e+02, ...,
         1.69064121e+01, 1.89716518e-02, 1.50931213e+02],
        [1.07514885e+02, 8.81267471

In [27]:
test_data

{'train_locs': array([[ 142.49539 ,   59.577072],
        [ -53.238483,   48.464687],
        [  32.825478,   62.949966],
        ...,
        [  64.63981 ,   26.615644],
        [-121.90786 ,   53.281975],
        [  28.404524,   47.151432]], shape=(19924, 2), dtype=float32),
 'val_locs': array([[ -9.910397,  29.639257],
        [104.30443 ,  26.591017],
        [ 86.16613 ,  36.983994],
        ...,
        [116.83735 ,  46.03483 ],
        [ 66.92603 ,  50.512848],
        [ 20.499146,  66.660126]], shape=(4981, 2), dtype=float32),
 'train_labels': array([ 169.26715921,  100.24858892,  272.4877168 , ..., 1101.6675496 ,
        1130.8944546 ,   95.21605049], shape=(19924,)),
 'val_labels': array([ 245.05220033, 2063.15844245, 5047.17393594, ...,  995.83144531,
         369.2575738 ,  373.21170245], shape=(4981,)),
 'train_feats': array([[1.90226273e+02, 1.51160278e+02, 2.42391922e+02, ...,
         8.84392319e+01, 3.88942289e+00, 3.69282593e+02],
        [1.42180939e+02, 8.47650833e+

In [28]:
img_tr = train_data["val_feats"] # shape=(19924, 2048)
loc_tr = train_data["val_locs"] # shape=(19924, 2)
y_tr = train_data["val_labels"] # shape=(19924,)

In [29]:
print(img_tr.shape, loc_tr.shape, y_tr.shape)

(19924, 2048) (19924, 2) (19924,)


In [30]:
img_te = test_data["val_feats"] # shape=(4981, 2048)
loc_te = test_data["val_locs"] # shape=(4981, 2)
y_te = test_data["val_labels"] # shape=(4981,)

In [31]:
print(img_te.shape, loc_te.shape, y_te.shape)

(4981, 2048) (4981, 2) (4981,)


In [35]:
if task == "Classification":
    embed_dim = img_dim 
elif task == "Regression": 
    embed_dim = img_dim + loc_dim

In [37]:
train_data = list(zip(img_tr, loc_tr, y_tr))
test_data  = list(zip(img_te, loc_te, y_te))

In [38]:
# - Dataloader (loads image embeddings)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
train_loader2 = DataLoader(train_data, batch_size=32, shuffle=True) # For demonstration only; used below to not iterate through the actual train_loader, which is an iterator and each row only be accessed once per refill
test_loader  = DataLoader(test_data, batch_size=32, shuffle=False)

In [None]:
for img_b, loc_b, y_b in train_loader2:
    print(img_b.shape, type(img_b), img_b) 
    print(loc_b.shape, type(loc_b), loc_b) 
    print(y_b.shape, type(y_b), y_b) 
    break

torch.Size([32, 2048]) <class 'torch.Tensor'> tensor([[8.6213e+01, 1.7304e+02, 1.3278e+02,  ..., 1.4729e+01, 2.6515e-03,
         1.4438e+02],
        [1.3062e+02, 5.3233e+01, 4.8042e+01,  ..., 2.7860e+01, 3.5443e+00,
         7.8044e+01],
        [9.9071e+01, 3.2816e+01, 7.3385e+01,  ..., 2.6624e+01, 1.9787e-02,
         1.1771e+02],
        ...,
        [1.5240e+02, 1.0245e+02, 1.7994e+02,  ..., 7.6576e+01, 1.1382e+00,
         3.0921e+02],
        [1.6880e+02, 2.3754e+02, 1.9398e+02,  ..., 1.7495e+00, 4.0724e-04,
         6.9914e+01],
        [1.0461e+02, 8.0904e+01, 1.0658e+02,  ..., 2.0065e+01, 9.8399e-02,
         1.3300e+02]], dtype=torch.float64)
torch.Size([32, 2]) <class 'torch.Tensor'> tensor([[  47.4380,  -16.8295],
        [ -71.2115,    3.7447],
        [-100.5744,   54.4031],
        [ 117.2918,    5.9539],
        [  90.6559,   56.6701],
        [  25.4848,    6.7068],
        [  -9.1805,   11.1950],
        [ 139.2313,   52.3583],
        [  77.1865,   15.5750],
      

In [40]:
# - location encoder
# Allowed: Space2Vec-grid, Space2Vec-theory, xyz, NeRF, Sphere2Vec-sphereC, Sphere2Vec-sphereC+, Sphere2Vec-sphereM, Sphere2Vec-sphereM+, Sphere2Vec-dfs, rbf, rff, wrap, wrap_ffn, tile_ffn
# overrides is a dictionary that allows overriding specific params. 
# ex. loc_encoder = get_loc_encoder(name = "Space2Vec-grid", overrides = {"max_radius":7800, "min_radius":15, "spa_embed_dim":2048, "device": device})
loc_encoder_name = "Space2Vec-grid"
loc_encoder = get_loc_encoder(name = loc_encoder_name, overrides = {"coord_dim": coord_dim, "spa_embed_dim": loc_dim, "device": device}) # "device": device is needed if you defined device = 'cpu' above and don't have cuda setup to prevent "AssertionError: Torch not compiled with CUDA enabled", because the default is device="cuda"

In [41]:
# - model
decoder = premade_models.ThreeLayerMLP(input_dim = embed_dim, hidden_dim = 1024, category_count = num_classes).to(device)

In [None]:
# - model
# - Criterion
if task == "Classification":
    criterion = nn.CrossEntropyLoss()
elif task == "Regression":
    criterion = nn.MSELoss()
# - Optimizer
optimizer = Adam(params = list(loc_encoder.ffn.parameters()) + list(decoder.parameters()), lr = 1e-3)


In [None]:
# - train() 
epochs = 10
trainer.train(task=task,
        epochs = epochs, 
        batch_count_print_avg_loss = 10,
        loc_encoder = loc_encoder,
        dataloader = train_loader,
        decoder = decoder,
        criterion = criterion,
        optimizer = optimizer,
        device = device)

[epoch 1, batch    10] loss: 6.232
[epoch 1, batch    20] loss: 6.220
[epoch 1, batch    30] loss: 6.206
[epoch 1, batch    40] loss: 6.181
[epoch 1, batch    50] loss: 6.165
[epoch 1, batch    60] loss: 6.120
[epoch 1, batch    70] loss: 6.003
[epoch 1, batch    80] loss: 5.686
[epoch 1, batch    90] loss: 5.132
[epoch 1, batch   100] loss: 4.786
[epoch 1, batch   110] loss: 4.281
[epoch 1, batch   120] loss: 3.864
[epoch 1, batch   130] loss: 3.558
[epoch 1, batch   140] loss: 3.019
[epoch 1, batch   150] loss: 2.591
[epoch 1, batch   160] loss: 2.612
[epoch 1, batch   170] loss: 2.294
[epoch 1, batch   180] loss: 2.232
[epoch 1, batch   190] loss: 1.723
[epoch 1, batch   200] loss: 1.918
[epoch 1, batch   210] loss: 1.940
[epoch 1, batch   220] loss: 1.802
[epoch 1, batch   230] loss: 1.652
[epoch 1, batch   240] loss: 1.749
[epoch 1, batch   250] loss: 1.551
[epoch 1, batch   260] loss: 1.652
[epoch 1, batch   270] loss: 1.425
[epoch 1, batch   280] loss: 1.441
[epoch 1, batch   29

In [None]:
# - test
loc_encoder.eval()
decoder.eval()

total = 0
correct_top1 = 0
correct_top3 = 0
mrr_sum = 0

with torch.no_grad():
    for img_b, loc_b, y_b in test_loader:
        img_b, loc_b, y_b = img_b.to(device), loc_b.to(device), y_b.to(device)

        img_embedding = img_b
        loc_embedding = trainer.forward_with_np_array(batch_data=loc_b, model=loc_encoder)

        if task == "Classification":
            loc_img_interaction_embedding = torch.mul(loc_embedding, img_embedding)
            logits = decoder(loc_img_interaction_embedding)

            # Top-1
            pred = logits.argmax(dim=1)
            y_b = y_b.argmax(dim=1)

            # Top-3 accuracy
            top3_idx = logits.topk(3, dim=1).indices                    # [B, 3]
            correct_top3 += (top3_idx == y_b.unsqueeze(1)).any(dim=1).sum().item()

            # MRR (full ranking over all classes)
            ranking = logits.argsort(dim=1, descending=True)             # [B, C]
            positions = ranking.argsort(dim=1)                           # [B, C] where positions[b, c] = rank index (0-based)
            true_pos0 = positions.gather(1, y_b.view(-1, 1)).squeeze(1)  # [B]
            mrr_sum += (1.0 / (true_pos0.float() + 1.0)).sum().item()

            total += y_b.size(0)
            correct_top1 += (pred == y_b).sum().item()
        
        elif task == "Regression":
            loc_img_concat_embedding = torch.cat((loc_embedding, img_embedding), dim = 1)
            yhat = decoder(loc_img_concat_embedding)

            # r-square
            # Compute per-sample mean over feature dimension
            y_mean = torch.mean(y_b, dim=1, keepdim=True)          # (B, 1)

            ss_res = torch.sum((y_b - yhat) ** 2, dim=1)           # (B,)
            ss_tot = torch.sum((y_b - y_mean) ** 2, dim=1)         # (B,)

            r2 = 1 - ss_res / ss_tot                               # (B,)
            r2 = torch.mean(r2)                                    # scalar

            # MAE
            mae = torch.mean(torch.abs(yhat - y_b))

            # RMSE
            rmse = torch.sqrt(torch.mean((yhat - y_b) ** 2))

            total += y_b.size(0)

if task == "Classification":
    top1_acc = 100.0 * correct_top1 / total if total else 0.0
    top3_acc = 100.0 * correct_top3 / total if total else 0.0
    mrr = mrr_sum / total if total else 0.0

    print(f"Top-1 Accuracy on {total} test images: {top1_acc:.2f}%")
    print(f"Top-3 Accuracy on {total} test images: {top3_acc:.2f}%")
    print(f"MRR on {total} test images: {mrr:.4f}")
elif task == "Regression":
    print(f"r-square on {total} test images: {r2:.2f}%")
    print(f"MAE of pred on {total} test images: {mae:.2f}%")
    print(f"RMSE of pred on {total} test images: {rmse:.2f}%")

Top-1 Accuracy on 3827 test images: 90.04%
Top-3 Accuracy on 3827 test images: 98.77%
MRR on 3827 test images: 0.9430


**Model saving**

In [None]:
from pathlib import Path
def save_model(loc_encoder, decoder, optimizer, epoch, path):
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)

    torch.save({
        "epoch": epoch,
        "loc_encoder": loc_encoder.state_dict(),
        "decoder": decoder.state_dict(),
        "optimizer": optimizer.state_dict(),
    }, path)

In [None]:
save_model(
    loc_encoder=loc_encoder,
    decoder=decoder,
    optimizer=optimizer,
    epoch=epochs,
    path="TorchSpatial/checkpoints/final.pt"
)

**Use Saved Models**

In [None]:
loc_encoder = get_loc_encoder(name = "Space2Vec-grid", overrides = {"coord_dim": coord_dim, "spa_embed_dim": loc_dim, "device": device}) # "device": device is needed if you defined device = 'cpu' above and don't have cuda setup to prevent "AssertionError: Torch not compiled with CUDA enabled", because the default is device="cuda"
decoder = premade_models.ThreeLayerMLP(input_dim = embed_dim, hidden_dim = 1024, category_count = num_classes).to(device)
optimizer = Adam(params = list(loc_encoder.ffn.parameters()) + list(decoder.parameters()), lr = 1e-3)

In [None]:
ckpt = torch.load("TorchSpatial/checkpoints/final.pt", map_location=device)

In [None]:
loc_encoder.load_state_dict(ckpt["loc_encoder"])
decoder.load_state_dict(ckpt["decoder"])

<All keys matched successfully>

In [None]:
optimizer.load_state_dict(ckpt["optimizer"])
start_epoch = ckpt["epoch"] + 1

In [10]:
# - test
loc_encoder.eval()
decoder.eval()

total = 0
correct_top1 = 0
correct_top3 = 0
mrr_sum = 0

with torch.no_grad():
    for img_b, loc_b, y_b in test_loader:
        img_b, loc_b, y_b = img_b.to(device), loc_b.to(device), y_b.to(device)

        img_embedding = img_b
        loc_embedding = trainer.forward_with_np_array(batch_data=loc_b, model=loc_encoder)

        loc_img_interaction_embedding = torch.mul(loc_embedding, img_embedding)
        logits = decoder(loc_img_interaction_embedding)

        # Top-1
        pred = logits.argmax(dim=1)
        y_b = y_b.argmax(dim=1)

        # Top-3 accuracy
        top3_idx = logits.topk(3, dim=1).indices                    # [B, 3]
        correct_top3 += (top3_idx == y_b.unsqueeze(1)).any(dim=1).sum().item()

        # MRR (full ranking over all classes)
        ranking = logits.argsort(dim=1, descending=True)             # [B, C]
        positions = ranking.argsort(dim=1)                           # [B, C] where positions[b, c] = rank index (0-based)
        true_pos0 = positions.gather(1, y_b.view(-1, 1)).squeeze(1)  # [B]
        mrr_sum += (1.0 / (true_pos0.float() + 1.0)).sum().item()

        total += y_b.size(0)
        correct_top1 += (pred == y_b).sum().item()

top1_acc = 100.0 * correct_top1 / total if total else 0.0
top3_acc = 100.0 * correct_top3 / total if total else 0.0
mrr = mrr_sum / total if total else 0.0

print(f"Top-1 Accuracy on {total} test images: {top1_acc:.2f}%")
print(f"Top-3 Accuracy on {total} test images: {top3_acc:.2f}%")
print(f"MRR on {total} test images: {mrr:.4f}")

# Results below match the final model pre-saving

NameError: name 'loc_encoder' is not defined