In [4]:
from typing import Any

import torch
from torch import nn

### Discriminator

In [5]:
class FeedbackModule(nn.Module):
    """API for converting regional feature maps into logits for multi-class classification"""

    def __init__(self) -> None:
        """
        Instantiate the module with softmax on channel dimension
        """
        super().__init__()
        self.softmax = nn.Softmax(dim=1)

    def forward(
            self, visual_features: torch.Tensor, textual_features: torch.Tensor
    ) -> Any:
        """
        Fuse two types of features together to get output for feeding into the classification loss
        :param torch.Tensor visual_features:
            Feature maps of an image after being processed by discriminator
        :param torch.Tensor textual_features: Result of text encoder
        :return: Logits for each word in the picture
        :rtype: Any
        """
        textual_features = torch.transpose(textual_features, 1, 2)
        word_region_correlations = textual_features @ visual_features
        # normalize across L dimension
        m_norm_l = nn.functional.normalize(word_region_correlations, dim=1)
        # normalize across H*W dimension
        m_norm_hw = nn.functional.normalize(m_norm_l, dim=2)
        m_norm_hw = torch.transpose(m_norm_hw, 1, 2)
        weighted_img_feats = visual_features @ m_norm_hw
        weighted_img_feats = torch.sum(weighted_img_feats, dim=1)
        deltas = self.softmax(weighted_img_feats)
        return deltas


class Discriminator(nn.Module):
    """Simple CNN-based discriminator"""

    def __init__(self, img_chans: int) -> None:
        """
        Create a bunch of convolutions to extract features
        :param int img_chans:
            Amount of image channels and channels of text embeddings
        """
        super().__init__()
        self.convs = nn.Sequential(
            *[
                self.conv_block(in_chans, out_chans)
                for in_chans, out_chans in [
                    (img_chans, 5),
                    (5, 7),
                    (7, 9),
                    (9, 7),
                    (7, 5),
                    (5, img_chans),
                ]
            ]
        )
        # skip batch and channel dims, flatten only feature maps
        self.flat = nn.Flatten(start_dim=2)
        self.logits = FeedbackModule()

    @staticmethod
    def conv_block(in_chans: int, out_chans: int) -> nn.Sequential:
        """
        Simple feature extraction block followed by an activation function
        :param int in_chans: Number of input channels for conv layer
        :param int out_chans: Number of output channels for conv layer
        :return: Reusable convolutional block to extract features
        :rtype: Any
        """
        return nn.Sequential(
            nn.Conv2d(
                in_channels=in_chans,
                out_channels=out_chans,
                kernel_size=3,
                stride=1,
                padding=0,
            ),
            nn.LeakyReLU(0.2),
        )

    def forward(self, images: torch.Tensor, textual_info: torch.Tensor) -> Any:
        """
        Obtain regional features for images and return word logits from that image
        :param images: Images to be analyzed
        :param textual_info: Output of RNN (text encoder)
        :return: Word-level feedback (logits) for presence of text in picture
        :rtype: Any
        """
        img_features = self.convs(images)
        img_flat = self.flat(img_features)
        logits = self.logits(img_flat, textual_info)
        return logits

B, C, H, W, L = 64, 2, 15, 15, 7
v = torch.rand((B, C, H, W))
w = torch.rand((B, C, L))
D = Discriminator(img_chans=C)
y = D(v, w)
print(y.size())

torch.Size([64, 7])


In [6]:
a = torch.rand((3,4,5))
a_1 = nn.functional.normalize(a, dim=1)
a_1 = nn.functional.normalize(a_1, dim=2)
print(a_1)

tensor([[[0.1088, 0.7220, 0.1861, 0.3513, 0.5557],
         [0.0978, 0.1983, 0.6332, 0.1221, 0.7316],
         [0.7790, 0.4329, 0.3721, 0.2534, 0.0562],
         [0.5428, 0.1391, 0.5085, 0.6439, 0.1135]],

        [[0.6559, 0.4532, 0.3903, 0.0514, 0.4576],
         [0.3335, 0.3534, 0.6754, 0.5496, 0.0754],
         [0.0066, 0.3907, 0.0783, 0.5845, 0.7068],
         [0.6224, 0.7246, 0.0978, 0.1810, 0.2127]],

        [[0.0964, 0.1867, 0.5406, 0.4586, 0.6733],
         [0.3365, 0.6771, 0.0982, 0.6036, 0.2329],
         [0.5164, 0.4229, 0.4876, 0.2261, 0.5153],
         [0.5997, 0.1640, 0.5586, 0.4883, 0.2511]]])


### Text Encoder

In [17]:
class TextEncoder(nn.Module):
    def __init__(self, vocab_size, emb_dim, hidden_dim):
        super().__init__()
        self.embs = nn.Embedding(vocab_size, emb_dim)
        self.lstm = nn.LSTM(
            emb_dim,
            hidden_dim,
            bidirectional=True,
            batch_first=True
        )

    def forward(self, tokens):
        embs = self.embs(tokens)
        output, (h, c) = self.lstm(embs)
        word_embs = torch.transpose(output, 1, 2)
        sent_embs = torch.cat((h[-1, :, :], h[0, :, :]), dim=1)
        return word_embs, sent_embs

encoder = TextEncoder(5000, 200, 150)

x = torch.randint(0, 100, (64, 130))
w, s = encoder(x)
print(w.size(), s.size())

torch.Size([64, 300, 130]) torch.Size([64, 300])
