# Packages Installation

In [None]:
import torch
import torch.nn as nn
from typing import List
from typing import Dict
import torch.nn.functional as F
import pandas as pd
import numpy as np
import itertools

ModuleNotFoundError: No module named 'torch'

# Data fusion

## 1. Cross Attention

### 1. Dynamic Sequential Cross Attention


This is a single-head sequential cross attention function. The number of input modalities is dynamic. The shape of input modalities[ ] must be (batch_size, sequence_length, feature_dim). The main modality always serves as the Query, iteratively attending to other modalities as Key-Value pairs, with residual connections and layer normalization to preserve original features.

In [2]:
class DynamicSequentialCrossAttention(nn.Module):
    def __init__(self, d_model: int, total_modalities: int):
        super().__init__()
        self.d_model = d_model
        self.total_modalities = total_modalities  # Total number of modality

        # Number of Cross-Attention layers = total_modalities - 1 (A→B, A→C, etc.)
        self.cross_attentions = nn.ModuleList([
            nn.MultiheadAttention(embed_dim=d_model, num_heads=1, batch_first=True)
            for _ in range(total_modalities - 1)
        ])

        # Normalization layers
        self.norms = nn.ModuleList([
            nn.LayerNorm(d_model)
            for _ in range(total_modalities - 1)
        ])

    def forward(self, *modalities: List[torch.Tensor]):
        """
        Args:
            modalities: List of tensors where:
                - modalities[0]: Main modality A [B, N, d_model] (Query)
                - modalities[1:]: Auxiliary modalities [B, M_i, d_model] (Keys/Values)
        Returns:
            a_enhanced: [B, N, d_model] (Enhanced main modality)
        """
        assert len(modalities) == self.total_modalities, \
            f"Expected {self.total_modalities} modalities (including A), got {len(modalities)}"

        a = modalities[0]  # Main modality A

        # Sequentially fuse each auxiliary modality
        for i in range(1, self.total_modalities):
            # Current auxiliary modality (B, C, etc.)
            current_modality = modalities[i]

            # Cross-Attention: A as Query, current modality as Key/Value
            attn_output, _ = self.cross_attentions[i-1](
                query=a,
                key=current_modality,
                value=current_modality
            )

            # # Residual connection + LayerNorm
            # a = a + attn_output
            # a = self.norms[i-1](a)

        return a

NameError: name 'nn' is not defined

Example

In [None]:
model = DynamicSequentialCrossAttention(d_model=256, total_modalities=3) # a, b, c, 3 modalities align to a

# Inputs (order matters: A first, then B, C)
a = torch.randn(32, 10, 256)  # Main modality (Query), it should be clinical data
b = torch.randn(32, 20, 256)  # Auxiliary modality 1
c = torch.randn(32, 15, 256)  # Auxiliary modality 2

output = model(a, b, c)  # A enhanced by B and C sequentially
print(output.shape)

torch.Size([32, 10, 256])


### 2. Multihead Cross Attention

In [None]:

# device = torch.device("cuda")
# torch.cuda.set_device(0)
# torch.cuda.synchronize()
class FowardNetwork(nn.Module):
    def __init__(self, embed_dim):
        super(FowardNetwork, self).__init__()
        self.Fc1 = nn.Linear(embed_dim, embed_dim, bias=True)
        self.Fc2 = nn.Linear(embed_dim, embed_dim, bias=True)

    def forward(self, x):
        x = F.silu(self.Fc1(x))
        x = F.silu(self.Fc2(x))
        return x


class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, batch_size):
        super(CrossAttention, self).__init__()
        self.dropout = 0.2
        self.batch_size = batch_size
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = self.embed_dim // self.num_heads
        self.W_q = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_k = nn.Linear(embed_dim, embed_dim, bias=False)
        self.W_v = nn.Linear(embed_dim, embed_dim, bias=False)
        self.O_layer = nn.Linear(embed_dim, embed_dim, bias=False)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.drop1 = nn.Dropout(self.dropout)
        self.drop2 = nn.Dropout(self.dropout)
        self.alpha = nn.Parameter(torch.tensor(0.2))
        self.belta = nn.Parameter(torch.ones(num_heads))
        self.fowNet = FowardNetwork(self.embed_dim)

    def split_heads(self, x):
        x = x.view(self.batch_size, -1, self.num_heads, self.head_dim)
        return x.permute(0, 2, 1, 3)

    def scaled_dot_product_attention(self, Q, K, V):
        scores = (torch.matmul(Q, K.transpose(-1, -2)) /
                  torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float64)))
        attn_weights = F.softmax(scores, dim=-1)
        original_mask = torch.zeros_like(attn_weights)
        mask_indices = (attn_weights >= self.alpha).float()
        natural_index = torch.arange(0, attn_weights.size(3))
        natural_index = natural_index[None,None,None,:].expand(self.batch_size,
                                                self.num_heads,attn_weights.size(2), -1)
        original_mask.scatter_(-1, natural_index, src=mask_indices)
        attn_weights = attn_weights * original_mask
        attn_weights_adjusted = F.softmax(attn_weights, dim=-1)

        attn_output = torch.matmul(attn_weights_adjusted, V)
        return attn_output, attn_weights

    def forward(self, query, key, value):

        Q = self.split_heads(self.W_q(query))
        K = self.split_heads(self.W_k(key))
        V = self.split_heads(self.W_v(value))
        attn_output, atten_maps = self.scaled_dot_product_attention(Q, K, V)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(self.batch_size, query.size(1), self.embed_dim)
        attn_output = self.O_layer(attn_output)
        attn_output = self.norm1(query + self.drop1(attn_output))
        inter_output = self.fowNet(attn_output)
        final_output = self.norm2(attn_output + self.drop2(inter_output))

        return final_output, atten_maps

In [None]:
#Example
embed_dim=256
num_heads=8
batch_size=32
model= CrossAttention(embed_dim, num_heads, batch_size)
a = torch.randn(32, 12, 256)  # A as main
b = torch.randn(32, 12, 256)   # B
c = torch.randn(32, 12, 256)  # C
integated_data, attn_weights = model.forward(a, b, c)

### 3. Dynamic Sequential Multihead Cross Attention

This is the combination of sequential and multihead cross attention.

The difference with sequential is cross attention part in forward().

In [None]:
class DynamicSequentialMultiheadCrossAttention(nn.Module):
    def __init__(self, d_model: int, total_modalities: int, embed_dim: int, num_heads: int, batch_size: int):
        super().__init__()
        self.d_model = d_model
        self.total_modalities = total_modalities  # Total number of modality
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.batch_size = batch_size

        # Initial multihead cross attention models
        self.cross_attentions = nn.ModuleList([
            CrossAttention(embed_dim=self.embed_dim, num_heads=self.num_heads, batch_size=self.batch_size)
            for _ in range(self.total_modalities - 1)
        ])

    def forward(self, *modalities: List[torch.Tensor]):
        """
        Args:
            modalities: List of tensors where:
                - modalities[0]: Main modality A [B, N, d_model] (Query)
                - modalities[1:]: Auxiliary modalities [B, M_i, d_model] (Keys/Values)
        Returns:
            a_enhanced: [B, N, d_model] (Enhanced main modality)
        """
        assert len(modalities) == self.total_modalities, \
            f"Expected {self.total_modalities} modalities (including A), got {len(modalities)}"

        a = modalities[0]  # Main modality A

        # Sequentially fuse each auxiliary modality
        for i in range(1, self.total_modalities):
            # Current auxiliary modality (B, C, etc.)
            current_modality = modalities[i]
            # Cross-Attention: A as Query, current modality as Key/Value
            a, attn_weights = self.cross_attentions[i - 1](a, current_modality, current_modality)

        return a, attn_weights

Example

In [None]:
model = DynamicSequentialMultiheadCrossAttention(d_model=256, total_modalities=3, embed_dim=256, num_heads=8, batch_size=32)

a = torch.randn(32, 10, 256)  # A as main
b = torch.randn(32, 8, 256)   # B
c = torch.randn(32, 12, 256)  # C

output, attn_weights = model(a, b, c)
print(output.shape)

torch.Size([32, 10, 256])
