# SIAT: Social Interaction-Aware Transformer â€” Colab Notebook

This notebook is a self-contained, Colab-ready version of the SIAT codebase. It includes all modules inline and runnable end-to-end: setup, preprocessing (ETH/UCY -> NPZ), dataset, model (GCN + Transformer), training, and evaluation with visualizations.

In [None]:
# Setup: Install deps and configure paths (Colab-friendly)
import sys, os, subprocess
from pathlib import Path

# Detect Colab
IN_COLAB = 'google.colab' in sys.modules
print('Running in Colab:', IN_COLAB)

# Optional: Mount Google Drive to persist data/checkpoints
if IN_COLAB:
    try:
        from google.colab import drive
        drive.mount('/content/drive', force_remount=False)
        BASE_DIR = Path('/content/drive/MyDrive/SIAT')
    except Exception:
        BASE_DIR = Path('/content/SIAT')
else:
    BASE_DIR = Path.cwd()

BASE_DIR.mkdir(parents=True, exist_ok=True)
DATA_DIR = BASE_DIR / 'data_npz'
CKPT_DIR = BASE_DIR / 'checkpoints'
RESULTS_DIR = BASE_DIR / 'results'
DATASETS_DIR = BASE_DIR / 'datasets'
for d in [DATA_DIR, CKPT_DIR, RESULTS_DIR, DATASETS_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# Install Python packages
reqs = [
    'numpy', 'pandas', 'torch', 'torchvision', 'torchaudio', 'matplotlib', 'tqdm'
]
for pkg in reqs:
    try:
        __import__(pkg)
    except Exception:
        print('Installing', pkg)
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', pkg])

print('BASE_DIR =', BASE_DIR)
print('DATA_DIR =', DATA_DIR)
print('CKPT_DIR =', CKPT_DIR)
print('RESULTS_DIR =', RESULTS_DIR)

In [None]:
# Auto-download preprocessed data (data_npz) if missing
import shutil, zipfile

# Optional: set a ZIP archive URL containing data_npz (if you have one)
DATA_ARCHIVE_URL = None  # e.g., 'https://example.com/SIAT_data_npz.zip'

# GitHub raw fallback (public repo assumed)
GITHUB_OWNER = 'VidushaSanidu'
GITHUB_REPO = 'SIAT'
GITHUB_BRANCH = 'main'
GITHUB_FOLDER = 'data_npz'

# Known NPZ filenames (from repository)
NPZ_FILES = [
    'eth_test_biwi_eth.npz',
    'eth_train_biwi_hotel_train.npz',
    'eth_train_crowds_zara01_train.npz',
    'eth_train_crowds_zara02_train.npz',
    'eth_train_crowds_zara03_train.npz',
    'eth_train_students001_train.npz',
    'eth_train_students003_train.npz',
    'eth_train_uni_examples_train.npz',
    'eth_val_biwi_hotel_val.npz',
    'eth_val_crowds_zara01_val.npz',
    'eth_val_crowds_zara02_val.npz',
    'eth_val_crowds_zara03_val.npz',
    'eth_val_students001_val.npz',
    'eth_val_students003_val.npz',
    'eth_val_uni_examples_val.npz',
    'hotel_test_biwi_hotel.npz',
    'hotel_train_biwi_eth_train.npz',
    'hotel_train_crowds_zara01_train.npz',
    'hotel_train_crowds_zara02_train.npz',
    'hotel_train_crowds_zara03_train.npz',
    'hotel_train_students001_train.npz',
    'hotel_train_students003_train.npz',
    'hotel_train_uni_examples_train.npz',
    'hotel_val_biwi_eth_val.npz',
    'hotel_val_crowds_zara01_val.npz',
    'hotel_val_crowds_zara02_val.npz',
    'hotel_val_crowds_zara03_val.npz',
    'hotel_val_students001_val.npz',
    'hotel_val_students003_val.npz',
    'hotel_val_uni_examples_val.npz',
    'raw_all_data_biwi_eth.npz',
    'raw_all_data_biwi_hotel.npz',
    'raw_all_data_crowds_zara01.npz',
    'raw_all_data_crowds_zara02.npz',
    'raw_all_data_crowds_zara03.npz',
    'raw_all_data_students001.npz',
    'raw_all_data_students003.npz',
    'raw_all_data_uni_examples.npz',
    'raw_train_biwi_eth_train.npz',
    'raw_train_biwi_hotel_train.npz',
    'raw_train_crowds_zara01_train.npz',
    'raw_train_crowds_zara02_train.npz',
    'raw_train_crowds_zara03_train.npz',
    'raw_train_students001_train.npz',
    'raw_train_students003_train.npz',
    'raw_train_uni_examples_train.npz',
    'raw_val_biwi_eth_val.npz',
    'raw_val_biwi_hotel_val.npz',
    'raw_val_crowds_zara01_val.npz',
    'raw_val_crowds_zara02_val.npz',
    'raw_val_crowds_zara03_val.npz',
    'raw_val_students001_val.npz',
    'raw_val_students003_val.npz',
    'raw_val_uni_examples_val.npz',
    'univ_test_students001.npz',
    'univ_test_students003.npz',
    'univ_train_biwi_eth_train.npz',
    'univ_train_biwi_hotel_train.npz'
]

# Ensure requests is available
try:
    import requests
except Exception:
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'requests'])
    import requests


def download_file(url: str, dest: Path, chunk_size: int = 1 << 20) -> bool:
    try:
        with requests.get(url, stream=True, timeout=60) as r:
            if r.status_code != 200:
                return False
            total = int(r.headers.get('content-length', 0) or 0)
            with open(dest, 'wb') as f:
                for chunk in r.iter_content(chunk_size=chunk_size):
                    if chunk:
                        f.write(chunk)
        return True
    except Exception:
        return False


def download_zip_and_extract(url: str, out_dir: Path) -> bool:
    tmp_zip = out_dir.parent / 'data_npz_tmp.zip'
    print('Downloading zip:', url)
    if not download_file(url, tmp_zip):
        print('Zip download failed')
        return False
    try:
        with zipfile.ZipFile(tmp_zip, 'r') as zf:
            zf.extractall(out_dir)
        print('Extracted to', out_dir)
        return True
    except Exception as e:
        print('Zip extract failed:', e)
        return False
    finally:
        try:
            tmp_zip.unlink(missing_ok=True)
        except Exception:
            pass


def auto_download_data():
    # Skip if data exists
    if any(DATA_DIR.glob('*.npz')):
        print('Data already present in', DATA_DIR)
        return True

    # Option 1: Zip archive URL
    if DATA_ARCHIVE_URL:
        ok = download_zip_and_extract(DATA_ARCHIVE_URL, DATA_DIR)
        if ok and any(DATA_DIR.glob('*.npz')):
            print('Data downloaded via ZIP archive.')
            return True
        print('ZIP archive route failed; falling back to GitHub raw files...')

    # Option 2: Download each file from GitHub raw
    base = f'https://raw.githubusercontent.com/{GITHUB_OWNER}/{GITHUB_REPO}/{GITHUB_BRANCH}/{GITHUB_FOLDER}'
    success = 0
    total = len(NPZ_FILES)
    for name in NPZ_FILES:
        url = f'{base}/{name}'
        dest = DATA_DIR / name
        if dest.exists():
            success += 1
            continue
        ok = download_file(url, dest)
        if ok:
            success += 1
        else:
            # remove partial file if any
            try:
                if dest.exists():
                    dest.unlink()
            except Exception:
                pass
    print(f'Downloaded {success}/{total} files to', DATA_DIR)
    if success == 0:
        print('No files downloaded. You can:')
        print('  - Set DATA_ARCHIVE_URL to a valid ZIP containing .npz files')
        print('  - Upload your .txt files into', DATASETS_DIR, 'and run preprocessing')
        print('  - Or upload .npz files directly into', DATA_DIR)
        return False
    return True

# Trigger auto-download if data directory is empty
_ = auto_download_data()

In [None]:
# Config
from dataclasses import dataclass, field
from typing import Optional

@dataclass
class ModelConfig:
    obs_len: int = 8
    pred_len: int = 12
    in_size: int = 2
    embed_size: int = 64
    enc_layers: int = 2
    dec_layers: int = 1
    nhead: int = 4
    gcn_hidden: int = 64
    gcn_layers: int = 2
    dropout: float = 0.1

@dataclass
class TrainingConfig:
    epochs: int = 20
    batch_size: int = 32
    learning_rate: float = 1e-3
    weight_decay: float = 0.0
    grad_clip: Optional[float] = 1.0
    device: str = 'auto'  # auto, cuda, mps, cpu

@dataclass
class DataConfig:
    data_dir: str = str(DATA_DIR)
    obs_len: int = 8
    pred_len: int = 12

@dataclass
class Config:
    model: ModelConfig = field(default_factory=ModelConfig)
    training: TrainingConfig = field(default_factory=TrainingConfig)
    data: DataConfig = field(default_factory=DataConfig)
    
    def __post_init__(self):
        self.model.obs_len = self.data.obs_len
        self.model.pred_len = self.data.pred_len

CFG = Config()
CFG

In [None]:
# Dataset and collate
import numpy as np
import torch
from torch.utils.data import Dataset

author_note = 'Target agent is at index 0 in each window.'

class TrajectoryDataset(Dataset):
    def __init__(self, npz_files: list, obs_len: int = 8, pred_len: int = 12, transform=None):
        self.samples = []
        self.obs_len = obs_len
        self.pred_len = pred_len
        self.transform = transform
        for f in npz_files:
            data = np.load(f, allow_pickle=True)
            if 'observations' in data and 'futures' in data and 'windows' in data:
                observations = data['observations']
                futures = data['futures']
                windows = data['windows']
                for i in range(len(observations)):
                    obs = observations[i].astype(np.float32)
                    fut = futures[i].astype(np.float32)
                    window = windows[i].astype(np.float32)
                    self.samples.append((obs, fut, window))
            elif 'trajectories' in data:
                trajs = data['trajectories']  # (N, T, 2)
                N, T, _ = trajs.shape
                for i in range(N):
                    for start in range(0, T - (obs_len + pred_len) + 1):
                        obs = trajs[i, start:start + obs_len]
                        fut = trajs[i, start + obs_len:start + obs_len + pred_len]
                        window = trajs[:, start:start + obs_len + pred_len]
                        reordered_window = np.zeros_like(window)
                        reordered_window[0] = window[i]
                        other_idx = 1
                        for j in range(N):
                            if j != i:
                                reordered_window[other_idx] = window[j]
                                other_idx += 1
                        self.samples.append((obs.astype(np.float32), fut.astype(np.float32), reordered_window.astype(np.float32)))
            else:
                raise ValueError(f'Unsupported .npz format in {f}.')

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        obs, fut, window = self.samples[idx]
        if self.transform:
            obs, fut, window = self.transform(obs, fut, window)
        return {
            'obs': torch.from_numpy(obs),
            'fut': torch.from_numpy(fut),
            'window': torch.from_numpy(window)
        }


def collate_fn(batch):
    obs_batch = torch.stack([item['obs'] for item in batch])
    fut_batch = torch.stack([item['fut'] for item in batch])
    max_agents = max(item['window'].size(0) for item in batch)
    batch_size = len(batch)
    seq_len = batch[0]['window'].size(1)
    window_batch = torch.zeros(batch_size, max_agents, seq_len, 2)
    agent_masks = torch.zeros(batch_size, max_agents, dtype=torch.bool)
    for i, item in enumerate(batch):
        n_agents = item['window'].size(0)
        window_batch[i, :n_agents] = item['window']
        agent_masks[i, :n_agents] = True
    return {'obs': obs_batch, 'fut': fut_batch, 'window': window_batch, 'agent_mask': agent_masks}

In [None]:
# Model: GCN layer
import torch
import torch.nn as nn
import torch.nn.functional as F

class GCNLayer(nn.Module):
    def __init__(self, in_feats: int, out_feats: int, bias: bool = True):
        super().__init__()
        self.linear = nn.Linear(in_feats, out_feats, bias=bias)
    def forward(self, X: torch.Tensor, A_norm: torch.Tensor) -> torch.Tensor:
        H = torch.bmm(A_norm, X)
        H = self.linear(H)
        return F.relu(H)

# Model: SIAT
class SIAT(nn.Module):
    def __init__(self, obs_len=8, pred_len=12, in_size=2, embed_size=64, enc_layers=2, dec_layers=1,
                 nhead=4, gcn_hidden=64, gcn_layers=2, dropout=0.1):
        super().__init__()
        self.obs_len = obs_len
        self.pred_len = pred_len
        self.in_size = in_size
        self.embed_size = embed_size
        seq_len = obs_len + pred_len
        self.flatten_dim = seq_len * in_size
        self.embedding = nn.Linear(self.flatten_dim, embed_size)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=nhead,
                                                   dim_feedforward=embed_size*2, dropout=dropout, activation='relu')
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=enc_layers)
        gcn_modules = []
        gcn_in = embed_size
        for _ in range(gcn_layers):
            gcn_modules.append(GCNLayer(gcn_in, gcn_hidden))
            gcn_in = gcn_hidden
        self.gcn = nn.ModuleList(gcn_modules)
        self.fuse_trans = nn.Linear(embed_size, embed_size)
        self.fuse_gcn = nn.Linear(gcn_hidden, embed_size)
        self.lambda1 = nn.Parameter(torch.tensor(0.5))
        self.lambda2 = nn.Parameter(torch.tensor(0.5))
        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_size, nhead=nhead,
                                                   dim_feedforward=embed_size*2, dropout=dropout, activation='relu')
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=dec_layers)
        self.reg_head = nn.Linear(embed_size, pred_len*2)
        self._reset_parameters()
    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
    def compute_adjacency(self, positions: torch.Tensor, sigma: float = 1.5, eps: float = 1e-6) -> torch.Tensor:
        B, N, _ = positions.shape
        pos_expand1 = positions.unsqueeze(2)
        pos_expand2 = positions.unsqueeze(1)
        diff = pos_expand1 - pos_expand2
        dist2 = (diff**2).sum(-1)
        A = torch.exp(-dist2 / (sigma**2 + eps))
        deg = A.sum(-1)
        deg_inv_sqrt = (deg + eps).pow(-0.5)
        D_inv_sqrt = deg_inv_sqrt.unsqueeze(-1) * deg_inv_sqrt.unsqueeze(-2)
        A_norm = A * D_inv_sqrt
        return A_norm
    def forward(self, obs: torch.Tensor, full_window: torch.Tensor, agent_mask: torch.Tensor | None = None) -> torch.Tensor:
        B = obs.size(0)
        N = full_window.size(1)
        agent_flat = full_window.view(B, N, -1)
        agent_emb = self.embedding(agent_flat)
        if agent_mask is not None:
            agent_emb = agent_emb * agent_mask.unsqueeze(-1).float()
        agent_emb_t = agent_emb.permute(1, 0, 2)
        src_key_padding_mask = (~agent_mask) if agent_mask is not None else None
        trans_enc_out = self.transformer_encoder(agent_emb_t, src_key_padding_mask=src_key_padding_mask)
        trans_enc_out = trans_enc_out.permute(1, 0, 2)
        last_pos = full_window[:, :, self.obs_len - 1, :]
        A_norm = self.compute_adjacency(last_pos)
        if agent_mask is not None:
            mask_matrix = agent_mask.unsqueeze(-1) & agent_mask.unsqueeze(-2)
            A_norm = A_norm * mask_matrix.float()
        H = agent_emb
        for layer in self.gcn:
            H = layer(H, A_norm)
        H_trans_proj = self.fuse_trans(trans_enc_out)
        H_gcn_proj = self.fuse_gcn(H)
        H_fused = self.lambda1 * H_trans_proj + self.lambda2 * H_gcn_proj
        target_feat = H_fused[:, 0, :].unsqueeze(0)
        query = target_feat.repeat(self.pred_len, 1, 1)
        memory = H_fused.permute(1, 0, 2)
        memory_key_padding_mask = (~agent_mask) if agent_mask is not None else None
        dec_out = self.transformer_decoder(tgt=query, memory=memory, memory_key_padding_mask=memory_key_padding_mask)
        dec_out_mean = dec_out.permute(1, 0, 2).mean(dim=1)
        reg = self.reg_head(dec_out_mean)
        pred = reg.view(B, self.pred_len, 2)
        return pred

In [None]:
# Metrics and training utils
from typing import Tuple
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm


def ade_fde(pred: torch.Tensor, gt: torch.Tensor) -> Tuple[float, float]:
    err = torch.norm(pred - gt, dim=-1)
    ade = err.mean().item()
    fde = err[:, -1].mean().item()
    return ade, fde


def train_one_epoch(model, optimizer, loader: DataLoader, device: torch.device, clip: float | None = 1.0) -> float:
    model.train()
    total_loss = 0.0
    pbar = tqdm(loader, desc='Training', leave=False)
    for batch in pbar:
        obs = batch['obs'].to(device)
        fut = batch['fut'].to(device)
        window = batch['window'].to(device)
        agent_mask = batch['agent_mask'].to(device)
        optimizer.zero_grad()
        pred = model(obs, window, agent_mask)
        loss = F.mse_loss(pred, fut)
        loss.backward()
        if clip is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        bl = loss.item()
        total_loss += bl * obs.size(0)
        pbar.set_postfix({'loss': f'{bl:.6f}'})
    return total_loss / len(loader.dataset)


def evaluate(model, loader: DataLoader, device: torch.device) -> tuple[float, float]:
    model.eval()
    total_ade = 0.0
    total_fde = 0.0
    with torch.no_grad():
        pbar = tqdm(loader, desc='Evaluating', leave=False)
        for batch in pbar:
            obs = batch['obs'].to(device)
            fut = batch['fut'].to(device)
            window = batch['window'].to(device)
            agent_mask = batch['agent_mask'].to(device)
            pred = model(obs, window, agent_mask)
            ade, fde = ade_fde(pred, fut)
            total_ade += ade * obs.size(0)
            total_fde += fde * obs.size(0)
            pbar.set_postfix({'ADE': f'{ade:.4f}', 'FDE': f'{fde:.4f}'})
    n = len(loader.dataset)
    return total_ade / n, total_fde / n

In [None]:
# Preprocessing: ETH/UCY -> NPZ
import pandas as pd
from pathlib import Path


def load_eth_ucy_file(file_path: str | Path):
    data = pd.read_csv(file_path, sep='\t', header=None, names=['frame', 'ped_id', 'x', 'y'])
    trajectories = {}
    for ped_id, group in data.groupby('ped_id'):
        group_sorted = group.sort_values('frame')
        coords = group_sorted[['x', 'y']].values
        trajectories[int(ped_id)] = coords.astype(np.float32)
    return trajectories


def create_sliding_windows(trajectories, obs_len=8, pred_len=12, min_agents=2):
    samples = []
    total_len = obs_len + pred_len
    if len(trajectories) < min_agents:
        return samples
    ped_ids = list(trajectories.keys())
    ped_trajs = [trajectories[pid] for pid in ped_ids]
    valid_trajs = [(pid, traj) for pid, traj in zip(ped_ids, ped_trajs) if len(traj) >= total_len]
    if len(valid_trajs) < min_agents:
        return samples
    max_frames = max(len(traj) for _, traj in valid_trajs)
    for start_frame in range(max_frames - total_len + 1):
        end_frame = start_frame + total_len
        window_data = []
        for pid, traj in valid_trajs:
            if len(traj) >= end_frame:
                window_data.append(traj[start_frame:end_frame])
        if len(window_data) < min_agents:
            continue
        window_array = np.array(window_data)
        for target_idx in range(len(window_data)):
            target_obs = window_array[target_idx, :obs_len]
            target_fut = window_array[target_idx, obs_len:]
            reordered_window = np.zeros_like(window_array)
            reordered_window[0] = window_array[target_idx]
            other_idx = 1
            for i in range(len(window_data)):
                if i != target_idx:
                    reordered_window[other_idx] = window_array[i]
                    other_idx += 1
            final_window = reordered_window[:len(window_data)]
            samples.append((target_obs, target_fut, final_window))
    return samples


def preprocess_scene(input_file: str | Path, output_file: str | Path, obs_len=8, pred_len=12):
    trajectories = load_eth_ucy_file(input_file)
    if not trajectories:
        return False
    samples = create_sliding_windows(trajectories, obs_len, pred_len)
    if len(samples) == 0:
        return False
    obs_list, fut_list, window_list = [], [], []
    for obs, fut, window in samples:
        obs_list.append(obs)
        fut_list.append(fut)
        window_list.append(window)
    output_file = Path(output_file)
    output_file.parent.mkdir(parents=True, exist_ok=True)
    np.savez_compressed(output_file, observations=np.array(obs_list), futures=np.array(fut_list), windows=np.array(window_list, dtype=object))
    return True

In [None]:
# Training/Evaluation runners
import glob, time
from torch.utils.data import DataLoader, random_split


def setup_device(arg='auto'):
    if arg == 'auto':
        if torch.cuda.is_available():
            return torch.device('cuda')
        try:
            if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
                return torch.device('mps')
        except Exception:
            pass
        return torch.device('cpu')
    return torch.device(arg)


def create_datasets(data_dir: Path, obs_len: int, pred_len: int, batch_size: int, num_workers: int = 2):
    npz_files = glob.glob(str(data_dir / '*.npz'))
    if len(npz_files) == 0:
        raise RuntimeError(f'No .npz files found in {data_dir}. Place data or run preprocessing.')
    full_dataset = TrajectoryDataset(npz_files, obs_len=obs_len, pred_len=pred_len)
    total = len(full_dataset)
    train_size = max(1, int(0.8 * total))
    val_size = total - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))
    use_pin = torch.cuda.is_available()
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn, pin_memory=use_pin)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn, pin_memory=use_pin)
    return train_loader, val_loader


def train_model(cfg: Config):
    device = setup_device(cfg.training.device)
    print('Device:', device)
    train_loader, val_loader = create_datasets(Path(cfg.data.data_dir), cfg.model.obs_len, cfg.model.pred_len, cfg.training.batch_size)
    model = SIAT(obs_len=cfg.model.obs_len, pred_len=cfg.model.pred_len, embed_size=cfg.model.embed_size).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.training.learning_rate, weight_decay=cfg.training.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5)
    best_ade = float('inf')
    best_path = None
    for epoch in range(1, cfg.training.epochs + 1):
        t0 = time.time()
        train_loss = train_one_epoch(model, optimizer, train_loader, device, cfg.training.grad_clip)
        val_ade, val_fde = evaluate(model, val_loader, device)
        scheduler.step(val_ade)
        dt = time.time() - t0
        print(f'Epoch {epoch:03d}/{cfg.training.epochs} | Loss {train_loss:.6f} | ADE {val_ade:.4f} | FDE {val_fde:.4f} | time {dt:.1f}s')
        if val_ade < best_ade:
            best_ade = val_ade
            best_path = CKPT_DIR / 'best_model.pth'
            torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'best_ade': best_ade, 'obs_len': cfg.model.obs_len, 'pred_len': cfg.model.pred_len, 'embed_size': cfg.model.embed_size}, best_path)
            print('Saved best checkpoint:', best_path)
    return best_path


def evaluate_model(checkpoint_path: Path, data_dir: Path, batch_size: int = 32, device_arg: str = 'auto', visualize: bool = True, num_samples: int = 10):
    device = setup_device(device_arg)
    ckpt = torch.load(checkpoint_path, map_location=device)
    obs_len = int(ckpt.get('obs_len', 8))
    pred_len = int(ckpt.get('pred_len', 12))
    embed_size = int(ckpt.get('embed_size', 64))
    model = SIAT(obs_len=obs_len, pred_len=pred_len, embed_size=embed_size).to(device)
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()
    dataset = TrajectoryDataset(glob.glob(str(data_dir / '*.npz')), obs_len=obs_len, pred_len=pred_len)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=2)
    all_pred, all_gt, all_obs = [], [], []
    total_ade = 0.0
    total_fde = 0.0
    n = 0
    with torch.no_grad():
        for batch in loader:
            obs = batch['obs'].to(device)
            fut = batch['fut'].to(device)
            window = batch['window'].to(device)
            agent_mask = batch['agent_mask'].to(device)
            pred = model(obs, window, agent_mask)
            ade, fde = ade_fde(pred, fut)
            b = obs.size(0)
            total_ade += ade * b
            total_fde += fde * b
            n += b
            all_pred.append(pred.cpu().numpy())
            all_gt.append(fut.cpu().numpy())
            all_obs.append(obs.cpu().numpy())
    final_ade = total_ade / max(1, n)
    final_fde = total_fde / max(1, n)
    print('Evaluation: ADE=', final_ade, 'FDE=', final_fde, 'Samples=', n)
    preds = np.concatenate(all_pred, axis=0) if all_pred else np.empty((0, pred_len, 2))
    gts = np.concatenate(all_gt, axis=0) if all_gt else np.empty((0, pred_len, 2))
    obss = np.concatenate(all_obs, axis=0) if all_obs else np.empty((0, obs_len, 2))
    if visualize and len(preds) > 0:
        try:
            import matplotlib.pyplot as plt
            os.makedirs(RESULTS_DIR, exist_ok=True)
            idxs = np.random.choice(len(preds), min(num_samples, len(preds)), replace=False)
            for i, idx in enumerate(idxs):
                plt.figure(figsize=(7,6))
                obs_i, pred_i, gt_i = obss[idx], preds[idx], gts[idx]
                plt.plot(obs_i[:,0], obs_i[:,1], 'b-o', label='Observed')
                plt.plot(pred_i[:,0], pred_i[:,1], 'r-s', label='Predicted')
                plt.plot(gt_i[:,0], gt_i[:,1], 'g-^', label='Ground Truth')
                plt.legend(); plt.grid(True, alpha=0.3); plt.axis('equal')
                plt.title(f'Trajectory sample {idx}')
                out = RESULTS_DIR / f'trajectory_{idx}.png'
                plt.savefig(out, dpi=150, bbox_inches='tight'); plt.close()
            print('Saved visualizations to', RESULTS_DIR)
        except Exception as e:
            print('Visualization skipped:', e)
    return {'ade': float(final_ade), 'fde': float(final_fde), 'num_samples': int(n)}

In [None]:
# Helper: Preprocess all ETH/UCY .txt under DATASETS_DIR
from typing import List

def preprocess_all(input_dir: Path = DATASETS_DIR, output_dir: Path = DATA_DIR, obs_len=8, pred_len=12) -> int:
    txt_files: List[Path] = [p for p in input_dir.rglob('*.txt')]
    print('Found', len(txt_files), '.txt files')
    ok = 0
    for p in txt_files:
        rel = p.relative_to(input_dir).as_posix().replace('/', '_')
        out = output_dir / (Path(rel).stem + '.npz')
        try:
            if preprocess_scene(p, out, obs_len, pred_len):
                ok += 1
        except Exception as e:
            print('Failed:', p, e)
    print('Preprocessed OK:', ok)
    return ok

# If no data present, create dummy sample for quick run
if len(list(DATA_DIR.glob('*.npz'))) == 0:
    print('No .npz found, creating dummy_data.npz for a smoke test...')
    N = 5; T = CFG.model.obs_len + CFG.model.pred_len
    trajs = np.cumsum(np.random.randn(N, T, 2).astype(np.float32)*0.2, axis=1)
    np.savez(DATA_DIR / 'dummy_data.npz', trajectories=trajs)
    print('Dummy data saved to', DATA_DIR / 'dummy_data.npz')

In [None]:
# Train (set epochs higher on Colab GPU)
CFG.training.epochs = 3  # Increase to 30+ for real training
best_ckpt = train_model(CFG)
best_ckpt

In [None]:
# Evaluate best checkpoint
if best_ckpt is not None:
    results = evaluate_model(best_ckpt, DATA_DIR, batch_size=32, device_arg=CFG.training.device, visualize=True, num_samples=6)
    results
else:
    print('No checkpoint to evaluate.')