Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Defining custom kernel is causing errors #69

Closed
AHsu98 opened this issue Feb 17, 2024 · 1 comment
Closed

Defining custom kernel is causing errors #69

AHsu98 opened this issue Feb 17, 2024 · 1 comment

Comments

@AHsu98
Copy link

AHsu98 commented Feb 17, 2024

I'm following the example notebook for defining a custom kernel, but when I run the code, I get an error when it checks for requires_grad, as X2 and v seem have turned into None objects. I copied the code from here pretty much exactly.
Here is the code I have that causes the error for reproducibility.

class BasicLinearKernel(Kernel): def __init__(self, lengthscale, options): # The base class takes as inputs a name for the kernel, and # an instance of FalkonOptions`.
super().init("basic_linear", options)

    self.lengthscale = lengthscale

def compute(self, X1: torch.Tensor, X2: torch.Tensor, out: torch.Tensor, diag: bool) -> torch.Tensor:
    # To support different devices/data types, you must make sure
    # the lengthscale is compatible with the data.
    lengthscale = self.lengthscale.to(device=X1.device, dtype=X1.dtype)

    scaled_X1 = X1 * lengthscale

    if diag:
        out.copy_(torch.sum(scaled_X1 * X2, dim=-1))
    else:
        # The dot-product row-by-row on `X1` and `X2` can be computed
        # on many rows at a time with matrix multiplication.
        out = torch.matmul(scaled_X1, X2.T, out=out)

    return out

def compute_sparse(self, X1, X2, out, diag, **kwargs) -> torch.Tensor:
    raise NotImplementedError("Sparse not implemented")

k.mmv(torch.randn(4, 3).contiguous(), torch.randn(4, 3), v=torch.randn(4, 1))`

Here is my error traceback from where it seems the problem is:

File ~/miniconda3/envs/onemod/lib/python3.11/site-packages/falkon/mmv_ops/fmmv.py:973, in fmmv(X1, X2, v, kernel, out, opt, kwargs_m1, kwargs_m2)
971 return KernelMmvFnFull.apply(kernel, opt, kwargs_m1, kwargs_m2, out, X1, X2, v, *kernel.diff_params.values())
972 else:
--> 973 return KernelMmvFnFull.apply(kernel, opt, out, X1, X2, v, kwargs_m1, kwargs_m2)

File ~/miniconda3/envs/onemod/lib/python3.11/site-packages/torch/autograd/function.py:553, in Function.apply(cls, *args, **kwargs)
550 if not torch._C._are_functorch_transforms_active():
551 # See NOTE: [functorch vjp and autograd interaction]
552 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 553 return super().apply(*args, **kwargs) # type: ignore[misc]
555 if not is_setup_ctx_defined:
556 raise RuntimeError(
557 "In order to use an autograd.Function with functorch transforms "
558 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
559 "staticmethod. For more details, please see "
560 "https://pytorch.org/docs/master/notes/extending.func.html"
561 )

File ~/miniconda3/envs/onemod/lib/python3.11/site-packages/falkon/mmv_ops/fmmv.py:904, in KernelMmvFnFull.forward(ctx, kernel, opt, kwargs_m1, kwargs_m2, out, X1, X2, v, *kernel_params)
902 else:
903 _check_contiguity((X1, "X1"), (X2, "X2"), (v, "v"), (out, "out"))
--> 904 differentiable = any(t.requires_grad for t in [X1, X2, v] + [*kernel_params])
905 data_devs = (X1.device, X2.device, v.device)
906 comp_dev_type = "cpu" if opt.use_cpu or not torch.cuda.is_available() else "cuda"

File ~/miniconda3/envs/onemod/lib/python3.11/site-packages/falkon/mmv_ops/fmmv.py:904, in (.0)
902 else:
903 _check_contiguity((X1, "X1"), (X2, "X2"), (v, "v"), (out, "out"))
--> 904 differentiable = any(t.requires_grad for t in [X1, X2, v] + [*kernel_params])
905 data_devs = (X1.device, X2.device, v.device)
906 comp_dev_type = "cpu" if opt.use_cpu or not torch.cuda.is_available() else "cuda"

AttributeError: 'NoneType' object has no attribute 'requires_grad'

My real use case is that I'm trying to define discrete kernel based on the number of matching values. In theory, what I'm trying to do could be done with one-hot-encoding sparse arrays and an exponential kernel, but that just seems much more complex than needed.

@Giodiro
Copy link
Contributor

Giodiro commented Feb 17, 2024

Hi!
Thanks for the report, I hope it's fixed now. Let me know if you manage to test it on your code, and if you encounter any more issues.

btw I noticed you're implementing the diag code path, I added a note in the docs because it's not really necessary unless you're using the automatic hyperparameter optimization code. diag will always be False otherwise and you can either assert that or just ignore it :)

@Giodiro Giodiro closed this as completed Feb 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants