In [5]:
import torch
import torch.nn as nn
import time
from transformers import CLIPModel, AutoModelForCausalLM, AutoTokenizer

# Vision Encoder (Using CLIP)
class VisionEncoder(nn.Module):
    def __init__(self, clip_model_name="openai/clip-vit-base-patch32"):
        super(VisionEncoder, self).__init__()
        self.clip_model = CLIPModel.from_pretrained(clip_model_name)

    def forward(self, images):
        vision_outputs = self.clip_model.get_image_features(images)
        return vision_outputs

# Language Decoder (Using a Pretrained Language Model)
class LanguageDecoder(nn.Module):
    def __init__(self, language_model_name="gpt2"):
        super(LanguageDecoder, self).__init__()
        self.language_model = AutoModelForCausalLM.from_pretrained(language_model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(language_model_name)

        # Add a padding token
        if self.tokenizer.pad_token is None:
            self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            self.language_model.resize_token_embeddings(len(self.tokenizer))

    def forward(self, text_inputs, attention_mask=None):
        outputs = self.language_model(input_ids=text_inputs, attention_mask=attention_mask, output_hidden_states=True)
        # Use the last hidden state for representation
        hidden_states = outputs.hidden_states[-1]  # Shape: (batch_size, sequence_length, hidden_size)
        return hidden_states

# Additional Layers
class IntermediateLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(IntermediateLayer, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.activation = nn.ReLU()

    def forward(self, x):
        return self.activation(self.linear(x))

# Combined Vision-Language Model
class LlavaModel(nn.Module):
    def __init__(self, vision_encoder, language_decoder, intermediate_layers=3):
        super(LlavaModel, self).__init__()
        self.vision_encoder = vision_encoder
        self.language_decoder = language_decoder

        # Adjusting hidden sizes dynamically
        vision_output_size = self.vision_encoder.clip_model.config.projection_dim
        language_output_size = self.language_decoder.language_model.config.n_embd

        # Add intermediate layers
        self.intermediate_layers = nn.ModuleList([
            IntermediateLayer(vision_output_size + language_output_size, vision_output_size + language_output_size)
            for _ in range(intermediate_layers)
        ])

        self.fusion_layer = nn.Linear(vision_output_size + language_output_size, language_output_size)

    def forward(self, images, text_inputs):
        # Process vision inputs
        vision_features = self.vision_encoder(images)  # Shape: (batch_size, vision_output_size)

        # Process language inputs
        text_features = self.language_decoder(text_inputs)  # Shape: (batch_size, sequence_length, language_output_size)

        # Reduce dimensions of text features
        text_mean = text_features.mean(dim=1)  # Shape: (batch_size, language_output_size)

        # Combine features (fusion logic)
        combined_features = torch.cat([vision_features, text_mean], dim=-1)  # Shape: (batch_size, vision_output_size + language_output_size)

        # Pass through intermediate layers
        for layer in self.intermediate_layers:
            combined_features = layer(combined_features)

        fused_output = self.fusion_layer(combined_features)  # Shape: (batch_size, language_output_size)
        return fused_output

    def prune_largest_flop_layer(self):
        # Find the intermediate layer with the largest FLOPs
        max_flops = 0
        max_layer_index = -1
        input_size = self.intermediate_layers[0].linear.in_features

        for i, layer in enumerate(self.intermediate_layers):
            layer_flops = input_size * input_size
            if layer_flops > max_flops:
                max_flops = layer_flops
                max_layer_index = i

        # Remove the layer with the largest FLOPs
        if max_layer_index >= 0:
            self.intermediate_layers = nn.ModuleList([
                layer for i, layer in enumerate(self.intermediate_layers) if i != max_layer_index
            ])

# Function to calculate time
def calculate_time(model, inputs):
    start_time = time.time()
    with torch.no_grad():
        _ = model(*inputs)
    end_time = time.time()
    elapsed_time = end_time - start_time
    return elapsed_time

# Function to calculate FLOPs manually
def calculate_flops(model, inputs):
    flops = 0
    batch_size = inputs[0].shape[0]

    if isinstance(model, VisionEncoder):
        vision_output_size = model.clip_model.config.projection_dim
        flops = batch_size * vision_output_size * inputs[0].shape[2] * inputs[0].shape[3]  # Rough estimate for vision encoding

    elif isinstance(model, LanguageDecoder):
        seq_len = inputs[0].shape[1]
        hidden_size = model.language_model.config.n_embd
        flops = batch_size * seq_len * hidden_size * 2  # For self-attention and feed-forward layers

    elif isinstance(model, LlavaModel):
        # Vision FLOPs
        vision_output_size = model.vision_encoder.clip_model.config.projection_dim
        vision_flops = batch_size * vision_output_size * inputs[0].shape[2] * inputs[0].shape[3]

        # Language FLOPs
        seq_len = inputs[1].shape[1]
        hidden_size = model.language_decoder.language_model.config.n_embd
        language_flops = batch_size * seq_len * hidden_size * 2

        # Intermediate Layers FLOPs
        intermediate_layer_flops = 0
        input_size = vision_output_size + hidden_size
        for _ in model.intermediate_layers:
            intermediate_layer_flops += batch_size * input_size * input_size  # Linear layer FLOPs

        # Fusion Layer FLOPs
        fusion_input_size = vision_output_size + hidden_size
        fusion_flops = batch_size * fusion_input_size * hidden_size

        flops = vision_flops + language_flops + intermediate_layer_flops + fusion_flops

    return flops

# Instantiate, Simulate, and Evaluate
if __name__ == "__main__":
    # Initialize Vision Encoder and Language Decoder
    vision_encoder = VisionEncoder()
    language_decoder = LanguageDecoder()

    # Combine into LLaVA model with intermediate layers
    llava_model = LlavaModel(vision_encoder, language_decoder, intermediate_layers=3)

    # Simulate inputs
    dummy_images = torch.rand((2, 3, 224, 224))  # Batch of 2 images, 3 channels, 224x224 resolution
    dummy_text = ["Describe this image.", "What do you see in the picture?"]

    # Tokenize text inputs
    tokenizer = language_decoder.tokenizer
    encoded_text = tokenizer(dummy_text, return_tensors="pt", padding=True, truncation=True)

    # Test Dataset
    test_images = torch.rand((10, 3, 224, 224))  # Batch of 10 test images
    test_text = ["Test image {}".format(i) for i in range(10)]
    encoded_test_text = tokenizer(test_text, return_tensors="pt", padding=True, truncation=True)

    # Simulated labels for accuracy calculation
    test_labels = torch.randint(0, 2, (10,))  # Random binary labels for testing

    # Measure time and FLOPs for Vision Encoder
    vision_time = calculate_time(vision_encoder, (test_images,))
    vision_flops = calculate_flops(vision_encoder, (test_images,))
    print(f"Vision Encoder: Time = {vision_time:.4f}s, FLOPs = {vision_flops}")

    # Measure time and FLOPs for Language Decoder
    lang_time = calculate_time(language_decoder, (encoded_test_text.input_ids,))
    lang_flops = calculate_flops(language_decoder, (encoded_test_text.input_ids,))
    print(f"Language Decoder: Time = {lang_time:.4f}s, FLOPs = {lang_flops}")

    # Measure time, FLOPs for Full Model
    full_time = calculate_time(llava_model, (test_images, encoded_test_text.input_ids))
    full_flops = calculate_flops(llava_model, (test_images, encoded_test_text.input_ids))
    print(f"Full Model: Time = {full_time:.4f}s, FLOPs = {full_flops}")

    # Prune the layer with the largest FLOPs
    llava_model.prune_largest_flop_layer()

    # Measure time, FLOPs for the pruned model
    pruned_time = calculate_time(llava_model, (test_images, encoded_test_text.input_ids))
    pruned_flops = calculate_flops(llava_model, (test_images, encoded_test_text.input_ids))
    print(f"Pruned Model: Time = {pruned_time:.4f}s, FLOPs = {pruned_flops}")


Vision Encoder: Time = 0.5615s, FLOPs = 256901120
Language Decoder: Time = 0.0687s, FLOPs = 46080
Full Model: Time = 0.6542s, FLOPs = 315929600
Pruned Model: Time = 0.5701s, FLOPs = 299545600
