Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non stationary GP Gpytorch support #2458

Open
sanaamouzahir opened this issue Dec 13, 2023 · 1 comment
Open

Non stationary GP Gpytorch support #2458

sanaamouzahir opened this issue Dec 13, 2023 · 1 comment
Labels

Comments

@sanaamouzahir
Copy link

sanaamouzahir commented Dec 13, 2023

馃悰 Bug

To reproduce

** Code snippet to reproduce **

import torch
import gpytorch
from gpytorch.constraints import Interval

class Temperal_Kernel(gpytorch.kernels.Kernel):
    def __init__(self, G_prior=None, G_constraint=None, **kwargs):
        super().__init__(**kwargs)
        
        # Register the raw parameter for G
        self.register_parameter(name='raw_G', parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, 1)))
        
        # Set the constraint for G to be in the interval (0, 1]
        if G_constraint is None:
            G_constraint = Interval(0, 1)  # Ensuring 0 < G <= 1
        self.register_constraint("raw_G", G_constraint)

        # Register the prior for G if provided
        if G_prior is not None:
            self.register_prior("G_prior", G_prior, lambda m: m.G, lambda m, v: m._set_G(v))

    @property
    def G(self):
        # Apply the constraint transform when accessing G
        return self.raw_G_constraint.transform(self.raw_G)

    @G.setter
    def G(self, value):
        self._set_G(value)

    def _set_G(self, value):
        if not torch.is_tensor(value):
            value = torch.as_tensor(value).to(self.raw_G)
        self.initialize(raw_G=self.raw_G_constraint.inverse_transform(value))

    def forward(self, t1, t2):
        # Compute the absolute difference between t1 and t2
        diff = torch.abs(t1 - t2)
        # Scale the difference by G
        return self.G * diff

class MyGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(MyGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.MultitaskMean(gpytorch.means.ConstantMean(), num_tasks=num_tasks
        )
        self.base_covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        self.temporal_covar_module = Temperal_Kernel()

    def forward(self, x, t):
        mean_x = self.mean_module(x)
        covar_x = self.base_covar_module(x)
        covar_t = self.temporal_covar_module(t.unsqueeze(-1), t.unsqueeze(-1))
        covar = covar_x * covar_t
        return gpytorch.distributions.MultivariateNormal(mean_x, covar)

** Stack trace/error message **

0%|          | 0/2000 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File <timed eval>:1

Cell In[68], line 29, in train()
     26 temporal_data = train_x[:, 0, 1]  # Extracting one time value per sample and reshaping
     28 # Pass both spatial and temporal data to the model
---> 29 output = model(spatial_data, temporal_data)
     31 loss = -mll(output, train_y)
     32 loss.backward()

File /software/Anaconda3/envs/MLGPyTorchJup_env/lib/python3.8/site-packages/gpytorch/models/exact_gp.py:265, in ExactGP.__call__(self, *args, **kwargs)
    259     raise RuntimeError(
    260         "train_inputs, train_targets cannot be None in training mode. "
    261         "Call .eval() for prior predictions, or call .set_train_data() to add training data."
    262     )
    263 if settings.debug.on():
    264     if not all(
--> 265         torch.equal(train_input, input) for train_input, input in length_safe_zip(train_inputs, inputs)
    266     ):
    267         raise RuntimeError("You must train on the training inputs!")
    268 res = super().__call__(*inputs, **kwargs)

File /software/Anaconda3/envs/MLGPyTorchJup_env/lib/python3.8/site-packages/gpytorch/utils/generic.py:12, in length_safe_zip(*args)
     10 args = [a if hasattr(a, "__len__") else list(a) for a in args]
     11 if len({len(a) for a in args}) > 1:
---> 12     raise ValueError(
     13         "Expected the lengths of all arguments to be equal. Got lengths "
     14         f"{[len(a) for a in args]} for args {args}. Did you pass in "
     15         "fewer inputs than expected?"
     16     )
     17 return zip(*args)

ValueError: Expected the lengths of all arguments to be equal. Got lengths [1, 2] for args [[tensor([[[ 0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
         ...,
         [ 0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00]],

        [[ 0.0000e+00,  6.6667e-03],
         [ 2.6537e-02,  6.6667e-03],
         [ 5.5855e-03,  6.6667e-03],
         ...,
         [ 2.4196e-01,  6.6667e-03],
         [ 3.4021e-01,  6.6667e-03],
         [ 3.6739e-14,  6.6667e-03]],

        [[ 0.0000e+00,  1.3333e-02],
         [ 2.6280e-02,  1.3333e-02],
         [ 5.3350e-03,  1.3333e-02],
         ...,
         [ 2.4148e-01,  1.3333e-02],
         [ 3.3784e-01,  1.3333e-02],
         [ 0.0000e+00,  1.3333e-02]],

        ...,

        [[ 0.0000e+00,  7.8000e-01],
         [ 4.8140e-03,  7.8000e-01],
         [-1.4121e-02,  7.8000e-01],
         ...,
         [ 1.7260e-01,  7.8000e-01],
         [ 1.6853e-01,  7.8000e-01],
         [ 0.0000e+00,  7.8000e-01]],

        [[ 0.0000e+00,  7.8667e-01],
         [ 4.6857e-03,  7.8667e-01],
         [-1.4223e-02,  7.8667e-01],
         ...,
         [ 1.7202e-01,  7.8667e-01],
         [ 1.6764e-01,  7.8667e-01],
         [ 0.0000e+00,  7.8667e-01]],

        [[ 0.0000e+00,  7.9333e-01],
         [ 4.5583e-03,  7.9333e-01],
         [-1.4324e-02,  7.9333e-01],
         ...,
         [ 1.7144e-01,  7.9333e-01],
         [ 1.6676e-01,  7.9333e-01],
         [ 0.0000e+00,  7.9333e-01]]])], [tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 0.0000e+00,  2.6537e-02,  5.5855e-03,  ...,  2.4196e-01,
          3.4021e-01,  3.6739e-14],
        [ 0.0000e+00,  2.6280e-02,  5.3350e-03,  ...,  2.4148e-01,
          3.3784e-01,  0.0000e+00],
        ...,
        [ 0.0000e+00,  4.8140e-03, -1.4121e-02,  ...,  1.7260e-01,
          1.6853e-01,  0.0000e+00],
        [ 0.0000e+00,  4.6857e-03, -1.4223e-02,  ...,  1.7202e-01,
          1.6764e-01,  0.0000e+00],
        [ 0.0000e+00,  4.5583e-03, -1.4324e-02,  ...,  1.7144e-01,
          1.6676e-01,  0.0000e+00]]), tensor([[0.0000],
        [0.0067],
        [0.0133],
        [0.0200],
        [0.0267],
        [0.0333],
        [0.0400],
        [0.0467],
        [0.0533],
        [0.0600],
        [0.0667],
        [0.0733],
        [0.0800],
        [0.0867],
        [0.0933],
        [0.1000],
        [0.1067],
        [0.1133],
        [0.1200],
        [0.1267],
        [0.1333],
        [0.1400],
        [0.1467],
        [0.1533],
        [0.1600],
        [0.1667],
        [0.1733],
        [0.1800],
        [0.1867],
        [0.1933],
        [0.2000],
        [0.2067],
        [0.2133],
        [0.2200],
        [0.2267],
        [0.2333],
        [0.2400],
        [0.2467],
        [0.2533],
        [0.2600],
        [0.2667],
        [0.2733],
        [0.2800],
        [0.2867],
        [0.2933],
        [0.3000],
        [0.3067],
        [0.3133],
        [0.3200],
        [0.3267],
        [0.3333],
        [0.3400],
        [0.3467],
        [0.3533],
        [0.3600],
        [0.3667],
        [0.3733],
        [0.3800],
        [0.3867],
        [0.3933],
        [0.4000],
        [0.4067],
        [0.4133],
        [0.4200],
        [0.4267],
        [0.4333],
        [0.4400],
        [0.4467],
        [0.4533],
        [0.4600],
        [0.4667],
        [0.4733],
        [0.4800],
        [0.4867],
        [0.4933],
        [0.5000],
        [0.5067],
        [0.5133],
        [0.5200],
        [0.5267],
        [0.5333],
        [0.5400],
        [0.5467],
        [0.5533],
        [0.5600],
        [0.5667],
        [0.5733],
        [0.5800],
        [0.5867],
        [0.5933],
        [0.6000],
        [0.6067],
        [0.6133],
        [0.6200],
        [0.6267],
        [0.6333],
        [0.6400],
        [0.6467],
        [0.6533],
        [0.6600],
        [0.6667],
        [0.6733],
        [0.6800],
        [0.6867],
        [0.6933],
        [0.7000],
        [0.7067],
        [0.7133],
        [0.7200],
        [0.7267],
        [0.7333],
        [0.7400],
        [0.7467],
        [0.7533],
        [0.7600],
        [0.7667],
        [0.7733],
        [0.7800],
        [0.7867],
        [0.7933]])]]. Did you pass in fewer inputs than expected?


Expected Behavior

System information

Please complete the following information:

  • 1.11
  • 1.13.1
  • Linux 3.10.0-1160.62.1.el7.x86_64

Additional context

I am trying to perform GP regression for a non stationary problem where the function I am trying to sample from changes with time. In order to do this I have written a custom Kernel to be multiplied by the base (stationary) kernel.

@naefjo
Copy link

naefjo commented Jan 14, 2024

It seems to me that your code fails during training because your inputs are not of equal length. I.e. you call model(x,t) in your code and internally it tries to check your function args (which have dimension 2, 1 for x and 1 for t) against your train_x (which I assume only contains your training location in x and not t. I'm not 100% sure whether this will fix it but have you tried passing a list containing train_x and train_t to the GP constructor? I.e.

MyGPModel([train_x, train_t], train_y, your_likelihood)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants