In [1]:
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
import pandas as pd
import matplotlib.pyplot as plt
from me0.data.datasets.index import ME0IndexDataset
from me0.modules.transformer.model import ME0Transformer
from me0.losses import ME0BCELoss

In [39]:
model = ME0Transformer(
    dim_input = 5, # [ieta, layer, strip, cls, bx]
    num_layers = 4,
    dim_model = 128,
    num_heads = 8,
    dim_feedforward = 512,
    layer_norm_eps=1.0e-5,
    dropout=0.0,
    activation = "gelu",
    norm_first = True,
    bias = False,
).to_tensor_dict_module().eval()



In [40]:
state_dict = torch.load("../data/ME0Transformer_state_dict.pt", map_location='cpu')

model.load_state_dict(state_dict)

<All keys matched successfully>

In [18]:
dataset = ME0IndexDataset(
    file='../data/step4_0.h5',
    features={
        'cls': {'min': 0, 'max': 383},
        'bx': {'min': -3, 'max': 3},
    },
)

test_loader = DataLoader(dataset[-128:], collate_fn=dataset.collate, shuffle=False, drop_last=True, batch_size=128)

processing 35430 events:   0%|                                                                                                                           | 1/35430 [00:01<17:57:53,  1.83s/it]


In [41]:
loss_fn = ME0BCELoss(pos_weight=74.77, reduction='mean1').to_tensor_dict_module()

batch = next(iter(test_loader))

with torch.inference_mode():
    output = model(batch)
    loss = loss_fn(output)['loss']
    
print(loss)

tensor(0.0115)
