In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FusionModule(nn.Module):
    def __init__(self, visual_dim, audio_dim):
        super(FusionModule, self).__init__()
        self.conv1d_v = nn.Conv1d(visual_dim, visual_dim, kernel_size=3)
        self.conv1d_a = nn.Conv1d(audio_dim, audio_dim, kernel_size=3)

    def forward(self, visual_features, audio_features):
        visual_output = F.relu(self.conv1d_v(visual_features))
        audio_output = F.relu(self.conv1d_a(audio_features))
        
        # Multimodal Low-Rank Bilinear (MLB) pooling
        multimodal_features = visual_output * audio_output
        sign_multimodal = torch.sign(multimodal_features)
        fused_features = torch.pow(torch.abs(multimodal_features), 0.5) * sign_multimodal

        return fused_features

class CoAttention(nn.Module):
    def __init__(self, visual_dim, audio_dim):
        super(CoAttention, self).__init__()
        self.conv1d_c1 = nn.Conv1d(visual_dim, visual_dim, kernel_size=1)
        self.conv1d_c2 = nn.Conv1d(audio_dim, audio_dim, kernel_size=1)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, visual_features, audio_features):
        visual_mapped = F.relu(self.conv1d_c1(visual_features))
        audio_mapped = F.relu(self.conv1d_c2(audio_features))

        # Compute attention map
        attention_map = torch.matmul(visual_mapped.permute(0, 2, 1), audio_mapped)
        attention_map = self.softmax(attention_map)

        # Apply attention to features
        visual_attention = torch.matmul(attention_map, audio_mapped.permute(0, 2, 1))
        audio_attention = torch.matmul(attention_map.permute(0, 2, 1), visual_mapped)

        # Concatenate with original features
        visual_output = visual_features + visual_attention.permute(0, 2, 1)
        audio_output = audio_features + audio_attention.permute(0, 2, 1)

        return visual_output, audio_output

class SelfAttention(nn.Module):
    def __init__(self, input_dim):
        super(SelfAttention, self).__init__()
        self.conv1d_s = nn.Conv1d(input_dim, input_dim, kernel_size=1)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, input_features):
        input_mapped = F.relu(self.conv1d_s(input_features))

        # Compute attention map
        attention_map = torch.matmul(input_mapped, input_mapped.permute(0, 2, 1))
        attention_map = self.softmax(attention_map)

        # Apply attention to features
        output = torch.matmul(attention_map, input_mapped)

        return output

class AttentionModule(nn.Module):
    def __init__(self, visual_dim, audio_dim):
        super(AttentionModule, self).__init__()
        self.fusion_module = FusionModule(visual_dim, audio_dim)
        self.co_attention = CoAttention(visual_dim, audio_dim)
        self.visual_self_attention = SelfAttention(visual_dim)
        self.audio_self_attention = SelfAttention(audio_dim)

    def forward(self, visual_features, audio_features):
        fused_features = self.fusion_module(visual_features, audio_features)
        visual_co, audio_co = self.co_attention(*fused_features)
        visual_self = self.visual_self_attention(visual_co)
        audio_self = self.audio_self_attention(audio_co)

        # Concatenate features
        visual_concat = torch.cat((visual_features, visual_self), dim=1)
        audio_concat = torch.cat((audio_features, audio_self), dim=1)

        return visual_concat, audio_concat

class MyNetwork(nn.Module):
    def __init__(self, visual_dim, audio_dim):
        super(MyNetwork, self).__init__()
        self.attention_module = AttentionModule(visual_dim, audio_dim)
        self.fc = nn.Linear(visual_dim + audio_dim, 10)  # Example output size

    def forward(self, visual_input, audio_input):
        visual_output, audio_output = self.attention_module(visual_input, audio_input)
        # Concatenate features
        combined_features = torch.cat((visual_output, audio_output), dim=1)
        output = self.fc(combined_features)
        return output

# Example usage:
visual_features = torch.randn(1, 16, 64)  # Example visual features tensor with shape [batch_size, channels, time_steps]
audio_features = torch.randn(1, 16, 64)   # Example audio features tensor with shape [batch_size, channels, time_steps]

model = MyNetwork(16, 16)
output = model(visual_features, audio_features)
print(output.shape)  # Print the shape of the output tensor


Extracting visual features:   0%|          | 0/30 [00:00<?, ?image/s]


NameError: name 'Image' is not defined