In [None]:
import torch as th
from torch import nn
from torch.nn.utils.parametrizations import weight_norm

from music_diffusion.networks.convolutions import ChannelProjBlock
from music_diffusion.networks.utils import View, Permute

In [None]:
x = th.randn(1, 2, 32, 32)

In [None]:
unfold = nn.Unfold(kernel_size=3, stride=1, padding=1)

In [None]:
out = unfold(x).view(1, 2, 9, 32, 32)[:, :, 4, :, :]

In [None]:
out.size()

In [None]:
out

In [None]:
x

In [None]:
class BoxCrossAttention(nn.Module):
    def __init__(self, channels: int, tau_hidden_dim: int, kv_dim: int) -> None:
        super().__init__()
        
        self.__kernel_size = 3
        
        self.__query_conv = nn.Sequential(
            ChannelProjBlock(channels, channels),
            nn.Unfold(kernel_size=self.__kernel_size, stride=1, padding=1),
            View(channels, self.__kernel_size**2, -1),
            Permute(0, 3, 2, 1),
            nn.Flatten(0, 1)
        )

        self.__cross_att = nn.MultiheadAttention(
            channels,
            1,
            kdim=kv_dim,
            vdim=kv_dim,
            batch_first=True,
        )

        self.__to_key_value = nn.Sequential(
            weight_norm(nn.Linear(tau_hidden_dim, kv_dim * 2)),
            nn.Mish(),
            weight_norm(nn.Linear(kv_dim * 2, kv_dim * 2)),
        )
    
    def forward(self, x: th.Tensor, y: th.Tensor) -> th.Tensor:
        b, c, w, h = x.size()
        
        proj_query = self.__query_conv(x)
        
        proj_key, proj_value = (
            self.__to_key_value(y)
            .unsqueeze(1)
            .unsqueeze(1)
            .repeat(1, w * h, 1, 1)
            .flatten(0, 1)
            .chunk(dim=-1, chunks=2)
        )
        
        out = (
            self.__cross_att(proj_query, proj_key, proj_value)[0]
            .view(b, w * h, self.__kernel_size**2, c)
            .permute(0, 3, 2, 1)
            .sum(dim=2)
            .view(b, c, w, h)
        )
        
        return out

In [None]:
box_cross_att = BoxCrossAttention(2, 16, 4)

In [None]:
y = th.randn(1, 16)

In [None]:
box_cross_att(x, y)