**Put and run this notebook in the directory which contains TorchSpatial, because TorchSpatial will be used as a package. Relative imports are used within the package**

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam

from sklearn.model_selection import train_test_split

In [164]:
import TorchSpatial.modules.trainer as trainer
from TorchSpatial.modules.encoder_selector import get_loc_encoder
import TorchSpatial.modules.model as premade_models

In [165]:
import importlib
print(trainer.__file__)


/Users/bolongtang/Downloads/TorchSpatial/modules/trainer.py


In [166]:
# For easy reloading, bypassing cache in case trainer.py gets edited
importlib.reload(trainer) 

importlib.reload(premade_models)

<module 'TorchSpatial.modules.model' from '/Users/bolongtang/Downloads/TorchSpatial/modules/model.py'>

In [167]:
device = "cpu"
num_classes = 500 # birdsnap class count
img_dim = loc_dim = embed_dim = 784 # Assumed, can change.

In [168]:
coord_dim = 2
# - fake dataset: each row = (img_emb[784], latlon[2], class_index)
N = 2048
img = torch.randn(N, img_dim) * 10                         # [N,784]

lat = torch.rand(N)*180 - 90 # - 90 to 90
lon = torch.rand(N)*360 # 0 to 360
loc = torch.stack([lat, lon], dim=1) 

y = torch.randint(0, num_classes, (N,), dtype=torch.long)

In [169]:
# Structured relationship: class depends on latitude and longitude bands
num_lat_bands = 10
num_lon_bands = 10
lat_band = ((lat + 90) // (180 / num_lat_bands)).long()
lon_band = (lon // (360 / num_lon_bands)).long()
y = (lat_band * num_lon_bands + lon_band) % num_classes

In [170]:
print(img.shape, loc.shape, y.shape)

torch.Size([2048, 784]) torch.Size([2048, 2]) torch.Size([2048])


In [171]:
Ximg_tr, Ximg_te, Xloc_tr, Xloc_te, y_tr, y_te = train_test_split(
img, loc, y, test_size=0.2, random_state=42, shuffle=True
)


In [172]:
train_data = list(zip(Ximg_tr, Xloc_tr, y_tr))
test_data  = list(zip(Ximg_te, Xloc_te, y_te))

In [173]:
# - Dataloader (loads image embeddings)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
train_loader2 = DataLoader(train_data, batch_size=32, shuffle=True) # For demonstration only; used below to not iterate through the actual train_loader, which is an iterator and each row only be accessed once per refill
test_loader  = DataLoader(test_data, batch_size=32, shuffle=False)

In [174]:
for img_b, loc_b, y_b in train_loader2:
    print(img_b.shape, type(img_b), img_b) # Should be random floats
    print(loc_b.shape, type(loc_b), loc_b) # lat [-90, 90] and lon [0, 360]
    print(y_b.shape, type(y_b), y_b) # random ints [0, 500]
    break

torch.Size([32, 784]) <class 'torch.Tensor'> tensor([[ -1.3605, -16.6912,  -5.3204,  ...,  -5.2019,  -7.5185,  -5.2076],
        [  9.1958,  10.0931,   0.6847,  ...,  -8.6666,  14.4780,  14.1578],
        [ -3.4650,  -7.0869,   1.6418,  ...,  -1.9308, -10.9980, -13.4757],
        ...,
        [  7.8177,   2.1599,   8.6928,  ..., -15.9909,   2.0767, -17.8155],
        [ -5.1476,   6.0597,  -9.5395,  ..., -28.3844,  -7.6159,  -1.2148],
        [  3.1484,  -8.5589,  16.1132,  ..., -18.9965,   4.7044,  -3.6413]])
torch.Size([32, 2]) <class 'torch.Tensor'> tensor([[ 26.2840, 101.5449],
        [-58.9376, 274.1528],
        [-12.5994,  46.0138],
        [-50.0492,   7.4704],
        [ -4.6779, 163.9617],
        [  4.8002, 207.1658],
        [-37.9944,   0.7866],
        [-49.5741,  67.7224],
        [-57.6029, 200.6449],
        [-63.0124, 128.2206],
        [-86.2934, 245.2606],
        [-22.7024,  47.9350],
        [-63.4830, 304.7218],
        [-32.0722, 263.2834],
        [  6.6238, 263

In [None]:
# - location encoder
# Allowed: Space2Vec-grid, Space2Vec-theory, xyz, NeRF, Sphere2Vec-sphereC, Sphere2Vec-sphereC+, Sphere2Vec-sphereM, Sphere2Vec-sphereM+, Sphere2Vec-dfs, rbf, rff, wrap, wrap_ffn, tile_ffn
# overrides is a dictionary that allows overriding specific params. 
# ex. loc_encoder = get_loc_encoder(name = "Space2Vec-grid", overrides = {"max_radius":7800, "min_radius":15, "spa_embed_dim":784, "device": device})
loc_encoder = get_loc_encoder(name = "Space2Vec-grid", overrides = {"coord_dim": coord_dim, "spa_embed_dim": loc_dim, "device": device}) # "device": device is needed if you defined device = 'cpu' above and don't have cuda setup to prevent "AssertionError: Torch not compiled with CUDA enabled", because the default is device="cuda"

In [176]:
# - model
decoder = premade_models.ThreeLayerMLP(input_dim = embed_dim, hidden_dim = 1024, category_count = num_classes).to(device)

In [177]:
# - Criterion
criterion = nn.CrossEntropyLoss()
# - Optimizer
optimizer = Adam(params = list(loc_encoder.ffn.parameters()) + list(decoder.parameters()), lr = 1e-3)


In [178]:
# - train() 
trainer.train(epochs = 2, 
        batch_count_print_avg_loss = 15,
        loc_encoder = loc_encoder,
        dataloader = train_loader,
        decoder = decoder,
        criterion = criterion,
        optimizer = optimizer,
        device = device)

[epoch 1, batch    15] loss: 5.731
[epoch 1, batch    30] loss: 4.951
[epoch 1, batch    45] loss: 4.811
[epoch 2, batch    15] loss: 4.612
[epoch 2, batch    30] loss: 4.594
[epoch 2, batch    45] loss: 4.582
Training Completed.


In [179]:
# - test
decoder.eval()
total = correct = 0
with torch.no_grad():
    for img_b, loc_b, y_b in test_loader:
        img_b, loc_b, y_b = img_b.to(device), loc_b.to(device), y_b.to(device)
        img_embedding = img_b
        loc_embedding = trainer.forward_with_np_array(batch_data = loc_b, model = model)

        loc_img_interaction_embedding = torch.mul(loc_embedding, img_embedding)
        outputs = torch.nn.Softmax(dim=1)(loc_img_interaction_embedding)
        pred = outputs.argmax(dim=1)
        total += y_b.size(0)
        correct += (pred == y_b).sum().item()

print(f"Accuracy of the network on the {total} test images: {100 * correct / total:.2f}%")

Accuracy of the network on the 410 test images: 1.46%
