In [None]:
# default_exp data

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

# Data

Collection of classes to handle data passing to network

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
# export
from pytorch_inferno.pseudodata import paper_sig, paper_bkg, PseudoData

from torch.utils.data import DataLoader
import torch
from torch import Tensor

from typing import Tuple, Union, Optional
from fastcore.all import store_attr, delegates
import numpy as np

In [None]:
# export
class DataSet():
    r'''Class holding indexable input, target and weight data'''
    def __init__(self, x:np.ndarray, y:Optional[np.ndarray]=None, w:Optional[np.ndarray]=None): store_attr()
    def __len__(self) -> int: return len(self.x)
    def __getitem__(self, i:int) -> Tuple[Tensor,Optional[Tensor],Optional[Tensor]]:
        return (Tensor(self.x[i]),
                Tensor(self.y[i]) if self.y is not None else None,
                Tensor(self.w[i]) if self.w is not None else None)

In [None]:
# export
class WeightedDataLoader(DataLoader):
    r'''PyTorch DataLoader with support for optional weights and targets'''
    @delegates(DataLoader, but=['collate_fn'])
    def __init__(self, dataset, **kwargs): super().__init__(dataset, collate_fn=self.collate_fn, **kwargs)
        
    @staticmethod
    def collate_fn(b:Tuple[Tensor,Optional[Tensor],Optional[Tensor]]) \
            -> Tuple[Tensor,Optional[Tensor],Optional[Tensor]]:
        x,y,w = zip(*b)
        return (torch.stack(x),
                torch.stack(y) if y[0] is not None else None,
                torch.stack(w) if w[0] is not None else None)

In [None]:
# export
class DataPair():
    r'''Single class of training and validation data to simplify passing data for model training'''
    def __init__(self, trn_dl:WeightedDataLoader, val_dl:WeightedDataLoader): store_attr()
    
    @property
    def trn_ds(self): return self.trn_dl.dataset
        
    @property
    def val_ds(self): return self.val_dl.dataset

In [None]:
n = 105
trn = PseudoData(paper_sig, 1).sample(n)
val = PseudoData(paper_sig, 1).sample(n)

In [None]:
trn

(array([[ 0.6703831 , -1.975327  ,  0.07786307],
        [-1.607151  , -0.21165201,  0.533054  ],
        [ 0.2890597 , -1.2329735 ,  0.07855688],
        [ 0.3148843 , -0.8077682 ,  0.01684092],
        [ 0.9850615 ,  0.9820619 ,  0.3748669 ],
        [-0.45198023, -0.43057093,  0.2734322 ],
        [ 0.9407473 , -1.4106951 ,  0.11927112],
        [ 2.3425531 ,  0.37897772,  0.23066421],
        [ 0.89105105,  0.01307287,  0.7219162 ],
        [-0.74365467,  0.8589957 ,  0.968216  ],
        [ 0.74442315, -0.7271336 ,  0.41546497],
        [-1.3487953 ,  0.8914203 ,  0.4051963 ],
        [-0.96267956, -1.0753933 ,  0.5675918 ],
        [ 1.3598816 , -0.44744036,  0.06143827],
        [ 0.59028363,  1.7174244 ,  1.6332926 ],
        [ 0.24950022, -2.111771  ,  0.16660859],
        [ 0.56447875,  0.92321175,  0.14592573],
        [-1.3905183 , -0.37362397,  1.3647213 ],
        [-1.5956727 , -1.1259795 ,  0.22540851],
        [-1.1854737 ,  0.5137138 ,  0.9980676 ],
        [ 0.49355116

In [None]:
trn_ds,val_ds = DataSet(*trn),DataSet(*val)

In [None]:
assert len(trn_ds) == n

In [None]:
trn_ds[1]

(tensor([-1.6072, -0.2117,  0.5331]), tensor([1.]), None)

In [None]:
trn_dl = WeightedDataLoader(trn_ds, batch_size=10, shuffle=True, drop_last=True)
val_dl = WeightedDataLoader(val_ds, batch_size=10, shuffle=False)

In [None]:
next(iter(trn_dl))

(tensor([[ 1.1435, -1.8216,  0.4232],
         [-0.8965,  1.4160,  0.0198],
         [ 0.4936, -1.2429,  0.3957],
         [ 0.8530,  0.9635,  0.1346],
         [-0.2174,  0.2243,  0.6259],
         [ 0.8911,  0.0131,  0.7219],
         [-0.6765,  0.8259,  0.5119],
         [ 0.3875, -0.0670,  0.4274],
         [-0.7145,  1.0463,  0.1436],
         [-0.9626, -0.6129,  0.0922]]), tensor([[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.]]), None)

In [None]:
for i, (xb,yb,wb) in enumerate(trn_dl): print(i, xb, yb)

0 tensor([[-0.7145,  1.0463,  0.1436],
        [-1.3040, -2.3205,  0.2772],
        [ 0.5903,  1.7174,  1.6333],
        [ 1.5045, -0.7733,  2.0942],
        [ 1.3496,  1.0398,  0.3462],
        [ 0.5523,  1.1335,  0.2602],
        [-0.4520, -0.4306,  0.2734],
        [-0.1087, -0.5042,  0.0755],
        [-0.9627, -1.0754,  0.5676],
        [-2.3328, -0.5531,  0.1182]]) tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])
1 tensor([[-1.1855,  0.5137,  0.9981],
        [-0.2180,  0.7253,  0.0458],
        [ 0.1980,  0.3830,  0.6345],
        [ 1.1435, -1.8216,  0.4232],
        [-1.1750, -1.0982,  0.0273],
        [ 0.0502, -1.4477,  0.1259],
        [-1.5957, -1.1260,  0.2254],
        [ 0.5645,  0.9232,  0.1459],
        [ 0.6693, -2.2713,  0.0491],
        [ 1.0091,  0.1724,  0.2241]]) tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
 

In [None]:
data = DataPair(trn_dl, val_dl)

In [None]:
data.trn_ds

<__main__.DataSet at 0x120e0b4a8>

# Paper data

In [None]:
# export
def get_paper_data(n:int, bs=2000, n_test:int=0) -> Union[DataPair,Tuple[DataPair,WeightedDataLoader]]:
    r'''Function returning training, validation and testing data according to pseudodata used in INFERNO paper'''
    n,n_test = n//2,n_test//2
    sig_trn = PseudoData(paper_sig, 1).sample(n)
    bkg_trn = PseudoData(paper_bkg, 0).sample(n)
    sig_val = PseudoData(paper_sig, 1).sample(n)
    bkg_val = PseudoData(paper_bkg, 0).sample(n)

    trn = (np.vstack((sig_trn[0],bkg_trn[0])),np.vstack((sig_trn[1],bkg_trn[1])))
    val = (np.vstack((sig_val[0],bkg_val[0])),np.vstack((sig_val[1],bkg_val[1])))

    trn_dl = WeightedDataLoader(DataSet(*trn), batch_size=bs, shuffle=True, drop_last=True)
    val_dl = WeightedDataLoader(DataSet(*val), batch_size=2*bs, shuffle=True)
    data = DataPair(trn_dl, val_dl)
    if n_test <= 0: return data
    
    sig_tst = PseudoData(paper_sig, 1).sample(n_test)
    bkg_tst = PseudoData(paper_bkg, 0).sample(n_test)
    tst = (np.vstack((sig_tst[0],bkg_tst[0])),np.vstack((sig_tst[1],bkg_tst[1])))
    tst_dl = WeightedDataLoader(DataSet(*tst), batch_size=2*bs)
    return data, tst_dl

In [None]:
n = 10
data = get_paper_data(n)

In [None]:
assert len(data.trn_ds) == len(data.val_ds) == n

In [None]:
data, test = get_paper_data(n,n_test=2*n)

In [None]:
assert len(data.trn_ds) == len(data.val_ds) == 0.5*len(test.dataset) == 10