In [23]:
import torch
import lightning.pytorch as pl
import matplotlib.pyplot as plt
from hydra.utils import instantiate

from icicl.utils.experiment_utils import extract_config
from icicl.data.on_off_grid import OOTGBatch, SyntheticOOTGGenerator
from icicl.data.gp import RandomScaleGPGenerator

# Trying OOTG Batches!

In [47]:
OOTGGen = SyntheticOOTGGenerator(
    off_grid_generator=RandomScaleGPGenerator(
        dim= 1, 
        kernel_type= "eq", min_log10_lengthscale= -0.602, max_log10_lengthscale= 0.602, noise_std= 0.2,
        min_num_ctx = 1, max_num_ctx = 64, min_num_trg = 128, max_num_trg = 128,
        context_range = [[-2.0, 2.0]],
        target_range = [[-3.0, 3.0]],
        samples_per_epoch = 16384,
        batch_size = 16
    ),
    grid_range=[[-3.0, 3.0]],
    points_per_unit=4,
    samples_per_epoch = 16000,
    batch_size = 16,
)

In [51]:
for batch in OOTGGen:
    print(batch.xc_off_grid.shape)
    print(batch.xc_on_grid.shape)
    print(batch.yc_on_grid.shape)
    print(batch.xt.shape)
    break

torch.Size([16, 5, 1])
torch.Size([16, 33, 1])
torch.Size([16, 33, 1])
torch.Size([16, 128, 1])


In [9]:
config, config_dict = extract_config("experiments/configs/thesis/on_off_grid_gp_synthetic.yml", [])
pl.seed_everything(config.misc.seed)
experiment = instantiate(config)
pl.seed_everything(experiment.misc.seed)

Global seed set to 0
Global seed set to 0


0

# Can I chuck a transformer in a ConvCNP?

In [44]:
import torch
import torch.nn as nn
import math

from icicl.models.convcnp import ConvCNP, ConvCNPEncoder, ConvCNPDecoder
from icicl.networks.setconv import SetConvDecoder, SetConvEncoder
from icicl.networks.mlp import MLP
from icicl.likelihoods.gaussian import HeteroscedasticNormalLikelihood

In [4]:
class DummyFormer(nn.Module):

     def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x is actually z_grid, so yc on a grid I think.
        # shape is (batch, num_points * 2 + 1 (middle element), num_channels), where num_channels comes from resizer
        # middle dimension depends on points_per_unit, range of input, and margin.

        print(x.shape)
        return x

In [27]:
num_channels, points_per_unit = 16, 64

encoder_resizer = MLP(in_dim=2, out_dim=num_channels, num_layers=2,width= num_channels)
decoder_resizer = MLP(in_dim=num_channels, out_dim=2, num_layers=2, width= num_channels)

setconv_enc = SetConvEncoder(dim=1, points_per_unit=points_per_unit, init_lengthscale=2 * 1/ points_per_unit, margin=0.1, train_lengthscale=True)
setconv_dec = SetConvDecoder(dim=1, init_lengthscale=.1, scaling_factor=points_per_unit ** 1)

dummy = DummyFormer()
convcnp_encoder = ConvCNPEncoder(conv_net=dummy, setconv_encoder=setconv_enc, resizer=encoder_resizer)
convcnp_decoder = ConvCNPDecoder(setconv_decoder=setconv_dec, resizer=decoder_resizer)

convCNP = ConvCNP(encoder=convcnp_encoder, decoder=convcnp_decoder, likelihood=HeteroscedasticNormalLikelihood())

In [43]:
for batch in experiment.generators.train:
    print(batch.xc.shape, batch.xt.shape)
    xmin = torch.min(batch.xt, dim=-2)[0]
    xmax = torch.max(batch.xt, dim=-2)[0]
    num_points = torch.ceil(
        (0.5 * (xmax - xmin) + 0.1) * points_per_unit
    )
    num_points = torch.max(num_points, dim=0)[0] # ie. 135
    num_points = 2 ** torch.ceil(torch.log(torch.as_tensor(135)) / math.log(2.0)) # closest power of 2, so 256 for example. 
    print(num_points * 2 + 1) # both sides of zero, and the middle point.
    convCNP(xc=batch.xc, yc=batch.yc, xt=batch.xt)
    break

torch.Size([16, 11, 1]) torch.Size([16, 128, 1])
tensor(513.)
torch.Size([16, 513, 16])


  torch.range(-num_points[i], num_points[i], dtype=xmin.dtype)


# ConvSet TNP Encoder Tests

In [21]:
from icicl.utils.conv import make_grid, flatten_grid
import torch
import math

grid_range = torch.as_tensor([[-3.0, 3.0], [-4.0, 4.0]])
print(grid_range)

batch_shape = torch.Size((2,))

print(grid_range[:, 0].unsqueeze(0))

grid = make_grid(
    xmin = grid_range[:, 0].repeat(*batch_shape, 1), 
    xmax = grid_range[:, 1].repeat(*batch_shape, 1), 
    points_per_unit = 4, 
    margin = 0)

flatten_grid(grid).shape

tensor([[-3.,  3.],
        [-4.,  4.]])
tensor([[-3., -4.]])


torch.Size([2, 1089, 2])