In [None]:
# | default_exp nets/symswin_3d

### Inspiration

This architecture is only an idea. It builds on SwinV2, however, instead of applying windowed attention in a window of consecutive tokens, it applies attention on symmetrically opposite tokens (along the x-axis). This was inspired from medical radiology imaging, particularly HeadCTs, that rely heavily on symmetry in the left and right hemispheres to determine abnormality.

### Logic

# Imports

In [None]:
# | export


import torch
from einops import rearrange

In [None]:
# | export


def symmetry_attention_rearrange_forward(hidden_states: torch.Tensor):
    # hidden_states: (b, num_patches_z, num_patches_y, num_patches_x, dim)

    # Flip the second half of the patches along the x axis
    last_dim_length = hidden_states.shape[-2]
    hidden_states = torch.cat(
        [
            hidden_states[..., : last_dim_length // 2, :],
            hidden_states[..., last_dim_length // 2 :, :].flip(-2),
        ],
        dim=-2,
    )
    # (b, num_patches_z, num_patches_y, num_patches_x_first_half + flipped_num_patches_second_half, dim)

    # Rearrange the patches to alternate between the two halves along the x axis
    hidden_states = rearrange(
        hidden_states,
        "b num_patches_z num_patches_y (two half_num_patches_x) dim -> "
        "b num_patches_z num_patches_y (half_num_patches_x two) dim",
        two=2,
    ).contiguous()
    # (b, num_patches_z, num_patches_y, rearranged_num_patches_x, dim)

    return hidden_states


def symmetry_attention_rearrange_backward(hidden_states: torch.Tensor):
    # hidden_states: (b, num_patches_z, num_patches_y, rearranged_num_patches_x, dim)

    # Return the patches to their previous order along the x-axis
    hidden_states = rearrange(
        hidden_states,
        "b num_patches_z num_patches_y (half_num_patches_x two) dim -> "
        "b num_patches_z num_patches_y (two half_num_patches_x) dim",
        two=2,
    ).contiguous()
    # (b, num_patches_z, num_patches_y, num_patches_x_first_half + flipped_num_patches_second_half, dim)

    # Flip the second half of the patches along the x axis to return to the original state
    last_dim_length = hidden_states.shape[-2]
    hidden_states = torch.cat(
        [
            hidden_states[..., : last_dim_length // 2, :],
            hidden_states[..., last_dim_length // 2 :, :].flip(-2),
        ],
        dim=-2,
    )
    # (b, num_patches_z, num_patches_y, num_patches_x, dim)

    return hidden_states

In [None]:
test_arr = torch.arange(64).reshape(1, 1, 8, 8, 1)

forward_test_arr = symmetry_attention_rearrange_forward(test_arr)
backward_forward_test_arr = symmetry_attention_rearrange_backward(forward_test_arr)

assert torch.allclose(test_arr, backward_forward_test_arr)

test_arr.squeeze(0, 1, -1), forward_test_arr.squeeze(0, 1, -1), backward_forward_test_arr.squeeze(0, 1, -1)


[1m([0m
    [1;35mtensor[0m[1m([0m[1m[[0m[1m[[0m [1;36m0[0m,  [1;36m1[0m,  [1;36m2[0m,  [1;36m3[0m,  [1;36m4[0m,  [1;36m5[0m,  [1;36m6[0m,  [1;36m7[0m[1m][0m,
        [1m[[0m [1;36m8[0m,  [1;36m9[0m, [1;36m10[0m, [1;36m11[0m, [1;36m12[0m, [1;36m13[0m, [1;36m14[0m, [1;36m15[0m[1m][0m,
        [1m[[0m[1;36m16[0m, [1;36m17[0m, [1;36m18[0m, [1;36m19[0m, [1;36m20[0m, [1;36m21[0m, [1;36m22[0m, [1;36m23[0m[1m][0m,
        [1m[[0m[1;36m24[0m, [1;36m25[0m, [1;36m26[0m, [1;36m27[0m, [1;36m28[0m, [1;36m29[0m, [1;36m30[0m, [1;36m31[0m[1m][0m,
        [1m[[0m[1;36m32[0m, [1;36m33[0m, [1;36m34[0m, [1;36m35[0m, [1;36m36[0m, [1;36m37[0m, [1;36m38[0m, [1;36m39[0m[1m][0m,
        [1m[[0m[1;36m40[0m, [1;36m41[0m, [1;36m42[0m, [1;36m43[0m, [1;36m44[0m, [1;36m45[0m, [1;36m46[0m, [1;36m47[0m[1m][0m,
        [1m[[0m[1;36m48[0m, [1;36m49[0m, [1;36m50[0m, [1;36m51[0m, [1;3

# nbdev

In [None]:
!nbdev_export