In [1]:
import torch
from torchvision import transforms
from PIL import Image
from transformers import AutoTokenizer

import torch.nn as nn
import torch
import torch.nn as nn
from torchvision import models

import os
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
class TinyLLAVA(torch.nn.Module):
    def __init__(self, vision_encoder, projection_head, text_decoder, tokenizer, max_seq_length=4096, device="cuda"):
        super(TinyLLAVA, self).__init__()
        self.vision_encoder = vision_encoder
        self.projection_head = projection_head
        self.text_decoder = text_decoder
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length

        self.device = device

        self.vision_encoder.to(device)
        self.projection_head.to(device)
        self.text_decoder.to(device)
        for param in self.vision_encoder.parameters():
            param.requires_grad = False

    def forward(self, image, input_ids, attention_mask):
        # Extract visual features
        with torch.no_grad():
            visual_features = self.vision_encoder(image)  # Shape: (batch_size, vision_feature_dim)
    
        # Project visual features to text embedding space
        projected_features = self.projection_head(visual_features).to(self.device)  # Move to the same device
    
        # Embed input tokens
        token_embeddings = self.text_decoder.transformer.wte(input_ids).to(self.device)  # Shape: (batch_size, seq_len, embedding_dim)
    
        # Combine visual features with token embeddings
        combined_embeddings = torch.cat(
            [projected_features.unsqueeze(1), token_embeddings], dim=1
        ).to(self.device)  # Shape: (batch_size, seq_len + 1, embedding_dim)
    
        # Adjust attention mask to include visual tokens
        _ones = torch.ones((attention_mask.size(0), 1), dtype=attention_mask.dtype).to(self.device)
        extended_attention_mask = torch.cat(
            [_ones, attention_mask], dim=1
        ).to(self.device)  # Shape: (batch_size, seq_len + 1)
    
        # Truncate combined embeddings and attention mask to max_seq_length if needed
        if combined_embeddings.size(1) > self.max_seq_length:
            combined_embeddings = combined_embeddings[:, :self.max_seq_length]
            extended_attention_mask = extended_attention_mask[:, :self.max_seq_length]
    
        # Forward pass through the text decoder
        outputs = self.text_decoder(
            inputs_embeds=combined_embeddings,
            attention_mask=extended_attention_mask
        )
    
        # Remove the first token (vision embedding) from logits during output processing
        outputs.logits = outputs.logits[:, 1:, :]  # Shape: (batch_size, seq_len, vocab_size)

        return outputs




In [3]:
device = "cuda"
vision_encoder = models.mobilenet_v3_small()
vision_encoder.classifier[-1] = torch.nn.Linear(vision_encoder.classifier[-1].in_features, 768)

vision_encoder.load_state_dict(torch.load('./mobilenetv3_student_model.pth', map_location=torch.device(device), weights_only=True))
vision_encoder.eval()

for param in vision_encoder.parameters():
    param.requires_grad = False

print("Vision Encoder Ready")


llm = AutoModelForCausalLM.from_pretrained("distilgpt2").to(device)
llm.lm_head = nn.Linear(in_features=768, out_features=32000) # llava out features

# Set the new max position embeddings
llm.config.max_position_embeddings = 4096  # Update the max position embeddings

# Resize positional embeddings (wpe) to match new max_position_embeddings
old_embeddings = llm.transformer.wpe.weight.data  # Original embeddings
new_seq_length = llm.config.max_position_embeddings  # Desired sequence length

# Interpolate to resize
new_embeddings = torch.nn.functional.interpolate(
    old_embeddings.unsqueeze(0).transpose(1, 2),  # Add batch dimension for interpolation
    size=new_seq_length,  # New sequence length
    mode="linear",
    align_corners=False,
).squeeze(0).transpose(1, 0)  # Remove batch dimension and revert dimensions

# Update the embeddings in the model
llm.transformer.wpe.weight.data = new_embeddings

# Verify changes
print(f"Updated max_position_embeddings: {llm.config.max_position_embeddings}")
print(f"Positional embeddings shape: {llm.transformer.wpe.weight.shape}")

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")

print("LLM and Tokenizer Ready")

projection_head = nn.Linear(768, 768).to(device)
print("Projection Head Ready")

  return self.fget.__get__(instance, owner)()


Vision Encoder Ready




Updated max_position_embeddings: 4096
Positional embeddings shape: torch.Size([4096, 768])
LLM and Tokenizer Ready
Projection Head Ready


In [4]:
from transformers import AutoTokenizer

# Path to the saved tokenizer
load_path = "./exported_llava_tokenizer"

# Load the tokenizer
llava_tokenizer = AutoTokenizer.from_pretrained(load_path)

print("Tokenizer loaded successfully")

Tokenizer loaded successfully


In [5]:
tiny_llava = TinyLLAVA(vision_encoder, projection_head, llm, llava_tokenizer, device=device).to(device)

In [6]:
save_dir = "./LLAVA_KD_RESULTS"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

def find_latest_checkpoint(directory):
    checkpoints = [f for f in os.listdir(directory) if f.startswith("tiny_llava_epoch_") and f.endswith(".pth")]
    if not checkpoints:
        return None
    checkpoints.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
    return os.path.join(directory, checkpoints[-1])

# Load the latest checkpoint if it exists
latest_checkpoint = find_latest_checkpoint(save_dir)
start_epoch = 0
if latest_checkpoint:
    print(f"Loading from checkpoint: {latest_checkpoint}")
    tiny_llava.load_state_dict(torch.load(latest_checkpoint, map_location=torch.device('cpu')))
    start_epoch = int(latest_checkpoint.split("_")[-1].split(".")[0])
    print(f"Resuming from epoch {start_epoch + 1}")

Loading from checkpoint: ./LLAVA_KD_RESULTS/tiny_llava_epoch_20.pth
Resuming from epoch 21


In [7]:
def vqa_pipeline(model, image_path, question, tokenizer, device=device):
    # Load and preprocess image
    transform = transforms.Compose([
        transforms.Resize((336, 336)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    image = Image.open(image_path).convert("RGB")
    image_tensor = transform(image).unsqueeze(0).to(device)
    
    # Tokenize the question
    inputs = tokenizer(question, return_tensors="pt", padding=True, truncation=True).to(device)
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]

    # Forward pass
    with torch.no_grad():
        outputs = model(image=image_tensor, input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits

    # Decode the output
    predicted_ids = torch.argmax(logits, dim=-1)
    answer = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
    
    return answer


In [8]:
# Example usage
image_path = "./car.jpg"
question = "What is the object?"

answer = vqa_pipeline(
    model=tiny_llava,
    image_path=image_path,
    question=question,
    tokenizer=tiny_llava.tokenizer,
    device=device
)

print(f"Question: {question}")
print(f"Answer: {answer}")


Question: What is the object?
Answer: sierp sierp Rights meaning in



# Downstream

In [9]:
import torch
import os
from torch.utils.data import DataLoader, Dataset
from transformers import AutoProcessor
from PIL import Image
from torchvision import transforms
from datasets import load_dataset

print("Downloading COCO subset...")
coco_subset = load_dataset("phiyodr/coco2017", split="train")

transform = transforms.Compose([
            transforms.Resize((336, 336)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

class CocoSubsetDataset(Dataset):
    def __init__(self, dataset, transform=None, root="./coco", max_samples=None):
        self.dataset = dataset
        self.transform = transform
        self.root = root
        
        # Limit the number of samples if max_samples is specified
        if max_samples is not None:
            self.dataset = self.dataset.select(range(min(max_samples, len(dataset))))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        file = os.path.join(self.root, item["file_name"])
        image = Image.open(file).convert("RGB")  # Load image
        caption = item["captions"][0]  # Caption for the image
        if self.transform:
            image = self.transform(image)
        return image, "<image>\n"+caption

max_samples = 10000  # Change this to the number of samples you want
custom_dataset = CocoSubsetDataset(coco_subset, transform, max_samples=max_samples)

batch_size = 4
dataloader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=True)

for idx, (images, captions) in enumerate(dataloader):
    print(f"Batch {idx + 1}")
    print("Images shape:", images.shape)  # [batch_size, 3, 224, 224]
    print("Captions:", captions)  # List of captions
    break  # Test with one batch

Downloading COCO subset...
Batch 1
Images shape: torch.Size([4, 3, 336, 336])
Captions: ('<image>\nThe dog is swimming in the water with his Frisbee in his mouth. ', '<image>\nTHERE ARE MOTOR BIKES THAT ARE PARKED ON THE STREET', '<image>\nA black duck floating in a wavy pond.', '<image>\nA traffic light turns green on the corner of a city street.')


In [10]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AdamW
from tqdm import tqdm
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tiny_llava.to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = AdamW(tiny_llava.parameters(), lr=5e-5)

log_file = "training_log.txt"
with open(log_file, "w") as log:
    log.write("Training Started\n")

def train_one_epoch(model, dataloader, optimizer, loss_fn, epoch, device, log_file):
    model.train()
    epoch_loss = 0

    # Progress bar setup
    with tqdm(dataloader, desc=f"Epoch {epoch + 1}", unit="batch") as progress_bar:
        for batch_idx, (images, captions) in enumerate(progress_bar):
            # Preprocess inputs
            images = images.to(device)
            captions = list(captions)
            # print(images.shape)

            # Tokenize captions
            inputs = model.tokenizer(
                captions,
                return_tensors="pt",
                padding=True,
                truncation=True
            )
            input_ids = inputs["input_ids"].to(device)
            attention_mask = inputs["attention_mask"].to(device)
            
            labels = input_ids.clone()
            
            # Forward pass
            outputs = model(images, input_ids, attention_mask)
            logits = outputs.logits
            
            # Shift labels and logits to predict next token
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            
            # print(logits.shape, labels.shape)
            # # Calculate loss
            loss = loss_fn(logits.reshape(-1, logits.size(-1)), labels.reshape(-1))
            epoch_loss += loss.item()

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            progress_bar.set_postfix(loss=loss.item())

            if (batch_idx + 1) % 10 == 0:
                with open(log_file, "a") as log:
                    log.write(f"Epoch {epoch + 1}, Batch {batch_idx + 1}/{len(dataloader)}, Loss: {loss.item():.4f}\n")

    avg_loss = epoch_loss / len(dataloader)
    with open(log_file, "a") as log:
        log.write(f"Epoch {epoch + 1} Average Loss: {avg_loss:.4f}\n")
    return avg_loss

epochs = 3
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

for epoch in range(epochs):
    avg_loss = train_one_epoch(tiny_llava, dataloader, optimizer, loss_fn, epoch, device, log_file)

    # Save checkpoint
    checkpoint_path = os.path.join(checkpoint_dir, f"tiny_llava_epoch_{epoch + 1}.pt")
    torch.save({
        "epoch": epoch + 1,
        "model_state_dict": tiny_llava.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": avg_loss,
    }, checkpoint_path)

    with open(log_file, "a") as log:
        log.write(f"Checkpoint saved at {checkpoint_path}\n")

Epoch 1: 100%|██████████| 2500/2500 [03:16<00:00, 12.71batch/s, loss=0.163]  
Epoch 2: 100%|██████████| 2500/2500 [02:55<00:00, 14.22batch/s, loss=0.00295] 
Epoch 3: 100%|██████████| 2500/2500 [02:56<00:00, 14.15batch/s, loss=0.000628]


In [11]:
save_dir = "./LLAVA_KD_RESULTS"
checkpoint_path = os.path.join(save_dir, f"downstream_tiny_llava_epoch_{epoch + start_epoch + 1}.pth")
torch.save(tiny_llava.state_dict(), checkpoint_path)
with open(log_file, "a") as log:
    log.write(f"Model checkpoint saved to {checkpoint_path}\n")


# Eval

In [12]:
import torch
import os
from torch.utils.data import DataLoader, Dataset
from transformers import AutoProcessor
from PIL import Image
from torchvision import transforms
from datasets import load_dataset

print("Downloading COCO subset...")
coco_subset = load_dataset("phiyodr/coco2017", split="train")

transform = transforms.Compose([
            transforms.Resize((336, 336)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

class CocoSubsetDataset(Dataset):
    def __init__(self, dataset, transform=None, root="./coco", max_samples=None):
        self.dataset = dataset
        self.transform = transform
        self.root = root
        
        if max_samples is not None:
            self.dataset = self.dataset.select(range(max_samples, max_samples*2))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        file = os.path.join(self.root, item["file_name"])
        image = Image.open(file).convert("RGB")  # Load image
        caption = item["captions"][0]  # Caption for the image
        if self.transform:
            image = self.transform(image)
        return image, "<image>\n"+caption

# Ensure next 10k samples are selected for evaluation
max_samples = 10000
custom_dataset = CocoSubsetDataset(coco_subset, transform, max_samples=max_samples)

batch_size = 4
dataloader = DataLoader(custom_dataset, batch_size=batch_size, shuffle=True)

for idx, (images, captions) in enumerate(dataloader):
    print(f"Batch {idx + 1}")
    print("Images shape:", images.shape)  # [batch_size, 3, 224, 224]
    print("Captions:", captions)  # List of captions
    break  # Test with one batch

Downloading COCO subset...
Batch 1
Images shape: torch.Size([4, 3, 336, 336])
Captions: ('<image>\nA silver spoon sitting on top of an unknown surface.', '<image>\nA young boy and girl sitting and eating at a childs table with a dog nearby.', '<image>\nA woman with a tennis ball is next to a child', '<image>\nA street sign marking the intersection of Roberts and Cedar Streets')


In [13]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from tqdm import tqdm
import numpy as np
import json
from PIL import Image


def evaluate(model, dataloader, loss_fn, device, tokenizer, num_samples=5):
    model.eval()
    total_loss = 0
    all_predictions = []
    all_references = []
    
    with torch.no_grad():
        for images, captions in tqdm(dataloader, desc="Evaluating"):
            images = images.to(device)
            captions = list(captions)
            
            inputs = tokenizer(captions, return_tensors="pt", padding=True, truncation=True)
            input_ids = inputs["input_ids"].to(device)
            attention_mask = inputs["attention_mask"].to(device)
            
            outputs = model(images, input_ids, attention_mask)
            logits = outputs.logits
            
            # Calculate loss
            loss = loss_fn(logits.reshape(-1, logits.size(-1)), input_ids.reshape(-1))
            total_loss += loss.item()
            
            # Generate predictions
            predictions = torch.argmax(logits, dim=-1)
            all_predictions.extend(tokenizer.batch_decode(predictions, skip_special_tokens=True))
            all_references.extend(captions)
    
    # Calculate perplexity
    avg_loss = total_loss / len(dataloader)
    perplexity = np.exp(avg_loss)
    
    # Sample outputs
    sample_outputs = list(zip(all_references, all_predictions))[:num_samples]
    
    return perplexity, sample_outputs


loss_fn = nn.CrossEntropyLoss()
perplexity, sample_outputs = evaluate(tiny_llava, dataloader, loss_fn, device, llava_tokenizer)

log_file = "evaluation_log.txt"
with open(log_file, "w") as log:
    log.write(f"Evaluation Results:\n")
    log.write(f"Perplexity: {perplexity:.4f}\n\n")
    log.write("Sample Outputs:\n")
    for reference, prediction in sample_outputs:
        log.write(f"Reference: {reference}\n")
        log.write(f"Prediction: {prediction}\n\n")
        print(f"Reference: {reference}\n")
        print(f"Prediction: {prediction}\n\n")

print(f"Evaluation complete. Results saved to {log_file}")

Evaluating: 100%|██████████| 2500/2500 [02:45<00:00, 15.14it/s]

Reference: <image>
A living room area that has a couch, table, and lots of photos on the wall.

Prediction: <image>
A living room area that has a couch, table, and lots of photos on the wall.


Reference: <image>
A group of young people walking across a snow covered field.

Prediction: <image>
A group of young people walking across a snow covered field.


Reference: <image>
A small wooden cutting board and knife with a cut apple.

Prediction: <image>
A small wooden cutting board and knife with a cut apple.


Reference: <image>
An old woman sits on a bench and raises her hand

Prediction: <image>
An old woman sits on a bench and raises her hand


Reference: <image>
A woman in a wet suit surfs a wave.

Prediction: <image>
A woman in a wet suit surfs a wave.


Evaluation complete. Results saved to evaluation_log.txt





In [14]:
# Evaluate
perplexity, sample_outputs = evaluate(tiny_llava, dataloader, loss_fn, device, llava_tokenizer)

# Log results
log_file = "evaluation_log.txt"
with open(log_file, "w") as log:
    log.write(f"Evaluation Results:\n")
    log.write(f"Perplexity: {perplexity:.4f}\n\n")
    log.write("Sample Outputs:\n")
    for reference, prediction in sample_outputs:
        log.write(f"Reference: {reference}\n")
        log.write(f"Prediction: {prediction}\n\n")
        print(f"Reference: {reference}\n")
        print(f"Prediction: {prediction}\n\n")


Evaluating: 100%|██████████| 2500/2500 [02:06<00:00, 19.82it/s]

Reference: <image>
A group of giraffe eating food from a tree.

Prediction: <image>
A group of giraffe eating food from a tree.


Reference: <image>
A bunk bed sits next to an open window. 

Prediction: <image>
A bunk bed sits next to an open window. 


Reference: <image>
A group of cars, riding down the street.

Prediction: <image>
A group of cars, riding down the street.


Reference: <image>
A zebra is standing away from two adult giraffe.

Prediction: <image>
A zebra is standing away from two adult giraffe.


Reference: <image>
A bed with a wooden headboard next to a window.

Prediction: <image>
A bed with a wooden headboard next to a window.





