In [None]:
# default_exp data

In [None]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

# Data

> API details.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
# export
from torch.utils.data import DataLoader
import torch
from torch import Tensor

from typing import Tuple, Union, Optional
from fastcore.all import store_attr, delegates
import numpy as np

In [None]:
# export
class DataSet():
    def __init__(self, x:np.ndarray, y:Optional[np.ndarray]=None, w:Optional[np.ndarray]=None): store_attr()
    def __len__(self) -> int: return len(self.x)
    def __getitem__(self, i:int) -> Tuple[Tensor,Optional[Tensor],Optional[Tensor]]:
        return (Tensor(self.x[i]),
                Tensor(self.y[i]) if self.y is not None else None,
                Tensor(self.w[i]) if self.w is not None else None)

In [None]:
# export
class WeightedDataLoader(DataLoader):
    @delegates(DataLoader, but=['collate_fn'])
    def __init__(self, dataset, **kwargs): super().__init__(dataset, collate_fn=self.collate_fn, **kwargs)
        
    @staticmethod
    def collate_fn(b:Tuple[Tensor,Optional[Tensor],Optional[Tensor]]) \
            -> Tuple[Tensor,Optional[Tensor],Optional[Tensor]]:
        x,y,w = zip(*b)
        return (torch.stack(x),
                torch.stack(y) if y[0] is not None else None,
                torch.stack(w) if w[0] is not None else None)

In [None]:
# export
class DataPair():
    def __init__(self, trn_dl:WeightedDataLoader, val_dl:WeightedDataLoader): store_attr()
    
    @property
    def trn_ds(self): return self.trn_dl.dataset
        
    @property
    def val_ds(self): return self.val_dl.dataset

In [None]:
from pytorch_inferno.pseudodata import paper_sig, PseudoData

In [None]:
n = 105
trn = PseudoData(paper_sig, 1).sample(n)
val = PseudoData(paper_sig, 1).sample(n)

In [None]:
trn

(array([[-5.36622405e-01, -8.00279737e-01,  1.66521740e+00],
        [ 1.29522771e-01,  1.03510880e+00,  7.73968887e+00],
        [-1.26598942e+00,  1.11336589e+00,  4.36785030e+00],
        [-1.22300100e+00, -2.07645327e-01,  1.63283777e+00],
        [ 2.83285409e-01, -8.44387412e-01,  3.75807852e-01],
        [-1.29649282e+00,  1.36262372e-01,  1.15353024e+00],
        [-1.01295817e+00,  1.75543439e+00,  3.43467951e+00],
        [-4.14334744e-01, -3.74529064e-01,  1.38090372e+00],
        [-1.82133555e-01,  8.47890317e-01,  2.39968944e+00],
        [ 1.26101637e+00, -1.55749714e+00,  1.47763407e+00],
        [-2.85741389e-02, -1.84726965e+00,  9.59398091e-01],
        [-3.84712428e-01, -5.29842675e-01,  7.26075649e-01],
        [-1.75030923e+00, -9.09457028e-01,  1.23770487e+00],
        [-1.29760787e-01,  9.80832338e-01,  7.43583024e-01],
        [-1.37364542e+00, -3.53814304e-01,  1.86923885e+00],
        [-1.04974858e-01, -1.57842517e+00,  2.42790341e+00],
        [-6.45336747e-01

In [None]:
trn_ds,val_ds = DataSet(*trn),DataSet(*val)

In [None]:
assert len(trn_ds) == n

In [None]:
trn_ds[1]

(tensor([0.1295, 1.0351, 7.7397]), tensor([1.]), None)

In [None]:
trn_dl = WeightedDataLoader(trn_ds, batch_size=10, shuffle=True, drop_last=True)
val_dl = WeightedDataLoader(val_ds, batch_size=10, shuffle=False)

In [None]:
next(iter(trn_dl))

(tensor([[ 0.1295,  1.0351,  7.7397],
         [-1.2230, -0.2076,  1.6328],
         [-0.1298,  0.9808,  0.7436],
         [ 0.3323, -0.2103,  1.2120],
         [-0.3847, -0.5298,  0.7261],
         [-1.0825, -0.3173,  0.0813],
         [-0.2026,  1.3114,  0.1936],
         [-0.6453, -0.0725,  0.6041],
         [ 1.0653,  0.2381,  1.2565],
         [ 2.1612, -0.8341,  0.9830]]), tensor([[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.]]), None)

In [None]:
for i, (xb,yb,wb) in enumerate(trn_dl): print(i, xb, yb)

0 tensor([[ 2.8479e+00,  6.7420e-01,  4.0809e-01],
        [ 2.1612e+00, -8.3412e-01,  9.8298e-01],
        [ 8.0340e-01, -9.9053e-01,  1.1920e-02],
        [ 1.9983e-01, -1.0094e+00,  2.7794e+00],
        [-6.5076e-02, -1.6369e-02,  1.2823e+01],
        [-5.5348e-01,  3.9098e-01,  5.8651e-01],
        [-2.0261e-01,  1.3114e+00,  1.9355e-01],
        [ 1.0938e+00,  7.3156e-01,  8.0546e-01],
        [ 3.2827e-01, -7.0866e-01,  2.2664e+00],
        [-1.0060e+00,  1.7753e+00,  7.0853e-01]]) tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]])
1 tensor([[-0.8116, -0.0217,  1.9153],
        [-0.2339,  0.3706,  2.1346],
        [-1.0825, -0.3173,  0.0813],
        [-0.7710,  0.8428,  1.4172],
        [ 1.2610, -1.5575,  1.4776],
        [ 1.4453, -0.8730,  0.4820],
        [-1.3896, -0.4150,  0.4603],
        [-0.0696, -0.1305,  4.5888],
        [-1.7009,  0.6936,  1.7088],
        [-0.1141,  1.5137,  1.0

In [None]:
data = DataPair(trn_dl, val_dl)

In [None]:
data.trn_ds

<__main__.DataSet at 0x7f8438309358>

# Paper data

In [None]:
# export
from pytorch_inferno.pseudodata import *  # noqa F304

In [None]:
# export
def get_paper_data(n:int, bm:int, bs=2000, n_test:int=0) -> Union[DataPair,Tuple[DataPair,WeightedDataLoader]]:
    if   bm == 0: bm = paper_bkg_bm0
    elif bm == 1: bm = paper_bkg_bm1
    elif bm == 2: bm = paper_bkg_bm2
    elif bm == 3: bm = paper_bkg_bm3
    elif bm == 4: bm = paper_bkg_bm4
    
    n,n_test = n//2,n_test//2
    sig_trn = PseudoData(paper_sig, 1).sample(n)
    bkg_trn = PseudoData(bm, 0).sample(n)
    sig_val = PseudoData(paper_sig, 1).sample(n)
    bkg_val = PseudoData(bm, 0).sample(n)

    trn = (np.vstack((sig_trn[0],bkg_trn[0])),np.vstack((sig_trn[1],bkg_trn[1])))
    val = (np.vstack((sig_val[0],bkg_val[0])),np.vstack((sig_val[1],bkg_val[1])))

    trn_dl = WeightedDataLoader(DataSet(*trn), batch_size=bs, shuffle=True, drop_last=True)
    val_dl = WeightedDataLoader(DataSet(*val), batch_size=2*bs)
    data = DataPair(trn_dl, val_dl)
    if n_test <= 0: return data
    
    sig_tst = PseudoData(paper_sig, 1).sample(n_test)
    bkg_tst = PseudoData(bm, 0).sample(n_test)
    tst = (np.vstack((sig_tst[0],bkg_tst[0])),np.vstack((sig_tst[1],bkg_tst[1])))
    tst_dl = WeightedDataLoader(DataSet(*tst), batch_size=2*bs)
    return data, tst_dl

In [None]:
n = 10
data = get_paper_data(10,0)

In [None]:
assert len(data.trn_ds) == len(data.val_ds) == 10

In [None]:
data, test = get_paper_data(n,0,n_test=2*n)

In [None]:
assert len(data.trn_ds) == len(data.val_ds) == 0.5*len(test.dataset) == 10