In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, List

class MultimodalFusionLayer(nn.Module):
    def __init__(self, image_dim=768, text_dim=512, fusion_dim=512):
        super(MultimodalFusionLayer, self).__init__()
        self.image_projection = nn.Linear(image_dim, fusion_dim)
        self.text_projection = nn.Linear(text_dim, fusion_dim)
        self.fusion = nn.MultiheadAttention(embed_dim=fusion_dim, num_heads=8)
        self.norm = nn.LayerNorm(fusion_dim)
        
    def forward(self, image_features, text_features):
        # Project image and text features to the same dimension
        image_proj = self.image_projection(image_features).unsqueeze(0)  # [1, batch_size, fusion_dim]
        text_proj = self.text_projection(text_features).unsqueeze(0)  # [1, batch_size, fusion_dim]
        
        # Concatenate image and text features
        multimodal_features = torch.cat([image_proj, text_proj], dim=0)  # [2, batch_size, fusion_dim]
        
        # Apply multi-head attention
        attn_output, _ = self.fusion(multimodal_features, multimodal_features, multimodal_features)
        
        # Add residual connection and normalize
        fused_features = self.norm(multimodal_features + attn_output)
        
        return fused_features.mean(dim=0)  # [batch_size, fusion_dim]

class TaskRoutingLayer(nn.Module):
    def __init__(self, input_dim=512, num_tasks=7):
        super(TaskRoutingLayer, self).__init__()
        self.task_projection = nn.Linear(input_dim, num_tasks)
        self.task_names = [
            "Summarization",
            "Question_Answering",
            "Code_Generation",
            "Translation",
            "Paraphrasing",
            "Sentiment_Analysis",
            "Grammar_Correction"
        ]
        
    def forward(self, fused_features):
        task_scores = self.task_projection(fused_features)
        task_probs = F.softmax(task_scores, dim=-1)
        return task_probs

class MultimodalProcessor:
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.fusion_layer = MultimodalFusionLayer().to(device)
        self.routing_layer = TaskRoutingLayer().to(device)
        
    @torch.no_grad()
    def process(self, image_output: Dict, text_output: Dict) -> Dict:
        # Extract relevant features
        image_features = torch.tensor(image_output['classification_features'], device=self.device)
        ocr_embedding = text_output['ocr_processed']['embedding'].to(self.device)
        caption_embedding = text_output['caption_processed']['embedding'].to(self.device)
        
        # Combine text embeddings (you might want to use a more sophisticated method)
        text_features = (ocr_embedding + caption_embedding) / 2
        
        # Fuse multimodal features
        fused_features = self.fusion_layer(image_features, text_features)
        
        # Route to tasks
        task_probabilities = self.routing_layer(fused_features)
        
        # Prepare output
        task_routing = {task: prob.item() for task, prob in zip(self.routing_layer.task_names, task_probabilities[0])}
        
        return {
            'fused_features': fused_features.cpu().numpy(),
            'task_routing': task_routing
        }

# Usage
def main():
    # Simulated outputs from previous pipelines
    image_output = {
        'object_detection': [[0, 0, 100, 100, 0.9, 1]],
        'classification': 5,
        'classification_features': [0.1] * 768,  # Simulated feature vector
        'ocr': ['Hello', 'World'],
        'caption': 'A computer screen displaying text'
    }
    
    text_output = {
        'ocr_processed': {
            'preprocessed_text': 'hello world',
            'language': 'en',
            'embedding': torch.randn(512)
        },
        'caption_processed': {
            'preprocessed_text': 'computer screen displaying text',
            'language': 'en',
            'embedding': torch.randn(512)
        }
    }
    
    processor = MultimodalProcessor()
    result = processor.process(image_output, text_output)
    
    print("Fused Features Shape:", result['fused_features'].shape)
    print("\nTask Routing Probabilities:")
    for task, prob in result['task_routing'].items():
        print(f"{task}: {prob:.4f}")

if __name__ == "__main__":
    main()