In [86]:
import torch
import joblib
import numpy as np
from torch import nn
import torch.optim as optim
from pathlib import Path
from data.data_loader import DroneGraphDataset
from models.pretrained_model_loader import load_pretrained_traj_model, extract_context_embeddings
from models.tgn import DroneRelationModel

In [87]:
dataset = DroneGraphDataset(
    trajectory_csv='data/drone_states.csv',
    relationship_csv='data/drone_relations.csv',
    lookback=50,
    device='cuda'
)

sample = dataset[0]
print(sample['context_window'].shape)    # [50, num_drones, 4]
print(sample['current_features'].shape)  # [num_drones, 4]
print(sample['relationships'].shape)     # [num_pairs, 2]
print(sample['labels'].shape)            # [num_pairs]


torch.Size([50, 6, 4])
torch.Size([6, 4])
torch.Size([9, 2])
torch.Size([9])


In [88]:
# Example usage
experiment_dir = Path("experiments/20251015_134311")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model + config
model, config = load_pretrained_traj_model(experiment_dir, device)

# Load scalers
scaler_X = joblib.load(experiment_dir / "scaler_X.pkl")

# Dummy trajectory data (replace with your actual drone trajectory segment)
dummy_data = np.random.rand(100, 6 * 3).astype(np.float32)  # 100 timesteps, 6 drones, xyz

# Extract embeddings
context_emb = extract_context_embeddings(
    model,
    traj_data=dummy_data,
    scaler_X=scaler_X,
    lookback=config["LOOK_BACK"],
    features_per_agent=3,
    device=device,
)

print("Context embeddings shape:", context_emb.shape)

Context embeddings shape: torch.Size([6, 128])


  model.load_state_dict(torch.load(MODEL_PATH, map_location=device))


In [89]:
from torch.utils.data import Subset

flight_ids = dataset.flights
num_train = int(0.8 * len(flight_ids))
train_flights = flight_ids[:num_train]
test_flights = flight_ids[num_train:]

train_indices = [i for i, (fid, _) in enumerate(dataset.valid_indices) if fid in train_flights]
test_indices = [i for i, (fid, _) in enumerate(dataset.valid_indices) if fid in test_flights]

train_ds = Subset(dataset, train_indices)
test_ds = Subset(dataset, test_indices)

In [None]:
def train_epoch(model, loader, optimizer, criterion, scaler_X, pretrained_model, config):
    model.train()
    total_loss = 0
    for batch in loader:
        context_window = batch["context_window"].squeeze(0).cpu().numpy()  # [50, num_drones, 4]
        current_features = batch["current_features"].squeeze(0).to(model.device)
        print(f"current_features_before_squeeze: {current_features.shape}")
        relationships = batch["relationships"].squeeze(0).to(model.device)
        labels = batch["labels"].squeeze(0).float().to(model.device)

        # Extract context embeddings from pretrained model
        context_emb = extract_context_embeddings(
            pretrained_model,
            traj_data=context_window[:, :, :3].reshape(50, -1),
            scaler_X=scaler_X,
            lookback=config["LOOK_BACK"],
            features_per_agent=3,
            device=model.device,
        )

        preds = model(current_features, context_emb, relationships)
        loss = criterion(preds, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(loader)

In [None]:
@torch.no_grad()
def evaluate(model, loader, scaler_X, pretrained_model, config):
    model.eval()
    preds_all, labels_all = [], []
    for batch in loader:
        context_window = batch["context_window"].squeeze(0).cpu().numpy()
        current_features = batch["current_features"].squeeze(0).to(model.device)
        relationships = batch["relationships"].squeeze(0).to(model.device)
        labels = batch["labels"].squeeze(0).float().to(model.device)

        context_emb = extract_context_embeddings(
            pretrained_model,
            traj_data=context_window[:, :, :3].reshape(50, -1),
            scaler_X=scaler_X,
            lookback=config["LOOK_BACK"],
            features_per_agent=3,
            device=model.device,
        )

        preds = model(current_features, context_emb, relationships)
        preds_all.append(preds)
        labels_all.append(labels)

    preds_all = torch.cat(preds_all)
    labels_all = torch.cat(labels_all)
    acc = ((preds_all > 0.5) == labels_all).float().mean().item()
    return acc

In [92]:
# Load pretrained trajectory model
model_traj, config = load_pretrained_traj_model(experiment_dir, device)
scaler_X = joblib.load(experiment_dir / "scaler_X.pkl")

# Initialize new relation model
relation_model = DroneRelationModel(context_dim=model_traj.enc_hidden_size * 2, device=device).to(device)

optimizer = optim.Adam(relation_model.parameters(), lr=1e-3)
criterion = nn.BCELoss()

# Training loop
for epoch in range(10):
    loss = train_epoch(relation_model, train_ds, optimizer, criterion, scaler_X, model_traj, config)
    acc = evaluate(relation_model, test_ds, scaler_X, model_traj, config)
    print(f"Epoch {epoch+1}: Loss={loss:.4f}, Test Acc={acc:.3f}")


current_features_before_squeeze: torch.Size([6, 4])


  model.load_state_dict(torch.load(MODEL_PATH, map_location=device))


RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.