In [None]:
#==============blip2 with SLAKE================

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import io
import sys
import os
import json
import torch
from PIL import Image
import pandas as pd
from tqdm.notebook import tqdm
import traceback

from transformers import Blip2Processor, Blip2ForConditionalGeneration

global_models = {}

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
def load_models():
    if 'blip2' not in global_models:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")

        print("Loading BLIP-2 model...")

        processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
        model = Blip2ForConditionalGeneration.from_pretrained(
            "Salesforce/blip2-opt-2.7b",
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto"
        )

        global_models['processor'] = processor
        global_models['model'] = model
        global_models['device'] = device

        print("BLIP-2 model loaded successfully")

def process_image_with_blip2(image_path, prompt=""):
    try:
        print(f"Loading image from: {image_path}")

        if not os.path.exists(image_path):
            print(f"ERROR: Image file not found: {image_path}")
            return "Error: Image file not found"


        image = Image.open(image_path).convert('RGB')
        print(f"Image successfully loaded. Size: {image.size}")


        print("Processing image with BLIP-2...")
        processor = global_models['processor']
        model = global_models['model']

        inputs = processor(images=image, return_tensors="pt").to(model.device)


        print("Generating description...")
        with torch.no_grad():
            generated_ids = model.generate(
                **inputs,
                max_new_tokens=100,
                num_beams=5,
                min_length=20,
                top_p=0.9,
                repetition_penalty=1.5,
                length_penalty=1.0,
                do_sample=True
            )


        caption = processor.decode(generated_ids[0], skip_special_tokens=True)


        if prompt:
            medical_caption = f"{prompt} {caption}"
            print(f"Generated caption: {medical_caption}")
            return medical_caption
        else:
            print(f"Generated caption: {caption}")
            return caption

    except Exception as e:
        print(f"Error processing image: {e}")
        print(traceback.format_exc())
        return f"Error generating caption: {str(e)}"

def main():

    load_models()


    slake_base_dir = '/content/drive/MyDrive/PhD/Research1/slakedataset/Slake1.0'
    json_path = os.path.join(slake_base_dir, 'test.json')
    output_dir = '/content/drive/MyDrive/PhD/Research1/output'
    output_json = os.path.join(output_dir, 'slake_blip2_results.json')


    os.makedirs(output_dir, exist_ok=True)


    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)


    english_samples = [item for item in data if item.get('q_lang') == 'en']
    print(f"Loaded {len(english_samples)} English samples from test.json")


    modality_counts = {}
    for item in english_samples:
        modality = item.get('modality', 'Unknown')
        modality_counts[modality] = modality_counts.get(modality, 0) + 1

    print("Modality distribution:")
    for modality, count in modality_counts.items():
        print(f"  {modality}: {count} samples")

    results = []


    temp_output_json = os.path.join(output_dir, 'slake_blip2_results_temp.json')


    for idx, sample in tqdm(enumerate(english_samples), total=len(english_samples), desc="Processing images"):
        img_name = sample.get('img_name', '')
        img_id = sample.get('img_id', '')
        question = sample.get('question', '')
        answer = sample.get('answer', '')
        modality = sample.get('modality', '')


        img_path = os.path.join(slake_base_dir, 'imgs', img_name)

        print(f"\nProcessing image {idx} (ID: {img_id})")
        print(f"Image path: {img_path}")
        print(f"Question: {question}")
        print(f"Modality: {modality}")

        try:

            if not os.path.exists(img_path):
                print(f"WARNING: Image file not found at {img_path}")

                alternative_paths = [
                    os.path.join(slake_base_dir, 'img', img_name),
                    os.path.join(slake_base_dir, 'images', img_name)
                ]
                for alt_path in alternative_paths:
                    if os.path.exists(alt_path):
                        img_path = alt_path
                        print(f"Found image at alternative path: {img_path}")
                        break


            prompt = ""
            caption_prefix = ""
            if modality == "CT":
                caption_prefix = "This CT scan shows"
            elif modality == "MRI":
                caption_prefix = "This MRI scan reveals"
            elif modality == "X-ray":
                caption_prefix = "This X-ray image displays"
            else:
                caption_prefix = "This medical image shows"


            blip2_caption = process_image_with_blip2(img_path)


            if not blip2_caption.lower().startswith("this"):
                blip2_caption = f"{caption_prefix} {blip2_caption}"


            result = {
                "id": idx,
                "img_id": img_id,
                "img_name": img_name,
                "question": question,
                "original_answer": answer,
                "modality": modality,
                "blip2_caption": blip2_caption
            }
            results.append(result)
        except Exception as e:
            print(f"Error processing sample {idx}, image {img_name}: {e}")
            print(traceback.format_exc())

            results.append({
                "id": idx,
                "img_id": img_id,
                "img_name": img_name,
                "question": question,
                "original_answer": answer,
                "modality": modality,
                "blip2_caption": f"Error generating caption: {str(e)}"
            })


        if (idx + 1) % 10 == 0:
            try:
                with open(temp_output_json, 'w') as f:
                    json.dump(results, f, indent=2)
                print(f"Temporary results saved to {temp_output_json} after processing {idx+1} samples")
            except Exception as save_error:
                print(f"Error saving temporary results: {save_error}")

        print("-" * 50)


    try:
        with open(output_json, 'w') as f:
            json.dump(results, f, indent=2)
        print(f"Processing complete. Results saved to {output_json}")
    except Exception as save_error:
        print(f"Error saving final results: {save_error}")

        backup_output = os.path.join('/content', 'slake_blip2_results_backup.json')
        with open(backup_output, 'w') as f:
            json.dump(results, f, indent=2)
        print(f"Results saved to backup location: {backup_output}")


if __name__ == "__main__":
    main()