Multi-modal networks usually referred to architectures that can handle multiple inputs simultaneously, such as combining images and text, or processing multiple feature sets. This notebooks simplifies the concept with some simple examples.

These networks usually take two or more different inputs (possibly of different types or shapes) and combine them somewhere in the architecture. They’re widely used in:
* Vision + Language tasks (e.g., image captioning, visual question answering)
* Tabular + Text or Image + Metadata fusion
* Siamese or Triplet Networks for similarity comparison
* Multi-modal learning

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

#### Architectures

<b>Example 1: Tabular + Image Input</b>

Such networks takes both an image and a tabular feature vector (e.g., age, temperature, metadata) as an input, then fuses them to perform a certain task (i.e. classification).

```text
[Image] -----> [CNN] ---------\
                               +--> [Fusion] --> [FC] --> [Output]
[Tabular] --> [Dense Layer] --/
```

In [4]:
class ImageTabularNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Image branch/encoder (simple CNN)
        self.cnn_branch = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((8, 8)),
            nn.Flatten()
        )
        
        # Tabular branch/encoder
        self.tabular_branch = nn.Sequential(
            nn.Linear(5, 32),  # Assume 5 tabular features
            nn.ReLU()
        )
        
        # Fusion + classifier
        self.classifier = nn.Sequential(
            nn.Linear(16 * 8 * 8 + 32, 128),
            nn.ReLU(),
            nn.Linear(128, 3)  # 3 classes
        )

    def forward(self, image, tabular):
        image_feat = self.cnn_branch(image)
        tabular_feat = self.tabular_branch(tabular)
        combined = torch.cat((image_feat, tabular_feat), dim=1)
        return self.classifier(combined)

In the above network, the input shapes is `[batch_size, 3, 64, 64]` for image and `[batch_size, 5]` for tabular

<b>Example 2: Text + Image Input (e.g., CLIP-style)</b>

This pattern encodes an image and a text input separately and compares their embeddings.

```text
[Image] -----> [CNN/ViT] -----------\
                                     +--> [Similarity / Logits]
[Text] -----> [LSTM/BERT] ----------/
```

In [5]:
# A close implementation to clip architecture
class ImageTextModel(nn.Module):
    def __init__(self):
        super().__init__()

        self.image_encoder = ...
        self.text_encoder = ...

    def encode_images(self, batch):
        return self.image_encoder(torch.flatten(batch, 1))

    def encode_text(self, batch):
        return self.text_encoder(batch['text_tokens'])

    def forward(self, batch, **kwargs):
        '''Forward pass of the model training.'''
        # extract the features
        image_features = self.encode_images(batch)                      # [n, image_features_dimension]
        text_features  = self.encode_text(batch, text_pooling='eos')    # [n, model_output_dimension]

        # linear projection to map from each encoder’s representation to the multi-modal embedding space.
        image_embeddings = self.image_projection_layer(image_features) if self.image_projection_layer is not None else image_features  # [n, output_projection_dimension]
        text_embeddings  = self.text_projection_layer(text_features)   if self.text_projection_layer  is not None else text_features  # [n, output_projection_dimension]

        # normalise the embeddings
        image_embeddings = image_embeddings / image_embeddings.norm(dim=1, keepdim=True) # [n, output_projection_dimension]
        text_embeddings  = text_embeddings  / text_embeddings.norm(dim=1, keepdim=True)  # [n, output_projection_dimension]

        # cosine similarity as logits
        logit_scale = self.logit_scale.exp() # clamp the logit_scale?

        logits_per_image = logit_scale * image_embeddings @ text_embeddings.t() # [n, n]
        logits_per_text  = logit_scale * text_embeddings @ image_embeddings.t() # [n, n]

        output = {
            "image_embeddings": image_embeddings,
            "text_embeddings": text_embeddings,
            "logit_scale": logit_scale,
            "logits_per_image": logits_per_image,
            "logits_per_text": logits_per_text
        }

        return output

<b>Example 3: Siamese Network with 2 Inputs</b>

Two inputs of the same kind, passed through shared weights to compute similarity (e.g., face verification).

In [6]:
class SiameseNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(100, 64),
            nn.ReLU(),
            nn.Linear(64, 32)
        )

    def forward(self, input1, input2):
        emb1 = self.encoder(input1)
        emb2 = self.encoder(input2)
        # Compute L1 or cosine similarity
        return F.pairwise_distance(emb1, emb2)

<b>Example 4: Multi-Input with nn.ModuleDict or ModuleList</b>

We can use `nn.ModuleDict` to dynamically define input branches based on input types:

In [7]:
class MultiBranchNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.branches = nn.ModuleDict({
            'img': nn.Sequential(nn.Conv2d(1, 8, 3), nn.Flatten()),
            'text': nn.Sequential(nn.Linear(300, 128), nn.ReLU()),
            'meta': nn.Sequential(nn.Linear(10, 32))
        })
        self.classifier = nn.Linear(8*26*26 + 128 + 32, 2)

    def forward(self, inputs):
        # inputs: dict with keys 'img', 'text', 'meta'
        features = [self.branches[k](v) for k, v in inputs.items()]
        combined = torch.cat(features, dim=1)
        return self.classifier(combined)

#### Data Handling

We'll usually need to modify the training loop and dataloader to load both modalities. For example, in the code below, we load a sample of image and tabular data from for a single batch iteration.

```python
for batch in dataloader:
    img = batch['image'].to(device)
    tab = batch['tabular'].to(device)
    labels = batch['label'].to(device)
    outputs = model(img, tab)
    loss = criterion(outputs, labels)
```

Both modalities data are passed to the model, which will go through the `forward` function of the model class.