In [1]:
import torch
from torchvision import transforms
from PIL import Image
from transformers import AutoTokenizer

import torch.nn as nn
import torch
import torch.nn as nn
from torchvision import models

import os
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
class TinyLLAVA(torch.nn.Module):
    def __init__(self, vision_encoder, projection_head, text_decoder, tokenizer, max_seq_length=4096, device="cuda"):
        super(TinyLLAVA, self).__init__()
        self.vision_encoder = vision_encoder
        self.projection_head = projection_head
        self.text_decoder = text_decoder
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length

        self.device = device

        self.vision_encoder.to(device)
        self.projection_head.to(device)
        self.text_decoder.to(device)
        for param in self.vision_encoder.parameters():
            param.requires_grad = False

    def forward(self, image, input_ids, attention_mask):
        # Extract visual features
        with torch.no_grad():
            visual_features = self.vision_encoder(image)  # Shape: (batch_size, vision_feature_dim)
    
        # Project visual features to text embedding space
        projected_features = self.projection_head(visual_features).to(self.device)  # Move to the same device
    
        # Embed input tokens
        token_embeddings = self.text_decoder.transformer.wte(input_ids).to(self.device)  # Shape: (batch_size, seq_len, embedding_dim)
    
        # Combine visual features with token embeddings
        combined_embeddings = torch.cat(
            [projected_features.unsqueeze(1), token_embeddings], dim=1
        ).to(self.device)
    
        # Adjust attention mask to include visual tokens
        _ones = torch.ones((attention_mask.size(0), 1)).to(self.device)
        extended_attention_mask = torch.cat(
            [_ones, attention_mask], dim=1
        ).to(self.device)
    
        # Truncate combined embeddings and attention mask to max_seq_length if needed
        # if combined_embeddings.size(1) > self.max_seq_length:
        #     # print("T")
        #     combined_embeddings = combined_embeddings[:, :self.max_seq_length]
        #     extended_attention_mask = extended_attention_mask[:, :self.max_seq_length]
    
        # Forward pass through the text decoder
        outputs = self.text_decoder(
            inputs_embeds=combined_embeddings,
            attention_mask=extended_attention_mask
        )
        outputs.logits = outputs.logits[:, 1:, :]  # Get rid of distillgpt input token
        return outputs



In [3]:
device = "cpu"
vision_encoder = models.mobilenet_v3_small()
vision_encoder.classifier[-1] = torch.nn.Linear(vision_encoder.classifier[-1].in_features, 768)

vision_encoder.load_state_dict(torch.load('./mobilenetv3_student_model.pth', map_location=torch.device('cpu'), weights_only=True))
vision_encoder.eval()

for param in vision_encoder.parameters():
    param.requires_grad = False

print("Vision Encoder Ready")


llm = AutoModelForCausalLM.from_pretrained("distilgpt2").to(device)
llm.lm_head = nn.Linear(in_features=768, out_features=32000) # llava out features

# Set the new max position embeddings
llm.config.max_position_embeddings = 4096  # Update the max position embeddings

# Resize positional embeddings (wpe) to match new max_position_embeddings
old_embeddings = llm.transformer.wpe.weight.data  # Original embeddings
new_seq_length = llm.config.max_position_embeddings  # Desired sequence length

# Interpolate to resize
new_embeddings = torch.nn.functional.interpolate(
    old_embeddings.unsqueeze(0).transpose(1, 2),  # Add batch dimension for interpolation
    size=new_seq_length,  # New sequence length
    mode="linear",
    align_corners=False,
).squeeze(0).transpose(1, 0)  # Remove batch dimension and revert dimensions

# Update the embeddings in the model
llm.transformer.wpe.weight.data = new_embeddings

# Verify changes
print(f"Updated max_position_embeddings: {llm.config.max_position_embeddings}")
print(f"Positional embeddings shape: {llm.transformer.wpe.weight.shape}")

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")

print("LLM and Tokenizer Ready")

projection_head = nn.Linear(768, 768).to(device)
print("Projection Head Ready")

  return self.fget.__get__(instance, owner)()


Vision Encoder Ready




Updated max_position_embeddings: 4096
Positional embeddings shape: torch.Size([4096, 768])
LLM and Tokenizer Ready
Projection Head Ready


In [4]:
from transformers import AutoTokenizer

# Path to the saved tokenizer
load_path = "./exported_llava_tokenizer"

# Load the tokenizer
llava_tokenizer = AutoTokenizer.from_pretrained(load_path)

print("Tokenizer loaded successfully")

Tokenizer loaded successfully


In [5]:
tiny_llava = TinyLLAVA(vision_encoder, projection_head, llm, llava_tokenizer, device=device).to(device)

In [6]:
save_dir = "./LLAVA_KD_RESULTS"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

def find_latest_checkpoint(directory):
    checkpoints = [f for f in os.listdir(directory) if f.startswith("tiny_llava_epoch_") and f.endswith(".pth")]
    if not checkpoints:
        return None
    checkpoints.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
    return os.path.join(directory, checkpoints[-1])

# Load the latest checkpoint if it exists
latest_checkpoint = find_latest_checkpoint(save_dir)
start_epoch = 0
if latest_checkpoint:
    print(f"Loading from checkpoint: {latest_checkpoint}")
    tiny_llava.load_state_dict(torch.load(latest_checkpoint, map_location=torch.device('cpu')))
    start_epoch = int(latest_checkpoint.split("_")[-1].split(".")[0])
    print(f"Resuming from epoch {start_epoch + 1}")

Loading from checkpoint: ./LLAVA_KD_RESULTS/tiny_llava_epoch_20.pth
Resuming from epoch 21


In [7]:
def vqa_pipeline(model, image_path, question, tokenizer, device="cpu"):
    # Load and preprocess image
    transform = transforms.Compose([
        transforms.Resize((336, 336)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image = Image.open(image_path).convert("RGB")
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    # Tokenize the question
    inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True).to(device)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    # Forward pass
    with torch.no_grad():
        outputs = model(image=image_tensor, input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits

    # Decode the output
    predicted_ids = torch.argmax(logits, dim=-1)
    answer = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
    
    return answer


In [8]:
# Example usage
image_path = "./car.jpg"
question = "What is the object?"

answer = vqa_pipeline(
    model=tiny_llava,
    image_path=image_path,
    question=question,
    tokenizer=tiny_llava.tokenizer,
    device=device
)

print(f"Question: {question}")
print(f"Answer: {answer}")


Question: What is the object?
Answer: sierp sierp Rights meaning in



# CLIP TIMING

In [9]:
import torch
import time
from PIL import Image
from transformers import CLIPProcessor, CLIPModel

# Step 1: Load CLIP Model and Processor
print("Loading CLIP model...")
start_time = time.time()
model_name = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_name).to('cpu')
processor = CLIPProcessor.from_pretrained(model_name)
load_time = time.time() - start_time
print(f"Model loaded in {load_time:.2f} seconds.")

# Step 2: Prepare a Sample Image and Text for Inference
image_path = "./car.jpg"
image = Image.open(image_path).convert("RGB")
text = "a photo of a cat"

# Preprocess inputs
inputs = processor(text=[text], images=[image], return_tensors="pt", padding=True)
inputs = {key: val.to('cpu') for key, val in inputs.items()}

# Step 3: Measure Average Inference Time
print("Running CLIP inference benchmark...")
num_runs = 100
total_time = 0

for _ in range(num_runs):
    start_time = time.time()
    outputs = model(**inputs)
    total_time += time.time() - start_time

average_inference_time = total_time / num_runs
print(f"Average CLIP inference time over {num_runs} runs: {average_inference_time:.4f} seconds.")


Loading CLIP model...
Model loaded in 3.80 seconds.
Running CLIP inference benchmark...
Average CLIP inference time over 100 runs: 0.0690 seconds.


In [11]:
import torch
import time
from transformers import AutoTokenizer, AutoModelForCausalLM

# We ran this twice to ensure warmup

print("Loading LLaMA-7B model...")
start_time = time.time()
model_name = "huggyllama/llama-7b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu")  # Force CPU usage
load_time = time.time() - start_time
print(f"Model loaded in {load_time:.2f} seconds.")


text = "Once upon a time in a faraway land, there lived a wise old owl."  # Example input
inputs = tokenizer(text, return_tensors="pt").to("cpu")


print("Running inference...")
inference_times = []

for _ in range(10):  # Run inference 10 times
    start_time = time.time()
    outputs = model.generate(**inputs, max_length=50)  # Generate text with a max length of 50 tokens
    inference_time = time.time() - start_time
    inference_times.append(inference_time)

average_inference_time = sum(inference_times) / len(inference_times)
print(f"Average inference time over 10 runs: {average_inference_time:.2f} seconds.")

generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated text:")
print(generated_text)


Loading LLaMA-7B model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Model loaded in 2.62 seconds.
Running inference...
Average inference time over 10 runs: 14.75 seconds.
Generated text:
Once upon a time in a faraway land, there lived a wise old owl. He was the wisest of all the owls in the land. He was also the oldest. He had lived a long time and had seen many things
