<a href="https://colab.research.google.com/github/ajaysuseel/MiniProject_AD/blob/main/contrastive_finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import json
import requests
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import BlipProcessor, BlipForConditionalGeneration
from peft import LoraConfig, get_peft_model
import torch.nn.functional as F
from tqdm import tqdm

#CONFIGURABLE VARIABLES

In [None]:
GITHUB_REPO = "https://raw.githubusercontent.com/ajaysuseel/MiniProject_AD/main/ajay/"
JSON_FILE = "contrastive_captions.json"
IMAGES_FOLDER = "images/"

#Checking model modules

In [None]:
# Load your BLIP model
model_name = "Salesforce/blip-image-captioning-base"  # or your chosen variant
processor = BlipProcessor.from_pretrained(model_name)
model = BlipForConditionalGeneration.from_pretrained(model_name)

# Iterate over the model's modules
for name, module in model.named_modules():
    # Optionally, you can filter names containing 'attn' or 'proj'
    if "attn" in name.lower() or "proj" in name.lower():
        print(name)


In [None]:
dir(model)


#Load BLIP Model with LoRA

In [None]:
def load_blip_with_lora():
    model_name = "Salesforce/blip-image-captioning-base"
    print("Loading BLIP-1 model with LoRA...")
    processor = BlipProcessor.from_pretrained(model_name)
    model = BlipForConditionalGeneration.from_pretrained(model_name)

    target_modules=[f"vision_model.encoder.layers.{i}.self_attn.qkv" for i in range(12)]

    # Apply LoRA (PEFT)
    lora_config = LoraConfig(
        r=8, lora_alpha=16, lora_dropout=0.1, target_modules=target_modules
    )
    model = get_peft_model(model, lora_config)

    return model, processor

#Load Dataset

In [None]:
def load_dataset():
    json_url = GITHUB_REPO + JSON_FILE
    try:
        response = requests.get(json_url)
        response.raise_for_status()
        data = response.json()
        print(f"Loaded {len(data)} image-caption pairs.")
        return data
    except requests.exceptions.RequestException as e:
        print(f"Error loading dataset: {e}")
        return []

#Custom Dataset Class

In [None]:
import torch
from torch.utils.data import Dataset
import requests
from PIL import Image

class ContrastiveCaptionDataset(Dataset):
    def __init__(self, data, processor, images_dir):
        self.data = data
        self.processor = processor
        self.images_dir = images_dir

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_url = self.images_dir + item["filename"]

        try:
            image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
        except Exception as e:
            print(f"Error loading image {item['filename']}: {e}")
            return None

        # Tokenize positive caption using the 'pos_caption' key
        pos_encoding = self.processor(
            text=item["pos_caption"],
            images=image,
            return_tensors="pt",
            padding="max_length",
            truncation=True
        )
        pos_encoding = {key: val.squeeze(0) for key, val in pos_encoding.items()}

        # Tokenize negative caption using the 'neg_caption' key
        neg_encoding = self.processor(
            text=item["neg_caption"],
            images=image,
            return_tensors="pt",
            padding="max_length",
            truncation=True
        )
        neg_encoding = {key: val.squeeze(0) for key, val in neg_encoding.items()}

        # Set the labels for contrastive loss:
        # 'pos_labels' from the positive encoding and 'neg_labels' from the negative encoding
        pos_encoding["pos_labels"] = pos_encoding["input_ids"]
        pos_encoding["neg_labels"] = neg_encoding["input_ids"]

        return pos_encoding


#Create DataLoader

In [None]:
def create_dataloader(data, processor, batch_size=4):
    dataset = ContrastiveCaptionDataset(data, processor, GITHUB_REPO + IMAGES_FOLDER)

    def collate_fn(batch):
        batch = [b for b in batch if b is not None]
        if len(batch) == 0:
            return None
        keys = batch[0].keys()
        return {key: torch.stack([b[key] for b in batch]) for key in keys}

    return DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

#Contrastive Loss Function

In [None]:
def contrastive_loss(image_embeds, pos_text_embeds, neg_text_embeds, temperature=0.07):
    sim_pos = torch.cosine_similarity(image_embeds, pos_text_embeds, dim=-1)
    sim_neg = torch.cosine_similarity(image_embeds, neg_text_embeds, dim=-1)
    loss = -torch.log(torch.exp(sim_pos / temperature) / (torch.exp(sim_pos / temperature) + torch.exp(sim_neg / temperature)))
    return loss.mean()

#Fine-Tune BLIP with Contrastive Loss

In [None]:
def train_blip_contrastive(model, dataloader, num_epochs=3, learning_rate=5e-5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.train()

    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    print(f"🚀 Starting fine-tuning on {device} for {num_epochs} epochs...")

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for batch in progress_bar:
            optimizer.zero_grad()

            pixel_values = batch["pixel_values"].to(device)
            pos_input_ids = batch["pos_labels"].to(device)
            neg_input_ids = batch["neg_labels"].to(device)

            # Generate embeddings

            image_embeds = model.vision_model(pixel_values).last_hidden_state.mean(dim=1)
            # Pass embeddings through the BERT model to get last_hidden_state
            pos_outputs = model.text_decoder.bert(inputs_embeds=model.text_decoder.bert.embeddings(pos_input_ids))
            neg_outputs = model.text_decoder.bert(inputs_embeds=model.text_decoder.bert.embeddings(neg_input_ids))

            pos_text_embeds = pos_outputs.last_hidden_state.mean(dim=1)
            neg_text_embeds = neg_outputs.last_hidden_state.mean(dim=1)


            loss = contrastive_loss(image_embeds, pos_text_embeds, neg_text_embeds)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        print(f"✅ Epoch {epoch+1} completed | Average Loss: {epoch_loss:.4f}")

    model.save_pretrained("./models/finetuned_blip1")
    processor.save_pretrained("./models/finetuned_blip1")
    print("🎯 Fine-tuning complete and model saved!")

#Main Execution

In [None]:
if __name__ == "__main__":
    model, processor = load_blip_with_lora()
    data = load_dataset()

    if not data:
        print("No data found. Exiting.")
    else:
        dataloader = create_dataloader(data, processor)
        if dataloader is None:
            print("Error: No valid data samples found. Exiting.")
        else:
            train_blip_contrastive(model, dataloader,20)

#Evaluation

In [None]:
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu

In [None]:
!git clone https://github.com/ajaysuseel/MiniProject_AD.git

#CONFIGURABLE VARIABLES

In [None]:
IMAGE_FOLDER = "/content/MiniProject_AD/ajay/images"
GROUND_TRUTH_JSON = "/content/MiniProject_AD/ajay/captions.json"
MODEL_PATH = "./models/finetuned_blip1"  # Path to your fine-tuned model

#Loading datset and model

In [None]:
import json
from PIL import Image

def load_ground_truth(local_json_path):
    """
    Load ground truth captions from a local JSON file.
    The JSON should be a dictionary mapping image filenames to captions.

    Parameters:
      local_json_path (str): Local file path to the JSON file.

    Returns:
      dict: A dictionary with keys as image filenames and values as captions.
    """
    try:
        with open(local_json_path, "r") as f:
            gt_data = json.load(f)
        print(f"Loaded {len(gt_data)} ground truth captions from {local_json_path}.")
        return gt_data
    except Exception as e:
        print(f"Error loading ground truth captions: {e}")
        return {}

def load_image(image_path):
    """
    Open an image from a local file path and convert it to RGB.

    Parameters:
      image_path (str): The local file path to the image.

    Returns:
      PIL.Image or None: The loaded image in RGB mode, or None if an error occurs.
    """
    try:
        image = Image.open(image_path).convert("RGB")
        return image
    except Exception as e:
        print(f"Error loading image {image_path}: {e}")
        return None
# Load Fine-Tuned BLIP Model and Processor

def load_model_and_processor(model_path):
    """
    Loads the fine-tuned BLIP model and its processor from a given directory.
    """
    try:
        processor = BlipProcessor.from_pretrained(model_path)
        model = BlipForConditionalGeneration.from_pretrained(model_path)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)
        model.eval()
        print(f"Model loaded on {device}.")
        return model, processor, device
    except Exception as e:
        print(f"Error loading model: {e}")
        return None, None, None

#Generate Captions

In [None]:
def generate_caption(model, processor, device, image):
    """
    Given an image, generate a caption using the fine-tuned model.
    """
    try:
        inputs = processor(images=image, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            output_ids = model.generate(**inputs)
        caption = processor.decode(output_ids[0], skip_special_tokens=True)
        return caption
    except Exception as e:
        print(f"Error generating caption: {e}")
        return ""

In [None]:
def display_image_with_captions(image_path, gt_caption, generated_caption, bleu_score):
    image = load_image(image_path)
    if image is None:
        print(f"Cannot display image: {image_path}")
        return
    plt.figure(figsize=(8, 6))
    plt.imshow(image)
    plt.axis('off')
    title_text = f"GT: {gt_caption}\nGen: {generated_caption}\nBLEU: {bleu_score:.4f}"
    plt.title(title_text, fontsize=10)
    plt.show()

#Evaluating the finetuned model

In [None]:
import os
def evaluate_model(image_folder, gt_json_path, model_path):
    gt_captions = load_ground_truth(gt_json_path)
    if not gt_captions:
        print("No ground truth data available. Exiting evaluation.")
        return

    model, processor, device = load_model_and_processor(model_path)
    if model is None:
        print("Model loading failed. Exiting evaluation.")
        return

    generated_captions = {}
    image_files = [f for f in os.listdir(image_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    print(f"Found {len(image_files)} images in {image_folder}.")

    for filename in tqdm(image_files, desc="Evaluating images"):
        image_path = os.path.join(image_folder, filename)
        image = load_image(image_path)
        if image is None:
            continue
        caption = generate_caption(model, processor, device, image)
        generated_captions[filename] = caption

    individual_scores = {}
    references = []
    hypotheses = []

    for filename, gt_caption in gt_captions.items():
        if filename in generated_captions:
            gen_caption = generated_captions[filename]
            hypothesis = gen_caption.split()
            reference = [gt_caption.split()]  # BLEU expects a list of references
            score = sentence_bleu(reference, hypothesis)
            individual_scores[filename] = score
            references.append(reference)
            hypotheses.append(hypothesis)
            image_path = os.path.join(image_folder, filename)
            display_image_with_captions(image_path, gt_caption, gen_caption, score)
        else:
            print(f"Warning: No generated caption for {filename}")

    avg_bleu = corpus_bleu(references, hypotheses)
    print("\n--- Evaluation Summary ---")
    for filename, score in individual_scores.items():
        print(f"{filename}: BLEU Score = {score:.4f}")
    print(f"\nAverage Corpus BLEU Score: {avg_bleu:.4f}")


In [None]:
if __name__ == "__main__":
    # Clone your repository if needed:
    # !git clone https://github.com/ajaysuseel/MiniProject_AD.git

    # Fine-tuning Phase:
    with open(os.path.join(LOCAL_REPO_PATH, "captions.json"), "r") as f:
        fine_tuning_data = json.load(f)
    print(f"Loaded {len(fine_tuning_data)} fine-tuning samples.")

    # Load BLIP-2 model with LoRA and processor for fine-tuning
    model, processor = load_blip2_with_lora()
    dataloader = create_dataloader(fine_tuning_data, processor, batch_size=2)

    # # Fine-tune the model using contrastive fine-tuning
    # model = train_blip_contrastive(model, dataloader, num_epochs=3, learning_rate=5e-5)

    # Evaluate the fine-tuned model
    evaluate_model(IMAGE_FOLDER, GROUND_TRUTH_JSON, MODEL_SAVE_PATH)