# Benchmarking TinyLlama for Inference

In [10]:
import torch
import time
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch.nn as nn
import torch.nn.functional as F  # For cosine similarity
from torchvision import models

device = "cpu"

In [12]:
# TinyLLAVA model class
class TinyLLAVA(nn.Module):
    def __init__(self, vision_encoder, projection_head, text_decoder):
        super(TinyLLAVA, self).__init__()
        self.vision_encoder = vision_encoder
        self.projection_head = projection_head
        self.text_decoder = text_decoder

    def forward(self, image, input_ids, attention_mask):
        visual_features = self.vision_encoder(image)
        projected_features = self.projection_head(visual_features)
        outputs = self.text_decoder(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.logits


In [15]:
vision_encoder = models.mobilenet_v3_small()
vision_encoder.classifier[-1] = torch.nn.Linear(vision_encoder.classifier[-1].in_features, 768)

vision_encoder.eval()

for param in vision_encoder.parameters():
    param.requires_grad = False

print("Vision Encoder Ready")


Vision Encoder Ready


In [16]:
# Dummy projection head
projection_head = nn.Linear(768, 768)

# Load the text decoder and tokenizer
text_decoder = AutoModelForCausalLM.from_pretrained("distilgpt2").to(device)
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")

# Initialize TinyLLAVA
tiny_llava = TinyLLAVA(vision_encoder, projection_head, text_decoder).to(device)

# Dummy inputs for testing
dummy_image = torch.randn(1, 3, 224, 224).to(device)  # Single RGB image
dummy_input_ids = torch.randint(0, 50257, (1, 10)).to(device)  # Random token IDs
dummy_attention_mask = torch.ones_like(dummy_input_ids).to(device)


In [17]:
# Function to measure inference time
def measure_time(model, image, input_ids, attention_mask, iterations=10):
    model.eval()
    total_time = 0
    with torch.no_grad():
        for _ in range(iterations):
            start_time = time.time()
            _ = model(image, input_ids, attention_mask)
            total_time += time.time() - start_time
    avg_time = total_time / iterations
    return avg_time



In [18]:
# Function to evaluate cosine similarity
def evaluate_cosine_similarity(unquantized_model, quantized_model, iterations=10):
    unquantized_model.eval()
    quantized_model.eval()

    similarities = []

    with torch.no_grad():
        for i in range(iterations):
            # Randomize inputs for each iteration
            random_image = torch.randn(1, 3, 224, 224).to(device)
            random_input_ids = torch.randint(0, 50257, (1, 10)).to(device)
            random_attention_mask = torch.ones_like(random_input_ids).to(device)

            # Get model outputs
            unquantized_output = unquantized_model(random_image, random_input_ids, random_attention_mask)
            quantized_output = quantized_model(random_image, random_input_ids, random_attention_mask)

            # Flatten outputs for cosine similarity
            unquantized_output_flat = unquantized_output.view(-1)
            quantized_output_flat = quantized_output.view(-1)

            # Compute cosine similarity
            similarity = F.cosine_similarity(
                unquantized_output_flat.unsqueeze(0),
                quantized_output_flat.unsqueeze(0),
                dim=1
            ).item()

            similarities.append(similarity)
            print(f"Iteration {i + 1}: Cosine Similarity = {similarity:.4f}")

    avg_cosine_similarity = sum(similarities) / len(similarities)
    print(f"Average Cosine Similarity: {avg_cosine_similarity:.4f}")
    return avg_cosine_similarity


In [19]:
def apply_dynamic_quantization(model):
    return torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)


In [20]:
# Measure unquantized, uncompiled
print("Measuring unquantized, uncompiled model...")
unquantized_time = measure_time(tiny_llava, dummy_image, dummy_input_ids, dummy_attention_mask)
print(f"Unquantized, Uncompiled Avg Time: {unquantized_time:.4f} seconds")


Measuring unquantized, uncompiled model...
Unquantized, Uncompiled Avg Time: 0.5712 seconds


In [21]:
# Measure quantized, uncompiled
print("Measuring quantized, uncompiled model...")
quantized_tiny_llava = apply_dynamic_quantization(tiny_llava)
quantized_time = measure_time(quantized_tiny_llava, dummy_image, dummy_input_ids, dummy_attention_mask)
print(f"Quantized, Uncompiled Avg Time: {quantized_time:.4f} seconds")


Measuring quantized, uncompiled model...
Quantized, Uncompiled Avg Time: 0.1436 seconds


In [22]:
# Evaluate cosine similarity between unquantized and quantized models
print("\nEvaluating cosine similarity between unquantized and quantized models...")
avg_cosine_similarity = evaluate_cosine_similarity(
    unquantized_model=tiny_llava,
    quantized_model=quantized_tiny_llava,
    iterations=10
)
print(f"Average Cosine Similarity: {avg_cosine_similarity:.4f}")


Evaluating cosine similarity between unquantized and quantized models...
Iteration 1: Cosine Similarity = 0.9998
Iteration 2: Cosine Similarity = 0.9998
Iteration 3: Cosine Similarity = 0.9998
Iteration 4: Cosine Similarity = 0.9997
Iteration 5: Cosine Similarity = 0.9998
Iteration 6: Cosine Similarity = 0.9997
Iteration 7: Cosine Similarity = 0.9998
Iteration 8: Cosine Similarity = 0.9998
Iteration 9: Cosine Similarity = 0.9998
Iteration 10: Cosine Similarity = 0.9998
Average Cosine Similarity: 0.9998
Average Cosine Similarity: 0.9998


In [23]:
import torch

def trace_model(model, example_inputs):
    # Trace the model using torch.jit.trace
    traced_model = torch.jit.trace(model, example_inputs)
    return traced_model

# Example inputs: Adjust according to your model's input format
# Example inputs: Adjusting to your model's expected input
example_inputs = (dummy_image, dummy_input_ids, dummy_attention_mask)

# Trace the TinyLLaVA model
traced_tiny_llava = trace_model(tiny_llava, example_inputs)

# Measure time for the traced model (unquantized)
traced_unquantized_time = measure_time(traced_tiny_llava, dummy_image, dummy_input_ids, dummy_attention_mask)
print(f"Unquantized, Traced Avg Time: {traced_unquantized_time:.4f} seconds")

`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
  if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
  if past_key_values_length > 0:


Unquantized, Traced Avg Time: 0.1193 seconds


In [24]:
# Trace the TinyLLaVA model
traced_quantized_tiny_llava = trace_model(quantized_tiny_llava, example_inputs)

# Measure time for the traced model (quantized)
traced_quantized_time = measure_time(traced_quantized_tiny_llava, dummy_image, dummy_input_ids, dummy_attention_mask)
print(f"Quantized, Traced Avg Time: {traced_quantized_time:.4f} seconds")

Quantized, Traced Avg Time: 0.0735 seconds


In [32]:
def get_quantized_model_info(model, quantized_layers=[]):
    total_params = 0
    trainable_params = 0
    param_memory = 0
    
    for name, param in model.named_parameters():
        param_count = param.numel()
        total_params += param_count
        if param.requires_grad:
            trainable_params += param_count
        
        # Check if the layer is quantized
        if any(layer in name for layer in quantized_layers):
            # Assume int8 for quantized layers (1 byte per parameter)
            layer_memory = param_count * 1 / (1024 ** 2)
        else:
            # Default to float32 for other layers (4 bytes per parameter)
            layer_memory = param_count * 4 / (1024 ** 2)
        
        param_memory += layer_memory
    
    print(f"Model Total Parameters: {total_params:,}")
    print(f"Trainable Parameters: {trainable_params:,}")
    print(f"Estimated Memory for Parameters: {param_memory:.2f} MB")
    
    return total_params, trainable_params, param_memory


In [34]:
quantized_layers = ["linear"]  # Adjust this based on your layer naming
get_quantized_model_info(tiny_llava, quantized_layers)

Model Total Parameters: 84,808,224
Trainable Parameters: 82,503,168
Estimated Memory for Parameters: 323.52 MB


(84808224, 82503168, 323.5177001953125)