In [24]:
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from einops import rearrange
from copy import deepcopy


class FeatureEmbedder(nn.Module):

    def __init__(self, d_feat, d_model):
        super(FeatureEmbedder, self).__init__()
        self.d_model = d_model
        self.embedder = nn.Linear(d_feat, d_model)
        self.activation = nn.ReLU()

    def forward(self, x):
        # (B, S, d_model_m) <- (B, S, D_original_feat_dim)
        x = self.embedder(x)
        x = x * np.sqrt(self.d_model)
        x = self.activation(x)

        # (B, S, d_model_m)
        return x

class PositionalEncoder(nn.Module):

    def __init__(self, d_model, dout_p, seq_len=3660):
        super(PositionalEncoder, self).__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(dout_p)

        pos_enc_mat = np.zeros((seq_len, d_model))
        odds = np.arange(0, d_model, 2)
        evens = np.arange(1, d_model, 2)

        for pos in range(seq_len):
            pos_enc_mat[pos, odds] = np.sin(pos / (10000 ** (odds / d_model)))
            pos_enc_mat[pos, evens] = np.cos(pos / (10000 ** (evens / d_model)))

        self.pos_enc_mat = torch.from_numpy(pos_enc_mat).unsqueeze(0)

    def forward(self, x):
        B, S, d_model = x.shape
        x = x + self.pos_enc_mat[:, :S, :].type_as(x)
        x = self.dropout(x)
        # same as input
        return x


class LayerStack(nn.Module):

    def __init__(self, layer, N):
        super(LayerStack, self).__init__()
        self.layers = clone(layer, N)

    def forward(self, x, masks):
        for layer in self.layers:
            x = layer(x, masks)
        return x

def clone(module, N):
    return nn.ModuleList([deepcopy(module) for _ in range(N)])


In [25]:
class MHA(nn.Module):
    def __init__(self, d_model_Q, d_model_K, d_model_V, dout_p=0.0, d_model=256, n_heads=8):
        super(MHA, self).__init__()
        self.d_model_Q = d_model_Q
        self.d_model_K = d_model_K
        self.d_model_V = d_model_V
        self.n_heads = n_heads
        self.d_model = d_model
        self.dout_p = dout_p

        if self.d_model is None:
            print(f'd_model: is None')
            self.d_model = self.d_model_Q

        self.d_k = self.d_model // self.n_heads

        self.fc_q = nn.Linear(self.d_model_Q, self.d_model)
        self.fc_k = nn.Linear(self.d_model_K, self.d_model)
        self.fc_v = nn.Linear(self.d_model_V, self.d_model)
        self.fc_o = nn.Linear(self.d_model, self.d_model_Q)

        self.dropout = nn.Dropout(self.dout_p)
        self.scale = torch.sqrt(torch.tensor(self.d_model / self.n_heads))

        assert self.d_model % H == 0

    def forward(self, Q, K, V, mask):
        ''' 
            Q, K, V: (B, Sq, Dq), (B, Sk, Dk), (B, Sv, Dv)
            mask: (B, 1, Sk)
            Sk = Sv, 
            Dk != self.d_k
            Also: m1 is the target modality (queries); m2 is the source modality (keys, values)
        '''

        B, Sq, d_model_Q = Q.shape
        # (B, Sm, D) <- (B, Sm, Dm)
        Q = self.fc_q(Q)
        K = self.fc_k(K)
        V = self.fc_v(V)

        Q = rearrange(Q, '개 단 헤 차 -> 개 헤 단 차', 헤 = self.H)
        K = rearrange(K, '개 단 헤 차 -> 개 헤 단 차', 헤 = self.H)
        V = rearrange(V, '개 단 헤 차 -> 개 헤 단 차', 헤 = self.H)

        if mask is not None:
            mask = mask.unsqueeze(1)

        attention_score = Q @ K.transpose(-2,-1)/self.scale
        
        attention_weights = torch.softmax(attention_score, dim=-1)

        attention = attention_weights @ V

        Q = rearrange(attention, '개 헤 단 차 -> 개 단 (헤 차)')
        Q = self.fc_o(Q)

        return Q, attention_weights


In [26]:

class BiModalDecoderLayer(nn.Module):

    def __init__(self, d_model_A, d_model_V, d_model_C, d_model, dout_p, H, d_ff_C):
        super(BiModalDecoderLayer, self).__init__()
        # self attention
        self.res_layer_self_att = ResidualConnection(d_model_C, dout_p)
        self.self_att = MHA(d_model_C, d_model_C, d_model_C, H, dout_p, d_model)
        # encoder attention
        self.res_layer_enc_att_A = ResidualConnection(d_model_C, dout_p)
        self.res_layer_enc_att_V = ResidualConnection(d_model_C, dout_p)
        self.enc_att_A = MHA(d_model_C, d_model_A, d_model_A, H, dout_p, d_model)
        self.enc_att_V = MHA(d_model_C, d_model_V, d_model_V, H, dout_p, d_model)
        # bridge
        self.bridge = BridgeConnection(2*d_model_C, d_model_C, dout_p)
        # feed forward residual
        self.res_layer_ff = ResidualConnection(d_model_C, dout_p)
        self.feed_forward = PositionwiseFeedForward(d_model_C, d_ff_C, dout_p)

    def forward(self, x, masks):
        '''
        Inputs:
            x (C, memory): C: (B, Sc, Dc) 
                           memory: (Av: (B, Sa, Da), Va: (B, Sv, Dv))
            masks (V_mask: (B, 1, Sv); A_mask: (B, 1, Sa); C_mask (B, Sc, Sc))
        Outputs:
            x (C, memory): C: (B, Sc, Dc) 
                           memory: (Av: (B, Sa, Da), Va: (B, Sv, Dv))
        '''
        C, memory = x
        Av, Va = memory

        # Define sublayers
        # a comment regarding the motivation of the lambda function please see the EncoderLayer
        def sublayer_self_att(C): return self.self_att(C, C, C, masks['C_mask'])
        def sublayer_enc_att_A(C): return self.enc_att_A(C, Av, Av, masks['A_mask'])
        def sublayer_enc_att_V(C): return self.enc_att_V(C, Va, Va, masks['V_mask'])
        sublayer_feed_forward = self.feed_forward

        # 1. Self Attention
        # (B, Sc, Dc)
        C = self.res_layer_self_att(C, sublayer_self_att)

        # 2. Encoder-Decoder Attention
        # (B, Sc, Dc) each
        Ca = self.res_layer_enc_att_A(C, sublayer_enc_att_A)
        Cv = self.res_layer_enc_att_V(C, sublayer_enc_att_V)
        # (B, Sc, 2*Dc)
        C = torch.cat([Ca, Cv], dim=-1)
        # bridge: (B, Sc, Dc) <- (B, Sc, 2*Dc)
        C = self.bridge(C)

        # 3. Feed-Forward
        # (B, Sc, Dc) <- (B, Sc, Dc)
        C = self.res_layer_ff(C, sublayer_feed_forward)

        return C, memory



class BiModelDecoder(nn.Module):

    def __init__(self, d_model_A, d_model_V, d_model_C, d_model, dout_p, H, d_ff_C, N):
        super(BiModelDecoder, self).__init__()
        layer = BiModalDecoderLayer(
            d_model_A, d_model_V, d_model_C, d_model, dout_p, H, d_ff_C
        )
        self.decoder = LayerStack(layer, N)

    def forward(self, x, masks):
        '''
        Inputs:
            x (C, memory): C: (B, Sc, Dc)
                           memory: (Av: (B, Sa, Da), Va: (B, Sv, Dv))
            masks (V_mask: (B, 1, Sv); A_mask: (B, 1, Sa); C_mask (B, Sc, Sc))
        Outputs:
            x (C, memory): C: (B, Sc, Dc)
                memory: (Av: (B, Sa, Da), Va: (B, Sv, Dv))
        '''
        # x is (C, memory)
        C, memory = self.decoder(x, masks)

        return C


In [27]:

class BiModalEncoderLayer(nn.Module):

    def __init__(self, d_model_M1, d_model_M2, d_model, dout_p, H, d_ff_M1, d_ff_M2):
        super(BiModalEncoderLayer, self).__init__()
        self.self_att_M1 = MHA(d_model_M1, d_model_M1, d_model_M1, H, dout_p, d_model)
        self.self_att_M2 = MHA(d_model_M2, d_model_M2, d_model_M2, H, dout_p, d_model)
        self.bi_modal_att_M1 = MHA(d_model_M1, d_model_M2, d_model_M2, H, dout_p, d_model)
        self.bi_modal_att_M2 = MHA(d_model_M2, d_model_M1, d_model_M1, H, dout_p, d_model)
        self.feed_forward_M1 = PositionwiseFeedForward(d_model_M1, d_ff_M1, dout_p)
        self.feed_forward_M2 = PositionwiseFeedForward(d_model_M2, d_ff_M2, dout_p)
        self.res_layers_M1 = clone(ResidualConnection(d_model_M1, dout_p), 3)
        self.res_layers_M2 = clone(ResidualConnection(d_model_M2, dout_p), 3)

    def forward(self, x, masks):
        '''
        Inputs:
            x (M1, M2): (B, Sm, Dm)
            masks (M1, M2): (B, 1, Sm)
        Output:
            M1m2 (B, Sm1, Dm1), M2m1 (B, Sm2, Dm2),
        '''
        M1, M2 = x
        M1_mask, M2_mask = masks

        def sublayer_self_att_M1(M1): return self.self_att_M1(M1, M1, M1, M1_mask)
        def sublayer_self_att_M2(M2): return self.self_att_M2(M2, M2, M2, M2_mask)
        def sublayer_att_M1(M1): return self.bi_modal_att_M1(M1, M2, M2, M2_mask)
        def sublayer_att_M2(M2): return self.bi_modal_att_M2(M2, M1, M1, M1_mask)
        sublayer_ff_M1 = self.feed_forward_M1
        sublayer_ff_M2 = self.feed_forward_M2

        # 1. Self-Attention
        # both (B, Sm*, Dm*)
        M1 = self.res_layers_M1[0](M1, sublayer_self_att_M1)
        M2 = self.res_layers_M2[0](M2, sublayer_self_att_M2)

        # 2. Multimodal Attention (var names: M* is the target modality; m* is the source modality)
        # (B, Sm1, Dm1)
        M1m2 = self.res_layers_M1[1](M1, sublayer_att_M1)
        # (B, Sm2, Dm2)
        M2m1 = self.res_layers_M2[1](M2, sublayer_att_M2)

        # 3. Feed-forward (var names: M* is the target modality; m* is the source modality)
        # (B, Sm1, Dm1)
        M1m2 = self.res_layers_M1[2](M1m2, sublayer_ff_M1)
        # (B, Sm2, Dm2)
        M2m1 = self.res_layers_M2[2](M2m1, sublayer_ff_M2)

        return M1m2, M2m1


class BiModalEncoder(nn.Module):

    def __init__(self, d_model_A, d_model_V, d_model, dout_p, H, d_ff_A, d_ff_V, N):
        super(BiModalEncoder, self).__init__()
        layer_AV = BiModalEncoderLayer(d_model_A, d_model_V, d_model, dout_p, H, d_ff_A, d_ff_V)
        self.encoder_AV = LayerStack(layer_AV, N)

    def forward(self, x, masks: dict):
        A, V = x

        Av, Va = self.encoder_AV((A, V), (masks['A_mask'], masks['V_mask']))

        return (Av, Va)


In [28]:


class Generator(nn.Module):

    def __init__(self, d_model, voc_size):
        super(Generator, self).__init__()
        self.linear = nn.Linear(d_model, voc_size)

    def forward(self, x):
        x = self.linear(x)
        return F.log_softmax(x, dim=-1)

class BiModalTransformer(nn.Module):
    def __init__(self, d_aud, d_vid, d_model_audio, d_model_video, d_model_caps, train_dataset):
        super(BiModalTransformer, self).__init__()


        self.emb_A = FeatureEmbedder(d_aud, d_model_audio)
        self.emb_V = FeatureEmbedder(d_vid, d_model_video)

        self.emb_C = VocabularyEmbedder(train_dataset.trg_voc_size, cfg.d_model_caps)
        
        self.pos_enc_A = PositionalEncoder(cfg.d_model_audio, cfg.dout_p)
        self.pos_enc_V = PositionalEncoder(cfg.d_model_video, cfg.dout_p)
        self.pos_enc_C = PositionalEncoder(cfg.d_model_caps, cfg.dout_p)

        self.encoder = BiModalEncoder(
            d_model_audio, d_model_video, d_model, dout_p, H, 
            d_ff_audio, d_ff_video, N
        )
        
        self.decoder = BiModelDecoder(
            d_model_audio, d_model_video, d_model_caps, d_model, dout_p, 
            H, d_ff_caps, N
        )

        self.generator = Generator(d_model_caps, train_dataset.trg_voc_size)

        print('initialization: xavier')
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        # initialize embedding after, so it will replace the weights
        # of the prev. initialization
        self.emb_C.init_word_embeddings(train_dataset.train_vocab.vectors, unfreeze_word_emb)

    def forward(self, src: dict, trg, masks: dict):
        V, A = src['rgb'] + src['flow'], src['audio']
        C = trg

        # (B, Sm, Dm) <- (B, Sm, Dm), m in [a, v]; 
        A = self.emb_A(A)
        V = self.emb_V(V)
        # (B, Sc, Dc) <- (S, Sc)
        C = self.emb_C(C)
        
        A = self.pos_enc_A(A)
        V = self.pos_enc_V(V)
        C = self.pos_enc_C(C)
        
        # notation: M1m2m2 (B, Sm1, Dm1), M1 is the target modality, m2 is the source modality
        Av, Va = self.encoder((A, V), masks)

        # (B, Sc, Dc)
        C = self.decoder((C, (Av, Va)), masks)
        
        # (B, Sc, Vc) <- (B, Sc, Dc) 
        C = self.generator(C)

        return C
