# Multimodal Transformer (MulT)

---

 **Title**: Multimodal Transformer for Unaligned Multimodal Language Sequences

 **Authors**: Tsai, Yao-Hung Hubert and Bai, Shaojie and Liang, Paul Pu and Kolter, J. Zico and Morency, Louis-Philippe and Salakhutdinov, Ruslan

[MulT GitHub repo](https://github.com/yaohungt/Multimodal-Transformer) | [Paper](https://arxiv.org/pdf/1906.00295.pdf)

# Gated Multimodal Units (GMU)

---

**Title**: GATED MULTIMODAL UNITS FOR INFORMATION FU- SION

 **Authors**: Arevalo, John; Montes-y-Gomez, Manuel; Solorio, Thamar and Gonzalez, Fabio A.

 [Paper](https://arxiv.org/pdf/1702.01992.pdf)

# Multimodal Transformer GMU (MulT-GMU)

---

 **Title**: Multimodal Weighted Fusion of Transformers for Movie Genre Classification

 **Authors**: Isaac Rodríguez-Bribiesca, A. Pastor López-Monroy and Manuel Montes-y-Gómez

[MulT-GMU GitHub repo](https://github.com/IsaacRodgz/multimodal-transformers-movies) | [Paper](https://aclanthology.org/2021.maiworkshop-1.1.pdf)

# Translating Multimodal Transformer GMU (TMulT)

**This work**: Diego Moreno, A. Pastor López-Monroy and Luis Carlos González Gurrola

In [1]:
%%capture
!pip install transformers

In [2]:
# Python tools
from sklearn.metrics import average_precision_score
from collections import Counter
from argparse import Namespace
from tqdm.notebook import tqdm
!pip install jsonlines
import numpy as np
import jsonlines
import functools
import pickle
import shutil
import json
import math
import os

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# HuggingFace
from transformers import BertTokenizer, BertModel

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting jsonlines
  Downloading jsonlines-3.0.0-py3-none-any.whl (8.5 kB)
Installing collected packages: jsonlines
Successfully installed jsonlines-3.0.0


## Vocabulary Class

In [3]:
class Vocab(object):
    def __init__(self, emptyInit=False):
        if emptyInit:
            self.stoi, self.itos, self.vocab_sz = {}, [], 0
        else:
            self.stoi = {
                w: i
                for i, w in enumerate(["[PAD]", "[UNK]", "[CLS]", "[SEP]"])
            }
            self.itos = [w for w in self.stoi]
            self.vocab_sz = len(self.itos)

    def add(self, words):
        cnt = len(self.itos)
        for w in words:
            if w in self.stoi:
                continue
            self.stoi[w] = cnt
            self.itos.append(w)
            cnt += 1
        self.vocab_sz = len(self.itos)

## Dataset

In [4]:
class JsonlDataset(Dataset):
    def __init__(self, data_path, tokenizer, vocab, args, data_dict=None):
        if data_dict is not None:
            self.data = data_dict
        else:
            self.data = [json.loads(l) for l in open(data_path)]
        self.data_dir = os.path.dirname(data_path)
        self.tokenizer = tokenizer
        self.args = args
        self.vocab = vocab
        self.n_classes = len(args.labels)
        self.text_start_token = ["[CLS]"]

        self.max_seq_len = args.max_seq_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sentence = segment = None
        
        # Process plot text
        sentence = (
            self.text_start_token
            + self.tokenizer(self.data[index]["synopsis"])[:(self.args.max_seq_len - 1)]
        )
        
        segment = torch.zeros(len(sentence))
        sentence = torch.LongTensor(
            [
                self.vocab.stoi[w] if w in self.vocab.stoi else self.vocab.stoi["[UNK]"]
                for w in sentence
            ]
        )
        
        # Process labels
        label = torch.zeros(self.n_classes)
        label[
            [self.args.labels.index(tgt) for tgt in self.data[index]["label"]]
        ] = 1

        # Load visual features
        image = None            
        if self.args.model in ["mmtrvpa", "tmmtrvpa", "mmtrv"]:
            file = open(os.path.join(self.data_dir, 'video_frames', f'{str(self.data[index]["id"])}.pt'), 'rb')
            image = torch.load(file).squeeze(0)

        # Load audio spectrograms
        audio = None
        if self.args.model in ["mmtrvpa", "tmmtrvpa", "mmtra"]:
            file = open(os.path.join(self.data_dir, 'spectrograms', f'{str(self.data[index]["id"])}.pt'), 'rb')
            audio = torch.load(file).squeeze(0)

        return sentence, segment, image, label, audio

## Collate function to process batch

In [5]:
def collate_fn(batch, args):
    lens = [len(row[0]) for row in batch]
    bsz, max_seq_len = len(batch), max(lens)

    mask_tensor = torch.zeros(bsz, max_seq_len).long()
    text_tensor = torch.zeros(bsz, max_seq_len).long()
    segment_tensor = torch.zeros(bsz, max_seq_len).long()

    video_tensor = None
    if batch[0][2] is not None:
        video_tensor = torch.stack([row[2] for row in batch])

    audio_tensor = None
    if batch[0][4] is not None:
        audio_lens = [row[4].shape[1] for row in batch]
        audio_min_len = min(audio_lens)
        audio_tensor = torch.stack([row[4][..., :audio_min_len] for row in batch])

    tgt_tensor = torch.stack([row[3] for row in batch])

    for i_batch, (input_row, length) in enumerate(zip(batch, lens)):
        tokens, segment = input_row[:2]
        text_tensor[i_batch, :length] = tokens
        segment_tensor[i_batch, :length] = segment
        mask_tensor[i_batch, :length] = 1

    return text_tensor, segment_tensor, mask_tensor, video_tensor, tgt_tensor, audio_tensor

## Load dataset partitions (train, validation, test)

In [7]:
import io

def load_vectors(fname):
    fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
    n, d = map(int, fin.readline().split())
    data = {}
    for line in fin:
        tokens = line.rstrip().split(' ')
        data[tokens[0]] = map(float, tokens[1:])
    return data

In [None]:
vec = load_vectors("/content/drive/MyDrive/_MASTER/Tesis/wiki-news-300d-1M-subword.vec")

In [6]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


Create `Namespace` objject for storing all parameters of model, training and data

In [None]:
args = Namespace()
args.data_path = os.path.join(os.getcwd(), "moviescope")

Download spectrograms, video frame features and data parition (labels and plots)

In [None]:
# Spectrograms
!gdown https://drive.google.com/uc?id=1wNr8qf2yMouuZ4FjpdQO4SvGaExUYLAz

# Video frames
!gdown https://drive.google.com/uc?id=1zKvC9U1q0n-9_nSDcnTIEeecHWD9GbUA

# Labels and plots
!gdown https://drive.google.com/uc?id=1999p8LLBR6imXLCJyIeEU8owuR-QHbcH

Downloading...
From: https://drive.google.com/uc?id=1wNr8qf2yMouuZ4FjpdQO4SvGaExUYLAz
To: /content/sub_spectrograms.pt
100% 2.29G/2.29G [00:24<00:00, 92.8MB/s]
Downloading...
From: https://drive.google.com/uc?id=1zKvC9U1q0n-9_nSDcnTIEeecHWD9GbUA
To: /content/sub_video_frames.pt
100% 1.64G/1.64G [00:35<00:00, 46.8MB/s]
Downloading...
From: https://drive.google.com/uc?id=1999p8LLBR6imXLCJyIeEU8owuR-QHbcH
To: /content/train.jsonl
100% 1.73M/1.73M [00:00<00:00, 55.4MB/s]


Unpack each observation in a separate file (all spectrograms and video frame features are downloaded as a single dictionary object)

In [None]:
!mkdir moviescope
!mkdir moviescope/spectrograms
!mkdir moviescope/video_frames
!mv -t moviescope train.jsonl

spectrograms = torch.load('sub_spectrograms.pt')
for id, tensor in spectrograms.items():
    torch.save(tensor, f'moviescope/spectrograms/{id}.pt')

video_frames = torch.load('sub_video_frames.pt')
for id, tensor in video_frames.items():
    torch.save(tensor, f'moviescope/video_frames/{id}.pt')

Generate train, validation and test splits

In [None]:
train_labels = [json.loads(line) for line in open('moviescope/train.jsonl')]

# Validation
with jsonlines.open(os.path.join(os.getcwd(), 'moviescope/dev.jsonl'), 'w') as writer:
    writer.write_all(train_labels[300:400])

# Test
with jsonlines.open(os.path.join(os.getcwd(), 'moviescope/test.jsonl'), 'w') as writer:
    writer.write_all(train_labels[400:])

# Train
with jsonlines.open(os.path.join(os.getcwd(), 'moviescope/train.jsonl'), 'w') as writer:
    writer.write_all(train_labels[:300])

Data structure:

* moviescope/
    * spectrograms/
    * video_frames/
    * train.jsonl
    * dev.jsonl
    * test.jsonl

In [None]:
del spectrograms, video_frames
!rm sub_spectrograms.pt
!rm sub_video_frames.pt

* From the train partition load all the different labels for classification task (movie genres) into variable `args.labels`
* Count frequencies for each label into variable `args.label_freqs`

In [None]:
label_freqs = Counter()
data_labels = [json.loads(line)["label"] for line in open(os.path.join(args.data_path, "train.jsonl"))]

if type(data_labels[0]) == list:
    for label_row in data_labels:
        label_freqs.update(label_row)
else:
    label_freqs.update(data_labels)

args.labels = list(label_freqs.keys())
args.label_freqs = label_freqs

In [None]:
print(f"Movie genres (labels): {args.labels}")
print("Training labels distribution: ")
for label, count in args.label_freqs.items():
    print(f'    {label}: {count}')

Movie genres (labels): ['Mystery', 'Thriller', 'Comedy', 'Action', 'Crime', 'Drama', 'Family', 'Horror', 'Biography', 'Romance', 'Sci-Fi', 'Fantasy', 'Animation']
Training labels distribution: 
    Mystery: 25
    Thriller: 87
    Comedy: 110
    Action: 76
    Crime: 59
    Drama: 164
    Family: 21
    Horror: 33
    Biography: 15
    Romance: 73
    Sci-Fi: 30
    Fantasy: 33
    Animation: 6


## Create vocabulary and tokenizer

A pre-trained BERT is used for extracting text features, so `BertTokenizer` is used for preprocessing text



In [None]:
# Load BERT tokenizer for processing text with BERT model
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)

vocab = Vocab()
vocab.stoi = tokenizer.vocab # mapping from subword to index
vocab.itos = tokenizer.ids_to_tokens # Reverse mapping from index to subword
vocab.vocab_sz = len(vocab.itos) # Vocabulary size

args.vocab = vocab
args.vocab_sz = vocab.vocab_sz
args.n_classes = len(args.labels)
plot_tokenizer = tokenizer.tokenize # tokenize method does not add any special tokens

## Create Datasets and Dataloaders and select model

In [None]:
# Data parameters
args.max_seq_len = 512
args.batch_sz = 1 # For Multimodal Transformer model, reduce batch_sz to 1
args.n_workers = 0
collate = functools.partial(collate_fn, args=args)

# Training
train = JsonlDataset(
    os.path.join(args.data_path, "train.jsonl"),
    plot_tokenizer,
    vocab,
    args,
)

train_loader = DataLoader(
    train,
    batch_size=args.batch_sz,
    shuffle=True,
    num_workers=args.n_workers,
    collate_fn=collate,
    drop_last=True,
)

args.train_data_len = len(train)


# Validation
dev = JsonlDataset(
    os.path.join(args.data_path, "dev.jsonl"),
    plot_tokenizer,
    vocab,
    args,
)

val_loader = DataLoader(
    dev,
    batch_size=args.batch_sz,
    shuffle=False,
    num_workers=args.n_workers,
    collate_fn=collate,
)

# Testing
test_set = JsonlDataset(
    os.path.join(args.data_path, "test.jsonl"),
    plot_tokenizer,
    vocab,
    args,
)

test_loader = DataLoader(
    test_set,
    batch_size=args.batch_sz,
    shuffle=False,
    num_workers=args.n_workers,
    collate_fn=collate,
)

## Models

### Text Encoder (BERT base pretrained)

In [None]:
class BertEncoder(nn.Module):
    def __init__(self, args):
        super(BertEncoder, self).__init__()
        self.args = args
        self.bert = BertModel.from_pretrained(args.bert_model)

    def forward(self, txt, mask, segment):
        encoded_layers, out = self.bert(
            input_ids=txt,
            token_type_ids=segment,
            attention_mask=mask,
            return_dict=False,
        )
        return encoded_layers

### Text classifier

In [None]:
class BertClf(nn.Module):
    def __init__(self, args):
        super(BertClf, self).__init__()
        self.args = args
        self.bert = BertModel.from_pretrained(args.bert_model)
        self.clf = nn.Linear(args.hidden_sz, args.n_classes)

    def forward(self, txt, mask, segment):
        _, x = self.bert(
            input_ids=txt,
            token_type_ids=segment,
            attention_mask=mask,
            return_dict=False,
        )
        return self.clf(x)

### Audio Encoder (CNN)

In [None]:
class AudioEncoder(nn.Module):
    def __init__(self, args):
        super(AudioEncoder, self).__init__()
        self.args = args
        
        conv_layers = []

        conv_layers.append(nn.Conv1d(96, 96, 128, stride=2))
        conv_layers.append(nn.Conv1d(96, 96, 128, stride=2))
        conv_layers.append(nn.AdaptiveAvgPool1d(200))
        self.conv_layers = nn.ModuleList(conv_layers)

    def forward(self, x):
        for layer in self.conv_layers:
            x = layer(x)
        return x

### GMU Fusion

In [None]:
class TextShifting3Layer(nn.Module):
    """ Layer inspired by 'Gated multimodal networks, Arevalo1 et al.' (https://arxiv.org/abs/1702.01992) """
    def __init__(self, size_in1, size_in2, size_in3, size_out):
        super(TextShifting3Layer, self).__init__()
        self.size_in1, self.size_in2, self.size_in3, self.size_out = size_in1, size_in2, size_in3, size_out
        
        self.hidden1 = nn.Linear(size_in1, size_out, bias=False)
        self.hidden2 = nn.Linear(size_in2, size_out, bias=False)
        self.hidden3 = nn.Linear(size_in3, size_out, bias=False)
        self.x1_gate = nn.Linear(size_in1+size_in2+size_in3, size_out, bias=False)
        self.x2_gate = nn.Linear(size_in1+size_in2+size_in3, size_out, bias=False)
        self.x3_gate = nn.Linear(size_in1+size_in2+size_in3, size_out, bias=False)

    def forward(self, x1, x2, x3):
        h1 = torch.tanh(self.hidden1(x1))
        h2 = torch.tanh(self.hidden2(x2))
        h3 = torch.tanh(self.hidden3(x3))
        x_cat = torch.cat((x1, x2, x3), dim=1)
        z1 = torch.sigmoid(self.x1_gate(x_cat))
        z2 = torch.sigmoid(self.x2_gate(x_cat))
        z3 = torch.sigmoid(self.x3_gate(x_cat))

        return z1*h1 + z2*h2 + z3*h3, torch.cat((z1, z2, z3), dim=1)

### Transformer Encoder Layer

In [None]:
class TransformerEncoderLayer(nn.Module):
    """Encoder layer block.
    In the original paper each operation (multi-head attention or FFN) is
    postprocessed with: `dropout -> add residual -> layernorm`. In the
    tensor2tensor code they suggest that learning is more robust when
    preprocessing each layer with layernorm and postprocessing with:
    `dropout -> add residual`. We default to the approach in the paper, but the
    tensor2tensor approach can be enabled by setting
    *args.encoder_normalize_before* to ``True``.
    Args:
        embed_dim: Embedding dimension
    """

    def __init__(self, embed_dim, num_heads=4, attn_dropout=0.1, relu_dropout=0.1, res_dropout=0.1,
                 attn_mask=False):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        
        self.self_attn = nn.MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=self.num_heads,
            dropout=attn_dropout
        )
        self.attn_mask = attn_mask

        self.relu_dropout = relu_dropout
        self.res_dropout = res_dropout
        self.normalize_before = True

        self.fc1 = nn.Linear(self.embed_dim, 4*self.embed_dim)   # The "Add & Norm" part in the paper
        self.fc2 = nn.Linear(4*self.embed_dim, self.embed_dim)
        self.layer_norms = nn.ModuleList([nn.LayerNorm(self.embed_dim) for _ in range(2)])

    def forward(self, x, x_k=None, x_v=None):
        """
        Args:
            x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
            encoder_padding_mask (ByteTensor): binary ByteTensor of shape
                `(batch, src_len)` where padding elements are indicated by ``1``.
            x_k (Tensor): same as x
            x_v (Tensor): same as x
        Returns:
            encoded output of shape `(batch, src_len, embed_dim)`
        """
        residual = x
        x = self.maybe_layer_norm(0, x, before=True)
        mask = buffered_future_mask(x, x_k) if self.attn_mask else None
        if x_k is None and x_v is None:
            x, _ = self.self_attn(query=x, key=x, value=x, attn_mask=mask)
        else:
            x_k = self.maybe_layer_norm(0, x_k, before=True)
            x_v = self.maybe_layer_norm(0, x_v, before=True) 
            x, _ = self.self_attn(query=x, key=x_k, value=x_v, attn_mask=mask)
        x = F.dropout(x, p=self.res_dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(0, x, after=True)

        residual = x
        x = self.maybe_layer_norm(1, x, before=True)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=self.relu_dropout, training=self.training)
        x = self.fc2(x)
        x = F.dropout(x, p=self.res_dropout, training=self.training)
        x = residual + x
        x = self.maybe_layer_norm(1, x, after=True)
        return x

    def maybe_layer_norm(self, i, x, before=False, after=False):
        assert before ^ after
        if after ^ self.normalize_before:
            return self.layer_norms[i](x)
        else:
            return x

def fill_with_neg_inf(t):
    """FP16-compatible function that fills a tensor with -inf."""
    return t.float().fill_(float('-inf')).type_as(t)


def buffered_future_mask(tensor, tensor2=None):
    dim1 = dim2 = tensor.size(0)
    if tensor2 is not None:
        dim2 = tensor2.size(0)
    future_mask = torch.triu(fill_with_neg_inf(torch.ones(dim1, dim2)), 1+abs(dim2-dim1))
    if tensor.is_cuda:
        future_mask = future_mask.cuda()
    return future_mask[:dim1, :dim2]

### Positional Embeddings

In [None]:
# Code adapted from the fairseq repo.

def make_positions(tensor, padding_idx, left_pad):
    """Replace non-padding symbols with their position numbers.
    Position numbers begin at padding_idx+1.
    Padding symbols are ignored, but it is necessary to specify whether padding
    is added on the left side (left_pad=True) or right side (left_pad=False).
    """
    max_pos = padding_idx + 1 + tensor.size(1)
    device = tensor.get_device()
    buf_name = f'range_buf_{device}'
    if not hasattr(make_positions, buf_name):
        setattr(make_positions, buf_name, tensor.new())
    setattr(make_positions, buf_name, getattr(make_positions, buf_name).type_as(tensor))
    if getattr(make_positions, buf_name).numel() < max_pos:
        torch.arange(padding_idx + 1, max_pos, out=getattr(make_positions, buf_name))
    mask = tensor.ne(padding_idx)
    positions = getattr(make_positions, buf_name)[:tensor.size(1)].expand_as(tensor)
    if left_pad:
        positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
    new_tensor = tensor.clone()
    return new_tensor.masked_scatter_(mask, positions[mask]).long()


class SinusoidalPositionalEmbedding(nn.Module):
    """This module produces sinusoidal positional embeddings of any length.
    Padding symbols are ignored, but it is necessary to specify whether padding
    is added on the left side (left_pad=True) or right side (left_pad=False).
    """

    def __init__(self, embedding_dim, padding_idx=0, left_pad=0, init_size=128):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.padding_idx = padding_idx
        self.left_pad = left_pad
        self.weights = dict()   # device --> actual weight; due to nn.DataParallel :-(
        self.register_buffer('_float_tensor', torch.FloatTensor(1))

    @staticmethod
    def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
        """Build sinusoidal embeddings.
        This matches the implementation in tensor2tensor, but differs slightly
        from the description in Section 3.5 of "Attention Is All You Need".
        """
        half_dim = embedding_dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
        emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
        if embedding_dim % 2 == 1:
            # zero pad
            emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
        if padding_idx is not None:
            emb[padding_idx, :] = 0
        return emb

    def forward(self, input):
        """Input is expected to be of size [bsz x seqlen]."""
        bsz, seq_len = input.size()
        max_pos = self.padding_idx + 1 + seq_len
        device = input.get_device()
        if device not in self.weights or max_pos > self.weights[device].size(0):
            # recompute/expand embeddings if needed
            self.weights[device] = SinusoidalPositionalEmbedding.get_embedding(
                max_pos,
                self.embedding_dim,
                self.padding_idx,
            )
        self.weights[device] = self.weights[device].type_as(self._float_tensor)
        positions = make_positions(input, self.padding_idx, self.left_pad)
        return self.weights[device].index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach()

    def max_positions(self):
        """Maximum number of supported positions."""
        return int(1e5)  # an arbitrary large number

### Transformer Encoder

In [None]:
class TransformerEncoder(nn.Module):
    """
    Transformer encoder consisting of *args.encoder_layers* layers. Each layer
    is a :class:`TransformerEncoderLayer`.
    Args:
        embed_tokens (torch.nn.Embedding): input embedding
        num_heads (int): number of heads
        layers (int): number of layers
        attn_dropout (float): dropout applied on the attention weights
        relu_dropout (float): dropout applied on the first layer of the residual block
        res_dropout (float): dropout applied on the residual block
        attn_mask (bool): whether to apply mask on the attention weights
    """

    def __init__(self, embed_dim, num_heads, layers, attn_dropout=0.0, relu_dropout=0.0, res_dropout=0.0,
                 embed_dropout=0.0, attn_mask=False):
        super().__init__()
        self.dropout = embed_dropout
        self.attn_dropout = attn_dropout
        self.embed_dim = embed_dim
        self.embed_scale = math.sqrt(embed_dim)
        self.embed_positions = SinusoidalPositionalEmbedding(embed_dim)
        
        self.attn_mask = attn_mask

        self.layers = nn.ModuleList([])
        for layer in range(layers):
            new_layer = TransformerEncoderLayer(embed_dim,
                                                num_heads=num_heads,
                                                attn_dropout=attn_dropout,
                                                relu_dropout=relu_dropout,
                                                res_dropout=res_dropout,
                                                attn_mask=attn_mask)
            self.layers.append(new_layer)

        self.normalize = True
        if self.normalize:
            self.layer_norm = nn.LayerNorm(embed_dim)

    def forward(self, x_in, x_in_k = None, x_in_v = None):
        """
        Args:
            x_in (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)`
            x_in_k (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)`
            x_in_v (FloatTensor): embedded input of shape `(src_len, batch, embed_dim)`
        Returns:
            dict:
                - **encoder_out** (Tensor): the last encoder layer's output of
                  shape `(src_len, batch, embed_dim)`
                - **encoder_padding_mask** (ByteTensor): the positions of
                  padding elements of shape `(batch, src_len)`
        """
        # embed tokens and positions
        x = self.embed_scale * x_in
        if self.embed_positions is not None:
            x += self.embed_positions(x_in.transpose(0, 1)[:, :, 0]).transpose(0, 1) # Add positional embedding
        x = F.dropout(x, p=self.dropout, training=self.training)

        if x_in_k is not None and x_in_v is not None:
            # embed tokens and positions    
            x_k = self.embed_scale * x_in_k
            x_v = self.embed_scale * x_in_v
            if self.embed_positions is not None:
                x_k += self.embed_positions(x_in_k.transpose(0, 1)[:, :, 0]).transpose(0, 1) # Add positional embedding
                x_v += self.embed_positions(x_in_v.transpose(0, 1)[:, :, 0]).transpose(0, 1) # Add positional embedding
            x_k = F.dropout(x_k, p=self.dropout, training=self.training)
            x_v = F.dropout(x_v, p=self.dropout, training=self.training)
        
        # encoder layers
        intermediates = [x]
        for layer in self.layers:
            if x_in_k is not None and x_in_v is not None:
                x = layer(x, x_k, x_v)
            else:
                x = layer(x)
            intermediates.append(x)

        if self.normalize:
            x = self.layer_norm(x)

        return x

### Multimodal Transformer (Text + Video + Audio)

In [None]:
class MMTransformerGMUClf(nn.Module):
    def __init__(self, args):
        """
        Construct a MulT model for Text, Video frames and Audio spectrogram with GMU late fusion.
        """
        super(MMTransformerGMUClf, self).__init__()
        self.args = args
        self.orig_d_l, self.orig_d_v, self.orig_d_a = args.orig_d_l, args.orig_d_v, args.orig_d_a
        self.d_l, self.d_a, self.d_v = 768, 768, 768
        self.v_len = args.v_len
        self.l_len = args.l_len
        self.a_len = args.a_len
        self.vonly = args.vonly
        self.lonly = args.lonly
        self.aonly = args.aonly
        self.num_heads = args.num_heads
        self.layers = args.layers
        self.attn_dropout = args.attn_dropout
        self.attn_dropout_v = args.attn_dropout_v
        self.attn_dropout_a = args.attn_dropout_a
        self.relu_dropout = args.relu_dropout
        self.res_dropout = args.res_dropout
        self.out_dropout = args.out_dropout
        self.embed_dropout = args.embed_dropout
        self.attn_mask = args.attn_mask
        
        self.enc = BertEncoder(args)
        self.audio_enc = AudioEncoder(args)

        combined_dim = self.d_l + self.d_a + self.d_v
        
        self.partial_mode = self.lonly + self.aonly + self.vonly
        if self.partial_mode == 1:
            combined_dim = 2*self.d_l   # assuming d_l == d_a == d_v
        else:
            combined_dim = 2*(self.d_l + self.d_a + self.d_v)
        combined_dim = 768 # For GMU
        
        output_dim = args.n_classes        # This is actually not a hyperparameter :-)

        # 1. Temporal convolutional layers
        self.proj_l = nn.Conv1d(self.orig_d_l, self.d_l, kernel_size=1, padding=0, bias=False)
        self.proj_v = nn.Conv1d(self.orig_d_v, self.d_v, kernel_size=1, padding=0, bias=False)
        self.proj_a = nn.Conv1d(self.orig_d_a, self.d_a, kernel_size=1, padding=0, bias=False)

        # 2. Crossmodal Attentions
        if self.lonly:
            self.trans_l_with_a = self.get_network(self_type='la', seq_len=self.l_len, seq_len_kv=self.a_len)
            self.trans_l_with_v = self.get_network(self_type='lv', seq_len=self.l_len, seq_len_kv=self.v_len)
        if self.vonly:
            self.trans_v_with_l = self.get_network(self_type='vl', seq_len=self.v_len, seq_len_kv=self.l_len)
            self.trans_v_with_a = self.get_network(self_type='va', seq_len=self.v_len, seq_len_kv=self.a_len)
        if self.aonly:
            self.trans_a_with_l = self.get_network(self_type='al', seq_len=self.a_len, seq_len_kv=self.l_len)
            self.trans_a_with_v = self.get_network(self_type='av', seq_len=self.a_len, seq_len_kv=self.v_len)
        
        # 3. Self Attentions (Could be replaced by LSTMs, GRUs, etc.)
        #    [e.g., self.trans_x_mem = nn.LSTM(self.d_x, self.d_x, 1)
        self.trans_l_mem = self.get_network(self_type='l_mem', layers=3, seq_len=self.l_len)
        self.trans_v_mem = self.get_network(self_type='v_mem', layers=3, seq_len=self.v_len)
        self.trans_a_mem = self.get_network(self_type='a_mem', layers=3, seq_len=self.a_len)
       
        # Projection layers
        self.proj1 = nn.Linear(combined_dim, combined_dim)
        self.proj2 = nn.Linear(combined_dim, combined_dim)
        self.out_layer = nn.Linear(combined_dim, output_dim)
        
        # GMU layer for fusing text and image and audio information
        self.gmu = TextShifting3Layer(self.d_l*2, self.d_v*2, self.d_a*2, self.d_l)

    def get_network(self, self_type='l', layers=-1, seq_len=512, seq_len_kv=None):
        if self_type in ['l', 'al', 'vl']:
            embed_dim, attn_dropout = self.d_l, self.attn_dropout
        elif self_type in ['a', 'la', 'va']:
            embed_dim, attn_dropout = self.d_a, self.attn_dropout_a
        elif self_type in ['v', 'lv', 'av']:
            embed_dim, attn_dropout = self.d_v, self.attn_dropout_v
        elif self_type == 'l_mem':
            embed_dim, attn_dropout = 2*self.d_l, self.attn_dropout
        elif self_type == 'a_mem':
            embed_dim, attn_dropout = 2*self.d_a, self.attn_dropout
        elif self_type == 'v_mem':
            embed_dim, attn_dropout = 2*self.d_v, self.attn_dropout
        else:
            raise ValueError("Unknown network type")
        
        return TransformerEncoder(embed_dim=embed_dim,
                                  num_heads=self.num_heads,
                                  layers=max(self.layers, layers),
                                  attn_dropout=attn_dropout,
                                  relu_dropout=self.relu_dropout,
                                  res_dropout=self.res_dropout,
                                  embed_dropout=self.embed_dropout,
                                  attn_mask=self.attn_mask)
            
    def forward(self, txt, mask, segment, img, audio):
        """
        text, audio, and vision should have dimension [batch_size, seq_len, n_features]
        """
        x_l = self.enc(txt, mask, segment)
        x_l = F.dropout(x_l.transpose(1, 2), p=self.embed_dropout, training=self.training)
        x_v = img.transpose(1, 2)
        x_a = self.audio_enc(audio)

        # Project the textual/visual/audio features
        proj_x_l = x_l if self.orig_d_l == self.d_l else self.proj_l(x_l)
        proj_x_a = x_a if self.orig_d_a == self.d_a else self.proj_a(x_a)
        proj_x_v = x_v if self.orig_d_v == self.d_v else self.proj_v(x_v)
        proj_x_l = proj_x_l.permute(2, 0, 1)
        proj_x_a = proj_x_a.permute(2, 0, 1)
        proj_x_v = proj_x_v.permute(2, 0, 1)

        if self.lonly:
            # (V,A) --> L
            h_l_with_as = self.trans_l_with_a(proj_x_l, proj_x_a, proj_x_a)    # Dimension (L, N, d_l)
            h_l_with_vs = self.trans_l_with_v(proj_x_l, proj_x_v, proj_x_v)    # Dimension (L, N, d_l)
            h_ls = torch.cat([h_l_with_as, h_l_with_vs], dim=2)
            h_ls = self.trans_l_mem(h_ls)
            if type(h_ls) == tuple:
                h_ls = h_ls[0]
            last_h_l = last_hs = h_ls[-1]   # Take the last output for prediction

        if self.aonly:
            # (L,V) --> A
            h_a_with_ls = self.trans_a_with_l(proj_x_a, proj_x_l, proj_x_l)
            h_a_with_vs = self.trans_a_with_v(proj_x_a, proj_x_v, proj_x_v)
            h_as = torch.cat([h_a_with_ls, h_a_with_vs], dim=2)
            h_as = self.trans_a_mem(h_as)
            if type(h_as) == tuple:
                h_as = h_as[0]
            last_h_a = last_hs = h_as[-1]

        if self.vonly:
            # (L,A) --> V
            h_v_with_ls = self.trans_v_with_l(proj_x_v, proj_x_l, proj_x_l)
            h_v_with_as = self.trans_v_with_a(proj_x_v, proj_x_a, proj_x_a)
            h_vs = torch.cat([h_v_with_ls, h_v_with_as], dim=2)
            h_vs = self.trans_v_mem(h_vs)
            if type(h_vs) == tuple:
                h_vs = h_vs[0]
            last_h_v = last_hs = h_vs[-1]
        
        last_hs, z = self.gmu(last_h_l, last_h_v, last_h_a)
        
        # A residual block
        last_hs_proj = self.proj2(F.dropout(F.relu(self.proj1(last_hs)), p=self.out_dropout, training=self.training))
        last_hs_proj += last_hs
        
        output = self.out_layer(last_hs_proj)
        return output

### Translating Multimodal Transformer

#### GMU Modification

In [None]:
class GatedMultimodalLayerFeatures(nn.Module):
    """ Gated Multimodal Layer based on 'Gated multimodal networks, Arevalo1 et al.' (https://arxiv.org/abs/1702.01992) """
    def __init__(self, size_in1, size_in2, size_out):
        super(GatedMultimodalLayerFeatures, self).__init__()
        self.size_in1, self.size_in2, self.size_out = size_in1, size_in2, size_out
        
        self.hidden1 = nn.Linear(size_in1, size_out, bias=False)
        self.hidden2 = nn.Linear(size_in2, size_out, bias=False)
        self.x_gate = nn.Linear(size_in1+size_in2, size_out, bias=False)

    def forward(self, x1, x2):
        x1_dim, x2_dim = int(x1.size(0)), int(x2.size(0))
        '''
        if x1_dim != x2_dim:
            max_dim = max(x1_dim, x2_dim)
            y  = torch.zeros((max_dim-x1.size(0), x1.size(1), x1.size(2))).cuda()
            x1 = torch.cat((x1, y), 0)
            y  = torch.zeros((max_dim-x2.size(0), x2.size(1), x2.size(2))).cuda()
            x2 = torch.cat((x2, y), 0)
        '''
        h1 = torch.tanh(self.hidden1(x1))
        h2 = torch.tanh(self.hidden2(x2))
        x_cat = torch.cat((x1, x2), dim=2)
        z = torch.sigmoid(self.x_gate(x_cat))
        
        return z*h1*x1 + (1-z)*h2*x2, torch.cat((z, (1-z)), dim=2)

In [None]:
class TranslatingMMTransformerGMUClf(nn.Module):
    def __init__(self, args):
        """
        Construct a MulT model for Text, Video frames and Audio spectrogram with GMU late fusion.
        """
        super(TranslatingMMTransformerGMUClf, self).__init__()
        self.args = args
        self.orig_d_l, self.orig_d_v, self.orig_d_a = args.orig_d_l, args.orig_d_v, args.orig_d_a
        self.d_l, self.d_a, self.d_v = 768, 768, 768
        self.v_len = args.v_len
        self.l_len = args.l_len
        self.a_len = args.a_len
        self.vonly = args.vonly
        self.lonly = args.lonly
        self.aonly = args.aonly
        self.num_heads = args.num_heads
        self.layers = args.layers
        self.attn_dropout = args.attn_dropout
        self.attn_dropout_v = args.attn_dropout_v
        self.attn_dropout_a = args.attn_dropout_a
        self.relu_dropout = args.relu_dropout
        self.res_dropout = args.res_dropout
        self.out_dropout = args.out_dropout
        self.embed_dropout = args.embed_dropout
        self.attn_mask = args.attn_mask
        
        self.enc = BertEncoder(args)
        self.audio_enc = AudioEncoder(args)

        combined_dim = self.d_l + self.d_a + self.d_v
        
        self.partial_mode = self.lonly + self.aonly + self.vonly
        if self.partial_mode == 1:
            combined_dim = 2*self.d_l   # assuming d_l == d_a == d_v
        else:
            combined_dim = 2*(self.d_l + self.d_a + self.d_v)
        combined_dim = 768 # For GMU
        
        output_dim = args.n_classes        # This is actually not a hyperparameter :-)

        # GMU blocks
        #------ GMU Middle
        self.gmu_l_m = GatedMultimodalLayerFeatures(self.d_v, self.d_a, self.d_l)
        self.gmu_v_m = GatedMultimodalLayerFeatures(self.d_l, self.d_a, self.d_v)
        self.gmu_a_m = GatedMultimodalLayerFeatures(self.d_l, self.d_v, self.d_a)
        #------ GMU Top
        self.gmu_l = GatedMultimodalLayerFeatures(self.d_l, self.d_l, self.d_l)
        self.gmu_v = GatedMultimodalLayerFeatures(self.d_v, self.d_v, self.d_l)
        self.gmu_a = GatedMultimodalLayerFeatures(self.d_a, self.d_a, self.d_l)

        # 1. Temporal convolutional layers
        self.proj_l = nn.Conv1d(self.orig_d_l, self.d_l, kernel_size=1, padding=0, bias=False)
        self.proj_v = nn.Conv1d(self.orig_d_v, self.d_v, kernel_size=1, padding=0, bias=False)
        self.proj_a = nn.Conv1d(self.orig_d_a, self.d_a, kernel_size=1, padding=0, bias=False)

        # 2. Crossmodal Attentions
        if self.lonly:
            self.trans_l_with_a = self.get_network(self_type='la')
            self.trans_l_with_v = self.get_network(self_type='lv')
            self.trans_l_with_v2a = self.get_network(self_type='lv2a')
            self.trans_l_with_a2v = self.get_network(self_type='la2v')
        if self.vonly:
            self.trans_v_with_l = self.get_network(self_type='vl')
            self.trans_v_with_a = self.get_network(self_type='va')
            self.trans_v_with_l2a = self.get_network(self_type='vl2a')
            self.trans_v_with_a2l = self.get_network(self_type='va2l')
        if self.aonly:
            self.trans_a_with_l = self.get_network(self_type='al')
            self.trans_a_with_v = self.get_network(self_type='av')
            self.trans_a_with_v2l = self.get_network(self_type='av2l')
            self.trans_a_with_l2v = self.get_network(self_type='al2v')
        
        # 3. Self Attentions (Could be replaced by LSTMs, GRUs, etc.)
        #    [e.g., self.trans_x_mem = nn.LSTM(self.d_x, self.d_x, 1)
       
        # Projection layers
        self.proj1 = nn.Linear(combined_dim, combined_dim)
        self.proj2 = nn.Linear(combined_dim, combined_dim)
        self.out_layer = nn.Linear(combined_dim, output_dim)
        
        # GMU layer for fusing text and image and audio information
        self.gmu = TextShifting3Layer(self.d_l, self.d_v, self.d_a, self.d_l)

        self.transfm_a2l = nn.Linear(200, 512)
        self.transfm_v2l = nn.Linear(200, 512)
        self.transfm_l2a = nn.Linear(512, 200)
        self.transfm_l2v = nn.Linear(512, 200)
        self.transfm_v2a = nn.Linear(200, 200)
        self.transfm_a2v = nn.Linear(200, 200)

    def get_network(self, self_type='l', layers=-1, seq_len=512, seq_len_kv=None):
        if self_type in ['l', 'al', 'vl', 'av2l', 'va2l']:
            embed_dim, attn_dropout = self.d_l, self.attn_dropout
        elif self_type in ['a', 'la', 'va', 'lv2a', 'vl2a']:
            embed_dim, attn_dropout = self.d_a, self.attn_dropout_a
        elif self_type in ['v', 'lv', 'av', 'la2v', 'al2v']:
            embed_dim, attn_dropout = self.d_v, self.attn_dropout_v
        elif self_type == 'l_mem':
            embed_dim, attn_dropout = self.d_l, self.attn_dropout
        elif self_type == 'a_mem':
            embed_dim, attn_dropout = self.d_a, self.attn_dropout
        elif self_type == 'v_mem':
            embed_dim, attn_dropout = self.d_v, self.attn_dropout
        else:
            raise ValueError("Unknown network type")
        
        return TransformerEncoder(embed_dim=embed_dim,
                                  num_heads=self.num_heads,
                                  layers=max(self.layers, layers),
                                  attn_dropout=attn_dropout,
                                  relu_dropout=self.relu_dropout,
                                  res_dropout=self.res_dropout,
                                  embed_dropout=self.embed_dropout,
                                  attn_mask=self.attn_mask)
        
    def transfm_2dim(self, x_t, dim, out_dim):
        if x_t.size(dim) != out_dim:
            if dim == 2:
                y  = torch.zeros((x_t.size(0), x_t.size(1), out_dim-x_t.size(2))).cuda()
            elif dim == 1:
                y  = torch.zeros((x_t.size(0), out_dim-x_t.size(1), x_t.size(2))).cuda()
            elif dim == 0:
                y  = torch.zeros((out_dim-x_t.size(0), x_t.size(1), x_t.size(2))).cuda()
            x_t = torch.cat((x_t, y), dim)
        
        return x_t
            
    def forward(self, txt, mask, segment, img, audio):
        """
        text, audio, and vision should have dimension [batch_size, seq_len, n_features]
        """
        x_l = self.enc(txt, mask, segment)
        x_l = F.dropout(x_l.transpose(1, 2), p=self.embed_dropout, training=self.training)
        x_v = img.transpose(1, 2)
        x_a = self.audio_enc(audio)

        # Project the textual/visual/audio features
        proj_x_l = x_l if self.orig_d_l == self.d_l else self.proj_l(x_l)
        proj_x_a = x_a if self.orig_d_a == self.d_a else self.proj_a(x_a)
        proj_x_v = x_v if self.orig_d_v == self.d_v else self.proj_v(x_v)
        proj_x_l = proj_x_l.permute(2, 0, 1)
        proj_x_a = proj_x_a.permute(2, 0, 1)
        proj_x_v = proj_x_v.permute(2, 0, 1)

        if proj_x_l.size(0) != 512:
            proj_x_l = self.transfm_2dim(proj_x_l, 0, 512)
        if proj_x_a.size(0) != 200:
            proj_x_a = self.transfm_2dim(proj_x_a, 0, 200)
        if proj_x_v.size(0) != 200:
            proj_x_v = self.transfm_2dim(proj_x_v, 0, 200)

        if self.lonly:
            # (V,A) --> L
            h_v_with_as = self.trans_v_with_a(proj_x_v, proj_x_a, proj_x_a)    # Dimension (L, N, d_v)
            h_a_with_vs = self.trans_a_with_v(proj_x_a, proj_x_v, proj_x_v)    # Dimension (L, N, d_a)
            
            # Feature Dimension Transformation
            t_h_a_with_vs = self.transfm_a2l(h_a_with_vs.permute(2, 1, 0)).permute(2, 1, 0)
            t_h_v_with_as = self.transfm_v2l(h_v_with_as.permute(2, 1, 0)).permute(2, 1, 0)
            # GMU Middle --------
            h_l_gmu, z1_l = self.gmu_l_m(t_h_v_with_as, t_h_a_with_vs)
            #h_ls = self.trans_l_with_v2a(proj_x_l, h_l_gmu, h_l_gmu)
            # GMU Top ---------
            h_l_with_v2a = self.trans_l_with_v2a(proj_x_l, h_a_with_vs, h_a_with_vs)    # Dimension (L, N, d_l)
            h_l_with_a2v = self.trans_l_with_a2v(proj_x_l, h_v_with_as, h_v_with_as)    # Dimension (L, N, d_l)
            # Residual conection
            h_l_with_v2a += t_h_a_with_vs
            h_l_with_a2v += t_h_v_with_as
            # Option 1 ---------
            h_ls_gmu, z2_l = self.gmu_l(h_l_with_a2v, h_l_with_v2a)
            h_ls = h_ls_gmu + h_l_gmu   

            last_h_l = last_hs = h_ls[-1] # Take the last output for prediction

        if self.aonly:
            # (L,V) --> A
            h_v_with_ls = self.trans_v_with_l(proj_x_v, proj_x_l, proj_x_l)
            h_l_with_vs = self.trans_l_with_v(proj_x_l, proj_x_v, proj_x_v)
            
            # Feature Dimension Transformation
            t_h_l_with_vs = self.transfm_l2a(h_l_with_vs.permute(2, 1, 0)).permute(2, 1, 0)
            # GMU Middle --------
            h_a_gmu, z1_a = self.gmu_a_m(t_h_l_with_vs, h_v_with_ls)
            #h_as = self.trans_a_with_l2v(proj_x_a, h_a_gmu, h_a_gmu)
            # GMU Top --------
            h_a_with_v2l = self.trans_a_with_v2l(proj_x_a, h_l_with_vs, h_l_with_vs)
            h_a_with_l2v = self.trans_a_with_l2v(proj_x_a, h_v_with_ls, h_v_with_ls)
            # Residual conection
            h_a_with_v2l += t_h_l_with_vs
            h_a_with_l2v += h_v_with_ls
            # Option 1 ---------
            h_as_gmu, z2_a = self.gmu_a(h_a_with_v2l, h_a_with_l2v)
            h_as = h_as_gmu + h_a_gmu

            last_h_a = last_hs = h_as[-1]

        if self.vonly:
            # (L,A) --> V
            h_a_with_ls = self.trans_a_with_l(proj_x_a, proj_x_l, proj_x_l)
            h_l_with_as = self.trans_l_with_a(proj_x_l, proj_x_a, proj_x_a)
            
            # Feature Dimension Transformation
            t_h_l_with_as = self.transfm_l2v(h_l_with_as.permute(2, 1, 0)).permute(2, 1, 0)
            # GMU Middle --------
            h_v_gmu, z1_v = self.gmu_v_m(t_h_l_with_as, h_a_with_ls)
            #h_vs = self.trans_v_with_l2a(proj_x_v, h_v_gmu, h_v_gmu)
            # GMU Top --------
            h_v_with_a2l = self.trans_v_with_a2l(proj_x_v, h_l_with_as, h_l_with_as)
            h_v_with_l2a = self.trans_v_with_l2a(proj_x_v, h_a_with_ls, h_a_with_ls)
            # Residual conection
            h_v_with_a2l += t_h_l_with_as
            h_v_with_l2a += h_a_with_ls
            # Option 1 ---------
            h_vs_gmu, z2_v = self.gmu_v(h_v_with_a2l, h_v_with_l2a)
            h_vs = h_vs_gmu + h_v_gmu

            last_h_v = last_hs = h_vs[-1]
        
        last_hs, z = self.gmu(last_h_l, last_h_v, last_h_a)
        
        # A residual block
        last_hs_proj = self.proj2(F.dropout(F.relu(self.proj1(last_hs)), p=self.out_dropout, training=self.training))
        last_hs_proj += last_hs
        
        output = self.out_layer(last_hs_proj)
        return output

## Models dictionary

In [None]:
MODELS = {
    "bert": BertClf,
    "tmmtrvpa": TranslatingMMTransformerGMUClf,
    "mmtrvpa": MMTransformerGMUClf
}

## Training helper functions

In [None]:
def save_checkpoint(state, is_best, checkpoint_path, filename="checkpoint.pt"):
    filename = os.path.join(checkpoint_path, filename)
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, os.path.join(checkpoint_path, "model_best.pt"))

In [None]:
def load_checkpoint(model, path):
    best_checkpoint = torch.load(path)
    model.load_state_dict(best_checkpoint["state_dict"])

In [None]:
def model_eval(data, model, args, criterion):
    with torch.no_grad():
        losses, preds, tgts = [], [], []
        raw_preds = []
        for batch in data:

            loss, out, tgt = model_forward(model, args, criterion, batch)
            losses.append(loss.item())

            # Predictions
            pred = torch.sigmoid(out).cpu().detach().numpy() > 0.5
            raw_preds.append(torch.sigmoid(out).cpu().detach().numpy())
        
            preds.append(pred)
            tgt = tgt.cpu().detach().numpy()
            tgts.append(tgt)

    # Get metrics
    metrics = {"loss": np.mean(losses)}
    #print(tgts[0].shape, len(tgts))
    tgts = np.vstack(tgts)
    preds = np.vstack(preds)
    raw_preds = np.vstack(raw_preds)
    metrics["auc_pr_macro"] = average_precision_score(tgts, raw_preds, average="macro")
    metrics["auc_pr_micro"] = average_precision_score(tgts, raw_preds, average="micro")

    return metrics

In [None]:
def model_forward(model, args, criterion, batch):

    txt, segment, mask, img, tgt, audio = batch

    if args.model == "bert":
        if args.use_gpu:
            txt, mask, segment = txt.cuda(), mask.cuda(), segment.cuda()
        out = model(txt, mask, segment)

    else:
        assert args.model == "mmtrvpa" or args.model == "tmmtrvpa"
        if args.use_gpu:
            txt, mask, segment = txt.cuda(), mask.cuda(), segment.cuda()
            img, audio = img.cuda(), audio.cuda()
        out = model(txt, mask, segment, img, audio)
    
    if args.use_gpu:
        tgt = tgt.cuda()
    loss = criterion(out, tgt)
    
    return loss, out, tgt

In [None]:
def train(args):

    # Model
    args.use_gpu = torch.cuda.is_available()
    model = MODELS[args.model](args)
    if args.use_gpu:
        model.cuda()

    # Loss function
    freqs = [args.label_freqs[l] for l in args.labels]
    label_weights = (torch.FloatTensor(freqs) / args.train_data_len) ** -1
    criterion = nn.BCEWithLogitsLoss(pos_weight=label_weights.cuda() if args.use_gpu else label_weights)

    # Optimizer
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

    # Scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, "max", patience=args.lr_patience, verbose=True, factor=args.lr_factor
        )
    
    # Training
    args.savedir = os.path.join('model_save', args.model)
    os.makedirs(args.savedir, exist_ok=True)
    torch.save(args, os.path.join(args.savedir, "args.pt"))
    start_epoch, global_step, n_no_improve, best_metric = 0, 0, 0, -np.inf

    for i_epoch in range(start_epoch, args.max_epochs):
        train_losses = []
        model.train()
        optimizer.zero_grad()

        for batch in tqdm(train_loader, total=len(train_loader)):

            loss, out, tgt = model_forward(model, args, criterion, batch)

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            loss.backward()
            train_losses.append(loss.item())

            # Optimizer step
            global_step += 1
            if global_step % args.gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

        model.eval()
        metrics = model_eval(val_loader, model, args, criterion)
        print("Epoch {} | Train Loss: {:.4f}".format(i_epoch+1, np.mean(train_losses)))
        print("Val auc_pr_micro: {:.4f} | Val auc_pr_macro: {:.4f}".format(metrics["auc_pr_micro"], metrics["auc_pr_macro"]))

        tuning_metric = metrics["auc_pr_micro"]

        scheduler.step(tuning_metric)
        is_improvement = tuning_metric > best_metric
        if is_improvement:
            best_metric = tuning_metric
            n_no_improve = 0
        else:
            n_no_improve += 1

        save_checkpoint(
            {
                "epoch": i_epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "scheduler": scheduler.state_dict(),
                "n_no_improve": n_no_improve,
                "best_metric": best_metric,
            },
            is_improvement,
            args.savedir,
        )

        if n_no_improve >= args.patience:
            print("\nNo improvement. Breaking out of loop.")
            break

    torch.cuda.empty_cache()
    load_checkpoint(model, os.path.join(args.savedir, "model_best.pt"))
    model.eval()

    test_metrics = model_eval(test_loader, model, args, criterion)
    print("-"*55)
    print("\nTest auc_pr_micro: {:.4f} | Test auc_pr_macro: {:.4f}".format(test_metrics["auc_pr_micro"], test_metrics["auc_pr_macro"]))
    print("-"*55)

In [None]:
def test(args):

    #_, _, test_loader = get_data_loaders(args)
    
    #if args.trained_model_dir: # load in fine-tuned (with cloze-style LM objective) model
     #   args.previous_state_dict_dir = os.path.join(args.trained_model_dir, WEIGHTS_NAME)

    args.use_gpu = torch.cuda.is_available()
    model = MODELS[args.model](args)
    if args.use_gpu:
        model.cuda()

    #Criterion
    freqs = [args.label_freqs[l] for l in args.labels]
    label_weights = (torch.FloatTensor(freqs) / args.train_data_len) ** -1
    criterion = nn.BCEWithLogitsLoss(pos_weight=label_weights.cuda() if args.use_gpu else label_weights)

    # Optimizer
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

    # Scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, "max", patience=args.lr_patience, verbose=True, factor=args.lr_factor
        )

    args.savedir = os.path.join('model_save', args.model)
    model.cuda()

    load_checkpoint(model, os.path.join(args.savedir, "model_best.pt"))
    model.eval()

    test_metrics = model_eval(test_loader, model, args, criterion)
    print("-"*55)
    print("\nTest auc_pr_micro: {:.4f} | Test auc_pr_macro: {:.4f}".format(test_metrics["auc_pr_micro"], test_metrics["auc_pr_macro"]))
    print("-"*55)

## Training Text model (BERT)

In [None]:
# Optimizer
args.lr = 5e-5
args.patience = 5
args.max_epochs = 30
args.gradient_accumulation_steps = 4

# Scheduler
args.lr_patience = 2
args.lr_factor = 0.5

# Models parameters
args.model = "bert"
args.bert_model = "bert-base-uncased"
args.hidden_sz = 768
args.attn_dropout = 0.1
args.relu_dropout = 0.1 # relu dropout
args.embed_dropout = 0.25 # embedding dropout
args.res_dropout = 0.1 # residual block dropout
args.out_dropout = 0.0 # output layer dropout

# Train
train(args)

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 1 | Train Loss: 0.3159
Val auc_pr_micro: 0.3388 | Val auc_pr_macro: 0.2752


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 2 | Train Loss: 0.3156
Val auc_pr_micro: 0.4131 | Val auc_pr_macro: 0.3705


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 3 | Train Loss: 0.3123
Val auc_pr_micro: 0.4220 | Val auc_pr_macro: 0.3566


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 4 | Train Loss: 0.3136
Val auc_pr_micro: 0.3631 | Val auc_pr_macro: 0.2794


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 5 | Train Loss: 0.3127
Val auc_pr_micro: 0.3717 | Val auc_pr_macro: 0.3156


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 6 | Train Loss: 0.3096
Val auc_pr_micro: 0.2455 | Val auc_pr_macro: 0.2608
Epoch     6: reducing learning rate of group 0 to 2.5000e-05.


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 7 | Train Loss: 0.3155
Val auc_pr_micro: 0.3258 | Val auc_pr_macro: 0.2512


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 8 | Train Loss: 0.3139
Val auc_pr_micro: 0.4181 | Val auc_pr_macro: 0.3183

No improvement. Breaking out of loop.
-------------------------------------------------------

Test auc_pr_micro: 0.4550 | Test auc_pr_macro: 0.3618
-------------------------------------------------------


## Training multimodal model (MulT GMU)

In [None]:
# Optimizer
args.lr = 5e-5
args.patience = 5
args.max_epochs = 30
args.gradient_accumulation_steps = 4

# Scheduler
args.lr_patience = 2
args.lr_factor = 0.5

# Models parameters
args.model = "mmtrvpa"
args.bert_model = "bert-base-uncased"
args.hidden_sz = 768
args.vonly = True
args.lonly = True
args.aonly = True
args.orig_d_v = 4096
args.orig_d_l = 768
args.orig_d_a = 96
args.v_len = 200
args.l_len = 512
args.a_len = 200
args.attn_dropout = 0.1
args.attn_dropout_v = 0.0 # attention dropout (for visual)
args.attn_dropout_a = 0.0 # attention dropout (for audio)
args.relu_dropout = 0.1 # relu dropout
args.embed_dropout = 0.25 # embedding dropout
args.res_dropout = 0.1 # residual block dropout
args.out_dropout = 0.0 # output layer dropout
args.nlevels = 1 # number of layers in the Transformer Encoder
args.layers = 1
args.num_heads = 1 # number of heads for the transformer Encoder
args.attn_mask = True # use attention mask for Transformer Encoder

# Train
train(args)

Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 1 | Train Loss: 0.3286
Val auc_pr_micro: 0.2274 | Val auc_pr_macro: 0.2926


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 2 | Train Loss: 0.2868
Val auc_pr_micro: 0.3920 | Val auc_pr_macro: 0.3058


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 3 | Train Loss: 0.2482
Val auc_pr_micro: 0.2609 | Val auc_pr_macro: 0.3231


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 4 | Train Loss: 0.2025
Val auc_pr_micro: 0.3469 | Val auc_pr_macro: 0.3237


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 5 | Train Loss: 0.1536
Val auc_pr_micro: 0.4452 | Val auc_pr_macro: 0.3418


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 6 | Train Loss: 0.1260
Val auc_pr_micro: 0.3620 | Val auc_pr_macro: 0.3144


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 7 | Train Loss: 0.0885
Val auc_pr_micro: 0.4434 | Val auc_pr_macro: 0.3365


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 8 | Train Loss: 0.0689
Val auc_pr_micro: 0.4295 | Val auc_pr_macro: 0.3337
Epoch     8: reducing learning rate of group 0 to 2.5000e-05.


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 9 | Train Loss: 0.0491
Val auc_pr_micro: 0.3843 | Val auc_pr_macro: 0.3242


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 10 | Train Loss: 0.0367
Val auc_pr_micro: 0.4277 | Val auc_pr_macro: 0.3251

No improvement. Breaking out of loop.


RuntimeError: ignored

In [None]:
test(args)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


-------------------------------------------------------

Test auc_pr_micro: 0.3645 | Test auc_pr_macro: 0.3119
-------------------------------------------------------


## Training Translating Multimodal Transformer

In [None]:
# Optimizer
args.lr = 5e-5
args.patience = 5
args.max_epochs = 30
args.gradient_accumulation_steps = 4

# Scheduler
args.lr_patience = 2
args.lr_factor = 0.5

# Models parameters
args.model = "tmmtrvpa"
args.bert_model = "bert-base-uncased"
args.hidden_sz = 768
args.vonly = True
args.lonly = True
args.aonly = True
args.orig_d_v = 4096
args.orig_d_l = 768
args.orig_d_a = 96
args.v_len = 200
args.l_len = 512
args.a_len = 200
args.attn_dropout = 0.1
args.attn_dropout_v = 0.0 # attention dropout (for visual)
args.attn_dropout_a = 0.0 # attention dropout (for audio)
args.relu_dropout = 0.1 # relu dropout
args.embed_dropout = 0.25 # embedding dropout
args.res_dropout = 0.1 # residual block dropout
args.out_dropout = 0.0 # output layer dropout
args.nlevels = 1 # number of layers in the Transformer Encoder
args.layers = 1
args.num_heads = 1 # number of heads for the transformer Encoder
args.attn_mask = True # use attention mask for Transformer Encoder

# Train
train(args)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


  0%|          | 0/300 [00:00<?, ?it/s]

  app.launch_new_instance()


Epoch 1 | Train Loss: 0.3196
Val auc_pr_micro: 0.3375 | Val auc_pr_macro: 0.3138


  0%|          | 0/300 [00:02<?, ?it/s]

Epoch 2 | Train Loss: 0.2796
Val auc_pr_micro: 0.3310 | Val auc_pr_macro: 0.3207


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 3 | Train Loss: 0.2426
Val auc_pr_micro: 0.4320 | Val auc_pr_macro: 0.3465


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 4 | Train Loss: 0.1915
Val auc_pr_micro: 0.4266 | Val auc_pr_macro: 0.3419


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 5 | Train Loss: 0.1417
Val auc_pr_micro: 0.4246 | Val auc_pr_macro: 0.3558


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 6 | Train Loss: 0.0929
Val auc_pr_micro: 0.4216 | Val auc_pr_macro: 0.3776
Epoch     6: reducing learning rate of group 0 to 2.5000e-05.


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 7 | Train Loss: 0.0565
Val auc_pr_micro: 0.4896 | Val auc_pr_macro: 0.4069


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 8 | Train Loss: 0.0367
Val auc_pr_micro: 0.5178 | Val auc_pr_macro: 0.4155


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 9 | Train Loss: 0.0272
Val auc_pr_micro: 0.5256 | Val auc_pr_macro: 0.4160


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 10 | Train Loss: 0.0192
Val auc_pr_micro: 0.5266 | Val auc_pr_macro: 0.4143


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 11 | Train Loss: 0.0138
Val auc_pr_micro: 0.5243 | Val auc_pr_macro: 0.4196


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 12 | Train Loss: 0.0102
Val auc_pr_micro: 0.5355 | Val auc_pr_macro: 0.4176


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 13 | Train Loss: 0.0071
Val auc_pr_micro: 0.5441 | Val auc_pr_macro: 0.4406


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 14 | Train Loss: 0.0049
Val auc_pr_micro: 0.5526 | Val auc_pr_macro: 0.4452


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 15 | Train Loss: 0.0042
Val auc_pr_micro: 0.5548 | Val auc_pr_macro: 0.4421


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 16 | Train Loss: 0.0035
Val auc_pr_micro: 0.5458 | Val auc_pr_macro: 0.4417


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 17 | Train Loss: 0.0029
Val auc_pr_micro: 0.5595 | Val auc_pr_macro: 0.4529


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 18 | Train Loss: 0.0021
Val auc_pr_micro: 0.5604 | Val auc_pr_macro: 0.4547


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 19 | Train Loss: 0.0016
Val auc_pr_micro: 0.5646 | Val auc_pr_macro: 0.4528


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 20 | Train Loss: 0.0014
Val auc_pr_micro: 0.5699 | Val auc_pr_macro: 0.4542


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 21 | Train Loss: 0.0012
Val auc_pr_micro: 0.5703 | Val auc_pr_macro: 0.4595


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 22 | Train Loss: 0.0010
Val auc_pr_micro: 0.5745 | Val auc_pr_macro: 0.4626


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 23 | Train Loss: 0.0008
Val auc_pr_micro: 0.5745 | Val auc_pr_macro: 0.4656


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 24 | Train Loss: 0.0008
Val auc_pr_micro: 0.5750 | Val auc_pr_macro: 0.4585


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 25 | Train Loss: 0.0007
Val auc_pr_micro: 0.5763 | Val auc_pr_macro: 0.4637


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 26 | Train Loss: 0.0006
Val auc_pr_micro: 0.5796 | Val auc_pr_macro: 0.4662


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 27 | Train Loss: 0.0006
Val auc_pr_micro: 0.5815 | Val auc_pr_macro: 0.4671


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 28 | Train Loss: 0.0005
Val auc_pr_micro: 0.5804 | Val auc_pr_macro: 0.4687


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 29 | Train Loss: 0.0004
Val auc_pr_micro: 0.5851 | Val auc_pr_macro: 0.4677


  0%|          | 0/300 [00:00<?, ?it/s]

Epoch 30 | Train Loss: 0.0004
Val auc_pr_micro: 0.5792 | Val auc_pr_macro: 0.4702
-------------------------------------------------------

Test auc_pr_micro: 0.5073 | Test auc_pr_macro: 0.4306
-------------------------------------------------------
