In [None]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from typing import Dict
import numpy as np

class SummarizationModule:
    def __init__(self, model_name='t5-small', device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.tokenizer = T5Tokenizer.from_pretrained(model_name)
        self.model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)
        self.model.eval()

        # Optimize model for inference
        self.model = torch.jit.script(self.model)  # JIT compilation for faster inference

    @torch.no_grad()
    def generate_summary(self, text: str, max_length: int = 150) -> str:
        # Prepare input
        input_ids = self.tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=512, truncation=True).to(self.device)

        # Generate summary
        summary_ids = self.model.generate(
            input_ids,
            max_length=max_length,
            num_beams=4,
            no_repeat_ngram_size=2,
            early_stopping=True
        )

        # Decode summary
        summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
        return summary

    def process(self, fused_features: np.ndarray, task_routing: Dict[str, float], original_text: str) -> Dict:
        summarization_prob = task_routing.get('Summarization', 0)
        
        if summarization_prob > 0.5:  # Threshold for triggering summarization
            summary = self.generate_summary(original_text)
            return {
                'summary': summary,
                'confidence': summarization_prob
            }
        else:
            return {
                'summary': None,
                'confidence': summarization_prob
            }

class MultimodalProcessor:
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.fusion_layer = MultimodalFusionLayer().to(device)
        self.routing_layer = TaskRoutingLayer().to(device)
        self.summarization_module = SummarizationModule(device=device)

    @torch.no_grad()
    def process(self, image_output: Dict, text_output: Dict) -> Dict:
        # ... (previous fusion and routing code) ...

        # Process with Summarization module
        original_text = text_output['ocr_processed']['preprocessed_text'] + ' ' + text_output['caption_processed']['preprocessed_text']
        summarization_result = self.summarization_module.process(fused_features.cpu().numpy(), task_routing, original_text)

        return {
            'fused_features': fused_features.cpu().numpy(),
            'task_routing': task_routing,
            'summarization': summarization_result
        }

# Usage
def main():
    # Simulated outputs from previous pipelines
    image_output = {
        'object_detection': [[0, 0, 100, 100, 0.9, 1]],
        'classification': 5,
        'classification_features': [0.1] * 768,  # Simulated feature vector
        'ocr': ['Hello', 'World'],
        'caption': 'A computer screen displaying text'
    }
    
    text_output = {
        'ocr_processed': {
            'preprocessed_text': 'hello world',
            'language': 'en',
            'embedding': torch.randn(512)
        },
        'caption_processed': {
            'preprocessed_text': 'computer screen displaying text',
            'language': 'en',
            'embedding': torch.randn(512)
        }
    }
    
    processor = MultimodalProcessor()
    result = processor.process(image_output, text_output)
    
    print("Fused Features Shape:", result['fused_features'].shape)
    print("\nTask Routing Probabilities:")
    for task, prob in result['task_routing'].items():
        print(f"{task}: {prob:.4f}")
    
    print("\nSummarization Result:")
    print(f"Confidence: {result['summarization']['confidence']:.4f}")
    if result['summarization']['summary']:
        print(f"Summary: {result['summarization']['summary']}")
    else:
        print("No summary generated (confidence below threshold)")

if __name__ == "__main__":
    main()