# Training code for ParT model

In [5]:
import sys
from pathlib import Path

sys.path.insert(0, str(Path.cwd()))
print(sys.path[0])

/home/z.ling.865/smartpixel-brevitas/particle_transformer


In [1]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from tqdm import tqdm
from model import ParTModel



ImportError: attempted relative import with no known parent package

In [None]:
# Dataset definition

class ParticleDataset(Dataset):
    def __init__(self, n_samples=1000, T=128, in_features=8, num_classes=10):
        self.n = n_samples
        self.T = T
        self.in_features = in_features
        self.num_classes = num_classes

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        x = torch.randn(self.T, self.in_features)
        U = torch.zeros(self.T, self.T)          # placeholder interaction mask
        mask = torch.ones(self.T, dtype=torch.bool)
        y = torch.randint(0, self.num_classes, (1,)).item()
        return x, U, mask, y
    
# Training and evaluation functions
def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0

    for x, U, mask, y in tqdm(loader, leave=False):
        x = x.to(device)
        U = U.to(device)
        mask = mask.to(device)
        y = y.to(device)

        optimizer.zero_grad()

        logits = model(x, U, mask)        # (B, num_classes)
        loss = F.cross_entropy(logits, y) # softmax inside loss

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)

    return total_loss / len(loader.dataset)


@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0

    for x, U, mask, y in loader:
        x = x.to(device)
        U = U.to(device)
        mask = mask.to(device)
        y = y.to(device)

        logits = model(x, U, mask)
        preds = logits.argmax(dim=-1)

        correct += (preds == y).sum().item()
        total += y.numel()

    return correct / total



In [None]:
# Main training cell


# Hyperparameters
batch_size = 32
num_epochs = 20
learning_rate = 3e-4
num_classes = 10


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_ds = ParticleDataset(n_samples=2000, num_classes=num_classes)
val_ds   = ParticleDataset(n_samples=500,  num_classes=num_classes)

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=batch_size)

model = ParTModel(
    in_features=8,
    d_model=128,
    num_heads=8,
    num_classes=num_classes,
    w_bit_width=8,
    a_bit_width=8,
    pab_num=8,
    cab_num=2,
).to(device)

# Optimizer
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-4)


# Training loop

for epoch in range(1, num_epochs + 1):
    train_loss = train_one_epoch(model, train_loader, optimizer, device)
    val_acc = evaluate(model, val_loader, device)

    print(f"Epoch {epoch:02d} | Train loss: {train_loss:.4f} | Val acc: {val_acc:.4f}")

# Save the trained model
torch.save(model.state_dict(), "./trained_models/part_model.pth")

In [None]:
# Test

model.eval()
x, U, mask, _ = train_ds[0]
with torch.no_grad():
    logits = model(x.unsqueeze(0).to(device), U.unsqueeze(0).to(device), mask.unsqueeze(0).to(device))
    probs = torch.softmax(logits, dim=-1)

print("Predicted probabilities:", probs.cpu())