In [67]:
import torch
from torch.utils.data import DataLoader
import sys
sys.path.append('../')
import matplotlib.pyplot as plt

In [68]:
from src.datasets import UEADataset
from exp.format import dict_collate_fn, to_gpytorch_format, get_input_transform, get_collate_fn

In [69]:
from src.datasets.subsampling import MissingAtRandomSubsampler, LabelBasedSubsampler

In [70]:
from src.imputation import ImputationStrategy

In [71]:
data_format='GP'

In [72]:
def get_subsampler(subsampler_name, subsampler_parameters):
    import src.datasets.subsampling

    subsampling_cls = getattr(src.datasets.subsampling, subsampler_name)
    instance = subsampling_cls(**subsampler_parameters)
    return instance

def get_imputation_scheme(imputation_scheme):
    from src.imputation import ImputationStrategy

    instance = ImputationStrategy(imputation_scheme)
    return instance

In [89]:
transforms = [
        get_subsampler('MissingAtRandomSubsampler', subsampler_parameters={'probability':0.1}),
        get_imputation_scheme(data_format),
        get_input_transform(data_format, grid_spacing=1.0)
    ]

In [56]:
loss_fn = torch.nn.CrossEntropyLoss()

In [61]:
'CrossEntropyLoss' in str(loss_fn)


True

tensor([[0],
        [9],
        [5],
        [6],
        [3],
        [6],
        [7],
        [1],
        [7],
        [7],
        [5],
        [5],
        [6],
        [1],
        [8],
        [0],
        [5],
        [9],
        [7],
        [7],
        [6],
        [2],
        [7],
        [7],
        [1],
        [6],
        [0],
        [2],
        [0],
        [0],
        [5],
        [6]])

In [90]:
dataset = UEADataset('PenDigits', 'training', transform=transforms)

Loading stratified training/validation split.


In [75]:
dataset[0].keys()

dict_keys(['values', 'label', 'inputs', 'indices', 'n_tasks', 'test_inputs', 'test_indices', 'data_format'])

In [91]:
dataset[0]['values']

array([-1.0006285 ,  1.0124474 ,  0.47434118,  1.3820243 ,  0.8418734 ,
       -0.0566479 ,  0.15957755,  0.7988345 , -0.15314138,  1.4773206 ,
       -0.8638662 ,  0.29734483, -1.460875  , -1.4726188 , -1.4040171 ],
      dtype=float32)

In [25]:
dataset[0]['label']

array([3.], dtype=float32)

In [92]:
n_input_dims = dataset.measurement_dims
collate_fn = get_collate_fn(data_format, n_input_dims)

In [93]:
train_loader = torch.utils.data.DataLoader(dataset, batch_size=32, collate_fn=collate_fn, shuffle=True,
                                                pin_memory=True, num_workers=1)

In [78]:
batch = next(iter(train_loader))


In [80]:
batch.keys()

dict_keys(['values', 'label', 'inputs', 'indices', 'test_inputs', 'test_indices', 'valid_lengths'])

In [97]:
batch['values'].shape

torch.Size([11, 15])

In [94]:
maximum = 0 
for batch in train_loader:
    curr = batch['valid_lengths'].max()
    if curr > maximum:
        maximum = curr
        print(maximum)

tensor(15)


In [29]:
fn = to_gpytorch_format

In [32]:
'gpytorch' in fn.__name__

True

In [41]:
MAR = MissingAtRandomSubsampler()

In [45]:
str(fn)

'<function to_gpytorch_format at 0x10dcd9b00>'