In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

## Toy Data

In [2]:
x = torch.Tensor([[ 8.7541382e+01,  7.4368027e+01,  6.8198807e+01,  5.5106983e+01,
         3.8173203e+01,  2.3811323e+01,  1.8535612e+01,  1.8095873e+01],
       [-5.1816605e-02,  4.1964740e-01,  4.4769207e-01, -6.5377988e-02,
         6.1878175e-01,  6.2997532e-01, -8.1161506e-02, -7.9315454e-02],
       [ 2.4468634e+00,  2.2728169e+00,  2.2340939e+00,  2.6091626e+00,
         1.8117365e+00,  1.7758595e+00,  2.3642292e+00,  2.5458102e+00],
       [ 8.7658936e+01,  8.1014442e+01,  7.5148209e+01,  5.5224800e+01,
         4.5717682e+01,  2.8694658e+01,  1.8596695e+01,  1.8153358e+01]])

x_batch = torch.Tensor([[[ 8.7541382e+01,  7.4368027e+01,  6.8198807e+01,  5.5106983e+01,
          3.8173203e+01,  2.3811323e+01,  1.8535612e+01,  1.8095873e+01],
        [-5.1816605e-02,  4.1964740e-01,  4.4769207e-01, -6.5377988e-02,
          6.1878175e-01,  6.2997532e-01, -8.1161506e-02, -7.9315454e-02],
        [ 2.4468634e+00,  2.2728169e+00,  2.2340939e+00,  2.6091626e+00,
          1.8117365e+00,  1.7758595e+00,  2.3642292e+00,  2.5458102e+00],
        [ 8.7658936e+01,  8.1014442e+01,  7.5148209e+01,  5.5224800e+01,
          4.5717682e+01,  2.8694658e+01,  1.8596695e+01,  1.8153358e+01]],

       [[ 8.2484528e+01,  5.2682617e+01,  5.1243843e+01,  3.6217686e+01,
          2.8948278e+01,  2.6579512e+01,  2.1946012e+01,  2.1011120e+01],
        [-4.3566185e-01, -8.7309110e-01, -4.4896263e-01, -6.0569459e-01,
         -4.8134822e-01, -7.0045888e-01, -6.0671657e-01, -5.7662535e-01],
        [-1.9739739e+00, -2.4504409e+00, -1.9982951e+00, -1.4225215e+00,
         -1.9399333e+00, -2.3558097e+00, -1.4185165e+00, -1.4236869e+00],
        [ 9.0437065e+01,  7.4070679e+01,  5.6495895e+01,  4.3069641e+01,
          3.2367134e+01,  3.3371326e+01,  2.6115334e+01,  2.4602446e+01]],

       [[ 8.6492935e+01,  7.0192978e+01,  5.8423912e+01,  5.6638733e+01,
          4.9270725e+01,  4.1237038e+01,  3.6133625e+01,  3.5519596e+01],
        [ 1.4010678e-01,  2.7912292e-01,  1.4376265e-01,  3.4672296e-01,
          3.4966472e-01,  1.0524009e-01,  1.2958543e-01,  3.3264065e-01],
        [ 1.9334941e+00,  1.6967584e+00,  1.9219695e+00,  1.6735281e+00,
          1.6587850e+00,  1.8386338e+00,  1.9120301e+00,  1.6680365e+00],
        [ 8.7343246e+01,  7.2945129e+01,  5.9030762e+01,  6.0084766e+01,
          5.2313778e+01,  4.1465607e+01,  3.6440781e+01,  3.7506149e+01]]])

In [3]:
x = x.unsqueeze(0) # batch dimension
x.shape, x_batch.shape

(torch.Size([1, 4, 8]), torch.Size([3, 4, 8]))

## Quantum Function

In [4]:
import numpy as np
np.ComplexWarning = Warning

!pip install tensorcircuit
from typing import Callable

import tensorcircuit as tc
import jax.numpy as jnp
import flax.linen

import torch
import torch.nn as nn
import torch.nn.functional as F

import tensorcircuit as tc

K = tc.set_backend("jax")


def angle_embedding(c: tc.Circuit, inputs):
    num_qubits = inputs.shape[-1]

    for j in range(num_qubits):
        c.rx(j, theta=inputs[j])


def basic_vqc(c: tc.Circuit, inputs, weights):
    num_qubits = inputs.shape[-1]
    num_qlayers = weights.shape[-2]

    for i in range(num_qlayers):
        for j in range(num_qubits):
            c.rx(j, theta=weights[i, j])
        if num_qubits == 2:
            c.cnot(0, 1)
        elif num_qubits > 2:
            for j in range(num_qubits):
                c.cnot(j, (j + 1) % num_qubits)


def get_quantum_layer_circuit(inputs, weights,
                              embedding: Callable = angle_embedding, vqc: Callable = basic_vqc):
    """
    Equivalent to the following PennyLane circuit:
        def circuit(inputs, weights):
            qml.templates.AngleEmbedding(inputs, wires=range(num_qubits))
            qml.templates.BasicEntanglerLayers(weights, wires=range(num_qubits))
    """

    num_qubits = inputs.shape[-1]

    c = tc.Circuit(num_qubits)
    embedding(c, inputs)
    vqc(c, inputs, weights)

    return c


def get_circuit(embedding: Callable = angle_embedding, vqc: Callable = basic_vqc,
                torch_interface: bool = False):
    def qpred(inputs, weights):
        c = get_quantum_layer_circuit(inputs, weights, embedding, vqc)
        return K.real(jnp.array([c.expectation_ps(z=[i]) for i in range(weights.shape[1])]))

    qpred_batch = K.vmap(qpred, vectorized_argnums=0)
    if torch_interface:
        qpred_batch = tc.interfaces.torch_interface(qpred_batch, jit=True)

    return qpred_batch


class QuantumLayer(flax.linen.Module):
    circuit: Callable
    num_qubits: int
    w_shape: tuple = (1,)

    @flax.linen.compact
    def __call__(self, x):
        shape = x.shape
        x = jnp.reshape(x, (-1, shape[-1]))
        w = self.param('w', flax.linen.initializers.xavier_normal(), self.w_shape + (self.num_qubits,))
        x = self.circuit(x, w)
        x = jnp.concatenate(x, axis=-1)
        x = jnp.reshape(x, tuple(shape))
        return x



NUM_QUBITS     = 8
NUM_Q_LAYERS   = 1
torch_layer_fn = get_circuit(torch_interface=True)


class TCTorchLayer(nn.Module):
    """
    A thin PyTorch wrapper around the TensorCircuit/TC quantum layer.
    Stores the circuit's trainable parameters as an nn.Parameter so
    they appear in .parameters() and get updated by any torch optimizer.
    """
    def __init__(self, num_qubits=NUM_QUBITS, num_qlayers=NUM_Q_LAYERS):
        super().__init__()
        init_w = 0.01 * torch.randn(num_qlayers, num_qubits)
        self.w = nn.Parameter(init_w)
        self.num_qubits = num_qubits

    def forward(self, x):
        """
        x: (batch, num_qubits) – already pre-scaled into rotation angles.
        Returns expectation values ⟨Z_i⟩ for every qubit i, shape identical
        to the input (batch, num_qubits).
        """
        return torch_layer_fn(x, self.w)


class QuantumLinear(nn.Module):
    """
    Linear -> angle map -> TCTorchLayer -> Linear
    Works on tensors shaped (..., din) and returns (..., dout).
    """
    def __init__(self, din, dout, num_qubits):
        super().__init__()
        self.din  = din
        self.dout = dout
        self.nq   = num_qubits

        self.to_q   = nn.Linear(din,  self.nq, bias=False)
        self.from_q = nn.Linear(self.nq, dout, bias=False)
        self.q = TCTorchLayer(self.nq)

    @staticmethod
    def _to_angles(x):
        return torch.tanh(x) * math.pi

    def forward(self, x):
        # x: (..., din)
        *prefix, _ = x.shape
        x = x.reshape(-1, self.din)

        x = self.to_q(x)
        x = self._to_angles(x)
        x = self.q(x).float()
        x = self.from_q(x)

        x = x.reshape(*prefix, self.dout)
        return x

Collecting tensorcircuit
  Downloading tensorcircuit-0.12.0-py3-none-any.whl.metadata (29 kB)
Collecting tensornetwork-ng (from tensorcircuit)
  Downloading tensornetwork_ng-0.5.1-py3-none-any.whl.metadata (7.0 kB)
Downloading tensorcircuit-0.12.0-py3-none-any.whl (342 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m342.0/342.0 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading tensornetwork_ng-0.5.1-py3-none-any.whl (244 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m244.1/244.1 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensornetwork-ng, tensorcircuit
Successfully installed tensorcircuit-0.12.0 tensornetwork-ng-0.5.1




## Toy interaction matrix

In [5]:
class InteractionEncoder(nn.Module):
    """
    ParT interaction-feature encoder.

    Args
    ----
    n_heads per mhsa: output channels d′
    hidden_channels : list[int] for intermediate 1×1 conv layers
    eps             : numerical guard for log
    """

    def __init__(self,
                 n_heads: int = 8,
                 hidden_channels: list[int] = (64, 64, 64),
                 eps: float = 1e-8):
        super().__init__()
        self.eps = eps

        layers: list[nn.Module] = []
        in_ch = 4                               # lnΔ, ln kT, ln z, ln m²
        for h in hidden_channels:
            layers += [
                nn.Conv2d(in_ch, h, 1, bias=False),
                nn.BatchNorm2d(h),
                nn.GELU()
            ]
            in_ch = h
        layers.append(nn.Conv2d(in_ch, n_heads, 1, bias=False))
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x : (B, 4, N)  where the 4 dims are (E, px, py, pz)
        returns
        ------
        U : (B, n_heads, N, N)  interaction embedding
        """
        B, four, N = x.shape
        assert four == 4, "input must have 4 features: E, px, py, pz"

        # Split components
        E, px, py, pz = x.unbind(dim=1)         # each (B, N)

        # Basic kinematics ------------------------------------------------
        pT = torch.sqrt(px**2 + py**2) + self.eps
        phi = torch.atan2(py, px)               # (−π, π]
        num = (E + pz).clamp(min=self.eps)  #need to avoid negative numbers
        den = (E - pz).clamp(min=self.eps)
        y   = 0.5 * torch.log(num / den)

        # Expand to (B, N, N)
        y_a, y_b = y.unsqueeze(2), y.unsqueeze(1)          # (B,N,1),(B,1,N)
        phi_a, phi_b = phi.unsqueeze(2), phi.unsqueeze(1)
        pT_a, pT_b = pT.unsqueeze(2), pT.unsqueeze(1)
        E_a, E_b = E.unsqueeze(2), E.unsqueeze(1)
        px_a, px_b = px.unsqueeze(2), px.unsqueeze(1)
        py_a, py_b = py.unsqueeze(2), py.unsqueeze(1)
        pz_a, pz_b = pz.unsqueeze(2), pz.unsqueeze(1)

        # ΔR, kT, z
        delta = torch.sqrt((y_a - y_b) ** 2 + (phi_a - phi_b) ** 2) + self.eps
        kT = torch.minimum(pT_a, pT_b) * delta
        z = torch.minimum(pT_a, pT_b) / (pT_a + pT_b + self.eps)

        # m² of pair
        E_sum = E_a + E_b
        px_sum = px_a + px_b
        py_sum = py_a + py_b
        pz_sum = pz_a + pz_b
        m2 = E_sum**2 - (px_sum**2 + py_sum**2 + pz_sum**2) + self.eps
        m2 = torch.clamp(m2, min=self.eps)      # avoid negatives

        # Stack → (B, 4, N, N)
        feats = torch.stack([
            torch.log(delta),
            torch.log(kT),
            torch.log(z),
            torch.log(m2)
        ], dim=1)

        # conv
        U = self.net(feats)                     # (B, n_heads, N, N)
        return U



B, _, N = x.shape
n_heads = 2          # d′
enc = InteractionEncoder(n_heads=n_heads)
U = enc(x)
print("U.shape:", U.shape)

U.shape: torch.Size([1, 2, 8, 8])


## Particle Transformer

In [6]:
class ParticleTokenizer(nn.Module):
    def __init__(self, in_dim=4, out_dim=6):
        super().__init__()
        self.proj = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        """
        x: tensor of shape (B, n_particles, in_dim)
        returns: (B, n_particles, out_dim)
        """
        x = x.transpose(1, 2)  # Input shape: (B, n_particles, in_dim) → (B, in_dim, n_particles)
        return self.proj(x)

tokenizer = ParticleTokenizer(4, 10)
output = tokenizer(x)
output.shape

torch.Size([1, 8, 10])

In [7]:
class MLP(nn.Module):
    """
    Same interface as your tiny MLP, but nn.Linear -> QuantumLinear.
    Works for inputs shaped (..., dim).

    Args:
        dim         : feature size
        dropout     : dropout prob
        num_qubits  : qubits per QuantumLinear block (defaults to dim)
    """
    def __init__(self, dim, dropout=0., num_qubits=None):
        super().__init__()
        nq = num_qubits if num_qubits is not None else dim

        self.fc1 = QuantumLinear(dim, dim, nq)
        self.fc2 = QuantumLinear(dim, dim, nq)

        self.act  = nn.GELU()
        self.do1  = nn.Dropout(dropout)
        self.do2  = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.do1(x)

        x = self.fc2(x)
        x = self.do2(x)
        return x

# usage
mlp = MLP(10, dropout=0.1)
output = mlp(output)
output.shape

torch.Size([1, 8, 10])

In [8]:
class ParticleMHA(nn.Module):
    """
    Multi-head self-attention with quantum projections (q, k, v, o).

    Args
    ----
    d            : embedding dim
    heads        : number of attention heads
    dropout      : dropout prob on attn weights
    return_attn  : return attention maps?
    num_qubits   : qubits per quantum block (defaults to d)
    """
    def __init__(self, d: int, heads: int = 8,
                 dropout: float = 0.1, return_attn: bool = False,
                 num_qubits: int | None = None):
        super().__init__()
        assert d % heads == 0, "`d` must be divisible by `heads`"

        self.d           = d
        self.h           = heads
        self.d_head      = d // heads
        self.scale       = 1 / math.sqrt(self.d_head)
        self.return_attn = return_attn

        nq = num_qubits if num_qubits is not None else d

        # quantum projections
        self.q_proj = QuantumLinear(d, d, nq)
        self.k_proj = QuantumLinear(d, d, nq)
        self.v_proj = QuantumLinear(d, d, nq)
        self.o_proj = QuantumLinear(d, d, nq)

        self.drop = nn.Dropout(dropout)

    def _split(self, t: torch.Tensor):
        # (B, N, d) -> (B, H, N, d_head)
        B, N, _ = t.shape
        return t.view(B, N, self.h, self.d_head).transpose(1, 2)

    def forward(self, x: torch.Tensor, U: torch.Tensor | None = None):
        B, N, _ = x.shape

        Q = self._split(self.q_proj(x))
        K = self._split(self.k_proj(x))
        V = self._split(self.v_proj(x))

        logits = (Q @ K.transpose(-2, -1)) * self.scale  # (B, H, N, N)

        if U is not None:
            logits = logits + U

        attn = F.softmax(logits, dim=-1)
        attn = self.drop(attn)

        context = attn @ V  # (B, H, N, d_head)

        context = (
            context.transpose(1, 2)   # (B, N, H, d_head)
                   .contiguous()
                   .view(B, N, self.d)
        )
        out = self.o_proj(context)

        if self.return_attn:
            return out, attn
        else:
            return out

B, N, d = output.shape

U = torch.randn(1, 2, N, N)  # broadcast to (B, H, N, N)

pmha = ParticleMHA(d=d, heads=2, dropout=0.1, return_attn=True)
output, A = pmha(output, U)          # out: (B, N, d)   A: (B, 8, N, N)

print(output.shape, A.shape)

torch.Size([1, 8, 10]) torch.Size([1, 2, 8, 8])


## transformer

In [9]:
class MHA(nn.Module):
    """
    Multi-head attention (batch_first) with QuantumLinear projections.

    Args
    ----
    d_model : int          embedding dim
    n_heads : int
    dropout: float
    bias   : bool          (ignored here, QuantumLinear has no bias)
    num_qubits : int|None  qubits per quantum block (defaults to d_model)
    """
    def __init__(self, d_model: int, n_heads: int,
                 dropout: float = 0., bias: bool = False,
                 num_qubits: int | None = None):
        super().__init__()
        assert d_model % n_heads == 0, "`d_model` must be divisible by `n_heads`"
        self.d_model = d_model
        self.h       = n_heads
        self.d_head  = d_model // n_heads
        self.scale   = self.d_head ** -0.5

        nq = num_qubits if num_qubits is not None else d_model

        # Quantum projections replace nn.Linear
        self.q_proj = QuantumLinear(d_model, d_model, nq)
        self.k_proj = QuantumLinear(d_model, d_model, nq)
        self.v_proj = QuantumLinear(d_model, d_model, nq)
        self.o_proj = QuantumLinear(d_model, d_model, nq)

        self.drop = nn.Dropout(dropout)

    def _split_heads(self, x: torch.Tensor):
        # (B, L, d_model) -> (B, h, L, d_head)
        B, L, _ = x.shape
        return x.view(B, L, self.h, self.d_head).transpose(1, 2)

    def _merge_heads(self, x: torch.Tensor):
        # (B, h, L, d_head) -> (B, L, d_model)
        B, H, L, Dh = x.shape
        return x.transpose(1, 2).contiguous().view(B, L, H * Dh)

    def forward(
        self,
        q: torch.Tensor,          # (B, Lq, d_model)
        k: torch.Tensor,          # (B, Lk, d_model)
        v: torch.Tensor,          # (B, Lk, d_model)
        attn_mask: torch.Tensor | None = None,
        key_padding_mask: torch.Tensor | None = None,
        need_weights: bool = False
    ):
        B, Lq, _ = q.shape
        _, Lk, _ = k.shape

        Q = self._split_heads(self.q_proj(q))  # (B,h,Lq,d_h)
        K = self._split_heads(self.k_proj(k))  # (B,h,Lk,d_h)
        V = self._split_heads(self.v_proj(v))  # (B,h,Lk,d_h)

        logits = torch.matmul(Q, K.transpose(-2, -1)) * self.scale  # (B,h,Lq,Lk)

        attn = F.softmax(logits, dim=-1)
        attn = self.drop(attn)

        context = torch.matmul(attn, V)        # (B,h,Lq,d_h)

        out = self.o_proj(self._merge_heads(context))  # (B,Lq,d_model)

        if need_weights:
            return out, attn.mean(dim=1)  # (B,Lq,Lk)
        return out, None

In [10]:
# Particle attention block  (NormFormer style + U-bias)
class ParticleAttentionBlock(nn.Module):
    def __init__(self, dim, heads, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = ParticleMHA(dim, heads, dropout)
        self.ln2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, dropout)
    def forward(self, x, U):
        x = x + self.attn(self.ln1(x), U)    # bias-aware MHSA
        x = x + self.mlp(self.ln2(x))        # feed-forward
        return x

# Class attention block  (CaiT style, no U)
class ClassAttentionBlock(nn.Module):
    def __init__(self, dim, heads, mlp_ratio=4, dropout=0.):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = MHA(dim, heads, dropout)
        self.ln2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, dropout)
    def forward(self, tokens, cls):          # tokens: (B,N,d), cls: (B,1,d)
        z   = torch.cat([cls, tokens], dim=1)   # (B,1+N,d)
        q   = self.ln1(cls)
        kv  = self.ln1(z)
        cls = cls + self.attn(q, kv, kv, need_weights=False)[0]
        cls = cls + self.mlp(self.ln2(cls))
        return cls                             # (B,1,d)

# Complete Particle Transformer
class ParT(nn.Module):
    def __init__(self,
                 in_dim=4,          # (E,px,py,pz)
                 embed_dim=10,
                 n_heads=2,
                 depth=2,           # particle blocks
                 class_depth=2,     # class-attention blocks
                 mlp_ratio=4,
                 num_classes=10,
                 dropout=0.1):
        super().__init__()

        self.tokenizer = ParticleTokenizer(in_dim, embed_dim)
        self.U_encoder = InteractionEncoder(n_heads=n_heads)

        self.blocks = nn.ModuleList([
            ParticleAttentionBlock(embed_dim, n_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.cls_blocks = nn.ModuleList([
            ClassAttentionBlock(embed_dim, n_heads, mlp_ratio, 0.0)
            for _ in range(class_depth)
        ])

        self.head = nn.Linear(embed_dim, num_classes)

        # weight init
        nn.init.trunc_normal_(self.class_token, std=0.02)
        nn.init.trunc_normal_(self.head.weight,  std=0.02)
        nn.init.zeros_(self.head.bias)

    def forward(self, x):               # x: (B,4,N)
        B, _, N = x.shape

        tokens = self.tokenizer(x)                  # (B,N,d)
        U      = self.U_encoder(x)                  # (B,H,N,N)

        for blk in self.blocks:
            tokens = blk(tokens, U)                # (B,N,d)

        cls = self.class_token.expand(B, -1, -1)    # (B,1,d)
        for blk in self.cls_blocks:
            cls = blk(tokens, cls)                 # (B,1,d)

        logits = self.head(cls.squeeze(1))          # (B,10)
        return logits

In [11]:
B, _, N = x_batch.shape          # (3,4,8)
model = ParT(in_dim=4,
             embed_dim=10,
             n_heads=2,
             depth=2,
             class_depth=2,
             num_classes=10)

logits = model(x_batch)          # forward pass
print("logits:", logits.shape)   # torch.Size([3, 10])


logits: torch.Size([3, 10])


In [12]:
x_train = x_batch                    # (3, 4, 8)
y_train = torch.tensor([0, 1, 2])    # dummy class labels for testing

model = ParT(in_dim=4,
             embed_dim=10,
             n_heads=2,
             depth=2,
             class_depth=2,
             num_classes=10)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

n_epochs = 250
for epoch in range(n_epochs):
    model.train()
    logits = model(x_train)          # (3, 10)
    loss = criterion(logits, y_train)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # print every 5 epochs
    if (epoch+1) % 5 == 0 or epoch == 0:
        preds = logits.argmax(1)
        acc   = (preds == y_train).float().mean().item()
        print(f"epoch {epoch+1:3d}  loss {loss.item():.4f}  acc {acc:.3f}")

model.eval()
with torch.no_grad():
    probs = torch.softmax(model(x_train), dim=1)
print("softmax-probs\n", probs)


epoch   1  loss 2.3297  acc 0.000
epoch   5  loss 2.2604  acc 0.000
epoch  10  loss 2.1752  acc 0.333
epoch  15  loss 2.0782  acc 0.333
epoch  20  loss 1.9541  acc 0.333
epoch  25  loss 1.8064  acc 0.333
epoch  30  loss 1.6569  acc 0.333
epoch  35  loss 1.5183  acc 0.333
epoch  40  loss 1.3999  acc 0.333
epoch  45  loss 1.3081  acc 0.333
epoch  50  loss 1.2424  acc 0.333
epoch  55  loss 1.1977  acc 0.333
epoch  60  loss 1.1684  acc 0.333
epoch  65  loss 1.1496  acc 0.333
epoch  70  loss 1.1375  acc 0.333
epoch  75  loss 1.1293  acc 0.333
epoch  80  loss 1.1235  acc 0.333
epoch  85  loss 1.1192  acc 0.333
epoch  90  loss 1.1162  acc 0.333
epoch  95  loss 1.1138  acc 0.333
epoch 100  loss 1.1119  acc 0.333
epoch 105  loss 1.1105  acc 0.333
epoch 110  loss 1.1093  acc 0.333
epoch 115  loss 1.1083  acc 0.333
epoch 120  loss 1.1075  acc 0.333
epoch 125  loss 1.1067  acc 0.333
epoch 130  loss 1.1061  acc 0.333
epoch 135  loss 1.1056  acc 0.333
epoch 140  loss 1.1051  acc 0.333
epoch 145  los

## Load official Data

In [12]:
!git clone https://github.com/jet-universe/particle_transformer.git
!cd particle_transformer
!cd /content/particle_transformer
!touch env.sh
!chmod +x env.sh

Cloning into 'particle_transformer'...
remote: Enumerating objects: 101, done.[K
remote: Counting objects: 100% (52/52), done.[K
remote: Compressing objects: 100% (25/25), done.[K
remote: Total 101 (delta 38), reused 27 (delta 27), pack-reused 49 (from 1)[K
Receiving objects: 100% (101/101), 28.08 MiB | 14.94 MiB/s, done.
Resolving deltas: 100% (46/46), done.


In [14]:
!/content/particle_transformer/get_datasets.py JetClass -d ./datasets
!source env.sh
import os, glob, tarfile
os.environ['DATADIR_JetClass'] = os.path.abspath('./datasets/JetClass')
data_dir  = os.environ['DATADIR_JetClass']
!pip install awkward uproot vector
from particle_transformer.dataloader import read_file

tar_path = "/content/datasets/JetClass/JetClass_Pythia_val_5M.tar"

extract_dir = "/content/datasets/JetClass/JetClass_Pythia_val_5M"
os.makedirs(extract_dir, exist_ok=True)

Downloading data from https://zenodo.org/record/6619768/files/JetClass_Pythia_val_5M.tar to ./datasets/JetClass/JetClass_Pythia_val_5M.tar
./datasets/JetClass/JetClass_Pythia_val_5M.tar: 100% 7.07G/7.07G [10:19<00:00, 12.3MiB/s]
Updated dataset path in env.sh to "DATADIR_JetClass=./datasets/JetClass".
Collecting awkward
  Downloading awkward-2.8.5-py3-none-any.whl.metadata (6.9 kB)
Collecting uproot
  Downloading uproot-5.6.3-py3-none-any.whl.metadata (33 kB)
Collecting vector
  Downloading vector-1.6.3-py3-none-any.whl.metadata (16 kB)
Collecting awkward-cpp==47 (from awkward)
  Downloading awkward_cpp-47-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (2.1 kB)
Downloading awkward-2.8.5-py3-none-any.whl (886 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m886.8/886.8 kB[0m [31m54.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading awkward_cpp-47-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (638 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━

In [15]:
if not any(fname.endswith(".root") for fname in os.listdir(extract_dir)):
    print("⏬ extracting test-set…")
    with tarfile.open(tar_path) as tar:
        tar.extractall(path=extract_dir)
pattern = os.path.join(extract_dir, 'val_5M', "*.root")
files   = sorted(glob.glob(pattern))
print(f"Found {len(files)} ROOT files")

⏬ extracting test-set…
Found 50 ROOT files


In [17]:
import torch
from torch.utils.data import Dataset

all_x_parts = []
all_ys = []

num_file = 1
for file in files:
    num_file += 1
    if num_file % 5 == 0:
        x_part, x_jets, y = read_file(
            file,
            max_num_particles=8,
            particle_features=['part_pt', 'part_eta', 'part_phi', 'part_energy'],
            jet_features=['jet_pt', 'jet_eta', 'jet_phi', 'jet_energy'],
            labels=[
                'label_QCD', 'label_Hbb', 'label_Hcc', 'label_Hgg', 'label_H4q',
                'label_Hqql', 'label_Zqq', 'label_Wqq', 'label_Tbqq', 'label_Tbl',
            ]
        )
        all_x_parts.append(torch.tensor(x_part, dtype=torch.float32)[:100,:,:])
        all_ys.append(torch.tensor(y, dtype=torch.float32)[:100,:])

x_all = torch.cat(all_x_parts, dim=0)
y_all = torch.cat(all_ys, dim=0)
print(x_all.shape, y_all.shape)

class JetDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

dataset = JetDataset(x_all, y_all)

from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

torch.Size([1000, 4, 8]) torch.Size([1000, 10])


In [23]:
model = ParT(
    in_dim=4,           # part_pt, eta, phi, energy
    embed_dim=10,
    n_heads=2,
    depth=2,
    class_depth=2,
    num_classes=10
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [24]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [25]:
num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    for batch_idx, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        outputs = model(x)  # shape [batch, 10]

        loss = loss_fn(outputs, y)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

Epoch 1/100, Loss: 2.3056
Epoch 2/100, Loss: 2.3032
Epoch 3/100, Loss: 2.3018
Epoch 4/100, Loss: 2.2987
Epoch 5/100, Loss: 2.2935
Epoch 6/100, Loss: 2.2955
Epoch 7/100, Loss: 2.2838
Epoch 8/100, Loss: 2.2749
Epoch 9/100, Loss: 2.2692
Epoch 10/100, Loss: 2.2611
Epoch 11/100, Loss: 2.2566
Epoch 12/100, Loss: 2.2437
Epoch 13/100, Loss: 2.2414
Epoch 14/100, Loss: 2.2341
Epoch 15/100, Loss: 2.2255
Epoch 16/100, Loss: 2.2330
Epoch 17/100, Loss: 2.2221
Epoch 18/100, Loss: 2.2226
Epoch 19/100, Loss: 2.2205
Epoch 20/100, Loss: 2.2116
Epoch 21/100, Loss: 2.2099
Epoch 22/100, Loss: 2.2131
Epoch 23/100, Loss: 2.2063
Epoch 24/100, Loss: 2.2195
Epoch 25/100, Loss: 2.2166
Epoch 26/100, Loss: 2.2140
Epoch 27/100, Loss: 2.2045
Epoch 28/100, Loss: 2.2064
Epoch 29/100, Loss: 2.2087
Epoch 30/100, Loss: 2.2087
Epoch 31/100, Loss: 2.2096
Epoch 32/100, Loss: 2.2064
Epoch 33/100, Loss: 2.2051
Epoch 34/100, Loss: 2.2003
Epoch 35/100, Loss: 2.2138
Epoch 36/100, Loss: 2.2038
Epoch 37/100, Loss: 2.1994
Epoch 38/1

In [26]:
from torch.nn.functional import sigmoid, softmax

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        outputs = model(x)
        labels = torch.argmax(y, dim=1)          # convert one-hot to class id
        preds = torch.argmax(outputs, dim=1)     # predicted class
        correct += (preds == labels).sum().item()
        total += y.size(0)
accuracy = correct / total
print(f"Accuracy on full dataset: {accuracy:.4f}")


Accuracy on full dataset: 0.2010


In [27]:
from sklearn.metrics import roc_auc_score
import numpy as np

model.eval()
all_outputs = []
all_targets = []

with torch.no_grad():
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        outputs = model(x)
        all_outputs.append(outputs.cpu())
        all_targets.append(y.cpu())

# Concatenate batches
all_outputs = torch.cat(all_outputs, dim=0)
all_targets = torch.cat(all_targets, dim=0)

# Apply sigmoid (if using BCEWithLogitsLoss)
probs = sigmoid(all_outputs).numpy()  # shape: (N, C)
true = all_targets.numpy()            # shape: (N, C)

# Compute AUC for each class and average
try:
    auc_macro = roc_auc_score(true, probs, average='macro', multi_class='ovr')
    print(f"Macro-Averaged AUC: {auc_macro:.4f}")
except ValueError as e:
    print("AUC could not be computed:", e)


Macro-Averaged AUC: 0.6697
