In [None]:
from torch import nn
from torch.nn import functional as F
import torch
import torch as th

In [None]:
class Attention(nn.Module):
    def __init__(self, hidden_size):
        super(Attention, self).__init__()
        self.hidden_size = hidden_size
        self.attn = nn.Linear(hidden_size * 2, hidden_size)
        self.v = nn.Parameter(torch.rand(hidden_size))

    def forward(self, encoder_outputs, hidden):
        batch_size = encoder_outputs.size(0)
        seq_len = encoder_outputs.size(1)
        energy = F.relu(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        energy = energy.transpose(1, 2)
        v = self.v.repeat(batch_size, 1).unsqueeze(1)
        print(v.size(), energy.size())
        attention_scores = torch.bmm(v, energy).squeeze(1)
        print(attention_scores.size())
        return F.softmax(attention_scores, dim=1)

class SpectrogramAttentionModel(nn.Module):
    def __init__(self, cnn_output_dim, hidden_size, output_dim):
        super(SpectrogramAttentionModel, self).__init__()
        self.attention = Attention(cnn_output_dim)
        self.fc = nn.Linear(cnn_output_dim, hidden_size)
        self.output_layer = nn.Linear(hidden_size, output_dim)

    def forward(self, cnn_output):
        batch_size, time_steps, frequency_channels, features = cnn_output.size()
        cnn_output = cnn_output.view(batch_size * time_steps, frequency_channels, features)
        attention_weights = self.attention(cnn_output, cnn_output)
        print(attention_weights.size())
        attended_output = torch.bmm(attention_weights.unsqueeze(1), cnn_output).squeeze(1)
        fc_output = F.relu(self.fc(attended_output))
        output = self.output_layer(fc_output)
        return output.view(batch_size, time_steps, -1)

In [None]:
x = torch.randn(2, 16, 128, 5)

In [None]:
s = SpectrogramAttentionModel(5, 3, 5)

In [None]:
o = s(x)

In [None]:
o.size()

In [None]:
class AggregateFrequencies(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int) -> None:
        super().__init__()

        self.__to_input = nn.Linear(input_dim, hidden_dim)
        self.__to_key = nn.Linear(input_dim, hidden_dim)
        self.__to_value = nn.Linear(input_dim, output_dim)

    def forward(self, x: th.Tensor) -> th.Tensor:
        b, c, w, h = x.size()

        x = x.permute(0, 3, 2, 1).contiguous().view(b * h, w, c)

        q = self.__to_input(x)
        k = self.__to_key(x).transpose(1, 2)
        v = self.__to_value(x)
        
        weight = F.softmax(th.bmm(q, k), dim=1).transpose(1, 2)

        out = (
            th.bmm(weight, v).sum(dim=1).view(b, h, -1)
        )

        return out

In [None]:
agg = AggregateFrequencies(3, 6, 5)

In [None]:
x = torch.rand(2, 3, 16, 32)

In [None]:
o = agg(x)

In [None]:
o.size()