In [2]:
import torch
from torch import nn

In [14]:
class FeedbackModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, v, w):
        w = torch.transpose(w, 1, 2)
        m = w @ v
        m_norm_l = nn.functional.normalize(m, dim=1)
        m_norm_hw = nn.functional.normalize(m_norm_l, dim=2)
        m_norm_hw = torch.transpose(m_norm_hw, 1, 2)
        n = v @ m_norm_hw
        n = torch.sum(n, dim=1)
        delta = self.softmax(n)
        return delta

class Discriminator(nn.Module):
    def __init__(self, img_size = 256, n_blocks = 5):
        super().__init__()
        self.convs = nn.Sequential(*[
            self.conv_block(in_chans, out_chans)
            for in_chans, out_chans in [
                (3, 5), (5, 7), (7, 9),
                (9, 7), (7, 5), (5, 3)
            ]
        ])
        self.flat = nn.Flatten(start_dim=2)
        self.logits = FeedbackModule()

    @staticmethod
    def conv_block(in_chans, out_chans):
        return nn.Conv2d(
            in_channels=in_chans,
            out_channels=out_chans,
            kernel_size=3,
            stride=1,
            padding=0
        )

    def forward(self, v, w):
        v = self.convs(v)
        v = self.flat(v)
        x = self.logits(v, w)
        return x


B, C, H, W, L = 64, 3, 64, 128, 16
v = torch.rand((B, C, H, W))
w = torch.rand((B, C, L))
D = Discriminator()
y = D(v, w)
print(y.size())

torch.Size([64, 16])


In [8]:
a = torch.rand((3,4,5))
a_1 = nn.functional.normalize(a, dim=1)
a_1 = nn.functional.normalize(a_1, dim=2)
print(a_1)

tensor([[[0.2319, 0.5504, 0.6496, 0.2657, 0.3881],
         [0.5068, 0.0193, 0.1362, 0.5512, 0.6484],
         [0.5681, 0.5599, 0.3909, 0.3808, 0.2570],
         [0.4262, 0.1694, 0.2609, 0.6978, 0.4844]],

        [[0.1334, 0.5566, 0.4525, 0.6715, 0.1295],
         [0.7147, 0.3396, 0.1811, 0.3193, 0.4891],
         [0.5128, 0.5032, 0.1271, 0.1696, 0.6625],
         [0.3064, 0.3843, 0.6418, 0.4885, 0.3285]],

        [[0.5637, 0.4500, 0.6596, 0.1522, 0.1466],
         [0.4026, 0.4088, 0.4392, 0.1636, 0.6717],
         [0.5613, 0.6017, 0.0755, 0.1579, 0.5406],
         [0.3635, 0.3978, 0.4272, 0.6756, 0.2659]]])
