In [1]:
import numpy as np
import uproot
import awkward as ak
import matplotlib.pyplot as plt

from torch.utils.data import Dataset
from torch import Tensor
from torch import save, load
import torch.nn as nn

In [2]:
events = uproot.open("/home/gizago/dev/bxad-test/data/output_1000_seq3s.root:L1BMTFStubSequences").arrays()
df = ak.to_dataframe(events)

max_stubs_per_bx = 4

In [3]:
events.fields

['nL1BMTFStub',
 'L1BMTFStub_hwQual',
 'L1BMTFStub_hwPhi',
 'L1BMTFStub_hwPhiB',
 'L1BMTFStub_hwEta',
 'L1BMTFStub_hwQEta',
 'L1BMTFStub_wheel',
 'L1BMTFStub_sector',
 'L1BMTFStub_station',
 'orbitNumber',
 'bunchCrossing',
 'sequenceIndex']

In [4]:
sequences = [g for _, g in df.groupby("sequenceIndex")]

In [5]:
def nmax_group_stub(group):
    return max([len(sg) for _, sg in group.groupby("bunchCrossing")])

max_nstub_sequence = list(map(nmax_group_stub, sequences))

In [19]:
seq = sequences[2]

for _, sg in seq.groupby("bunchCrossing"):
    sg = sg.drop(columns=["nL1BMTFStub", "bunchCrossing", "sequenceIndex", "orbitNumber"])
    nrow, ncol = sg.to_numpy().shape
    zeros = np.zeros((max_stubs_per_bx, ncol), dtype=np.int32)
    zeros[:nrow, :ncol] = sg.to_numpy()
    print(zeros)

[[    6   861   220    16     0    -1     2     1]
 [    5 -1005   146     4     0    -1     3     2]
 [    2  -819    -1     0     0    -1     3     4]
 [    0     0     0     0     0     0     0     0]
 [    0     0     0     0     0     0     0     0]
 [    0     0     0     0     0     0     0     0]]
[[   6   28  275    8    8   -1    9    1]
 [   2  444 -209    0    0   -1    9    3]
 [   0    0    0    0    0    0    0    0]
 [   0    0    0    0    0    0    0    0]
 [   0    0    0    0    0    0    0    0]
 [   0    0    0    0    0    0    0    0]]
[[   5  466 -456    0    0   -2    4    2]
 [   0    0    0    0    0    0    0    0]
 [   0    0    0    0    0    0    0    0]
 [   0    0    0    0    0    0    0    0]
 [   0    0    0    0    0    0    0    0]
 [   0    0    0    0    0    0    0    0]]


In [None]:
class SequenceDataset(Dataset):
    def __init__(self, file):
        super().__init__()
        self.seq_length = 3
        self.max_stubs_per_bx = 3
        self.columns_to_drop = ["nL1BMTFStub", "bunchCrossing", "sequenceIndex", "orbitNumber"]
        self.ncolumns_to_drop = len(self.columns_to_drop)

        events = uproot.open(file + ":L1BMTFStubSequences").arrays()
        self.columns = events.fields
        self.ncolumns = len(self.columns)

        df = ak.to_dataframe(events)
        self.sequences = [g for _, g in df.groupby("sequenceIndex")]
        
    def __getitem__(self, index):
        sequence = self.sequences[index]

        array = np.empty(
            (
                self.seq_length, 
                self.max_stubs_per_bx, 
                self.ncolumns - self.ncolumns_to_drop
            ), 
            dtype=np.int32
        )        

        for idx, (_, sg) in enumerate(sequence.groupby("bunchCrossing")):
            sg = sg.drop(columns=self.columns_to_drop)
            nrow, _ = sg.to_numpy().shape
            zeros = np.zeros(
                (
                    self.max_stubs_per_bx, 
                    self.ncolumns - self.ncolumns_to_drop
                ), 
                dtype=np.int32
            )
            zeros[:nrow, :] = sg.to_numpy()
            array[idx, :, :] = zeros

        return array
        
    def __len__(self):
        return len(self.sequences)

In [31]:
dataset = SequenceDataset("/home/gizago/dev/bxad-test/data/output_1000_seq3s.root")

In [37]:
dataset[6]

array([[[   6,  392, -173,   16,   16,    1,    1,    1],
        [   6,  189, -122,    4,    4,    1,    1,    2],
        [   6,   94,  -34,    1,    1,    1,    1,    3],
        [   0,    0,    0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0,    0,    0]],

       [[   3,  -10,    2,    0,    0,    0,    4,    1],
        [   0,    0,    0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0,    0,    0]],

       [[   4, 1195,  200,    0,    0,    0,    4,    1],
        [   0,    0,    0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0,    0,    0],
        [ 