In [1]:
from functools import partial
from torch.utils.data import DataLoader

import problems as p
from src.data import OfflineDataset, OnlineDataset, custom_collate_fn
from src.train import DPTSolver
from utils import load_config

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
config = load_config('config.yaml')
print(f'problem: {config["problem"]}({config["problem_params"]})')
print(f'seq_len: {config["model_params"]["seq_len"]}')

problem_class = getattr(p, config["problem"])
problems = [problem_class(**config["problem_params"], seed=seed) for seed in range(10)]

problem: ContinuousQuadratic({'d': 1, 'n': 1})
seq_len: 3


In [10]:
device = "cpu"
# device = "cuda"
model = DPTSolver(config).to(device)

### Offline Dataset

In [12]:
seq_len = 3
batch_size = 2

dataset = OfflineDataset(problems, seq_len=seq_len)
collate_fn = partial(custom_collate_fn, problem_class=problem_class)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
batch = next(iter(dataloader))
batch

{'x': tensor([[[0.5488],
          [0.7152],
          [0.6028]],
 
         [[0.5488],
          [0.7152],
          [0.6028]]]),
 'y': tensor([[0.4171, 0.4739, 0.4002],
         [0.9162, 1.9833, 1.2314]]),
 'x_min': tensor([[6.0276e-01],
         [1.1437e-04]]),
 'problem': [<problems.quadratic.ContinuousQuadratic at 0x7bb9e699abd0>,
  <problems.quadratic.ContinuousQuadratic at 0x7bb9e550f0e0>]}

In [13]:
print('x', batch["x"].shape, batch["x"].dtype)
print('y', batch["y"].shape, batch["y"].dtype)
print('x_min', batch["x_min"].shape, batch["x_min"].dtype)
print('problem', batch["problem"], type(batch["problem"]))

x torch.Size([2, 3, 1]) torch.float32
y torch.Size([2, 3]) torch.float32
x_min torch.Size([2, 1]) torch.float32
problem [<problems.quadratic.ContinuousQuadratic object at 0x7bb9e699abd0>, <problems.quadratic.ContinuousQuadratic object at 0x7bb9e550f0e0>] <class 'list'>


In [16]:
outputs = model._offline_step(batch)
outputs

{'outputs': tensor([[[ 0.2753],
          [ 0.2370],
          [ 0.2271],
          [ 0.2481]],
 
         [[ 0.2753],
          [ 0.1574],
          [-0.0752],
          [ 0.0806]]], grad_fn=<ViewBackward0>),
 'predictions': tensor([[[ 0.2753],
          [ 0.2370],
          [ 0.2271],
          [ 0.2481]],
 
         [[ 0.2753],
          [ 0.1574],
          [-0.0752],
          [ 0.0806]]], grad_fn=<ViewBackward0>),
 'targets': tensor([[6.0276e-01],
         [1.1437e-04]])}

In [17]:
print('outputs', outputs["outputs"].shape, outputs["outputs"].dtype)
print('predictions', outputs["predictions"].shape, outputs["predictions"].dtype)
print('targets', outputs["targets"].shape, outputs["targets"].dtype)

outputs torch.Size([2, 4, 1]) torch.float32
predictions torch.Size([2, 4, 1]) torch.float32
targets torch.Size([2, 1]) torch.float32


In [19]:
model.training_step(batch, 0)

{'loss': tensor(0.0661, grad_fn=<MseLossBackward0>),
 'accuracy': tensor(0.),
 'mae': tensor(0.2175, grad_fn=<MeanBackward0>)}

In [20]:
model.eval()
model.validation_step(batch, 0)

{'loss': tensor(0.0661, grad_fn=<MseLossBackward0>),
 'accuracy': tensor(0.),
 'mae': tensor(0.2175, grad_fn=<MeanBackward0>)}

### Online Dataset

In [22]:
dataset = OnlineDataset(problems)
collate_fn = partial(custom_collate_fn, problem_class=problem_class)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
batch = next(iter(dataloader))
batch

{'x_min': tensor([[6.0276e-01],
         [1.1437e-04]]),
 'problem': [<problems.quadratic.ContinuousQuadratic at 0x7bb9e699abd0>,
  <problems.quadratic.ContinuousQuadratic at 0x7bb9e550f0e0>]}

In [23]:
print('x_min', batch["x_min"].shape, batch["x_min"].dtype)
print('problem', batch["problem"], type(batch["problem"]))

x_min torch.Size([2, 1]) torch.float32
problem [<problems.quadratic.ContinuousQuadratic object at 0x7bb9e699abd0>, <problems.quadratic.ContinuousQuadratic object at 0x7bb9e550f0e0>] <class 'list'>


In [26]:
outputs = model._online_step(batch)
outputs

{'outputs': tensor([[[0.2753],
          [0.1444],
          [0.0495]],
 
         [[0.2753],
          [0.3290],
          [0.3033]]], grad_fn=<StackBackward0>),
 'predictions': tensor([[[0.2753],
          [0.1444],
          [0.0495]],
 
         [[0.2753],
          [0.3290],
          [0.3033]]], grad_fn=<StackBackward0>),
 'targets': tensor([[6.0276e-01],
         [1.1437e-04]])}

In [27]:
print('outputs', outputs["outputs"].shape, outputs["outputs"].dtype)
print('predictions', outputs["predictions"].shape, outputs["predictions"].dtype)
print('targets', outputs["targets"].shape, outputs["targets"].dtype)

outputs torch.Size([2, 3, 1]) torch.float32
predictions torch.Size([2, 3, 1]) torch.float32
targets torch.Size([2, 1]) torch.float32


In [30]:
model.eval()
model.validation_step(batch, 0)

{'loss': tensor(0.1990, grad_fn=<MseLossBackward0>),
 'accuracy': tensor(0.),
 'mae': tensor(0.4282, grad_fn=<MeanBackward0>)}