In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from tqdm import tqdm

# Assuming your model code is in ../models/
import sys, os
sys.path.append("..")

from models.track_classifier import TrackClassifier


In [4]:
train_dataset = DummyTrackDataset(n_tracks=500)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

ERROR! Session/line number was not unique in database. History logging moved to new session 2


In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = TrackClassifier(
    hit_input_dim=8,
    track_feat_dim=4,
    latent_dim=16,
    pooling_type="sum",
    netA_hidden_dim=32,
    netA_hidden_layers=2,
    netB_hidden_dim=64,
    netB_hidden_layers=2
).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCELoss()

In [6]:
model.train()
for epoch in range(5):
    total_loss = 0
    for hit_features, track_features, batch_indices, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        hit_features = hit_features.to(device)
        track_features = track_features.to(device)
        batch_indices = batch_indices.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        preds = model(hit_features, track_features, batch_indices)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * len(labels)

    avg_loss = total_loss / len(train_dataset)
    print(f"Epoch {epoch+1}: loss = {avg_loss:.4f}")

Epoch 1: 100%|██████████| 16/16 [00:00<00:00, 30.97it/s]


Epoch 1: loss = 0.7183


Epoch 2: 100%|██████████| 16/16 [00:00<00:00, 155.81it/s]


Epoch 2: loss = 0.6792


Epoch 3: 100%|██████████| 16/16 [00:00<00:00, 152.84it/s]


Epoch 3: loss = 0.6604


Epoch 4: 100%|██████████| 16/16 [00:00<00:00, 165.28it/s]


Epoch 4: loss = 0.6484


Epoch 5: 100%|██████████| 16/16 [00:00<00:00, 155.14it/s]

Epoch 5: loss = 0.6277





In [7]:
model.eval()
with torch.no_grad():
    hit_features, track_features, batch_indices, labels = next(iter(train_loader))
    hit_features, track_features, batch_indices = (
        hit_features.to(device), track_features.to(device), batch_indices.to(device)
    )
    preds = model(hit_features, track_features, batch_indices)
    print("Predictions:", preds[:10].cpu().numpy())
    print("Labels:", labels[:10].numpy())

Predictions: [0.2575667  0.65084094 0.6044845  0.34727234 0.67349774 0.48047867
 0.30143496 0.51510495 0.4472016  0.54493636]
Labels: [0. 1. 1. 0. 1. 1. 0. 1. 0. 1.]


In [8]:
hit_features, track_features, batch_indices, labels = next(iter(train_loader))

In [9]:
hit_features.shape

torch.Size([310, 8])

In [10]:
track_features.shape

torch.Size([32, 4])

In [11]:
batch_indices.shape

torch.Size([310])

In [12]:
batch_indices

tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,  1,
         1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  3,  3,  3,  3,
         3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  4,  4,  4,  4,  4,  4,  4,  4,
         4,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  6,  6,  6,
         6,  6,  6,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  8,
         8,  8,  8,  8,  8,  8,  8,  9,  9,  9,  9,  9,  9,  9,  9,  9, 10, 10,
        10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 12, 12, 12, 12,
        12, 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 14,
        14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, 16, 16,
        16, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18,
        18, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 19, 19, 20, 20,
        20, 20, 20, 20, 20, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22,
        22, 22, 22, 23, 23, 23, 23, 23, 