In [9]:
import torch
import torch.nn as nn
import time
from transformers import CLIPModel, AutoModelForCausalLM, AutoTokenizer
from torch.profiler import profile, ProfilerActivity
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Vision Encoder (Using CLIP)
class VisionEncoder(nn.Module):
    def __init__(self, clip_model_name="openai/clip-vit-base-patch32"):
        super(VisionEncoder, self).__init__()
        self.clip_model = CLIPModel.from_pretrained(clip_model_name)

    def forward(self, images):
        vision_outputs = self.clip_model.get_image_features(images)
        return vision_outputs

# Language Decoder (Using a Pretrained Language Model)
class LanguageDecoder(nn.Module):
    def __init__(self, language_model_name="gpt2"):
        super(LanguageDecoder, self).__init__()
        self.language_model = AutoModelForCausalLM.from_pretrained(language_model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(language_model_name)
        
        # Add a padding token
        if self.tokenizer.pad_token is None:
            self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
            self.language_model.resize_token_embeddings(len(self.tokenizer))

    def forward(self, text_inputs, attention_mask=None):
        outputs = self.language_model(input_ids=text_inputs, attention_mask=attention_mask, output_hidden_states=True)
        hidden_states = outputs.hidden_states[-1]  # Shape: (batch_size, sequence_length, hidden_size)
        return hidden_states

# Additional Layers
class IntermediateLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(IntermediateLayer, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        self.activation = nn.ReLU()

    def forward(self, x):
        return self.activation(self.linear(x))

# Combined Vision-Language Model
class LlavaModel(nn.Module):
    def __init__(self, vision_encoder, language_decoder, intermediate_layers=3):
        super(LlavaModel, self).__init__()
        self.vision_encoder = vision_encoder
        self.language_decoder = language_decoder

        vision_output_size = self.vision_encoder.clip_model.config.projection_dim
        language_output_size = self.language_decoder.language_model.config.n_embd

        self.intermediate_layers = nn.ModuleList([
            IntermediateLayer(vision_output_size + language_output_size, vision_output_size + language_output_size)
            for _ in range(intermediate_layers)
        ])

        self.fusion_layer = nn.Linear(vision_output_size + language_output_size, language_output_size)
        self.classification_layer = nn.Linear(language_output_size, 1)  # Final output for binary classification

    def forward(self, images, text_inputs):
        vision_features = self.vision_encoder(images)
        text_features = self.language_decoder(text_inputs).mean(dim=1)
        combined_features = torch.cat([vision_features, text_features], dim=-1)

        for layer in self.intermediate_layers:
            combined_features = layer(combined_features)

        fused_output = self.fusion_layer(combined_features)
        classification_output = self.classification_layer(fused_output)  # Shape: (batch_size, 1)
        return classification_output

    def calculate_flops_per_layer(self, batch_size):
        flops = []
        vision_output_size = self.vision_encoder.clip_model.config.projection_dim
        language_output_size = self.language_decoder.language_model.config.n_embd

        # Calculate FLOPs for each intermediate layer
        for layer in self.intermediate_layers:
            input_size = vision_output_size + language_output_size
            layer_flops = batch_size * input_size * input_size
            flops.append(layer_flops)

        return flops

    def prune_max_flop_layer(self, batch_size):
        flops = self.calculate_flops_per_layer(batch_size)
        if flops:
            max_flop_index = flops.index(max(flops))
            del self.intermediate_layers[max_flop_index]

# Function for Profiling
def profile_model(model, inputs):
    device = "CUDA" if torch.cuda.is_available() else "CPU"
    print(f"Profiling model on {device}...")
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], on_trace_ready=torch.profiler.tensorboard_trace_handler("./log")) as prof:
        model(*inputs)
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))

# Function to calculate FLOPs
def calculate_flops(model, inputs):
    flops = 0
    batch_size = inputs[0].size(0)

    # Vision FLOPs
    vision_output_size = model.vision_encoder.clip_model.config.projection_dim
    flops += batch_size * vision_output_size * inputs[0].size(2) * inputs[0].size(3)

    # Language FLOPs
    seq_len = inputs[1].size(1)
    language_output_size = model.language_decoder.language_model.config.n_embd
    flops += batch_size * seq_len * language_output_size * 2

    # Intermediate Layers FLOPs
    for layer in model.intermediate_layers:
        input_size = vision_output_size + language_output_size
        flops += batch_size * input_size * input_size

    # Fusion Layer FLOPs
    fusion_input_size = vision_output_size + language_output_size
    flops += batch_size * fusion_input_size * language_output_size

    # Classification Layer FLOPs
    flops += batch_size * language_output_size * 1

    return flops

# Toy Task for Lossy-ness Measurement
def train_on_toy_task(model, dataloader, criterion, optimizer, epochs=2):
    model.train()
    for epoch in range(epochs):
        for images, labels in dataloader:
            optimizer.zero_grad()
            images = images.cuda() if torch.cuda.is_available() else images
            labels = (labels % 2).float().unsqueeze(1).cuda() if torch.cuda.is_available() else (labels % 2).float().unsqueeze(1)  # Convert to binary and match output shape
            dummy_text = ["Dummy text input"] * images.size(0)
            encoded_text = model.language_decoder.tokenizer(dummy_text, return_tensors="pt", padding=True, truncation=True).input_ids
            encoded_text = encoded_text.cuda() if torch.cuda.is_available() else encoded_text

            outputs = model(images, encoded_text)  # Output shape: (batch_size, 1)
            loss = criterion(outputs, labels)  # Loss expects shape (batch_size, 1)
            loss.backward()
            optimizer.step()
    return model

# Function to measure accuracy
def calculate_accuracy(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.cuda() if torch.cuda.is_available() else images
            labels = (labels % 2).float().unsqueeze(1).cuda() if torch.cuda.is_available() else (labels % 2).float().unsqueeze(1)  # Convert to binary
            dummy_text = ["Dummy text input"] * images.size(0)
            encoded_text = model.language_decoder.tokenizer(dummy_text, return_tensors="pt", padding=True, truncation=True).input_ids
            encoded_text = encoded_text.cuda() if torch.cuda.is_available() else encoded_text

            outputs = torch.sigmoid(model(images, encoded_text))  # Sigmoid for binary classification
            predictions = (outputs > 0.5).float()
            correct += (predictions == labels).sum().item()
            total += labels.size(0)
    return correct / total

if __name__ == "__main__":
    if not torch.cuda.is_available():
        print("CUDA is not available. Falling back to CPU.")
    else:
        print(f"CUDA is available. Using device: {torch.cuda.get_device_name(0)}")

    vision_encoder = VisionEncoder()
    language_decoder = LanguageDecoder()
    llava_model = LlavaModel(vision_encoder, language_decoder, intermediate_layers=3)
    llava_model = llava_model.cuda() if torch.cuda.is_available() else llava_model

    dummy_images = torch.rand((2, 3, 224, 224)).cuda() if torch.cuda.is_available() else torch.rand((2, 3, 224, 224))
    dummy_text = ["What is in this image?", "Describe the objects."]
    tokenizer = language_decoder.tokenizer
    encoded_text = tokenizer(dummy_text, return_tensors="pt", padding=True, truncation=True).input_ids
    encoded_text = encoded_text.cuda() if torch.cuda.is_available() else encoded_text

    print("Profiling before pruning:")
    profile_model(llava_model, (dummy_images, encoded_text))
    print(f"FLOPs before pruning: {calculate_flops(llava_model, (dummy_images, encoded_text))}")

    # Load toy dataset (MNIST) with transformation for RGB
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1))  # Convert grayscale (1 channel) to RGB (3 channels)
    ])
    mnist_data = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
    dataloader = DataLoader(mnist_data, batch_size=16, shuffle=True)

    # Define criterion and optimizer
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(llava_model.parameters(), lr=1e-3)

    # Calculate accuracy before pruning
    accuracy_before = calculate_accuracy(llava_model, dataloader)
    print(f"Accuracy before pruning: {accuracy_before:.2f}")

    # Train on toy dataset
    trained_model = train_on_toy_task(llava_model, dataloader, criterion, optimizer)
    print("Model trained on toy task.")

    llava_model.prune_max_flop_layer(batch_size=16)

    print("Profiling after pruning:")
    profile_model(llava_model, (dummy_images, encoded_text))
    print(f"FLOPs after pruning: {calculate_flops(llava_model, (dummy_images, encoded_text))}")

    # Calculate accuracy after pruning
    accuracy_after = calculate_accuracy(llava_model, dataloader)
    print(f"Accuracy after pruning: {accuracy_after:.2f}")


CUDA is available. Using device: NVIDIA RTX 4000 Ada Generation


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Profiling before pruning:
Profiling model on CUDA...
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                   aten::to         0.01%       1.000us         0.01%       1.000us       0.027us            37  
               aten::conv2d         0.03%       4.000us         1.10%     176.000us     176.000us             1  
          aten::convolution         0.16%      25.000us         1.08%     172.000us     172.000us             1  
         aten::_convolution         0.08%      12.000us         0.92%     147.000us     147.000us             1  
    aten::cudnn_convolution         0.84%     135.000us         0.84%     135.000us     135.000us             1  
              aten::flatten        

STAGE:2024-12-08 18:00:52 1793:1793 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-12-08 18:00:52 1793:1793 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-12-08 18:00:52 1793:1793 ActivityProfilerController.cpp:322] Completed Stage: Post Processing


Accuracy before pruning: 0.49
Model trained on toy task.
Profiling after pruning:
Profiling model on CUDA...


STAGE:2024-12-08 18:19:41 1793:1793 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-12-08 18:19:41 1793:1793 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-12-08 18:19:41 1793:1793 ActivityProfilerController.cpp:322] Completed Stage: Post Processing


---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                   aten::to         0.01%       2.000us         0.01%       2.000us       0.054us            37  
               aten::conv2d         0.03%       5.000us         0.97%     173.000us     173.000us             1  
          aten::convolution         0.17%      30.000us         0.94%     168.000us     168.000us             1  
         aten::_convolution         0.06%      11.000us         0.77%     138.000us     138.000us             1  
    aten::cudnn_convolution         0.71%     127.000us         0.71%     127.000us     127.000us             1  
              aten::flatten         0.03%       6.000us         0.11%      20.000us     