In [None]:
# default_exp data

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

# Data

Collection of classes to handle data passing to network

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
# export
from pytorch_inferno.pseudodata import paper_sig, paper_bkg, PseudoData

from torch.utils.data import DataLoader
import torch
from torch import Tensor

from typing import Tuple, Union, Optional
from fastcore.all import store_attr, delegates
import numpy as np

In [None]:
# export
class DataSet():
    r'''Class holding indexable input, target and weight data'''
    def __init__(self, x:np.ndarray, y:Optional[np.ndarray]=None, w:Optional[np.ndarray]=None): store_attr()
    def __len__(self) -> int: return len(self.x)
    def __getitem__(self, i:int) -> Tuple[Tensor,Optional[Tensor],Optional[Tensor]]:
        return (Tensor(self.x[i]),
                Tensor(self.y[i]) if self.y is not None else None,
                Tensor(self.w[i]) if self.w is not None else None)

In [None]:
# export
class WeightedDataLoader(DataLoader):
    r'''PyTorch DataLoader with support for optional weights and targets'''
    @delegates(DataLoader, but=['collate_fn'])
    def __init__(self, dataset, **kwargs): super().__init__(dataset, collate_fn=self.collate_fn, **kwargs)
        
    @staticmethod
    def collate_fn(b:Tuple[Tensor,Optional[Tensor],Optional[Tensor]]) \
            -> Tuple[Tensor,Optional[Tensor],Optional[Tensor]]:
        x,y,w = zip(*b)
        return (torch.stack(x),
                torch.stack(y) if y[0] is not None else None,
                torch.stack(w) if w[0] is not None else None)

In [None]:
# export
class DataPair():
    r'''Single class of training and validation data to simplify passing data for model training'''
    def __init__(self, trn_dl:WeightedDataLoader, val_dl:WeightedDataLoader): store_attr()
    
    @property
    def trn_ds(self): return self.trn_dl.dataset
        
    @property
    def val_ds(self): return self.val_dl.dataset

In [None]:
n = 105
trn = PseudoData(paper_sig, 1).sample(n)
val = PseudoData(paper_sig, 1).sample(n)

In [None]:
trn

(array([[-1.9198835 , -1.0762373 ,  2.0561838 ],
        [ 0.7635232 , -0.00798484,  0.6720786 ],
        [-1.4688606 ,  0.14419074,  0.08437859],
        [-0.9007238 ,  1.5351713 ,  0.39791936],
        [-0.11944689, -0.81342494,  0.06138625],
        [ 0.23471548, -0.88983023,  0.386645  ],
        [ 0.5056333 ,  0.9097138 ,  1.2382956 ],
        [-0.11091813,  0.5777042 ,  0.0621113 ],
        [ 0.47080657,  0.84146565,  0.79525316],
        [ 0.53776467, -0.41766354,  1.1415823 ],
        [-0.39915642,  0.05010221,  0.0497605 ],
        [ 0.28401792,  1.5436524 ,  0.75299543],
        [ 0.4004169 ,  2.4372673 ,  0.02545393],
        [-1.3024784 , -0.5359826 ,  0.03548376],
        [-0.11707021, -0.68494505,  0.00765719],
        [ 0.08257045,  1.5388554 ,  0.37543258],
        [ 1.0025002 ,  0.4524089 ,  0.26787472],
        [-0.20209394,  0.7995246 ,  0.03257589],
        [ 0.7019785 , -0.91270643,  1.0107836 ],
        [ 0.9701127 ,  2.1222665 ,  0.13343458],
        [ 0.08021289

In [None]:
trn_ds,val_ds = DataSet(*trn),DataSet(*val)

In [None]:
assert len(trn_ds) == n

In [None]:
trn_ds[1]

(tensor([ 0.7635, -0.0080,  0.6721]), tensor([1.]), None)

In [None]:
trn_dl = WeightedDataLoader(trn_ds, batch_size=10, shuffle=True, drop_last=True)
val_dl = WeightedDataLoader(val_ds, batch_size=10, shuffle=False)

In [None]:
next(iter(trn_dl))

(tensor([[-1.3087, -1.4191,  0.2791],
         [-2.0775,  0.0210,  0.1024],
         [-0.1194, -0.8134,  0.0614],
         [-0.1599, -0.9689,  0.6335],
         [-1.0200, -1.0619,  0.4841],
         [ 0.3111,  1.9276,  0.3580],
         [ 0.9712,  0.9846,  0.1717],
         [-0.7197,  0.6359,  0.1693],
         [ 1.8839, -3.0788,  0.2016],
         [ 0.7020, -0.9127,  1.0108]]), tensor([[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.]]), None)

In [None]:
for i, (xb,yb,wb) in enumerate(trn_dl): print(i, xb, yb)

0 tensor([[-0.5731,  0.5025,  0.1391],
        [ 1.0025,  0.4524,  0.2679],
        [-2.4189, -1.0408,  0.3012],
        [-0.7197,  0.6359,  0.1693],
        [-0.3745, -0.6151,  2.1637],
        [-1.0269,  0.8644,  0.7618],
        [ 0.6537, -0.9352,  0.2990],
        [-0.3762, -0.2583,  0.0107],
        [ 0.5056,  0.9097,  1.2383],
        [-1.1815, -0.2015,  0.2611]]) tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])
1 tensor([[ 0.2444,  0.1605,  0.5945],
        [-1.3087, -1.4191,  0.2791],
        [ 1.9576, -0.7189,  0.7281],
        [ 2.0018,  0.8234,  1.9770],
        [-0.1194, -0.8134,  0.0614],
        [ 0.5350, -1.7760,  0.3908],
        [-1.4818,  0.7652,  0.2362],
        [ 0.8324,  0.2588,  1.1819],
        [ 0.9038,  1.6524,  2.7248],
        [-0.6543, -0.1281,  0.0462]]) tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
 

In [None]:
data = DataPair(trn_dl, val_dl)

In [None]:
data.trn_ds

<__main__.DataSet at 0x7fb15c5fb320>

# Paper data

In [None]:
# export
def get_paper_data(n:int, bs=2000, n_test:int=0) -> Union[DataPair,Tuple[DataPair,WeightedDataLoader]]:
    r'''Function returning training, validation and testing data according to pseudodata used in INFERNO paper'''
    n,n_test = n//2,n_test//2
    sig_trn = PseudoData(paper_sig, 1).sample(n)
    bkg_trn = PseudoData(paper_bkg, 0).sample(n)
    sig_val = PseudoData(paper_sig, 1).sample(n)
    bkg_val = PseudoData(paper_bkg, 0).sample(n)

    trn = (np.vstack((sig_trn[0],bkg_trn[0])),np.vstack((sig_trn[1],bkg_trn[1])))
    val = (np.vstack((sig_val[0],bkg_val[0])),np.vstack((sig_val[1],bkg_val[1])))

    trn_dl = WeightedDataLoader(DataSet(*trn), batch_size=bs, shuffle=True, drop_last=True)
    val_dl = WeightedDataLoader(DataSet(*val), batch_size=2*bs, shuffle=True)
    data = DataPair(trn_dl, val_dl)
    if n_test <= 0: return data
    
    sig_tst = PseudoData(paper_sig, 1).sample(n_test)
    bkg_tst = PseudoData(paper_bkg, 0).sample(n_test)
    tst = (np.vstack((sig_tst[0],bkg_tst[0])),np.vstack((sig_tst[1],bkg_tst[1])))
    tst_dl = WeightedDataLoader(DataSet(*tst), batch_size=2*bs)
    return data, tst_dl

In [None]:
n = 10
data = get_paper_data(n)

In [None]:
assert len(data.trn_ds) == len(data.val_ds) == n

In [None]:
data, test = get_paper_data(n,n_test=2*n)

In [None]:
assert len(data.trn_ds) == len(data.val_ds) == 0.5*len(test.dataset) == 10