In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from tqdm import tqdm

# Assuming your model code is in ../models/
import sys, os
sys.path.append("..")

from models.track_classifier import TrackClassifier


In [None]:
class DummyTrackDataset(Dataset):
    """
    Generates random hit-level data grouped by track.
    Each track has 5–15 hits, each hit has `hit_input_dim` features.
    Each track also has `track_feat_dim` features and a binary label.
    """
    def __init__(self, n_tracks=2000, hit_input_dim=8, track_feat_dim=4, max_hits=15):
        super().__init__()
        self.hit_input_dim = hit_input_dim
        self.track_feat_dim = track_feat_dim

        self.tracks = []
        self.hits = []
        self.batch_idx = []
        self.labels = []

        for track_id in range(n_tracks):
            n_hits = np.random.randint(5, max_hits)
            hits = np.random.randn(n_hits, hit_input_dim).astype(np.float32)
            track_feats = np.random.randn(track_feat_dim).astype(np.float32)
            label = np.random.randint(0, 2)

            self.hits.append(hits)
            self.tracks.append(track_feats)
            self.labels.append(label)
            self.batch_idx.append(np.full(n_hits, track_id, dtype=np.int64))

        # Flatten for convenience
        self.all_hits = np.concatenate(self.hits, axis=0)
        self.all_batch_idx = np.concatenate(self.batch_idx, axis=0)
        self.all_tracks = np.stack(self.tracks, axis=0)
        self.all_labels = np.array(self.labels, dtype=np.float32)

    def __len__(self):
        return len(self.tracks)

    def __getitem__(self, idx):
        # Return per-track aggregated view
        mask = self.all_batch_idx == idx
        return (
            torch.tensor(self.all_hits[mask]),
            torch.tensor(self.all_tracks[idx]),
            torch.tensor(idx),
            torch.tensor(self.all_labels[idx])
        )

In [None]:
def collate_fn(batch):
    hits, tracks, batch_ids, labels = zip(*batch)
    hit_features = torch.cat(hits, dim=0)
    track_features = torch.stack(tracks, dim=0)
    batch_indices = torch.cat([
        torch.full((len(h),), i, dtype=torch.long)
        for i, h in enumerate(hits)
    ])
    labels = torch.stack(labels).float()
    return hit_features, track_features, batch_indices, labels

In [None]:
train_dataset = DummyTrackDataset(n_tracks=500)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = TrackClassifier(
    hit_input_dim=8,
    track_feat_dim=4,
    latent_dim=16,
    pooling_type="softmax",
    netA_hidden_dim=32,
    netA_hidden_layers=2,
    netB_hidden_dim=64,
    netB_hidden_layers=2
).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCELoss()

In [None]:
model.train()
for epoch in range(5):
    total_loss = 0
    for hit_features, track_features, batch_indices, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        hit_features = hit_features.to(device)
        track_features = track_features.to(device)
        batch_indices = batch_indices.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        preds = model(hit_features, track_features, batch_indices)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(labels)

    avg_loss = total_loss / len(train_dataset)
    print(f"Epoch {epoch+1}: loss = {avg_loss:.4f}")

In [None]:
model.eval()
with torch.no_grad():
    hit_features, track_features, batch_indices, labels = next(iter(train_loader))
    hit_features, track_features, batch_indices = (
        hit_features.to(device), track_features.to(device), batch_indices.to(device)
    )
    preds = model(hit_features, track_features, batch_indices)
    print("Predictions:", preds[:10].cpu().numpy())
    print("Labels:", labels[:10].numpy())