In [39]:
from typing import Any

import torch
from torch import nn

In [69]:
class FeedbackModule(nn.Module):
    """API for converting regional feature maps into logits for multi-class classification"""

    def __init__(self) -> None:
        """
        Instantiate the module with softmax on channel dimension
        """
        super().__init__()
        self.softmax = nn.Softmax(dim=1)

    def forward(
            self, visual_features: torch.Tensor, textual_features: torch.Tensor
    ) -> Any:
        """
        Fuse two types of features together to get output for feeding into the classification loss
        :param torch.Tensor visual_features:
            Feature maps of an image after being processed by discriminator
        :param torch.Tensor textual_features: Result of text encoder
        :return: Logits for each word in the picture
        :rtype: Any
        """
        textual_features = torch.transpose(textual_features, 1, 2)
        word_region_correlations = textual_features @ visual_features
        # normalize across L dimension
        m_norm_l = nn.functional.normalize(word_region_correlations, dim=1)
        # normalize across H*W dimension
        m_norm_hw = nn.functional.normalize(m_norm_l, dim=2)
        m_norm_hw = torch.transpose(m_norm_hw, 1, 2)
        weighted_img_feats = visual_features @ m_norm_hw
        weighted_img_feats = torch.sum(weighted_img_feats, dim=1)
        deltas = self.softmax(weighted_img_feats)
        return deltas


class Discriminator(nn.Module):
    """Simple CNN-based discriminator"""

    def __init__(self, img_chans: int) -> None:
        """
        Create a bunch of convolutions to extract features
        :param int img_chans:
            Amount of image channels and channels of text embeddings
        """
        super().__init__()
        self.convs = nn.Sequential(
            *[
                self.conv_block(in_chans, out_chans)
                for in_chans, out_chans in [
                    (img_chans, 5),
                    (5, 7),
                    (7, 9),
                    (9, 7),
                    (7, 5),
                    (5, img_chans),
                ]
            ]
        )
        # skip batch and channel dims, flatten only feature maps
        self.flat = nn.Flatten(start_dim=2)
        self.logits = FeedbackModule()

    @staticmethod
    def conv_block(in_chans: int, out_chans: int) -> nn.Sequential:
        """
        Simple feature extraction block followed by an activation function
        :param int in_chans: Number of input channels for conv layer
        :param int out_chans: Number of output channels for conv layer
        :return: Reusable convolutional block to extract features
        :rtype: Any
        """
        return nn.Sequential(
            nn.Conv2d(
                in_channels=in_chans,
                out_channels=out_chans,
                kernel_size=3,
                stride=1,
                padding=0,
            ),
            nn.LeakyReLU(0.2),
        )

    def forward(self, images: torch.Tensor, textual_info: torch.Tensor) -> Any:
        """
        Obtain regional features for images and return word logits from that image
        :param images: Images to be analyzed
        :param textual_info: Output of RNN (text encoder)
        :return: Word-level feedback (logits) for presence of text in picture
        :rtype: Any
        """
        img_features = self.convs(images)
        img_flat = self.flat(img_features)
        logits = self.logits(img_flat, textual_info)
        return logits

B, C, H, W, L = 64, 2, 15, 15, 7
v = torch.rand((B, C, H, W))
w = torch.rand((B, C, L))
D = Discriminator(img_chans=C)
y = D(v, w)
print(y.size())

torch.Size([64, 7])


In [8]:
a = torch.rand((3,4,5))
a_1 = nn.functional.normalize(a, dim=1)
a_1 = nn.functional.normalize(a_1, dim=2)
print(a_1)

tensor([[[0.2319, 0.5504, 0.6496, 0.2657, 0.3881],
         [0.5068, 0.0193, 0.1362, 0.5512, 0.6484],
         [0.5681, 0.5599, 0.3909, 0.3808, 0.2570],
         [0.4262, 0.1694, 0.2609, 0.6978, 0.4844]],

        [[0.1334, 0.5566, 0.4525, 0.6715, 0.1295],
         [0.7147, 0.3396, 0.1811, 0.3193, 0.4891],
         [0.5128, 0.5032, 0.1271, 0.1696, 0.6625],
         [0.3064, 0.3843, 0.6418, 0.4885, 0.3285]],

        [[0.5637, 0.4500, 0.6596, 0.1522, 0.1466],
         [0.4026, 0.4088, 0.4392, 0.1636, 0.6717],
         [0.5613, 0.6017, 0.0755, 0.1579, 0.5406],
         [0.3635, 0.3978, 0.4272, 0.6756, 0.2659]]])
