In [None]:
# default_exp data

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

# Data

Collection of classes to handle data passing to network

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
# export
from pytorch_inferno.pseudodata import paper_sig, paper_bkg, PseudoData

from torch.utils.data import DataLoader
import torch
from torch import Tensor

from typing import Tuple, Union, Optional
from fastcore.all import store_attr, delegates
import numpy as np

In [None]:
# export
class DataSet():
    r'''Class holding indexable input, target and weight data'''
    def __init__(self, x:np.ndarray, y:Optional[np.ndarray]=None, w:Optional[np.ndarray]=None): store_attr()
    def __len__(self) -> int: return len(self.x)
    def __getitem__(self, i:int) -> Tuple[Tensor,Optional[Tensor],Optional[Tensor]]:
        return (Tensor(self.x[i]),
                Tensor(self.y[i]) if self.y is not None else None,
                Tensor(self.w[i]) if self.w is not None else None)

In [None]:
# export
class WeightedDataLoader(DataLoader):
    r'''PyTorch DataLoader with support for optional weights and targets'''
    @delegates(DataLoader, but=['collate_fn'])
    def __init__(self, dataset, **kwargs): super().__init__(dataset, collate_fn=self.collate_fn, **kwargs)
        
    @staticmethod
    def collate_fn(b:Tuple[Tensor,Optional[Tensor],Optional[Tensor]]) \
            -> Tuple[Tensor,Optional[Tensor],Optional[Tensor]]:
        x,y,w = zip(*b)
        return (torch.stack(x),
                torch.stack(y) if y[0] is not None else None,
                torch.stack(w) if w[0] is not None else None)

In [None]:
# export
class DataPair():
    r'''Single class of training and validation data to simplify passing data for model training'''
    def __init__(self, trn_dl:WeightedDataLoader, val_dl:WeightedDataLoader): store_attr()
    
    @property
    def trn_ds(self): return self.trn_dl.dataset
        
    @property
    def val_ds(self): return self.val_dl.dataset

In [None]:
n = 105
trn = PseudoData(paper_sig, 1).sample(n)
val = PseudoData(paper_sig, 1).sample(n)

In [None]:
trn

(array([[ 0.2709233 ,  1.7907513 ,  1.674339  ],
        [ 0.01646089, -0.7460501 ,  0.6675601 ],
        [-0.23959269,  0.22303306,  0.8689174 ],
        [ 1.2815254 ,  0.7106047 ,  0.08947958],
        [-0.8260263 ,  0.22419322,  0.16161819],
        [ 0.73751366, -0.5809242 ,  0.7030808 ],
        [-0.6777991 ,  0.03431615,  1.2666872 ],
        [-1.0142534 , -1.6662276 ,  0.0513282 ],
        [ 0.51317775,  0.14159556,  0.12035868],
        [-1.1800369 ,  0.8506802 ,  0.97652584],
        [ 1.1993155 , -1.072847  ,  0.61093956],
        [ 0.25319907,  0.77248424,  0.45021617],
        [-0.46734175, -1.9731585 ,  0.01964188],
        [ 1.8196396 ,  1.4788404 ,  0.33221117],
        [-0.20271742, -0.11154626,  1.9630086 ],
        [-1.2942913 ,  0.4507109 ,  0.33215162],
        [-1.0124866 ,  0.11795827,  0.18895386],
        [ 0.67416126,  0.72892964,  0.14242482],
        [ 2.2467654 ,  2.0556061 ,  0.42249516],
        [ 0.00452289,  1.5189322 ,  1.3089397 ],
        [ 0.08481363

In [None]:
trn_ds,val_ds = DataSet(*trn),DataSet(*val)

In [None]:
assert len(trn_ds) == n

In [None]:
trn_ds[1]

(tensor([ 0.0165, -0.7461,  0.6676]), tensor([1.]), None)

In [None]:
trn_dl = WeightedDataLoader(trn_ds, batch_size=10, shuffle=True, drop_last=True)
val_dl = WeightedDataLoader(val_ds, batch_size=10, shuffle=False)

In [None]:
next(iter(trn_dl))

(tensor([[-0.8530, -1.0867,  0.4375],
         [-0.4673, -1.9732,  0.0196],
         [-0.2549, -0.1746,  0.8408],
         [ 1.0284, -0.2849,  0.2366],
         [ 0.2660,  2.4576,  0.2343],
         [-0.4060,  1.0713,  0.0498],
         [ 0.2417,  0.2304,  0.8558],
         [ 0.1521,  1.3913,  1.0406],
         [-0.2991,  0.0788,  0.5765],
         [ 1.8196,  1.4788,  0.3322]]), tensor([[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.]]), None)

In [None]:
for i, (xb,yb,wb) in enumerate(trn_dl): print(i, xb, yb)

0 tensor([[ 0.1521,  1.3913,  1.0406],
        [ 1.5703,  0.3769,  0.1019],
        [ 1.3524,  1.0247,  0.7360],
        [ 0.1877,  0.6605,  0.7572],
        [-1.1800,  0.8507,  0.9765],
        [ 1.1993, -1.0728,  0.6109],
        [-0.6141,  0.1932,  0.4817],
        [-2.0834, -1.7749,  0.1606],
        [-0.3424,  0.0516,  0.5307],
        [-0.4308, -1.4312,  0.3799]]) tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])
1 tensor([[-0.1978, -0.2431,  1.1055],
        [ 1.7234,  0.0033,  0.0039],
        [-1.1767, -0.0089,  0.3481],
        [ 0.3324,  0.8675,  0.1356],
        [-0.0617, -0.2325,  0.0073],
        [-0.0114,  1.7407,  0.8935],
        [-1.0622, -1.2542,  0.2110],
        [-1.1627,  0.8648,  0.2335],
        [ 0.2580,  0.8338,  0.0739],
        [-1.1326, -0.4091,  0.0282]]) tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
 

In [None]:
data = DataPair(trn_dl, val_dl)

In [None]:
data.trn_ds

<__main__.DataSet at 0x7ff67d7418d0>

# Paper data

In [None]:
# export
def get_paper_data(n:int, bs=2000, n_test:int=0) -> Union[DataPair,Tuple[DataPair,WeightedDataLoader]]:
    r'''Function returning training, validation and testing data according to pseudodata used in INFERNO paper'''
    n,n_test = n//2,n_test//2
    sig_trn = PseudoData(paper_sig, 1).sample(n)
    bkg_trn = PseudoData(paper_bkg, 0).sample(n)
    sig_val = PseudoData(paper_sig, 1).sample(n)
    bkg_val = PseudoData(paper_bkg, 0).sample(n)

    trn = (np.vstack((sig_trn[0],bkg_trn[0])),np.vstack((sig_trn[1],bkg_trn[1])))
    val = (np.vstack((sig_val[0],bkg_val[0])),np.vstack((sig_val[1],bkg_val[1])))

    trn_dl = WeightedDataLoader(DataSet(*trn), batch_size=bs, shuffle=True, drop_last=True)
    val_dl = WeightedDataLoader(DataSet(*val), batch_size=2*bs, shuffle=True)
    data = DataPair(trn_dl, val_dl)
    if n_test <= 0: return data
    
    sig_tst = PseudoData(paper_sig, 1).sample(n_test)
    bkg_tst = PseudoData(paper_bkg, 0).sample(n_test)
    tst = (np.vstack((sig_tst[0],bkg_tst[0])),np.vstack((sig_tst[1],bkg_tst[1])))
    tst_dl = WeightedDataLoader(DataSet(*tst), batch_size=2*bs)
    return data, tst_dl

In [None]:
n = 10
data = get_paper_data(n)

In [None]:
assert len(data.trn_ds) == len(data.val_ds) == n

In [None]:
data, test = get_paper_data(n,n_test=2*n)

In [None]:
assert len(data.trn_ds) == len(data.val_ds) == 0.5*len(test.dataset) == 10