# ToMnet  

In this notebook we are going to develop the codebase for ToMNet to work on our simulation data.

In [55]:
import numpy as np
import pymc
import networkx as nx
import matplotlib.pyplot as plt
import random
import json

In [56]:
import sys
import os
import osmnx as ox

# Adjust this path as needed to point to your project root
sys.path.append(os.path.abspath(".."))

In [57]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

In [58]:
import multiprocessing as mp
mp.set_start_method('fork', force=True)

In [59]:
from real_world_src.environment.campus_env import CampusEnvironment
from real_world_src.agents.agent_factory import AgentFactory
from real_world_src.agents.agent_species import ShortestPathAgent
from real_world_src.simulation.simulator import Simulator
#from real_world_src.simulation.experiment_1 import Simulator

from real_world_src.utils.run_manager import RunManager
from real_world_src.utils.config import VISUAL_CONFIG
from real_world_src.utils.config import get_agent_color

## Step 1: Loading the Data

In [60]:
# Create a run manager
# run_manager = RunManager('visuals')
# run_dir = run_manager.start_new_run()

# Initialize campus environment
campus = CampusEnvironment()

Loading map data for University of California, San Diego, La Jolla, CA, USA...


Environment loaded with 3136 nodes and 8704 edges


In [61]:
# Need to establish the set of common goals (just choose the landmark nodes)
goals = [469084068, 49150691, 768264666, 1926666015, 1926673385, 49309735,
         273627682, 445989107, 445992528, 446128310, 1772230346, 1926673336, 
         2872424923, 3139419286, 4037576308]

In [62]:
import pickle
# if you used dill, just replace pickle with dill

with open('data/agents.pkl', 'rb') as f:
    agents = pickle.load(f)

In [63]:
with open("./data/path_data.json", 'r') as file:
    path_data = json.load(file)

with open("./data/goal_data.json", 'r') as file:
    goal_data = json.load(file)

In [64]:
def convert_keys_to_int(data):
    if isinstance(data, dict):
        return {int(k) if isinstance(k, str) and k.isdigit() else k: convert_keys_to_int(v) for k, v in data.items()}
    elif isinstance(data, list):
        return [convert_keys_to_int(item) for item in data]
    else:
        return data

In [65]:
goal_data = convert_keys_to_int(goal_data)
path_data = convert_keys_to_int(path_data)

## Step 2: Defining ToMnet

In [66]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [67]:
class CharacterNet(nn.Module):
    """
    Encode K full trajectories per agent into a single 'character' vector c ∈ R^h.
    Input shape: (B, K, T_sup) of node indices (long).
    CharacterNet using precomputed node2vec embeddings.
    """
    def __init__(self,
                 node_embeddings: np.ndarray,
                 h_lstm: int = 64,
                 T_sup: int = 50,
                 K: int = 10):
        super().__init__()
        num_nodes, d_emb = node_embeddings.shape
        # Use precomputed embeddings, freeze them
        self.embedding = nn.Embedding.from_pretrained(
            torch.tensor(node_embeddings), freeze=True, padding_idx=0
        )
        self.lstm = nn.LSTM(d_emb, h_lstm, batch_first=True)
        self.K = K
        self.T_sup = T_sup

    def forward(self, support_trajs):
        # support_trajs: LongTensor[B, K, T_sup]
        B, K, T = support_trajs.size()
        assert K == self.K and T == self.T_sup

        flat = support_trajs.view(B * K, T)
        emb = self.embedding(flat)
        _, (h_n, _) = self.lstm(emb)
        h_n = h_n.squeeze(0)
        chars = h_n.view(B, K, -1).mean(dim=1)
        return chars

In [None]:
class MentalNet(nn.Module):
    """
    Encode the query prefix into a 'mental' vector m ∈ R^h'.
    Inputs:
      - prefix     : LongTensor of shape [B, T_q] (node indices, padded with 0)
      - prefix_len : LongTensor of shape [B]   (true lengths in 1..T_q)
    Outputs:
      - m          : FloatTensor of shape [B, h_lstm]
    """
    def __init__(self,
                 node_embeddings: np.ndarray,
                 h_lstm:int = 64,
                 T_q:int    = 20,
                 dropout:float = 0.1,
                 use_attention: bool = True):
        super().__init__()
        num_nodes, d_emb = node_embeddings.shape
        self.embedding = nn.Embedding.from_pretrained(
            torch.tensor(node_embeddings), freeze=False, padding_idx=0
        )
        # self.embedding = nn.Embedding(num_nodes, d_emb, padding_idx=0)
        self.lstm      = nn.LSTM(d_emb, h_lstm, batch_first=True, bidirectional=True)
        self.layer_norm = nn.LayerNorm(2 * h_lstm)
        
        self.dropout = nn.Dropout(0.1)
        self.T_q       = T_q
        self.use_attention = use_attention
        if use_attention:
            self.attn = nn.Linear(h_lstm * 2, 1)
        
    def forward(self, prefix: torch.LongTensor, prefix_len: torch.LongTensor):
        B, T = prefix.size()
        assert T == self.T_q, f"Expected T_q={self.T_q}, got {T}"

        # embed all time-steps
        emb = self.embedding(prefix)  # [B, T_q, d_emb]

        # pack by actual lengths
        packed = nn.utils.rnn.pack_padded_sequence(
            emb,
            lengths=prefix_len.cpu(),
            batch_first=True,
            enforce_sorted=False
        )
        
        packed_out, (h_n, _) = self.lstm(packed)
        out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True, total_length=self.T_q)
        out = self.layer_norm(self.dropout(out))  # [B, T_q, 2*h_lstm]

        if self.use_attention:
            mask = torch.arange(self.T_q, device=prefix.device)[None, :] < prefix_len[:, None]
            attn_scores = self.attn(out).squeeze(-1)  # [B, T_q]
            attn_scores[~mask] = float('-inf')
            attn_weights = torch.softmax(attn_scores, dim=1).unsqueeze(-1)  # [B, T_q, 1]
            m = (out * attn_weights).sum(dim=1)  # [B, 2*h_lstm]
        else:
            # Use last valid hidden state for each sequence
            idx = (prefix_len - 1).clamp(min=0)
            m = out[torch.arange(B), idx]  # [B, 2*h_lstm]

        return m

        # # run through LSTM
        # _, (h_n, _) = self.lstm(packed)
        # # h_n: [1, B, h_lstm]
        # m = h_n.squeeze(0)            # [B, h_lstm]

        # return m

In [69]:
class ToMNet(nn.Module):
    """
    Full ToMNet: CharacterNet + MentalNet + fusion MLP + prediction heads.
    """
    def __init__(self,
                 node_embeddings: np.ndarray,
                 num_nodes:int,
                 num_goals:int,
                 K:int=10,
                 T_sup:int=50,
                 T_q:int=20,
                 h_char:int=64,
                 h_ment:int=64,
                 z_dim:int=32):
        super().__init__()
        # submodules
        num_nodes, d_emb = node_embeddings.shape
        self.char_net   = CharacterNet(node_embeddings, h_char, T_sup, K)
        self.mental_net = MentalNet(node_embeddings, h_ment, T_q)
        # embedding to get last‐step token embedding
        # self.embedding  = nn.Embedding(num_nodes, d_emb, padding_idx=0)
        self.embedding = nn.Embedding.from_pretrained(
            torch.tensor(node_embeddings), freeze=False, padding_idx=0
        )

        # a small MLP to fuse [h_char + h_ment + d_emb] → z_dim
        fusion_dim = h_char + 2*h_ment + d_emb
        self.fusion = nn.Sequential(
            nn.Linear(fusion_dim, 128),
            nn.ReLU(),
            nn.Linear(128, z_dim),
            nn.ReLU()
        )
        # final prediction heads
        self.goal_head = nn.Linear(z_dim, num_goals)
        self.next_head = nn.Linear(z_dim, num_nodes)

    def forward(self,
                sup: torch.LongTensor,       # [B, K, T_sup]
                prefix: torch.LongTensor,    # [B, T_q]
                prefix_len: torch.LongTensor # [B]
               ):
        B, K, T_sup = sup.shape
        _, T_q     = prefix.shape

        # 1) character‐level features from the K support trajectories
        #    --> sup_feat: [B, h_char]
        sup_feat = self.char_net(sup)

        # 2) mental‐net encoding of the current prefix
        #    --> ment_feat: [B, h_ment]
        ment_feat = self.mental_net(prefix, prefix_len)

        # 3) take the *last non‐padded* token in each prefix, embed it
        #    prefix_len is in [1..T_q], so subtract 1 for zero‐based index
        last_indices = (prefix_len - 1).clamp(min=0)          # [B]
        # gather the node index at that last step
        last_nodes   = prefix[torch.arange(B), last_indices] # [B]
        # embed it
        last_emb     = self.embedding(last_nodes)            # [B, d_emb]

        # 4) fuse all three representations
        #    concat → [B, h_char + h_ment + d_emb]
        fusion_input = torch.cat([sup_feat, ment_feat, last_emb], dim=1)
        z            = self.fusion(fusion_input)             # [B, z_dim]

        # 5) heads
        next_logits = self.next_head(z)  # [B, num_nodes]
        goal_logits = self.goal_head(z)  # [B, num_goals]

        return next_logits, goal_logits

## Step 3: Prepare the Datasets

In [70]:
# build node2idx so that every node in campus.G_undirected maps to 0…V−1
all_nodes = set()
for episode in path_data.values():
    for path in episode.values():
        all_nodes.update(path)
all_nodes.update(campus.G_undirected.nodes())
all_nodes = list(all_nodes)

node2idx = {n: i for i, n in enumerate(all_nodes)}
print(f"Number of nodes in node2idx: {len(node2idx)}")

# all_nodes = list(campus.G_undirected.nodes())
# node2idx  = {n:i for i,n in enumerate(all_nodes)}
V = len(all_nodes)

# build goal2idx likewise for your goals list
goal2idx = {g:i for i,g in enumerate(goals)}
G = len(goals)

Number of nodes in node2idx: 3170


In [71]:
""" #Executed once - Find the node2vec embeddings in the data directory.

from node2vec import Node2Vec

G = campus.G_undirected

# Fitting node2vec model
node2vec = Node2Vec(G, dimensions=64, walk_length=100, num_walks=200, workers=16)
n2v_model = node2vec.fit(window=10, min_count=1, batch_words=8)

# Building a node_id -> embedding matrix
embedding_dim = n2v_model.wv.vector_size
node_embeddings = np.zeros((len(node2idx), embedding_dim), dtype=np.float32)

for node, idx in node2idx.items():
    key = str(node)
    if key in n2v_model.wv:
        node_embeddings[idx] = n2v_model.wv[key]
    else:
        pass

np.save("data/node2vec_embeddings.npy", node_embeddings)
print("Node2Vec embeddings saved to data/node2vec_embeddings.npy") 
"""

node_embeddings = np.load("data/node2vec_embeddings.npy")

In [72]:
train_agent_ids = list(range(0, 70))
test_agent_ids = list(range(70, 100))

In [73]:
# hyper‐params
K     = 10    # number of support trajectories per agent
T_sup = 75    # max length (pad/truncate) of each support trajectory
T_q   = 20    # prefix length for query trajectories

all_episodes    = list(path_data.keys())
examples_train  = []
examples_test   = []

for agent in agents:
    a_id = agent.id

    # choose which list to append into
    if a_id in train_agent_ids:
        target = examples_train
    elif a_id in test_agent_ids:
        target = examples_test
    else:
        # silently skip any id outside 0–99
        continue

    for ep in all_episodes:
        # ——— 1) build the K‐shot “support set” for this (agent, ep) ———
        other_eps   = [e for e in all_episodes if e != ep]
        support_eps = random.sample(other_eps, K)

        sup_tensor = torch.zeros(K, T_sup, dtype=torch.long)
        for k, se in enumerate(support_eps):
            raw_sup  = path_data[se][a_id]           # e.g. [n0, n1, n2, …]
            idxs_sup = [node2idx[n] for n in raw_sup]
            L        = min(len(idxs_sup), T_sup)
            sup_tensor[k, :L] = torch.tensor(idxs_sup[:L], dtype=torch.long)

        # ——— 2) unroll *this* episode’s path into (prefix→next) queries ———
        raw_q        = path_data[ep][a_id]
        idxs_q       = [node2idx[n] for n in raw_q]
        true_goal_idx = goal2idx[goal_data[ep][a_id]]

        for t in range(1, len(idxs_q)):
            prefix_idxs = idxs_q[:t]     # length t (we’ll pad later)
            next_idx    = idxs_q[t]      # ground‐truth “next node”

            target.append((
                sup_tensor.clone(),      # [K×T_sup] LongTensor
                prefix_idxs,             # Python list of length t
                next_idx,                # int
                true_goal_idx            # int
            ))

print(f"# train examples: {len(examples_train)}")
print(f"# test  examples: {len(examples_test)}")

# train examples: 289118
# test  examples: 125153


In [74]:
from torch.utils.data import Dataset, DataLoader

class ToMNetDataset(Dataset):
    def __init__(self, examples, T_q, pad_value=0):
        """
        examples: list of tuples
            (sup_tensor, prefix_idxs, next_idx, true_goal_idx)
        T_q: int
            length that we will pad/truncate every prefix to
        pad_value: int
            index to use for padding prefixes
        """
        self.examples = examples
        self.T_q       = T_q
        self.pad_value = pad_value

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        sup_tensor, prefix_list, next_idx, true_goal_idx = self.examples[idx]
        # sup_tensor: Tensor[K, T_sup]
        # prefix_list: Python list, length <= T_q (un‐padded)
        # next_idx: int
        # true_goal_idx: int
        return sup_tensor, torch.tensor(prefix_list, dtype=torch.long), next_idx, true_goal_idx

In [75]:
def tomnet_collate(batch, T_q, pad_value=0):
    """
    batch: list of tuples from __getitem__()
      sup_tensor:  K×T_sup
      prefix:     [t] (list of ints)
      next_idx:   scalar int
      goal_idx:   scalar int

    Returns:
      sup_batch:  (B, K, T_sup)
      prefix_batch: (B, T_q)
      next_batch:   (B,)
      goal_batch:   (B,)
      prefix_lens:  (B,)  # optional if you need to mask
    """
    sup_list, prefix_list, next_list, goal_list = zip(*batch)
    B = len(batch)

    # stack support tensors
    sup_batch = torch.stack(sup_list, dim=0)    # (B, K, T_sup)

    # pad prefixes to length T_q
    prefix_batch = torch.full((B, T_q), pad_value, dtype=torch.long)
    prefix_lens  = torch.zeros(B, dtype=torch.long)
    for i, p in enumerate(prefix_list):
        L = min(len(p), T_q)
        prefix_batch[i, :L] = p[:L]
        prefix_lens[i]      = L

    next_batch = torch.tensor(next_list, dtype=torch.long)     # (B,)
    goal_batch = torch.tensor(goal_list, dtype=torch.long)     # (B,)

    return sup_batch, prefix_batch, next_batch, goal_batch, prefix_lens

def tomnet_collate_fn(batch):
    # use your existing tomnet_collate, but wrap it
    return tomnet_collate(batch, T_q=T_q, pad_value=0)

In [76]:
batch_size = 128

In [77]:
train_ds = ToMNetDataset(examples_train, T_q=T_q, pad_value=0)
test_ds  = ToMNetDataset(examples_test,  T_q=T_q, pad_value=0)

test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False,
                          num_workers=6, collate_fn=tomnet_collate_fn)

In [78]:
# 3) (optional) split into train/val
n_total = len(train_ds)
n_val   = int(0.1 * n_total)
n_train = n_total - n_val
# train_ds, val_ds = random_split((train_ds), [n_train, n_val])
train_loader = torch.utils.data.DataLoader(train_ds, batch_size, shuffle=True,
                                           collate_fn=tomnet_collate_fn,
                                           num_workers=16)
val_loader   = torch.utils.data.DataLoader(test_ds,   batch_size, shuffle=False,
                                           collate_fn=tomnet_collate_fn,
                                           num_workers=16)

## Step 4: Model Training

In [79]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name()}")
else:
    device = torch.device("cpu")
device

Using GPU: NVIDIA GeForce RTX 4090


device(type='cuda')

In [80]:
from torch import nn, optim
from torch.utils.data import random_split

# 1) hyper‐parameters
lr         = 1e-3
weight_decay = 1e-5
num_epochs = 30

# 2) model, losses, optimizer
model = ToMNet(
    node_embeddings = node_embeddings,
    num_nodes   = len(node2idx),
    num_goals   = len(goal2idx),
    T_sup=75    # … etc …
).to(device)

loss_next = nn.CrossEntropyLoss()
loss_goal = nn.CrossEntropyLoss()
opt       = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)


In [81]:
from tqdm import tqdm

best_val_loss = float('inf')
best_state    = None

for epoch in range(1, num_epochs+1):
    # —————— Training ——————
    model.train()
    total_train_loss = 0.0
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs} [Train]", leave=False)
    for sup, prefix, next_idx, goal_idx, pre_len in train_bar:
        sup      = sup.to(device)       # [B, K, T_sup]
        prefix   = prefix.to(device)    # [B, T_q]
        pre_len  = pre_len.to(device)   # [B]
        next_idx = next_idx.to(device)  # [B]
        goal_idx = goal_idx.to(device)  # [B]

        opt.zero_grad()
        # forward
        pred_next_logits, pred_goal_logits = model(sup, prefix, pre_len)

        # compute losses
        L_next = loss_next(pred_next_logits, next_idx)
        L_goal = loss_goal(pred_goal_logits,   goal_idx)
        loss   = L_next + L_goal

        # backward + step
        loss.backward()
        opt.step()

        total_train_loss += loss.item() * prefix.size(0)

        # update tqdm bar with current batch loss
        train_bar.set_postfix(train_loss=loss.item())

    avg_train_loss = total_train_loss / n_train

    # —————— Validation ——————
    model.eval()
    total_val_loss = 0.0
    val_bar = tqdm(val_loader, desc=f"Epoch {epoch}/{num_epochs} [Val]  ", leave=False)
    with torch.no_grad():
        for sup, prefix, next_idx, goal_idx, pre_len in val_bar:
            sup     = sup.to(device)       # [B,K,T_sup]
            prefix  = prefix.to(device)    # [B,T_q]
            pre_len = pre_len.to(device)   # [B]
            next_idx= next_idx.to(device)
            goal_idx= goal_idx.to(device)
    
            # forward
            p_next, p_goal = model(sup, prefix, pre_len)
            L_next        = loss_next(p_next, next_idx)
            L_goal        = loss_goal(p_goal,   goal_idx)
            batch_loss    = (L_next + L_goal).item()
    
            total_val_loss += batch_loss * prefix.size(0)
            val_bar.set_postfix(val_loss=batch_loss)
    
    avg_val_loss = total_val_loss / n_val

    # print a summary line
    print(f"Epoch {epoch}/{num_epochs}  "
          f"train_loss={avg_train_loss:.4f}  val_loss={avg_val_loss:.4f}")

    # save best
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_state    = model.state_dict()

# finally, load best
model.load_state_dict(best_state)
print("Training complete. Best val loss:", best_val_loss)

                                                                                        

Epoch 1/30  train_loss=6.9122  val_loss=22.7074


                                                                                        

Epoch 2/30  train_loss=4.9487  val_loss=22.0016


                                                                                        

Epoch 3/30  train_loss=4.3502  val_loss=23.7592


                                                                                        

Epoch 4/30  train_loss=3.8895  val_loss=27.0902


                                                                                        

Epoch 5/30  train_loss=3.5232  val_loss=29.9555


                                                                                        

Epoch 6/30  train_loss=3.2446  val_loss=32.3303


                                                                                        

Epoch 7/30  train_loss=3.0546  val_loss=34.8195


                                                                                        

Epoch 8/30  train_loss=2.9160  val_loss=37.3059


                                                                                        

Epoch 9/30  train_loss=2.8179  val_loss=40.2299


                                                                                         

Epoch 10/30  train_loss=2.7311  val_loss=40.8868


                                                                                         

Epoch 11/30  train_loss=2.6703  val_loss=42.4022


                                                                                         

Epoch 12/30  train_loss=2.6125  val_loss=43.0564


                                                                                         

Epoch 13/30  train_loss=2.5812  val_loss=44.7923


                                                                                         

Epoch 14/30  train_loss=2.5355  val_loss=45.6033


                                                                                         

Epoch 15/30  train_loss=2.5034  val_loss=47.8446


                                                                                         

Epoch 16/30  train_loss=2.4782  val_loss=47.7170


                                                                                         

Epoch 17/30  train_loss=2.4624  val_loss=48.5024


                                                                                         

Epoch 18/30  train_loss=2.4417  val_loss=50.0424


                                                                                         

Epoch 19/30  train_loss=2.4259  val_loss=50.6000


                                                                                         

Epoch 20/30  train_loss=2.4163  val_loss=51.2067


                                                                                         

Epoch 21/30  train_loss=2.3990  val_loss=50.9581


                                                                                         

Epoch 22/30  train_loss=2.3840  val_loss=51.7084


                                                                                         

Epoch 23/30  train_loss=2.3737  val_loss=53.2495


                                                                                         

Epoch 24/30  train_loss=2.3743  val_loss=53.6829


                                                                                         

Epoch 25/30  train_loss=2.3616  val_loss=55.0527


                                                                                         

Epoch 26/30  train_loss=2.3641  val_loss=54.3776


                                                                                         

Epoch 27/30  train_loss=2.3473  val_loss=55.0723


                                                                                         

Epoch 28/30  train_loss=2.3426  val_loss=54.6573


                                                                                         

Epoch 29/30  train_loss=2.3359  val_loss=55.0202


                                                                                         

Epoch 30/30  train_loss=2.3266  val_loss=54.4623
Training complete. Best val loss: 22.00158353744319




In [82]:
torch.save(model.state_dict(), "data/All_en_tomnet_cuda.pth", _use_new_zipfile_serialization=False)

## Step 5: Testing and Evaluation with ToMnet

In [83]:
from real_world_src.utils.metrics import brier_along_path, accuracy_along_path

In [84]:
import torch.nn.functional as F

def make_support_tensor(agent_id, episode_id, path_data, node2idx, K, T_sup):
    # all eps for this agent
    all_eps = [ep for ep in path_data.keys() if ep != episode_id]
    # pick K random others:
    support_eps = random.sample(all_eps, K)
    sup_tensor = torch.zeros(K, T_sup, dtype=torch.long)
    for k, ep in enumerate(support_eps):
        raw = path_data[ep][agent_id]            # list of node‐ids
        idxs = [node2idx[n] for n in raw]
        L = min(len(idxs), T_sup)
        sup_tensor[k, :L] = torch.tensor(idxs[:L], dtype=torch.long)
    return sup_tensor  # (K×T_sup)

In [85]:
def infer_goal_dists(
    model, agent_id, test_ep,
    path_data, node2idx, goal2idx,
    K, T_sup, T_q,
    device='cuda'
):
    model.eval()
    # 1) build support once
    sup     = make_support_tensor(agent_id, test_ep, path_data, node2idx, K, T_sup)
    sup     = sup.to(device).unsqueeze(0)     # add batch‐dim → [1,K,T_sup]

    raw_seq = path_data[test_ep][agent_id]
    idxs    = [node2idx[n] for n in raw_seq]
    N       = len(idxs)

    goal_dists = []   # will be list of length N each [num_goals]
    with torch.no_grad():
        for t in range(1, N):
            # build prefix up to t (we treat t=0 as “no steps seen”)
            prefix_len = min(t, T_q)
            # pad prefix to T_q
            prefix = torch.zeros(T_q, dtype=torch.long)
            if prefix_len>0:
                prefix[:prefix_len] = torch.tensor(idxs[:prefix_len], dtype=torch.long)
            # move to device and batch‐dim
            prefix     = prefix.to(device).unsqueeze(0)       # [1,T_q]
            prefix_len = torch.tensor([prefix_len], dtype=torch.long, device=device)

            # forward through ToMNet
            _, goal_logits = model(sup, prefix, prefix_len)   # [1, num_goals]
            p_goal = F.softmax(goal_logits, dim=-1)[0]        # remove batch‐dim → [num_goals]

            goal_dists.append(p_goal.cpu().numpy())

    return goal_dists   # shape (N × num_goals) array

In [86]:
agent_id=2
test_ep=2

In [87]:
dists = infer_goal_dists(
    model, agent_id, test_ep,
    path_data, node2idx, goal2idx,
    K=10, T_sup=75, T_q=20,
    device='mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
)

In [88]:
dists[0]

array([9.6353358e-01, 3.3210710e-02, 9.7256631e-08, 4.3268625e-07,
       3.7784666e-08, 6.7487068e-04, 1.3993939e-09, 4.4296911e-16,
       4.3171546e-15, 7.7847410e-15, 9.6222379e-13, 2.4198443e-03,
       4.8614932e-12, 1.6025700e-04, 2.9114938e-07], dtype=float32)

In [89]:
len(path_data[test_ep][agent_id])

65

In [90]:
# goal2idx: { goal_node_id → index }
idx2goal = { idx: goal for goal, idx in goal2idx.items() }

In [91]:
goal_posteriors = [
    { idx2goal[i]: float(p) for i, p in enumerate(prob_row) }
    for prob_row in dists
]

In [92]:
scores = brier_along_path(path_data[test_ep][agent_id], 
                                  goal_data[test_ep][agent_id], 
                                  goal_posteriors, 
                                  goals)

In [93]:
scores

[0.9333333333333338,
 1.73920083467773e-05,
 4.4062945250520044e-06,
 0.004934948360281463,
 0.018079484875425628,
 0.018079484875425628,
 0.018079484875425628,
 0.018079484875425628,
 0.018079484875425628,
 0.018079484875425628,
 0.018079484875425628]

In [94]:
goal_posteriors

[{469084068: 0.9635335803031921,
  49150691: 0.03321070969104767,
  768264666: 9.725663119297678e-08,
  1926666015: 4.3268624949632795e-07,
  1926673385: 3.778466606263464e-08,
  49309735: 0.0006748706800863147,
  273627682: 1.399393934065074e-09,
  445989107: 4.429691146564372e-16,
  445992528: 4.3171545609504026e-15,
  446128310: 7.784741005035373e-15,
  1772230346: 9.622237902295883e-13,
  1926673336: 0.0024198442697525024,
  2872424923: 4.861493152485963e-12,
  3139419286: 0.0001602569973329082,
  4037576308: 2.9114937660779105e-07},
 {469084068: 0.9891675710678101,
  49150691: 0.008286788128316402,
  768264666: 1.3653706787408737e-07,
  1926666015: 5.198809276407701e-07,
  1926673385: 2.1881435330328713e-08,
  49309735: 0.0005325276870280504,
  273627682: 1.147552630698101e-08,
  445989107: 8.841138791631663e-16,
  445992528: 2.960048113738456e-13,
  446128310: 4.675619327744896e-15,
  1772230346: 5.994546764358233e-13,
  1926673336: 0.00018444479792378843,
  2872424923: 3.2565267

### DVIB LOSS Included - Modified Char and Mental Nets

In [97]:
class CharacterNet(nn.Module):
    """
    Variational CharacterNet: outputs z_char, mu_char, logvar_char.
    """
    def __init__(self, 
                 node_embeddings: np.ndarray, 
                 h_lstm: int = 64, 
                 T_sup: int = 50, 
                 K: int = 10, 
                 z_dim: int = 64):
        super().__init__()
        num_nodes, d_emb = node_embeddings.shape
        
        self.embedding = nn.Embedding.from_pretrained(
            torch.tensor(node_embeddings), freeze=True, padding_idx=0)
        self.lstm = nn.LSTM(d_emb, h_lstm, batch_first=True)
        
        self.K = K
        self.T_sup = T_sup
        self.z_dim = z_dim
        
        self.fc_mu = nn.Linear(h_lstm, z_dim)
        self.fc_logvar = nn.Linear(h_lstm, z_dim)

    def forward(self, support_trajs):
        
        B, K, T = support_trajs.size()
        assert K == self.K and T == self.T_sup
        
        flat = support_trajs.view(B * K, T)
        emb = self.embedding(flat)
        _, (h_n, _) = self.lstm(emb)
        h_n = h_n.squeeze(0)
        
        chars = h_n.view(B, K, -1).mean(dim=1)  # [B, h_lstm]
        mu = self.fc_mu(chars)
        logvar = self.fc_logvar(chars)
        
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        
        return z, mu, logvar

In [98]:
class MentalNet(nn.Module):
    """
    Variational MentalNet: outputs z_mental, mu_mental, logvar_mental.
    """
    def __init__(self, 
                 node_embeddings: np.ndarray, 
                 h_lstm: int = 64, 
                 T_q: int = 20, 
                 dropout: float = 0.1, 
                 use_attention: bool = True, 
                 z_dim: int = 64):
        super().__init__()
        num_nodes, d_emb = node_embeddings.shape
        self.embedding = nn.Embedding.from_pretrained(
            torch.tensor(node_embeddings), freeze=False, padding_idx=0)
        self.lstm = nn.LSTM(d_emb, h_lstm, batch_first=True, bidirectional=True)
        self.layer_norm = nn.LayerNorm(2 * h_lstm)
        self.dropout = nn.Dropout(dropout)
        
        self.T_q = T_q
        self.z_dim = z_dim
        
        self.use_attention = use_attention
        
        if use_attention:
            self.attn = nn.Linear(h_lstm * 2, 1)
        
        self.fc_mu = nn.Linear(2 * h_lstm, z_dim)
        self.fc_logvar = nn.Linear(2 * h_lstm, z_dim)

    def forward(self, prefix: torch.LongTensor, prefix_len: torch.LongTensor):
        B, T = prefix.size()
        assert T == self.T_q, f"Expected T_q={self.T_q}, got {T}"
        
        emb = self.embedding(prefix)
        packed = nn.utils.rnn.pack_padded_sequence(emb, lengths=prefix_len.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, (h_n, _) = self.lstm(packed)
        
        out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True, total_length=self.T_q)
        out = self.layer_norm(self.dropout(out))  # [B, T_q, 2*h_lstm]
        
        if self.use_attention:
            mask = torch.arange(self.T_q, device=prefix.device)[None, :] < prefix_len[:, None]
            attn_scores = self.attn(out).squeeze(-1)
            attn_scores[~mask] = float('-inf')
            attn_weights = torch.softmax(attn_scores, dim=1).unsqueeze(-1)
            feat = (out * attn_weights).sum(dim=1)
        else:
            idx = (prefix_len - 1).clamp(min=0)
            feat = out[torch.arange(B), idx]
        
        mu = self.fc_mu(feat)
        logvar = self.fc_logvar(feat)
        
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        
        return z, mu, logvar

In [99]:
class ToMNet(nn.Module):
    """
    ToMNet with DVIB: CharacterNet + MentalNet + fusion MLP + prediction heads.
    """
    def __init__(self, 
                 node_embeddings: np.ndarray, 
                 num_nodes: int, 
                 num_goals: int, 
                 K: int = 10, 
                 T_sup: int = 50, 
                 T_q: int = 20, 
                 h_char: int = 64, 
                 h_ment: int = 64, 
                 z_dim: int = 32, 
                 dvib_z_dim: int = 64):
        super().__init__()
        num_nodes, d_emb = node_embeddings.shape
        
        self.char_net = CharacterNet(node_embeddings, h_char, T_sup, K, z_dim=dvib_z_dim)
        self.mental_net = MentalNet(node_embeddings, h_ment, T_q, z_dim=dvib_z_dim)
        self.embedding = nn.Embedding.from_pretrained(
            torch.tensor(node_embeddings), freeze=False, padding_idx=0)
        
        fusion_dim = dvib_z_dim + dvib_z_dim + d_emb
        self.fusion = nn.Sequential(
            nn.Linear(fusion_dim, 128),
            nn.ReLU(),
            nn.Linear(128, z_dim),
            nn.ReLU()
        )
        self.goal_head = nn.Linear(z_dim, num_goals)
        self.next_head = nn.Linear(z_dim, num_nodes)

    def forward(self, sup, prefix, prefix_len):
        B, K, T_sup = sup.shape
        _, T_q = prefix.shape
        
        z_char, mu_char, logvar_char = self.char_net(sup)
        z_mental, mu_mental, logvar_mental = self.mental_net(prefix, prefix_len)
        
        last_indices = (prefix_len - 1).clamp(min=0)
        last_nodes = prefix[torch.arange(B), last_indices]
        last_emb = self.embedding(last_nodes)
        
        fusion_input = torch.cat([z_char, z_mental, last_emb], dim=1)
        z = self.fusion(fusion_input)
        
        next_logits = self.next_head(z)
        goal_logits = self.goal_head(z)
        
        return next_logits, goal_logits, mu_char, logvar_char, mu_mental, logvar_mental

### Training with the New Loss

In [102]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name()}")
else:
    device = torch.device("cpu")
device

Using GPU: NVIDIA GeForce RTX 4090


device(type='cuda')

In [None]:
from torch import nn, optim
from torch.utils.data import random_split

# TODO: Need to play with the hyperparameters
# 1) hyper‐parameters 
lr         = 1e-3
weight_decay = 1e-5
num_epochs = 30

# 2) model, losses, optimizer
model = ToMNet(
    node_embeddings = node_embeddings,
    num_nodes   = len(node2idx),
    num_goals   = len(goal2idx),
    T_sup=75    # … etc …
).to(device)

loss_next = nn.CrossEntropyLoss()
loss_goal = nn.CrossEntropyLoss()
opt       = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)


In [105]:
beta = 1e-3  # TODO: Test out different values of annealing

best_val_loss = float('inf')
best_state    = None

for epoch in range(1, num_epochs+1):
    
    model.train()
    total_train_loss = 0.0
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs} [Train]", leave=False)
    
    for sup, prefix, next_idx, goal_idx, pre_len in train_bar:
        sup = sup.to(device)
        prefix = prefix.to(device)
        pre_len = pre_len.to(device)
        next_idx = next_idx.to(device)
        goal_idx = goal_idx.to(device)

        opt.zero_grad()
        pred_next_logits, pred_goal_logits, mu_char, logvar_char, mu_mental, logvar_mental = model(sup, prefix, pre_len)

        # KL divergence for both character and mental
        kl_char = -0.5 * torch.sum(1 + logvar_char - mu_char.pow(2) - logvar_char.exp(), dim=1).mean()
        kl_mental = -0.5 * torch.sum(1 + logvar_mental - mu_mental.pow(2) - logvar_mental.exp(), dim=1).mean()
        loss_dvib = beta * (kl_char + kl_mental)

        L_next = loss_next(pred_next_logits, next_idx)
        L_goal = loss_goal(pred_goal_logits, goal_idx)
        loss = L_next + L_goal + loss_dvib

        loss.backward()
        opt.step()

        total_train_loss += loss.item() * prefix.size(0)
        train_bar.set_postfix(train_loss=loss.item())
        
    avg_train_loss = total_train_loss / n_train

    # —————— Validation ——————
    model.eval()
    total_val_loss = 0.0
    val_bar = tqdm(val_loader, desc=f"Epoch {epoch}/{num_epochs} [Val]  ", leave=False)
    with torch.no_grad():
        for sup, prefix, next_idx, goal_idx, pre_len in val_bar:
            sup     = sup.to(device)       # [B,K,T_sup]
            prefix  = prefix.to(device)    # [B,T_q]
            pre_len = pre_len.to(device)   # [B]
            next_idx= next_idx.to(device)
            goal_idx= goal_idx.to(device)
    
            # forward
            next_logits, goal_logits, mu_char, logvar_char, mu_mental, logvar_mental = model(sup, prefix, pre_len)
            L_next        = loss_next(next_logits, next_idx)
            L_goal        = loss_goal(goal_logits,   goal_idx)
            batch_loss    = (L_next + L_goal).item()
    
            total_val_loss += batch_loss * prefix.size(0)
            val_bar.set_postfix(val_loss=batch_loss)
    
    avg_val_loss = total_val_loss / n_val

    # print a summary line
    print(f"Epoch {epoch}/{num_epochs}  "
          f"train_loss={avg_train_loss:.4f}  val_loss={avg_val_loss:.4f}")

    # save best
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_state    = model.state_dict()

# finally, load best
model.load_state_dict(best_state)
print("Training complete. Best val loss:", best_val_loss)

                                                                                        

Epoch 1/30  train_loss=5.3406  val_loss=22.1776


                                                                                        

Epoch 2/30  train_loss=4.7434  val_loss=23.8398


                                                                                        

Epoch 3/30  train_loss=4.3669  val_loss=25.3899


                                                                                        

Epoch 4/30  train_loss=4.0965  val_loss=28.0717


                                                                                        

Epoch 5/30  train_loss=3.8553  val_loss=29.2518


                                                                                        

Epoch 6/30  train_loss=3.6571  val_loss=31.4405


                                                                                        

Epoch 7/30  train_loss=3.4994  val_loss=32.5525


                                                                                        

Epoch 8/30  train_loss=3.3719  val_loss=34.3932


                                                                                        

Epoch 9/30  train_loss=3.2582  val_loss=35.6223


                                                                                         

Epoch 10/30  train_loss=3.1493  val_loss=37.2275


                                                                                         

Epoch 11/30  train_loss=3.0581  val_loss=39.7338


                                                                                         

Epoch 12/30  train_loss=2.9763  val_loss=40.8826


                                                                                         

Epoch 13/30  train_loss=2.9053  val_loss=41.8911


                                                                                         

Epoch 14/30  train_loss=2.8401  val_loss=44.4048


                                                                                         

Epoch 15/30  train_loss=2.7886  val_loss=45.3735


                                                                                         

Epoch 16/30  train_loss=2.7456  val_loss=47.1154


                                                                                         

Epoch 17/30  train_loss=2.7048  val_loss=47.8506


                                                                                         

Epoch 18/30  train_loss=2.6747  val_loss=47.2186


                                                                                         

Epoch 19/30  train_loss=2.6544  val_loss=48.7228


                                                                                         

Epoch 20/30  train_loss=2.6252  val_loss=50.5351


                                                                                         

Epoch 21/30  train_loss=2.6049  val_loss=51.4390


                                                                                         

Epoch 22/30  train_loss=2.5894  val_loss=53.0016


                                                                                         

Epoch 23/30  train_loss=2.5756  val_loss=53.0672


                                                                                         

Epoch 24/30  train_loss=2.5688  val_loss=53.9286


                                                                                         

Epoch 25/30  train_loss=2.5441  val_loss=54.2290


                                                                                         

Epoch 26/30  train_loss=2.5394  val_loss=55.5172


                                                                                         

Epoch 27/30  train_loss=2.5254  val_loss=55.8927


                                                                                         

Epoch 28/30  train_loss=2.5123  val_loss=54.2455


                                                                                         

Epoch 29/30  train_loss=2.5098  val_loss=57.7980


                                                                                         

Epoch 30/30  train_loss=2.4947  val_loss=57.0495
Training complete. Best val loss: 22.177631315054054


