Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: GLiftingKernelSE2 ignores param mask=False #4

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
faa26c4
fix: GLiftingKernelSE2 ignores param mask=False
dgcnz May 13, 2024
0e4543e
feat: add support for poetry package manager
dgcnz May 13, 2024
1162c9f
feat: add relaxed lifting kernel and relaxed lifting convolution for …
dgcnz May 14, 2024
5ffc56d
fix: set padding to zero for pointwise subgroup convolutions
dgcnz May 18, 2024
43ed995
feat: add relaxed 3dgconv
dgcnz May 18, 2024
c134b9b
fix: standarize test folder structure and fix imports
dgcnz May 20, 2024
f93d740
fix: generalize shape assertions for 3d and 2d
dgcnz May 20, 2024
220c2a4
feat: implement octahedral lifting and separable kernels
dgcnz May 21, 2024
0fafbad
feat: add relaxed lifting kernel seq vs vec test
dgcnz May 21, 2024
622b70a
docs: update README
dgcnz May 21, 2024
4879da6
Update README.md
dgcnz May 21, 2024
19125bd
Added transposed conv to selection
MeneerTS May 21, 2024
98c0bfc
fmt: run formatter and improve docs
dgcnz May 22, 2024
0da754b
Merge pull request #3 from dgcnz/upconv
dgcnz May 22, 2024
8df2af4
Added output_padding to GSeparableConvNd
MeneerTS May 22, 2024
160a57f
Updated conv_transpose to use padding
MeneerTS May 22, 2024
275bca3
Fixed important typo
MeneerTS May 22, 2024
79f6dc1
Same typo
MeneerTS May 22, 2024
8cc4228
Small bug
MeneerTS May 22, 2024
b60e7fe
Updated transposed to fit all group sizes
MeneerTS May 22, 2024
c5d1d28
removed print statements
MeneerTS May 22, 2024
855b645
fix: made seperable convolution work with transposed
MeneerTS May 23, 2024
a3fe094
fix: Allowed more than just upconv
MeneerTS May 23, 2024
755e32e
add 3d grid sample using an if statement that looks at the dim of the…
Nesta-gitU Jun 9, 2024
be4d543
docs: update version
dgcnz Jun 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -387,4 +387,5 @@ $RECYCLE.BIN/
# Windows shortcuts
*.lnk

# End of https://www.toptal.com/developers/gitignore/api/python,linux,macos,windows,jupyternotebooks
# End of https://www.toptal.com/developers/gitignore/api/python,linux,macos,windows,jupyternotebooks
.vscode
35 changes: 9 additions & 26 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
# Continuous Regular Group Convolutions (WIP 👷‍♀️👷‍♂️)
# Regular Group Convolutions

This package implements a Pytorch framework for group convolutions that are easy to use and implement in existing Pytorch modules. The package offers premade modules for E3 and SE3 convolutions, as well as basic operations such as pooling and normalization for $\mathbb{R}^n \rtimes H$ input. The method is explained in the paper [Regular SE(3) Group Convolutions for Volumetric Medical Image Analysis](https://arxiv.org/abs/2306.13960), accepted at MICCAI 2023 (see reference below).
This package extends [Thijs Kuipers' gconv](https://github.com/ThijsKuipers1995/gconv) to add support for Approximate/Relaxed Group Equivariant kernels as described in [Wang et al. 2022](https://arxiv.org/abs/2201.11969) and [Wang et al. 2023](https://openreview.net/forum?id=B8EpSHEp9j)

## Installation from Source

Download `gconv` and save to a directory. Then from that directory run the following command:
## Installation

With pip:
```sh
pip install git+https://github.com/dgcnz/gconv.git
```
pip install -e gconv
With poetry:
```sh
poetry add git+https://github.com/dgcnz/gconv.git
```

## Getting Started
Expand All @@ -34,23 +37,3 @@ y = pool(x3, H2)
In line 5, a random batch of three-channel $\mathbb{R}^3$ volumes is created. In line 6, the $\mathbb{R}^3$ is lifted to $\text{SE}(3) = \mathbb{R}^3 \rtimes \text{SO}(3)$. In line 7, an $\text{SE}(3)$ convolution is performed. In line 14, a global pooling is performed, resulting in $\text{SE}(3)$ invariant features.

Furthermore, `gconv` offers all the necessary tools to build fully custom group convolutions. All that is required is implementing 5 (or less, depending on the type of convolution) group ops! For more details on how to implement custom group convolutions, see `gconv_tutorial.ipynb`.

## Requirements:
```
python >= 3.10
torch
tqdm
```

## Reference:
Paper accepted at MICCAI 2023.
```
@misc{kuipers2023regular,
title={Regular SE(3) Group Convolutions for Volumetric Medical Image Analysis},
author={Thijs P. Kuipers and Erik J. Bekkers},
year={2023},
eprint={2306.13960},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
121 changes: 120 additions & 1 deletion gconv/gnn/functional/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,126 @@
from torch import Tensor


from torch.nn.functional import grid_sample
from torch.nn.functional import grid_sample as grid_sample_from_torch

# Replace torch grid sample with this
def grid_sample(input, grid, *args, **kwargs):
if len(input.shape) == 4:
return grid_sample_from_torch(input, grid, *args, **kwargs)
else:
return grid_sample_3d(input, grid)


def grid_sample_3d(image, optical):
N, C, ID, IH, IW = image.shape
_, D, H, W, _ = optical.shape

ix = optical[..., 0]
iy = optical[..., 1]
iz = optical[..., 2]

ix = ((ix + 1) / 2) * (IW - 1);
iy = ((iy + 1) / 2) * (IH - 1);
iz = ((iz + 1) / 2) * (ID - 1);
with torch.no_grad():

ix_tnw = torch.floor(ix);
iy_tnw = torch.floor(iy);
iz_tnw = torch.floor(iz);

ix_tne = ix_tnw + 1;
iy_tne = iy_tnw;
iz_tne = iz_tnw;

ix_tsw = ix_tnw;
iy_tsw = iy_tnw + 1;
iz_tsw = iz_tnw;

ix_tse = ix_tnw + 1;
iy_tse = iy_tnw + 1;
iz_tse = iz_tnw;

ix_bnw = ix_tnw;
iy_bnw = iy_tnw;
iz_bnw = iz_tnw + 1;

ix_bne = ix_tnw + 1;
iy_bne = iy_tnw;
iz_bne = iz_tnw + 1;

ix_bsw = ix_tnw;
iy_bsw = iy_tnw + 1;
iz_bsw = iz_tnw + 1;

ix_bse = ix_tnw + 1;
iy_bse = iy_tnw + 1;
iz_bse = iz_tnw + 1;

tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);


with torch.no_grad():

torch.clamp(ix_tnw, 0, IW - 1, out=ix_tnw)
torch.clamp(iy_tnw, 0, IH - 1, out=iy_tnw)
torch.clamp(iz_tnw, 0, ID - 1, out=iz_tnw)

torch.clamp(ix_tne, 0, IW - 1, out=ix_tne)
torch.clamp(iy_tne, 0, IH - 1, out=iy_tne)
torch.clamp(iz_tne, 0, ID - 1, out=iz_tne)

torch.clamp(ix_tsw, 0, IW - 1, out=ix_tsw)
torch.clamp(iy_tsw, 0, IH - 1, out=iy_tsw)
torch.clamp(iz_tsw, 0, ID - 1, out=iz_tsw)

torch.clamp(ix_tse, 0, IW - 1, out=ix_tse)
torch.clamp(iy_tse, 0, IH - 1, out=iy_tse)
torch.clamp(iz_tse, 0, ID - 1, out=iz_tse)

torch.clamp(ix_bnw, 0, IW - 1, out=ix_bnw)
torch.clamp(iy_bnw, 0, IH - 1, out=iy_bnw)
torch.clamp(iz_bnw, 0, ID - 1, out=iz_bnw)

torch.clamp(ix_bne, 0, IW - 1, out=ix_bne)
torch.clamp(iy_bne, 0, IH - 1, out=iy_bne)
torch.clamp(iz_bne, 0, ID - 1, out=iz_bne)

torch.clamp(ix_bsw, 0, IW - 1, out=ix_bsw)
torch.clamp(iy_bsw, 0, IH - 1, out=iy_bsw)
torch.clamp(iz_bsw, 0, ID - 1, out=iz_bsw)

torch.clamp(ix_bse, 0, IW - 1, out=ix_bse)
torch.clamp(iy_bse, 0, IH - 1, out=iy_bse)
torch.clamp(iz_bse, 0, ID - 1, out=iz_bse)

image = image.view(N, C, ID * IH * IW)

tnw_val = torch.gather(image, 2, (iz_tnw * IW * IH + iy_tnw * IW + ix_tnw).long().view(N, 1, D * H * W).repeat(1, C, 1))
tne_val = torch.gather(image, 2, (iz_tne * IW * IH + iy_tne * IW + ix_tne).long().view(N, 1, D * H * W).repeat(1, C, 1))
tsw_val = torch.gather(image, 2, (iz_tsw * IW * IH + iy_tsw * IW + ix_tsw).long().view(N, 1, D * H * W).repeat(1, C, 1))
tse_val = torch.gather(image, 2, (iz_tse * IW * IH + iy_tse * IW + ix_tse).long().view(N, 1, D * H * W).repeat(1, C, 1))
bnw_val = torch.gather(image, 2, (iz_bnw * IW * IH + iy_bnw * IW + ix_bnw).long().view(N, 1, D * H * W).repeat(1, C, 1))
bne_val = torch.gather(image, 2, (iz_bne * IW * IH + iy_bne * IW + ix_bne).long().view(N, 1, D * H * W).repeat(1, C, 1))
bsw_val = torch.gather(image, 2, (iz_bsw * IW * IH + iy_bsw * IW + ix_bsw).long().view(N, 1, D * H * W).repeat(1, C, 1))
bse_val = torch.gather(image, 2, (iz_bse * IW * IH + iy_bse * IW + ix_bse).long().view(N, 1, D * H * W).repeat(1, C, 1))

out_val = (tnw_val.view(N, C, D, H, W) * tnw.view(N, 1, D, H, W) +
tne_val.view(N, C, D, H, W) * tne.view(N, 1, D, H, W) +
tsw_val.view(N, C, D, H, W) * tsw.view(N, 1, D, H, W) +
tse_val.view(N, C, D, H, W) * tse.view(N, 1, D, H, W) +
bnw_val.view(N, C, D, H, W) * bnw.view(N, 1, D, H, W) +
bne_val.view(N, C, D, H, W) * bne.view(N, 1, D, H, W) +
bsw_val.view(N, C, D, H, W) * bsw.view(N, 1, D, H, W) +
bse_val.view(N, C, D, H, W) * bse.view(N, 1, D, H, W))

return out_val


def create_grid_R3(size: int, device: str | None = None) -> Tensor:
Expand Down
2 changes: 2 additions & 0 deletions gconv/gnn/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
GSeparableKernel,
GLiftingKernel,
GSubgroupKernel,
RGLiftingKernel,
RGSeparableKernel
)
from .kernel_sen import *
from .kernel_en import *
Loading