**Put and run this notebook in the directory which contains TorchSpatial, because TorchSpatial will be used as a package. Relative imports are used within the package**

In [474]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam

from sklearn.model_selection import train_test_split

import numpy as np

In [471]:
import TorchSpatial.modules.trainer as trainer
from TorchSpatial.modules.encoder_selector import get_loc_encoder
from TorchSpatial.modules.model import ThreeLayerMLP

In [472]:
import importlib
print(trainer.__file__)


/Users/bolongtang/Downloads/TorchSpatial/modules/trainer.py


In [None]:
importlib.reload(trainer) # For easy reloading bypassing cache in case trainer.py gets edited

<module 'TorchSpatial.modules.trainer' from '/Users/bolongtang/Downloads/TorchSpatial/modules/trainer.py'>

In [None]:
device = "cpu"
num_classes = 500 # birdsnap class count
img_dim = loc_dim = embed_dim = 784 # Assumed, can change.

In [475]:
coord_dim = 2
# - fake dataset: each row = (img_emb[784], latlon[2], class_index)
N = 2048
img = torch.randn(N, img_dim)                          # [N,784]

lat = torch.rand(N)*180 - 90 # - 90 to 90
lon = torch.rand(N)*360 # 0 to 360
loc = torch.stack([lat, lon], dim=1) 

y = torch.randint(0, num_classes, (N,), dtype=torch.long)

In [476]:
print(img.shape, loc.shape, y.shape)

torch.Size([2048, 784]) torch.Size([2048, 2]) torch.Size([2048])


In [477]:
Ximg_tr, Ximg_te, Xloc_tr, Xloc_te, y_tr, y_te = train_test_split(
img, loc, y, test_size=0.2, random_state=42, shuffle=True
)


In [478]:
train_data = list(zip(Ximg_tr, Xloc_tr, y_tr))
test_data  = list(zip(Ximg_te, Xloc_te, y_te))

In [None]:
# - Dataloader (loads image embeddings)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
train_loader2 = DataLoader(train_data, batch_size=32, shuffle=True) # For demonstration only; used below to not iterate through the actual train_loader, which is an iterator and each row only be accessed once per refill
test_loader  = DataLoader(test_data, batch_size=32, shuffle=False)

In [None]:
for img_b, loc_b, y_b in train_loader2:
    print(img_b.shape, type(img_b), img_b) # Should be random floats
    print(loc_b.shape, type(loc_b), loc_b) # lat [-90, 90] and lon [0, 360]
    print(y_b.shape, type(y_b), y_b) # random ints [0, 500]
    break

torch.Size([32, 784]) <class 'torch.Tensor'> tensor([[-2.7539,  0.3787,  1.5690,  ..., -2.0325, -3.0357, -0.7114],
        [-1.2814, -1.3806,  0.2291,  ..., -0.6111, -0.2695, -0.5909],
        [-1.4826, -1.0383, -1.0919,  ...,  1.4022,  0.5440,  0.7114],
        ...,
        [ 0.5787,  1.5012, -0.6779,  ...,  0.7032, -1.2775, -0.9090],
        [-0.5163, -0.2988,  0.1979,  ...,  0.9552,  1.4006,  2.7824],
        [ 0.9412,  1.0194, -0.7546,  ...,  0.0653, -0.0789, -0.2332]])
torch.Size([32, 2]) <class 'torch.Tensor'> tensor([[-73.8534,  81.5765],
        [-32.5084, 330.2889],
        [-61.6616, 151.3372],
        [-84.5772, 205.7255],
        [-89.0486, 320.8131],
        [ 16.0544, 340.5121],
        [ 33.8640, 130.5281],
        [ 49.6667, 351.4696],
        [-20.7889,  63.1285],
        [-39.4998, 224.6970],
        [ 68.6975, 234.0212],
        [ 15.4468, 298.8843],
        [ 30.1908,  52.1262],
        [ 82.4278, 256.9445],
        [  3.9641, 332.7629],
        [-57.6065, 103.2340]

In [None]:
# - location encoder
# Allowed: Space2Vec-grid, Space2Vec-theory, xyz, NeRF, Sphere2Vec-sphereC, Sphere2Vec-sphereC+, Sphere2Vec-sphereM, Sphere2Vec-sphereM+, Sphere2Vec-dfs, rbf, rff, wrap, wrap_ffn, tile_ffn
# overrides is a dictionary that allows overriding specific params. 
# ex. loc_encoder = get_loc_encoder(name = "Space2Vec-grid", overrides = {"max_radius":7800, "min_radius":15, "spa_embed_dim":784, "device": device})
loc_encoder = get_loc_encoder(name = "Space2Vec-grid", overrides = {"coord_dim": coord_dim, "spa_embed_dim": loc_dim, "device": device}) # "device": device is needed to prevent AssertionError: Torch not compiled with CUDA enabled

In [505]:
# - model
# decoder = ThreeLayerMLP(input_dim = embed_dim, hidden_dim = 1024, category_count = num_classes)
model = loc_encoder.to(device)

In [503]:
# - Criterion
criterion = nn.CrossEntropyLoss()
# - Optimizer
optimizer = Adam(params = model.parameters(), lr = 1e-3)


In [506]:
# - train() 
trainer.train(epochs = 2, 
        batch_count_print_avg_loss = 15,
        loc_encoder = loc_encoder,
        dataloader = train_loader,
        model = model,
        criterion = criterion,
        optimizer = optimizer,
        device = device)



[epoch 1, batch    15] loss: 6.665
[epoch 1, batch    30] loss: 6.664
[epoch 1, batch    45] loss: 6.664
[epoch 2, batch    15] loss: 6.664
[epoch 2, batch    30] loss: 6.665
[epoch 2, batch    45] loss: 6.665
Training Completed.


In [507]:
# - test
model.eval()
total = correct = 0
with torch.no_grad():
    for img_b, loc_b, y_b in test_loader:
        img_b, loc_b, y_b = img_b.to(device), loc_b.to(device), y_b.to(device)
        img_embedding = img_b
        loc_embedding = trainer.forward_with_np_array(batch_data = loc_b, model = model)

        loc_img_interaction_embedding = torch.mul(loc_embedding, img_embedding)
        outputs = torch.nn.Softmax(dim=1)(loc_img_interaction_embedding)
        pred = outputs.argmax(dim=1)
        total += y_b.size(0)
        correct += (pred == y_b).sum().item()

print(f"Accuracy of the network on the {total} test images: {100 * correct / total:.2f}%")

Accuracy of the network on the 410 test images: 0.24%
