<a href="https://colab.research.google.com/github/ajaysuseel/MiniProject_AD/blob/main/incremental_contrastive_finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import json
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
from transformers import BlipProcessor, BlipForConditionalGeneration, BlipConfig
from peft import LoraConfig, get_peft_model

In [None]:
import os
import subprocess

GIT_REPO_URL = "https://github.com/your-username/your-repo.git"
GIT_LOCAL_PATH = "/content/your-repo"  # Path where repo is cloned

if not os.path.exists(GIT_LOCAL_PATH):
    !git clone {GIT_REPO_URL} {GIT_LOCAL_PATH}
else:
    subprocess.run(["git", "-C", GIT_LOCAL_PATH, "pull"], check=True)  # Pull latest changes


In [None]:
import json

CAPTIONS_PATH = os.path.join(GIT_LOCAL_PATH, "ajay", "contrastive_captions.json")
USED_FILES_PATH = os.path.join(GIT_LOCAL_PATH, "ajay", "used_files.json")
IMAGE_FOLDER = os.path.join(GIT_LOCAL_PATH, "ajay", "images")

# Load contrastive captions
with open(CAPTIONS_PATH, "r") as f:
    captions_data = json.load(f)

# Load already fine-tuned files
if os.path.exists(USED_FILES_PATH):
    with open(USED_FILES_PATH, "r") as f:
        used_files = set(json.load(f))
else:
    used_files = set()


In [None]:
# Filter only new images for fine-tuning
new_data = [item for item in captions_data if item["filename"] not in used_files]

if not new_data:
    print("‚úÖ No new images for fine-tuning. Exiting...")
    exit()


In [None]:
# ---------------------------------------------------------
# Function: Load Used Files List
# ---------------------------------------------------------
def load_used_files(used_files_path):
    """
    Load the set of filenames that have been used in previous fine-tuning sessions.
    If the file does not exist, return an empty set.
    """
    if os.path.exists(used_files_path):
        with open(used_files_path, "r") as f:
            used_files = set(json.load(f))
        print(f"Loaded {len(used_files)} used filenames from {used_files_path}.")
    else:
        used_files = set()
        print("No used files record found; starting fresh.")
    return used_files

In [None]:
# ---------------------------------------------------------
# Function: Save Used Files List
# ---------------------------------------------------------
def save_used_files(used_files, used_files_path):
    """
    Save the set of filenames to a JSON file.
    """
    with open(used_files_path, "w") as f:
        json.dump(list(used_files), f)
    print(f"Saved {len(used_files)} used filenames to {used_files_path}.")

In [None]:
# ---------------------------------------------------------
# Function: Filter New Samples for Incremental Fine-Tuning
# ---------------------------------------------------------
def get_new_samples(new_data, used_files_path):
    """
    Given the new dataset (a list of samples) and a path to a file that records used filenames,
    return a filtered list containing only the new samples (not used before).

    Each sample is assumed to be a dictionary with a "filename" key.
    """
    used_files = load_used_files(used_files_path)
    new_samples = [item for item in new_data if item["filename"] not in used_files]
    print(f"Found {len(new_samples)} new samples for fine-tuning.")
    return new_samples, used_files

In [None]:
# ---------------------------------------------------------
# Custom Dataset
# ---------------------------------------------------------
class CustomDataset(Dataset):
    def __init__(self, data, processor, images_folder):
        self.data = data
        self.processor = processor
        self.images_folder = images_folder

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image_path = os.path.join(self.images_folder, item["filename"])
        image = Image.open(image_path).convert("RGB")
        pos_text = item["positive_caption"]
        neg_text = item["negative_caption"]

        encoding = self.processor(images=image, text=[pos_text, neg_text], padding="max_length", return_tensors="pt")
        encoding = {k: v.squeeze(0) for k, v in encoding.items()} # Remove batch dim

        # Add positive and negative labels
        encoding["pos_labels"] = self.processor.tokenizer(pos_text, return_tensors="pt", padding="max_length").input_ids.squeeze(0)
        encoding["neg_labels"] = self.processor.tokenizer(neg_text, return_tensors="pt", padding="max_length").input_ids.squeeze(0)

        return encoding

In [None]:
# ---------------------------------------------------------
# Function: Create DataLoader
# ---------------------------------------------------------
def create_dataloader(data, processor, images_folder, batch_size=2):
    dataset = CustomDataset(data, processor, images_folder)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

In [None]:
# ---------------------------------------------------------
# Contrastive Loss Function
# ---------------------------------------------------------
def contrastive_loss(image_embeds, pos_text_embeds, neg_text_embeds, margin=1.0):
    pos_similarity = torch.cosine_similarity(image_embeds, pos_text_embeds, dim=-1)
    neg_similarity = torch.cosine_similarity(image_embeds, neg_text_embeds, dim=-1)
    loss = torch.relu(margin - pos_similarity + neg_similarity).mean()
    return loss

In [None]:
# ---------------------------------------------------------
# Function: Incremental Fine-Tuning
# ---------------------------------------------------------
def incremental_finetuning(new_data_json, new_images_folder, used_files_path,
                           model_save_path, base_model_name, num_epochs=3, learning_rate=5e-5, batch_size=2):
    """
    Incrementally fine-tune the model on new data.

    This function:
      1. Loads the new data (list of samples) from new_data_json.
      2. Loads the list of filenames that have already been used (from used_files_path).
      3. Filters out used samples.
      4. Loads the previously fine-tuned model if available; otherwise loads the base model.
      5. Fine-tunes on the new (filtered) data.
      6. Updates and saves the used filenames.

    Parameters:
      new_data_json (str): Path to a JSON file containing new fine-tuning samples.
      new_images_folder (str): Local directory containing the new images.
      used_files_path (str): Path to a JSON file storing used filenames.
      model_save_path (str): Directory where the fine-tuned model is saved.
      base_model_name (str): The base model identifier (from Hugging Face).
    """
    # Step 1: Load new fine-tuning data
    try:
        with open(new_data_json, "r") as f:
            new_data = json.load(f)
        print(f"Loaded {len(new_data)} new fine-tuning samples from {new_data_json}.")
    except Exception as e:
        print(f"Error loading new fine-tuning data: {e}")
        return

    # Step 2: Filter new samples based on used files
    new_samples, used_files = get_new_samples(new_data, used_files_path)
    if not new_samples:
        print("No new samples to fine-tune on. Exiting incremental fine-tuning.")
        return

    # Step 3: Load the previous fine-tuned model if exists, else the base model.
    if os.path.exists(model_save_path):
        print("Loading previously fine-tuned model...")
        processor = BlipProcessor.from_pretrained(model_save_path, ignore_mismatched_sizes=True)
        model = BlipForConditionalGeneration.from_pretrained(model_save_path, ignore_mismatched_sizes=True)
    else:
        print("No previously fine-tuned model found; loading base model...")
        config = BlipConfig.from_pretrained(base_model_name)
        processor = BlipProcessor.from_pretrained(base_model_name, ignore_mismatched_sizes=True)
        model = BlipForConditionalGeneration.from_pretrained(base_model_name, config=config, ignore_mismatched_sizes=True)
        # Optionally, apply LoRA:
        target_modules = [f"vision_model.encoder.layers.{i}.self_attn.qkv" for i in range(12)]
        lora_config = LoraConfig(
            r=8,
            lora_alpha=16,
            lora_dropout=0.1,
            target_modules=target_modules
        )
        model = get_peft_model(model, lora_config)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.train()

    # Step 4: Create DataLoader for new samples
    dataloader = create_dataloader(new_samples, processor, new_images_folder, batch_size=batch_size)
    if dataloader is None:
        print("No valid data samples found. Exiting incremental fine-tuning.")
        return

    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    print(f"üöÄ Starting incremental fine-tuning on {device} for {num_epochs} epochs...")
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            optimizer.zero_grad()
            pixel_values = batch["pixel_values"].to(device)
            pos_input_ids = batch["pos_labels"].to(device)
            neg_input_ids = batch["neg_labels"].to(device)
            image_embeds = model.vision_model(pixel_values).last_hidden_state.mean(dim=1)
            # For text embeddings, use text_decoder with output_hidden_states=True (T5-based)
            with torch.no_grad():
                pos_outputs = model.text_decoder(input_ids=pos_input_ids.long(), output_hidden_states=True)
                neg_outputs = model.text_decoder(input_ids=neg_input_ids.long(), output_hidden_states=True)
            pos_text_embeds = pos_outputs.hidden_states[-1].mean(dim=1)
            neg_text_embeds = neg_outputs.hidden_states[-1].mean(dim=1)
            loss = contrastive_loss(image_embeds, pos_text_embeds, neg_text_embeds)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"‚úÖ Epoch {epoch+1} completed | Average Loss: {epoch_loss/len(dataloader):.4f}")

    # Step 5: Save the updated model and processor
    model.save_pretrained(model_save_path)
    processor.save_pretrained(model_save_path)
    print(f"Model saved to {model_save_path}")

    # Step 6: Update used files list
    newly_used_files = {item["filename"] for item in new_samples}
    used_files.update(newly_used_files)
    used_files_path = os.path.join(model_save_path, "used_files.json")
    with open(used_files_path, "w") as f:
        json.dump(list(used_files), f)
    print(f"Updated used files list saved to {used_files_path}")

    return model, processor

In [None]:
# Clone the GitHub repository if not already done
GIT_REPO_URL = "https://github.com/your-username/your-repo.git"
GIT_LOCAL_PATH = "/content/your-repo"  # Local path after cloning

if not os.path.exists(GIT_LOCAL_PATH):
    !git clone {GIT_REPO_URL} {GIT_LOCAL_PATH}

# Paths from cloned GitHub repo
NEW_DATA_JSON = os.path.join(GIT_LOCAL_PATH, "data/new_data.json")  # Update the path inside the repo
USED_FILES_PATH = os.path.join(GIT_LOCAL_PATH, "models/used_files.json")


In [None]:
if __name__ == "__main__":
    # Paths for the new data for incremental fine-tuning:
    NEW_DATA_JSON = "/content/data/incremental/new_data.json"  # New fine-tuning data (list of samples)
    NEW_IMAGES_FOLDER = "/content/data/incremental/new_images"   # New images folder (e.g., file62 to file200)
    USED_FILES_PATH = "./models/finetuned_blip2/used_files.json"  # File that stores previously used filenames

    # Perform incremental fine-tuning
    model, processor = incremental_finetuning(
        new_data_json=NEW_DATA_JSON,
        new_images_folder=NEW_IMAGES_FOLDER,
        used_files_path=USED_FILES_PATH,
        model_save_path=MODEL_SAVE_PATH,
        base_model_name=BASE_MODEL_NAME,
        num_epochs=3,
        learning_rate=5e-5,
        batch_size=2
    )

In [None]:
# Update used files list
for item in new_data:
    used_files.add(item["filename"])

# Save updated used_files.json
with open(USED_FILES_PATH, "w") as f:
    json.dump(list(used_files), f)

print(f"‚úÖ Updated used_files.json with {len(used_files)} images.")


In [None]:
def push_to_github(repo_path, file_path, commit_message="Updated used files"):
    try:
        subprocess.run(["git", "-C", repo_path, "add", file_path], check=True)
        subprocess.run(["git", "-C", repo_path, "commit", "-m", commit_message], check=True)
        subprocess.run(["git", "-C", repo_path, "push"], check=True)
        print(f"üöÄ Successfully pushed {file_path} to GitHub.")
    except subprocess.CalledProcessError as e:
        print(f"‚ö†Ô∏è Git push error: {e}")

push_to_github(GIT_LOCAL_PATH, USED_FILES_PATH, "Updated used files after fine-tuning")


In [None]:
MODEL_SAVE_PATH = os.path.join(GIT_LOCAL_PATH, "ajay", "models", "finetuned_blip2")
model.save_pretrained(MODEL_SAVE_PATH)

push_to_github(GIT_LOCAL_PATH, MODEL_SAVE_PATH, "Saved fine-tuned BLIP-2 model")
