In [1]:
import torch
import torch.nn as nn
from torchvision import models

import os
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load Instruction Fine Tuning Data

In [2]:
import json

with open('./llava_instruct_150k.json', 'r') as f:
    data = json.load(f)

In [3]:
import os
import json
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

class LLaVADataset(Dataset):
    def __init__(self, json_path, image_dir, max_length=142, limit=None):
        with open(json_path, 'r') as f:
            self.data = json.load(f)
        
        # Limit the dataset to the specified number of samples
        if limit is not None:
            self.data = self.data[10000:10000+limit]
        
        self.image_dir = image_dir
        self.max_length = max_length
        self.image_transform = transforms.Compose([
            transforms.Resize((336, 336)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_path = os.path.join(self.image_dir, item['image'])
        image = Image.open(image_path).convert('RGB')
        image = self.image_transform(image)

        # Combine all conversation texts into a single string
        conversation_text = ""
        for conv in item['conversations']:
            if conv['from'] == 'human':
                conversation_text += f"Human: {conv['value']}\n"
            elif conv['from'] == 'gpt':
                conversation_text += f"Assistant: {conv['value']}\n"

        return {
            'image': image,
            'conversation_text': conversation_text
        }

# Usage
json_path = './llava_instruct_150k.json'
image_dir = './coco/train2017/'
tiny_dataset = LLaVADataset(json_path, image_dir, limit=5000)
dataloader = DataLoader(tiny_dataset, batch_size=8, shuffle=False)

# Teacher Model

In [5]:
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path
from llava.eval.run_llava import eval_model

model_path = "liuhaotian/llava-v1.5-7b"

llava_tokenizer, llava_model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=get_model_name_from_path(model_path),
    load_in_8bit=True
)
device = "cuda"
llava_model.model = llava_model.model.to(device)
llava_model.model = llava_model.model.to(torch.float16)

for param in llava_model.parameters():
    param.requires_grad = False

You are using a model of type llava to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
# Save the tokenizer to a directory
save_path = "./exported_llava_tokenizer"
llava_tokenizer.save_pretrained(save_path)

print(f"Tokenizer saved to {save_path}")

Tokenizer saved to ./exported_llava_tokenizer


In [7]:
import os
print(os.listdir(save_path))

['tokenizer_config.json', 'special_tokens_map.json', 'tokenizer.model']


# Student Model

In [8]:
class TinyLLAVA(nn.Module):
    def __init__(self, vision_encoder, projection_head, text_decoder, tokenizer, max_seq_length=4096, device="cuda"):
        super(TinyLLAVA, self).__init__()
        self.vision_encoder = vision_encoder
        self.projection_head = projection_head
        self.text_decoder = text_decoder
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length

        self.device = device

        self.vision_encoder.to(device)
        self.projection_head.to(device)
        self.text_decoder.to(device)
        for param in self.vision_encoder.parameters():
            param.requires_grad = False

    def forward(self, image, input_ids, attention_mask):
        # Extract visual features
        with torch.no_grad():
            visual_features = self.vision_encoder(image)  # Shape: (batch_size, vision_feature_dim)
    
        # Project visual features to text embedding space
        projected_features = self.projection_head(visual_features).to(self.device)  # Move to the same device
    
        # Embed input tokens
        token_embeddings = self.text_decoder.transformer.wte(input_ids).to(self.device)  # Shape: (batch_size, seq_len, embedding_dim)
    
        # Combine visual features with token embeddings
        combined_embeddings = torch.cat(
            [projected_features.unsqueeze(1), token_embeddings], dim=1
        ).to(self.device)
    
        # Adjust attention mask to include visual tokens
        _ones = torch.ones((attention_mask.size(0), 1)).to(self.device)
        extended_attention_mask = torch.cat(
            [_ones, attention_mask], dim=1
        ).to(self.device)
    
        # Truncate combined embeddings and attention mask to max_seq_length if needed
        # if combined_embeddings.size(1) > self.max_seq_length:
        #     # print("T")
        #     combined_embeddings = combined_embeddings[:, :self.max_seq_length]
        #     extended_attention_mask = extended_attention_mask[:, :self.max_seq_length]
    
        # Forward pass through the text decoder
        outputs = self.text_decoder(
            inputs_embeds=combined_embeddings,
            attention_mask=extended_attention_mask
        )
        outputs.logits = outputs.logits[:, 1:, :]  # Get rid of distillgpt input token
        return outputs



In [9]:
vision_encoder = models.mobilenet_v3_small()
vision_encoder.classifier[-1] = torch.nn.Linear(vision_encoder.classifier[-1].in_features, 768)

vision_encoder.load_state_dict(torch.load('./mobilenetv3_student_model.pth', map_location=torch.device('cpu'), weights_only=True))
vision_encoder.eval()

for param in vision_encoder.parameters():
    param.requires_grad = False

print("Vision Encoder Ready")


llm = AutoModelForCausalLM.from_pretrained("distilgpt2").to(device)
llm.lm_head = nn.Linear(in_features=768, out_features=32000) # llava out features

# Set the new max position embeddings
llm.config.max_position_embeddings = 4096  # Update the max position embeddings

# Resize positional embeddings (wpe) to match new max_position_embeddings
old_embeddings = llm.transformer.wpe.weight.data  # Original embeddings
new_seq_length = llm.config.max_position_embeddings  # Desired sequence length

# Interpolate to resize
new_embeddings = torch.nn.functional.interpolate(
    old_embeddings.unsqueeze(0).transpose(1, 2),  # Add batch dimension for interpolation
    size=new_seq_length,  # New sequence length
    mode="linear",
    align_corners=False,
).squeeze(0).transpose(1, 0)  # Remove batch dimension and revert dimensions

# Update the embeddings in the model
llm.transformer.wpe.weight.data = new_embeddings

# Verify changes
print(f"Updated max_position_embeddings: {llm.config.max_position_embeddings}")
print(f"Positional embeddings shape: {llm.transformer.wpe.weight.shape}")

tokenizer = AutoTokenizer.from_pretrained("distilgpt2")

print("LLM and Tokenizer Ready")

projection_head = nn.Linear(768, 768).to(device)
print("Projection Head Ready")

  return self.fget.__get__(instance, owner)()


Vision Encoder Ready
Updated max_position_embeddings: 4096
Positional embeddings shape: torch.Size([4096, 768])
LLM and Tokenizer Ready
Projection Head Ready


In [10]:
tiny_llava = TinyLLAVA(vision_encoder, projection_head, llm, llava_tokenizer).to("cuda")

In [11]:
for x in dataloader:
    print(x.keys())
    break

dict_keys(['image', 'conversation_text'])


# Performing Knowledge Distillation

In [12]:
import torch
from torch.nn import KLDivLoss
from torch.optim import Adam
from torch.nn.functional import log_softmax, softmax, cosine_similarity
from tqdm import tqdm

def train_one_epoch(tiny_llava, llava_model, dataloader, optimizer, temperature=1.0, device="cuda"):
    """
    Perform one epoch of knowledge distillation training.

    Args:
        tiny_llava: The student model.
        llava_model: The teacher model (frozen).
        dataloader: DataLoader for the training data.
        optimizer: Optimizer for the student model.
        temperature: Temperature for knowledge distillation.

    Returns:
        Average loss for the epoch.
    """
    tiny_llava.train()
    kl_div_loss = KLDivLoss(reduction="batchmean")
    total_loss = 0.0
    total_cosine_similarity = 0.0

    # Initialize tqdm progress bar
    progress_bar = tqdm(dataloader, desc="Training Batch", leave=True)
    for batch_idx, batch in enumerate(progress_bar):
        # Load batch data
        images = batch["image"].to(device)
        conversation_texts = batch["conversation_text"]

        # Tokenize conversation text
        inputs = tiny_llava.tokenizer(conversation_texts, return_tensors="pt", padding=True, truncation=True).to(tiny_llava.device)
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        
        batch_size = input_ids.shape[0]
        # Teacher model forward pass
        with torch.no_grad():
            teacher_outputs = llava_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                images=images.to(torch.float16)
            ).logits

        # Student model forward pass
        student_outputs = tiny_llava(image=images, input_ids=input_ids, attention_mask=attention_mask).logits

        # Compute KD loss
        student_log_probs = log_softmax(student_outputs / temperature, dim=-1)
        teacher_probs = softmax(teacher_outputs / temperature, dim=-1)
        loss = kl_div_loss(student_log_probs, teacher_probs) * (temperature ** 2)
    
        # print(student_outputs.view(-1).shape, student_outputs.view(-1).shape)
        # Compute cosine similarity
        cosine_sim = cosine_similarity(
            student_outputs.reshape(batch_size,-1), teacher_outputs.reshape(batch_size,-1), dim=1
        ).mean().item()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update metrics
        total_loss += loss.item()
        total_cosine_similarity += cosine_sim

        # Update progress bar
        avg_loss = total_loss / (batch_idx + 1)
        avg_cosine_similarity = total_cosine_similarity / (batch_idx + 1)
        progress_bar.set_postfix({
            "Avg Loss": f"{avg_loss:.4f}",
            "Avg Cosine Sim": f"{avg_cosine_similarity:.4f}"
        })

    return total_loss / len(dataloader), total_cosine_similarity/len(dataloader)

In [13]:
import torch
from torch.nn import MSELoss
from torch.optim import Adam
from torch.nn.functional import cosine_similarity
from tqdm import tqdm

def train_one_epoch(tiny_llava, llava_model, dataloader, optimizer, temperature=1.0, device="cuda"):
    """
    Perform one epoch of knowledge distillation training.

    Args:
        tiny_llava: The student model.
        llava_model: The teacher model (frozen).
        dataloader: DataLoader for the training data.
        optimizer: Optimizer for the student model.
        temperature: Temperature for knowledge distillation.

    Returns:
        Average loss for the epoch.
    """
    tiny_llava.train()
    mse_loss_fn = MSELoss()
    total_mse_loss = 0.0
    total_cosine_similarity = 0.0

    # Initialize tqdm progress bar
    progress_bar = tqdm(dataloader, desc="Training Batch", leave=True)
    for batch_idx, batch in enumerate(progress_bar):
        # Load batch data
        images = batch["image"].to(device)
        conversation_texts = batch["conversation_text"]

        # Tokenize conversation text
        inputs = tiny_llava.tokenizer(conversation_texts, return_tensors="pt", padding=True, truncation=True).to(tiny_llava.device)
        input_ids = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        
        batch_size = input_ids.shape[0]
        # Teacher model forward pass
        with torch.no_grad():
            teacher_outputs = llava_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                images=images.to(torch.float16)
            ).logits

        # Student model forward pass
        student_outputs = tiny_llava(image=images, input_ids=input_ids, attention_mask=attention_mask).logits

        # Compute MSE loss
        mse_loss = mse_loss_fn(student_outputs, teacher_outputs)

        # Compute cosine similarity
        cosine_sim = cosine_similarity(
            student_outputs.reshape(batch_size, -1), teacher_outputs.reshape(batch_size, -1), dim=1
        ).mean().item()

        # Backpropagation
        optimizer.zero_grad()
        mse_loss.backward()
        optimizer.step()

        # Update metrics
        total_mse_loss += mse_loss.item()
        total_cosine_similarity += cosine_sim

        # Update progress bar
        avg_mse_loss = total_mse_loss / (batch_idx + 1)
        avg_cosine_similarity = total_cosine_similarity / (batch_idx + 1)
        progress_bar.set_postfix({
            "Avg MSE Loss": f"{avg_mse_loss:.4f}",
            "Avg Cosine Sim": f"{avg_cosine_similarity:.4f}"
        })

    return total_mse_loss / len(dataloader), total_cosine_similarity / len(dataloader)

In [14]:
save_dir = "./LLAVA_KD_RESULTS"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

def find_latest_checkpoint(directory):
    checkpoints = [f for f in os.listdir(directory) if f.startswith("tiny_llava_epoch_") and f.endswith(".pth")]
    if not checkpoints:
        return None
    checkpoints.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
    return os.path.join(directory, checkpoints[-1])

# Load the latest checkpoint if it exists
latest_checkpoint = find_latest_checkpoint(save_dir)
start_epoch = 0
if latest_checkpoint:
    print(f"Loading from checkpoint: {latest_checkpoint}")
    tiny_llava.load_state_dict(torch.load(latest_checkpoint))
    start_epoch = int(latest_checkpoint.split("_")[-1].split(".")[0])
    print(f"Resuming from epoch {start_epoch + 1}")

Loading from checkpoint: ./LLAVA_KD_RESULTS/tiny_llava_epoch_7.pth
Resuming from epoch 8


In [15]:
import torch.optim.lr_scheduler
log_file = "llava_log.txt"
num_epochs = 20
initial_lr = 5e-4
optimizer = torch.optim.Adam(tiny_llava.parameters(), lr=initial_lr/2)

# Define LR scheduler: Decay LR by 0.5 every 2 epochs
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.5)

# Loop through the specified number of epochs
for _ in range(num_epochs):
    epoch = _
    print(f"Epoch {epoch + 1}/{num_epochs}")

    # Train for one epoch and retrieve the average loss
    avg_loss, avg_cos = train_one_epoch(
        tiny_llava=tiny_llava,
        llava_model=llava_model,
        dataloader=dataloader,
        optimizer=optimizer,
        temperature=0.5,
        device="cuda"
    )
    with open(log_file, "a") as log:
        log.write(f"Epoch {epoch + start_epoch + 1} completed. Average Loss: {avg_loss:.4f} Average Cos: {avg_cos:.4f}\n")

    checkpoint_path = os.path.join(save_dir, f"tiny_llava_epoch_{epoch + start_epoch + 1}.pth")
    torch.save(tiny_llava.state_dict(), checkpoint_path)
    with open(log_file, "a") as log:
        log.write(f"Model checkpoint saved to {checkpoint_path}\n")

    # Step the LR scheduler
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]
    print(f"Updated learning rate: {current_lr:.6f}")
    with open(log_file, "a") as log:
        log.write(f"Updated learning rate: {current_lr:.6f}\n")


Epoch 1/20


Training Batch: 100%|██████████| 625/625 [10:48<00:00,  1.04s/it, Avg MSE Loss=1.4292, Avg Cosine Sim=0.8988]


Updated learning rate: 0.000250
Epoch 2/20


Training Batch: 100%|██████████| 625/625 [10:43<00:00,  1.03s/it, Avg MSE Loss=1.3369, Avg Cosine Sim=0.9058]


Updated learning rate: 0.000125
Epoch 3/20


Training Batch: 100%|██████████| 625/625 [10:43<00:00,  1.03s/it, Avg MSE Loss=1.2692, Avg Cosine Sim=0.9109]


Updated learning rate: 0.000125
Epoch 4/20


Training Batch: 100%|██████████| 625/625 [10:42<00:00,  1.03s/it, Avg MSE Loss=1.2343, Avg Cosine Sim=0.9136]


Updated learning rate: 0.000063
Epoch 5/20


Training Batch: 100%|██████████| 625/625 [10:42<00:00,  1.03s/it, Avg MSE Loss=1.2017, Avg Cosine Sim=0.9160]


Updated learning rate: 0.000063
Epoch 6/20


Training Batch: 100%|██████████| 625/625 [10:42<00:00,  1.03s/it, Avg MSE Loss=1.1859, Avg Cosine Sim=0.9172]


Updated learning rate: 0.000031
Epoch 7/20


Training Batch: 100%|██████████| 625/625 [10:42<00:00,  1.03s/it, Avg MSE Loss=1.1700, Avg Cosine Sim=0.9184]


Updated learning rate: 0.000031
Epoch 8/20


Training Batch: 100%|██████████| 625/625 [10:42<00:00,  1.03s/it, Avg MSE Loss=1.1627, Avg Cosine Sim=0.9190]


Updated learning rate: 0.000016
Epoch 9/20


Training Batch: 100%|██████████| 625/625 [10:42<00:00,  1.03s/it, Avg MSE Loss=1.1541, Avg Cosine Sim=0.9196]


Updated learning rate: 0.000016
Epoch 10/20


Training Batch: 100%|██████████| 625/625 [10:42<00:00,  1.03s/it, Avg MSE Loss=1.1506, Avg Cosine Sim=0.9199]


Updated learning rate: 0.000008
Epoch 11/20


Training Batch: 100%|██████████| 625/625 [10:42<00:00,  1.03s/it, Avg MSE Loss=1.1462, Avg Cosine Sim=0.9202]


Updated learning rate: 0.000008
Epoch 12/20


Training Batch: 100%|██████████| 625/625 [10:41<00:00,  1.03s/it, Avg MSE Loss=1.1442, Avg Cosine Sim=0.9204]


Updated learning rate: 0.000004
Epoch 13/20


Training Batch: 100%|██████████| 625/625 [10:41<00:00,  1.03s/it, Avg MSE Loss=1.1423, Avg Cosine Sim=0.9205]


Updated learning rate: 0.000004
Epoch 14/20


Training Batch:  63%|██████▎   | 393/625 [06:42<03:57,  1.02s/it, Avg MSE Loss=1.1474, Avg Cosine Sim=0.9203]


KeyboardInterrupt: 