In [1]:
from functools import partial
from torch.utils.data import DataLoader

from problems import ContinuousQuadratic, DiscreteQuadratic
from src.data import OfflineDataset, OnlineDataset, custom_collate_fn

In [2]:
%load_ext autoreload
%autoreload 2

In [4]:
n = 1
d = 1

if n == 1:
    problem_class = ContinuousQuadratic
    problems = [problem_class(d=d, n=n, seed=seed) for seed in range(10)]
else:
    problem_class = DiscreteQuadratic
    problems = [problem_class(d=d, n=n, seed=seed) for seed in range(10)]

### Offline Dataset

In [6]:
dataset = OfflineDataset(problems, seq_len=3)
sample = dataset[0]
sample

{'x': tensor([[0.5488],
         [0.7152],
         [0.6028]]),
 'y': tensor([0.4171, 0.4739, 0.4002]),
 'x_min': tensor([0.6028]),
 'problem': <problems.quadratic.ContinuousQuadratic at 0x7a78186b2660>}

In [7]:
print('x', sample["x"].shape, sample["x"].dtype)
print('y', sample["y"].shape, sample["y"].dtype)
print('x_min', sample["x_min"].shape, sample["x_min"].dtype)
print('problem', sample["problem"], type(sample["problem"]))

x torch.Size([3, 1]) torch.float32
y torch.Size([3]) torch.float32
x_min torch.Size([1]) torch.float32
problem <problems.quadratic.ContinuousQuadratic object at 0x7a78186b2660> <class 'problems.quadratic.ContinuousQuadratic'>


In [11]:
collate_fn = partial(custom_collate_fn, problem_class=problem_class)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
batch = next(iter(dataloader))
batch

{'x': tensor([[[0.5488],
          [0.7152],
          [0.6028]],
 
         [[0.5488],
          [0.7152],
          [0.6028]]]),
 'y': tensor([[0.4171, 0.4739, 0.4002],
         [0.9162, 1.9833, 1.2314]]),
 'x_min': tensor([[6.0276e-01],
         [1.1437e-04]]),
 'problem': [<problems.quadratic.ContinuousQuadratic at 0x7a78186b2660>,
  <problems.quadratic.ContinuousQuadratic at 0x7a77f4b17b30>]}

In [12]:
print('x', batch["x"].shape, batch["x"].dtype)
print('y', batch["y"].shape, batch["y"].dtype)
print('x_min', batch["x_min"].shape, batch["x_min"].dtype)
print('problem', batch["problem"], type(batch["problem"]))

x torch.Size([2, 3, 1]) torch.float32
y torch.Size([2, 3]) torch.float32
x_min torch.Size([2, 1]) torch.float32
problem [<problems.quadratic.ContinuousQuadratic object at 0x7a78186b2660>, <problems.quadratic.ContinuousQuadratic object at 0x7a77f4b17b30>] <class 'list'>


### Online Dataset

In [14]:
dataset = OnlineDataset(problems)
sample = dataset[0]
sample

{'x_min': tensor([0.6028]),
 'problem': <problems.quadratic.ContinuousQuadratic at 0x7a78186b2660>}

In [15]:
print('x_min', sample["x_min"].shape, sample["x_min"].dtype)
print('problem', sample["problem"], type(sample["problem"]))

x_min torch.Size([1]) torch.float32
problem <problems.quadratic.ContinuousQuadratic object at 0x7a78186b2660> <class 'problems.quadratic.ContinuousQuadratic'>


In [17]:
collate_fn = partial(custom_collate_fn, problem_class=problem_class)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
batch = next(iter(dataloader))
batch

{'x_min': tensor([[6.0276e-01],
         [1.1437e-04]]),
 'problem': [<problems.quadratic.ContinuousQuadratic at 0x7a78186b2660>,
  <problems.quadratic.ContinuousQuadratic at 0x7a77f4b17b30>]}

In [18]:
print('x_min', batch["x_min"].shape, batch["x_min"].dtype)
print('problem', batch["problem"], type(batch["problem"]))

x_min torch.Size([2, 1]) torch.float32
problem [<problems.quadratic.ContinuousQuadratic object at 0x7a78186b2660>, <problems.quadratic.ContinuousQuadratic object at 0x7a77f4b17b30>] <class 'list'>
