In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

## Toy Data

In [2]:
x = torch.Tensor([[ 8.7541382e+01,  7.4368027e+01,  6.8198807e+01,  5.5106983e+01,
         3.8173203e+01,  2.3811323e+01,  1.8535612e+01,  1.8095873e+01],
       [-5.1816605e-02,  4.1964740e-01,  4.4769207e-01, -6.5377988e-02,
         6.1878175e-01,  6.2997532e-01, -8.1161506e-02, -7.9315454e-02],
       [ 2.4468634e+00,  2.2728169e+00,  2.2340939e+00,  2.6091626e+00,
         1.8117365e+00,  1.7758595e+00,  2.3642292e+00,  2.5458102e+00],
       [ 8.7658936e+01,  8.1014442e+01,  7.5148209e+01,  5.5224800e+01,
         4.5717682e+01,  2.8694658e+01,  1.8596695e+01,  1.8153358e+01]])

x_batch = torch.Tensor([[[ 8.7541382e+01,  7.4368027e+01,  6.8198807e+01,  5.5106983e+01,
          3.8173203e+01,  2.3811323e+01,  1.8535612e+01,  1.8095873e+01],
        [-5.1816605e-02,  4.1964740e-01,  4.4769207e-01, -6.5377988e-02,
          6.1878175e-01,  6.2997532e-01, -8.1161506e-02, -7.9315454e-02],
        [ 2.4468634e+00,  2.2728169e+00,  2.2340939e+00,  2.6091626e+00,
          1.8117365e+00,  1.7758595e+00,  2.3642292e+00,  2.5458102e+00],
        [ 8.7658936e+01,  8.1014442e+01,  7.5148209e+01,  5.5224800e+01,
          4.5717682e+01,  2.8694658e+01,  1.8596695e+01,  1.8153358e+01]],

       [[ 8.2484528e+01,  5.2682617e+01,  5.1243843e+01,  3.6217686e+01,
          2.8948278e+01,  2.6579512e+01,  2.1946012e+01,  2.1011120e+01],
        [-4.3566185e-01, -8.7309110e-01, -4.4896263e-01, -6.0569459e-01,
         -4.8134822e-01, -7.0045888e-01, -6.0671657e-01, -5.7662535e-01],
        [-1.9739739e+00, -2.4504409e+00, -1.9982951e+00, -1.4225215e+00,
         -1.9399333e+00, -2.3558097e+00, -1.4185165e+00, -1.4236869e+00],
        [ 9.0437065e+01,  7.4070679e+01,  5.6495895e+01,  4.3069641e+01,
          3.2367134e+01,  3.3371326e+01,  2.6115334e+01,  2.4602446e+01]],

       [[ 8.6492935e+01,  7.0192978e+01,  5.8423912e+01,  5.6638733e+01,
          4.9270725e+01,  4.1237038e+01,  3.6133625e+01,  3.5519596e+01],
        [ 1.4010678e-01,  2.7912292e-01,  1.4376265e-01,  3.4672296e-01,
          3.4966472e-01,  1.0524009e-01,  1.2958543e-01,  3.3264065e-01],
        [ 1.9334941e+00,  1.6967584e+00,  1.9219695e+00,  1.6735281e+00,
          1.6587850e+00,  1.8386338e+00,  1.9120301e+00,  1.6680365e+00],
        [ 8.7343246e+01,  7.2945129e+01,  5.9030762e+01,  6.0084766e+01,
          5.2313778e+01,  4.1465607e+01,  3.6440781e+01,  3.7506149e+01]]])

In [3]:
x = x.unsqueeze(0) # batch dimension
x.shape, x_batch.shape

(torch.Size([1, 4, 8]), torch.Size([3, 4, 8]))

## Toy interaction matrix

In [4]:
class InteractionEncoder(nn.Module):
    """
    ParT interaction-feature encoder.

    Args
    ----
    n_heads per mhsa: output channels d′
    hidden_channels : list[int] for intermediate 1×1 conv layers
    eps             : numerical guard for log
    """

    def __init__(self,
                 n_heads: int = 8,
                 hidden_channels: list[int] = (64, 64, 64),
                 eps: float = 1e-8):
        super().__init__()
        self.eps = eps

        layers: list[nn.Module] = []
        in_ch = 4                               # lnΔ, ln kT, ln z, ln m²
        for h in hidden_channels:
            layers += [
                nn.Conv2d(in_ch, h, 1, bias=False),
                nn.BatchNorm2d(h),
                nn.GELU()
            ]
            in_ch = h
        layers.append(nn.Conv2d(in_ch, n_heads, 1, bias=False))
        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x : (B, 4, N)  where the 4 dims are (E, px, py, pz)
        returns
        ------
        U : (B, n_heads, N, N)  interaction embedding
        """
        B, four, N = x.shape
        assert four == 4, "input must have 4 features: E, px, py, pz"

        # Split components
        E, px, py, pz = x.unbind(dim=1)         # each (B, N)

        # Basic kinematics ------------------------------------------------
        pT = torch.sqrt(px**2 + py**2) + self.eps
        phi = torch.atan2(py, px)               # (−π, π]
        num = (E + pz).clamp(min=self.eps)  #need to avoid negative numbers
        den = (E - pz).clamp(min=self.eps)
        y   = 0.5 * torch.log(num / den)

        # Expand to (B, N, N)
        y_a, y_b = y.unsqueeze(2), y.unsqueeze(1)          # (B,N,1),(B,1,N)
        phi_a, phi_b = phi.unsqueeze(2), phi.unsqueeze(1)
        pT_a, pT_b = pT.unsqueeze(2), pT.unsqueeze(1)
        E_a, E_b = E.unsqueeze(2), E.unsqueeze(1)
        px_a, px_b = px.unsqueeze(2), px.unsqueeze(1)
        py_a, py_b = py.unsqueeze(2), py.unsqueeze(1)
        pz_a, pz_b = pz.unsqueeze(2), pz.unsqueeze(1)

        # ΔR, kT, z
        delta = torch.sqrt((y_a - y_b) ** 2 + (phi_a - phi_b) ** 2) + self.eps
        kT = torch.minimum(pT_a, pT_b) * delta
        z = torch.minimum(pT_a, pT_b) / (pT_a + pT_b + self.eps)

        # m² of pair
        E_sum = E_a + E_b
        px_sum = px_a + px_b
        py_sum = py_a + py_b
        pz_sum = pz_a + pz_b
        m2 = E_sum**2 - (px_sum**2 + py_sum**2 + pz_sum**2) + self.eps
        m2 = torch.clamp(m2, min=self.eps)      # avoid negatives

        # Stack → (B, 4, N, N)
        feats = torch.stack([
            torch.log(delta),
            torch.log(kT),
            torch.log(z),
            torch.log(m2)
        ], dim=1)

        # conv
        U = self.net(feats)                     # (B, n_heads, N, N)
        return U



B, _, N = x.shape
n_heads = 2          # d′
enc = InteractionEncoder(n_heads=n_heads)
U = enc(x)
print("U.shape:", U.shape)

U.shape: torch.Size([1, 2, 8, 8])


## Particle Transformer

In [5]:
class ParticleTokenizer(nn.Module):
    def __init__(self, in_dim=4, out_dim=6):
        super().__init__()
        self.proj = nn.Linear(in_dim, out_dim)

    def forward(self, x):
        """
        x: tensor of shape (B, n_particles, in_dim)
        returns: (B, n_particles, out_dim)
        """
        x = x.transpose(1, 2)  # Input shape: (B, n_particles, in_dim) → (B, in_dim, n_particles)
        return self.proj(x)

tokenizer = ParticleTokenizer(4, 10)
output = tokenizer(x)
output.shape

torch.Size([1, 8, 10])

In [6]:
class MLP(nn.Module):
    def __init__(self, dim, expansion=1, dropout=0.):
        super().__init__()
        hidden = dim * expansion
        self.net = nn.Sequential(
            nn.Linear(dim, hidden), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(hidden, dim), nn.Dropout(dropout)
        )
    def forward(self, x): return self.net(x)

mlp = MLP(10, expansion=1, dropout=0.1)
output = mlp(output)
output.shape

torch.Size([1, 8, 10])

In [7]:
class ParticleMHA(nn.Module):
    """
    Multi-head self-attention with additive interaction bias U.

    Input
    -----
    x        : (B, N, d)        token / particle embeddings
    U        : (broadcast → B, H, N, N) or None

    Returns
    -------
    out      : (B, N, d)        attention output
    attn_map : (B, H, N, N)     attention weights (returned if
                                 return_attn=True)
    """
    def __init__(self, d: int, heads: int = 8,
                 dropout: float = 0.1, return_attn: bool = False):
        super().__init__()
        assert d % heads == 0, "`d` must be divisible by `heads`"

        self.d          = d
        self.h          = heads
        self.d_head     = d // heads
        self.scale      = 1 / math.sqrt(self.d_head)
        self.return_attn = return_attn

        # Projections
        self.q = nn.Linear(d, d, bias=False)
        self.k = nn.Linear(d, d, bias=False)
        self.v = nn.Linear(d, d, bias=False)
        self.o = nn.Linear(d, d, bias=False)

        self.drop = nn.Dropout(dropout)

    def _split(self, t: torch.Tensor):
        # (B, N, d) -> (B, H, N, d_head)
        B, N, _ = t.shape
        return (
            t.view(B, N, self.h, self.d_head)       # (B, N, H, d_head)
             .transpose(1, 2)                       # (B, H, N, d_head)
        )

    def forward(self, x: torch.Tensor,
                U: torch.Tensor | None = None):
        B, N, _ = x.shape

        Q = self._split(self.q(x))
        K = self._split(self.k(x))
        V = self._split(self.v(x))

        logits = (Q @ K.transpose(-2, -1)) * self.scale  # (B, H, N, N)

        if U is not None:
            logits = logits + U

        attn = F.softmax(logits, dim=-1)
        attn = self.drop(attn)

        context = attn @ V                               # (B, H, N, d_h)

        context = (
            context.transpose(1, 2)                      # (B, N, H, d_h)
                   .contiguous()
                   .view(B, N, self.d)                   # (B, N, d)
        )
        out = self.o(context)

        if self.return_attn:
            return out, attn        # (B, N, d), (B, H, N, N)
        else:
            return out

B, N, d = output.shape
U = torch.randn(1, 2, N, N)  # broadcast to (B, H, N, N)

pmha = ParticleMHA(d=d, heads=2, dropout=0.1, return_attn=True)
output, A = pmha(output, U)          # out: (B, N, d)  A: (B, n_heads, N, N)

print(output.shape, A.shape)

torch.Size([1, 8, 10]) torch.Size([1, 2, 8, 8])


## transformer

In [11]:
class MHA(nn.Module):
    """
    Multi-head attention (batch_first) implemented explicitly.

    Args
    ----
    d_model : int          embedding dim
    n_heads : int
    dropout: float
    bias   : bool          use bias in projections
    """
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0., bias: bool = False):
        super().__init__()
        assert d_model % n_heads == 0, "`d_model` must be divisible by `n_heads`"
        self.d_model  = d_model
        self.h        = n_heads
        self.d_head   = d_model // n_heads
        self.scale    = self.d_head ** -0.5

        self.q_proj = nn.Linear(d_model, d_model, bias=bias)
        self.k_proj = nn.Linear(d_model, d_model, bias=bias)
        self.v_proj = nn.Linear(d_model, d_model, bias=bias)
        self.o_proj = nn.Linear(d_model, d_model, bias=bias)

        self.drop = nn.Dropout(dropout)

    def _split_heads(self, x: torch.Tensor):
        # (B, L, d_model) -> (B, h, L, d_head)
        B, L, _ = x.shape
        return x.view(B, L, self.h, self.d_head).transpose(1, 2)

    def _merge_heads(self, x: torch.Tensor):
        # (B, h, L, d_head) -> (B, L, d_model)
        B, H, L, Dh = x.shape
        return x.transpose(1, 2).contiguous().view(B, L, H * Dh)

    def forward(
        self,
        q: torch.Tensor,          # (B, Lq, d_model)
        k: torch.Tensor,          # (B, Lk, d_model)
        v: torch.Tensor,          # (B, Lk, d_model)
        need_weights: bool = False
    ):
        B, Lq, _ = q.shape
        _, Lk, _ = k.shape

        Q = self._split_heads(self.q_proj(q))    # (B,h,Lq,d_h)
        K = self._split_heads(self.k_proj(k))    # (B,h,Lk,d_h)
        V = self._split_heads(self.v_proj(v))    # (B,h,Lk,d_h)

        logits = torch.matmul(Q, K.transpose(-2, -1)) * self.scale   # (B,h,Lq,Lk)

        attn = F.softmax(logits, dim=-1)
        attn = self.drop(attn)

        context = torch.matmul(attn, V)          # (B,h,Lq,d_h)

        # merge heads + output proj
        out = self.o_proj(self._merge_heads(context))  # (B,Lq,d_model)

        if need_weights:
            avg_weights = attn.mean(dim=1)            # (B,Lq,Lk)
            return out, avg_weights
        return out, None

In [15]:
# Particle attention block  (NormFormer style + U-bias)
class ParticleAttentionBlock(nn.Module):
    def __init__(self, dim, heads, mlp_ratio=4, dropout=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = ParticleMHA(dim, heads, dropout)
        self.ln2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, mlp_ratio, dropout)
    def forward(self, x, U):
        x = x + self.attn(self.ln1(x), U)    # bias-aware MHSA
        x = x + self.mlp(self.ln2(x))        # feed-forward
        return x

# Class attention block  (CaiT style, no U)
class ClassAttentionBlock(nn.Module):
    def __init__(self, dim, heads, mlp_ratio=4, dropout=0.):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        self.attn = MHA(dim, heads, dropout)
        self.ln2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, mlp_ratio, dropout)
    def forward(self, tokens, cls):          # tokens: (B,N,d), cls: (B,1,d)
        z   = torch.cat([cls, tokens], dim=1)   # (B,1+N,d)
        q   = self.ln1(cls)
        kv  = self.ln1(z)
        cls = cls + self.attn(q, kv, kv, need_weights=False)[0]
        cls = cls + self.mlp(self.ln2(cls))
        return cls                             # (B,1,d)

# Complete Particle Transformer
class ParT(nn.Module):
    def __init__(self,
                 in_dim=4,          # (E,px,py,pz)
                 embed_dim=10,
                 n_heads=2,
                 depth=2,           # particle blocks
                 class_depth=2,     # class-attention blocks
                 mlp_ratio=4,
                 num_classes=10,
                 dropout=0.1):
        super().__init__()

        self.tokenizer = ParticleTokenizer(in_dim, embed_dim)
        self.U_encoder = InteractionEncoder(n_heads=n_heads)

        self.blocks = nn.ModuleList([
            ParticleAttentionBlock(embed_dim, n_heads, mlp_ratio, dropout)
            for _ in range(depth)
        ])

        self.class_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.cls_blocks = nn.ModuleList([
            ClassAttentionBlock(embed_dim, n_heads, mlp_ratio, 0.0)
            for _ in range(class_depth)
        ])

        self.head = nn.Linear(embed_dim, num_classes)

        nn.init.trunc_normal_(self.class_token, std=0.02)
        nn.init.trunc_normal_(self.head.weight,  std=0.02)
        nn.init.zeros_(self.head.bias)

    def forward(self, x):               # x: (B,4,N)
        B, _, N = x.shape

        tokens = self.tokenizer(x)                  # (B,N,d)
        U      = self.U_encoder(x)                  # (B,H,N,N)

        for blk in self.blocks:
            tokens = blk(tokens, U)                # (B,N,d)

        cls = self.class_token.expand(B, -1, -1)    # (B,1,d)
        for blk in self.cls_blocks:
            cls = blk(tokens, cls)                 # (B,1,d)

        logits = self.head(cls.squeeze(1))          # (B,10)
        return logits

In [16]:
B, _, N = x_batch.shape          # (3,4,8)
model = ParT(in_dim=4,
             embed_dim=10,
             n_heads=2,
             depth=2,
             class_depth=2,
             num_classes=10)

logits = model(x_batch)          # forward pass
print("logits:", logits.shape)   # -> torch.Size([3, 10])


logits: torch.Size([3, 10])


In [18]:
x_train = x_batch                    # (3, 4, 8)
y_train = torch.tensor([0, 1, 2])    # dummy class labels for testing

model = ParT(in_dim=4,
             embed_dim=10,
             n_heads=2,
             depth=2,
             class_depth=2,
             num_classes=10)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

n_epochs = 250
for epoch in range(n_epochs):
    model.train()
    logits = model(x_train)          # (3, 10)
    loss = criterion(logits, y_train)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # print every 50 epochs
    if (epoch+1) % 50 == 0 or epoch == 0:
        preds = logits.argmax(1)
        acc   = (preds == y_train).float().mean().item()
        print(f"epoch {epoch+1:3d}  loss {loss.item():.4f}  acc {acc:.3f}")

model.eval()
with torch.no_grad():
    probs = torch.softmax(model(x_train), dim=1)
print("softmax-probs\n", probs)

epoch   1  loss 2.3060  acc 0.000
epoch  50  loss 1.1657  acc 0.333
epoch 100  loss 1.0977  acc 0.667
epoch 150  loss 0.8846  acc 1.000
epoch 200  loss 0.0219  acc 1.000
epoch 250  loss 0.0040  acc 1.000
softmax-probs
 tensor([[9.9536e-01, 2.0938e-04, 3.2167e-03, 1.9185e-04, 1.4622e-04, 1.9616e-04,
         1.6393e-04, 1.4288e-04, 2.2374e-04, 1.5239e-04],
        [1.3350e-04, 9.9691e-01, 2.1701e-03, 1.5398e-04, 1.0111e-04, 1.5174e-04,
         7.6620e-05, 1.2269e-04, 7.1940e-05, 1.1269e-04],
        [1.4062e-03, 1.9185e-03, 9.9667e-01, 1.0295e-06, 7.4998e-07, 8.4284e-07,
         6.6928e-07, 5.7153e-07, 5.3092e-07, 3.6395e-07]])


## Load official Data

In [19]:
# 1) Clone repo
!git clone https://github.com/jet-universe/particle_transformer.git
!cd particle_transformer
!cd /content/particle_transformer
!touch env.sh
!chmod +x env.sh

Cloning into 'particle_transformer'...
remote: Enumerating objects: 101, done.[K
remote: Counting objects: 100% (52/52), done.[K
remote: Compressing objects: 100% (25/25), done.[K
remote: Total 101 (delta 38), reused 27 (delta 27), pack-reused 49 (from 1)[K
Receiving objects: 100% (101/101), 28.08 MiB | 12.26 MiB/s, done.
Resolving deltas: 100% (46/46), done.


In [20]:
!/content/particle_transformer/get_datasets.py JetClass -d ./datasets
!source env.sh
import os, glob, tarfile
os.environ['DATADIR_JetClass'] = os.path.abspath('./datasets/JetClass')
data_dir  = os.environ['DATADIR_JetClass']
!pip install awkward uproot vector
from particle_transformer.dataloader import read_file

#  Path to the one and only thing downloaded
tar_path = "/content/datasets/JetClass/JetClass_Pythia_val_5M.tar"
extract_dir = "/content/datasets/JetClass/JetClass_Pythia_val_5M"
os.makedirs(extract_dir, exist_ok=True)

Downloading data from https://zenodo.org/record/6619768/files/JetClass_Pythia_val_5M.tar to ./datasets/JetClass/JetClass_Pythia_val_5M.tar
./datasets/JetClass/JetClass_Pythia_val_5M.tar: 100% 7.07G/7.07G [06:36<00:00, 19.1MiB/s]
Updated dataset path in env.sh to "DATADIR_JetClass=./datasets/JetClass".
Collecting awkward
  Downloading awkward-2.8.5-py3-none-any.whl.metadata (6.9 kB)
Collecting uproot
  Downloading uproot-5.6.3-py3-none-any.whl.metadata (33 kB)
Collecting vector
  Downloading vector-1.6.3-py3-none-any.whl.metadata (16 kB)
Collecting awkward-cpp==47 (from awkward)
  Downloading awkward_cpp-47-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (2.1 kB)
Downloading awkward-2.8.5-py3-none-any.whl (886 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m886.8/886.8 kB[0m [31m47.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading awkward_cpp-47-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (638 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━

In [21]:
if not any(fname.endswith(".root") for fname in os.listdir(extract_dir)):
    print("⏬ extracting test-set…")
    with tarfile.open(tar_path) as tar:
        tar.extractall(path=extract_dir)

#  Point glob at the real ROOT files
pattern = os.path.join(extract_dir, 'val_5M', "*.root")
files   = sorted(glob.glob(pattern))
print(f"Found {len(files)} ROOT files")

⏬ extracting test-set…
Found 50 ROOT files


In [22]:
from torch.utils.data import Dataset
all_x_parts = []
all_ys = []

num_file = 1
for file in files:
    num_file += 1
    if num_file % 5 == 0:
        x_part, x_jets, y = read_file(
            file,
            max_num_particles=8,
            particle_features=['part_pt', 'part_eta', 'part_phi', 'part_energy'],
            jet_features=['jet_pt', 'jet_eta', 'jet_phi', 'jet_energy'],
            labels=[
                'label_QCD', 'label_Hbb', 'label_Hcc', 'label_Hgg', 'label_H4q',
                'label_Hqql', 'label_Zqq', 'label_Wqq', 'label_Tbqq', 'label_Tbl',
            ]
        )
        all_x_parts.append(torch.tensor(x_part, dtype=torch.float32)[:100,:,:])
        all_ys.append(torch.tensor(y, dtype=torch.float32)[:100,:])

x_all = torch.cat(all_x_parts, dim=0)
y_all = torch.cat(all_ys, dim=0)
print(x_all.shape, y_all.shape)

class JetDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

dataset = JetDataset(x_all, y_all)

from torch.utils.data import DataLoader
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

torch.Size([1000, 4, 8]) torch.Size([1000, 10])


In [28]:
model = ParT(
    in_dim=4,           # part_pt, eta, phi, energy
    embed_dim=10,
    n_heads=2,
    depth=2,
    class_depth=2,
    num_classes=10
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [29]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [30]:
num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    for batch_idx, (x, y) in enumerate(dataloader):
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        outputs = model(x)  # shape [batch, 10]

        loss = loss_fn(outputs, y)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")

Epoch 1/100, Loss: 2.3047
Epoch 2/100, Loss: 2.3033
Epoch 3/100, Loss: 2.3028
Epoch 4/100, Loss: 2.3024
Epoch 5/100, Loss: 2.3011
Epoch 6/100, Loss: 2.2968
Epoch 7/100, Loss: 2.2857
Epoch 8/100, Loss: 2.2636
Epoch 9/100, Loss: 2.2503
Epoch 10/100, Loss: 2.2394
Epoch 11/100, Loss: 2.2311
Epoch 12/100, Loss: 2.2308
Epoch 13/100, Loss: 2.2296
Epoch 14/100, Loss: 2.2269
Epoch 15/100, Loss: 2.2260
Epoch 16/100, Loss: 2.2261
Epoch 17/100, Loss: 2.2195
Epoch 18/100, Loss: 2.2221
Epoch 19/100, Loss: 2.2174
Epoch 20/100, Loss: 2.2180
Epoch 21/100, Loss: 2.2088
Epoch 22/100, Loss: 2.2136
Epoch 23/100, Loss: 2.2127
Epoch 24/100, Loss: 2.2058
Epoch 25/100, Loss: 2.2002
Epoch 26/100, Loss: 2.1939
Epoch 27/100, Loss: 2.1861
Epoch 28/100, Loss: 2.1853
Epoch 29/100, Loss: 2.1678
Epoch 30/100, Loss: 2.1570
Epoch 31/100, Loss: 2.1857
Epoch 32/100, Loss: 2.1693
Epoch 33/100, Loss: 2.1595
Epoch 34/100, Loss: 2.1459
Epoch 35/100, Loss: 2.1359
Epoch 36/100, Loss: 2.1423
Epoch 37/100, Loss: 2.1434
Epoch 38/1

In [31]:
from torch.nn.functional import sigmoid, softmax

model.eval()
correct = 0
total = 0

with torch.no_grad():
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        outputs = model(x)
        labels = torch.argmax(y, dim=1)          # convert one-hot to class id
        preds = torch.argmax(outputs, dim=1)     # predicted class
        correct += (preds == labels).sum().item()
        total += y.size(0)
accuracy = correct / total
print(f"Accuracy on full dataset: {accuracy:.4f}")


Accuracy on full dataset: 0.2280


In [32]:
from sklearn.metrics import roc_auc_score
import numpy as np

model.eval()
all_outputs = []
all_targets = []

with torch.no_grad():
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        outputs = model(x)
        all_outputs.append(outputs.cpu())
        all_targets.append(y.cpu())

# Concatenate batches
all_outputs = torch.cat(all_outputs, dim=0)
all_targets = torch.cat(all_targets, dim=0)

# Apply sigmoid (if using BCEWithLogitsLoss)
probs = sigmoid(all_outputs).numpy()  # shape: (N, C)
true = all_targets.numpy()            # shape: (N, C)

# Compute AUC for each class and average
try:
    auc_macro = roc_auc_score(true, probs, average='macro', multi_class='ovr')
    print(f"Macro-Averaged AUC: {auc_macro:.4f}")
except ValueError as e:
    print("AUC could not be computed:", e)


Macro-Averaged AUC: 0.6726
