In [1]:
import numpy as np
import uproot
import awkward as ak
import matplotlib.pyplot as plt

from torch.utils.data import Dataset
from torch import Tensor
from torch import save, load
import torch.nn as nn

In [30]:
class SequenceDataset(Dataset):
    def __init__(self, sequenceTree, max_stubs_per_bx=3):
        super().__init__()
        self.max_stubs_per_bx = max_stubs_per_bx
        
        meta = sequenceTree["L1BMTFStubSequencesMeta"].arrays()
        self.sequence_length = meta["sequenceLength"].to_numpy().item()

        arrays = sequenceTree["L1BMTFStubSequences"].arrays()
        self.columns_to_keep = arrays.fields.copy()
        self.columns_to_keep.remove("nL1BMTFStub")
        self.columns_to_keep.remove("orbitNumber")
        self.columns_to_keep.remove("bunchCrossing")
        self.columns_to_keep.remove("sequenceIndex")
        self.ncolumns_to_keep = len(self.columns_to_keep)

        self.mean = ak.mean(ak.ravel(arrays[self.columns_to_keep]))
        self.std = ak.std(ak.ravel(arrays[self.columns_to_keep]))

        df = ak.to_dataframe(arrays)
        df = df.drop(columns=["nL1BMTFStub", "orbitNumber"])
        self.sequences = [g for _, g in df.groupby("sequenceIndex")]
        
    def __getitem__(self, index):
        array = np.zeros(
            (
                self.sequence_length, 
                self.max_stubs_per_bx, 
                self.ncolumns_to_keep
            ), 
            dtype=np.int32
        )        

        sequence = self.sequences[index]

        for idx, (_, sg) in enumerate(sequence.groupby("bunchCrossing")):
            sg = sg.drop(columns=["sequenceIndex", "bunchCrossing"])
            sg_array = sg.to_numpy()
            nrow = sg_array.shape[0] if sg_array.shape[0] <= self.max_stubs_per_bx else self.max_stubs_per_bx
            array[idx, :nrow, :] = sg_array[:nrow, :]

        return (array - self.mean) / self.std
        
    def __len__(self):
        return len(self.sequences)

In [36]:
filepath = "/home/gizago/dev/bxad-test/data/output_1000_seq3s.root"
tree = uproot.open(filepath)
tree["L1BMTFStubSequences"].arrays().fields

['nL1BMTFStub',
 'L1BMTFStub_hwQual',
 'L1BMTFStub_hwPhi',
 'L1BMTFStub_hwPhiB',
 'L1BMTFStub_hwEta',
 'L1BMTFStub_hwQEta',
 'L1BMTFStub_wheel',
 'L1BMTFStub_sector',
 'L1BMTFStub_station',
 'orbitNumber',
 'bunchCrossing',
 'sequenceIndex']

In [31]:
dataset = SequenceDataset(tree)

In [34]:
dataset[45]

array([[[-0.07631915, -2.20948038, -0.77034167, -0.08909257,
         -0.08909257, -0.09335038, -0.05077231, -0.08057696],
        [-0.08909257, -0.08909257, -0.08909257, -0.08909257,
         -0.08909257, -0.08909257, -0.08909257, -0.08909257],
        [-0.08909257, -0.08909257, -0.08909257, -0.08909257,
         -0.08909257, -0.08909257, -0.08909257, -0.08909257]],

       [[-0.08057696,  5.35664239, -0.5446779 , -0.08909257,
         -0.08909257, -0.08483476, -0.07206134, -0.08483476],
        [-0.08909257, -0.08909257, -0.08909257, -0.08909257,
         -0.08909257, -0.08909257, -0.08909257, -0.08909257],
        [-0.08909257, -0.08909257, -0.08909257, -0.08909257,
         -0.08909257, -0.08909257, -0.08909257, -0.08909257]],

       [[-0.07631915,  3.90473025,  1.16270264, -0.08909257,
         -0.08909257, -0.09335038, -0.08057696, -0.08483476],
        [-0.08909257, -0.08909257, -0.08909257, -0.08909257,
         -0.08909257, -0.08909257, -0.08909257, -0.08909257],
        [-0.