In [None]:
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch
from PIL import Image
import os
import pandas as pd

device = "cuda" if torch.cuda.is_available() else "cpu"

processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-2.7b", torch_dtype=torch.float16
).to(device)

In [None]:
from tqdm import tqdm
import pandas as pd
import os
from PIL import Image

captions = []
batch_size = 4
base_dir = '/home/apatil2/blue_link/Data'
magnifications = ['100x', '400x']
sets = ['First Set', 'Second Set']

for s in sets:
    udir = os.path.join(base_dir, s)
    for mag in magnifications:
        for label, class_name in enumerate([
            'Normal Oral Cavity Histopathological Images',
            'OSCC Histopathological Images'
        ]):

            bdir = os.path.join(udir, f"{mag} {class_name}")
            if not os.path.exists(bdir):
                print(f"Skipping missing folder: {bdir}")
                continue

            img_paths = [
                os.path.join(bdir, f)
                for f in os.listdir(bdir)
                if f.lower().endswith(('.png', '.jpg', '.jpeg'))
            ]

            for i in tqdm(range(0, len(img_paths), batch_size), desc=f"{mag} {class_name}"):
                batch_paths = img_paths[i:i+batch_size]
                images = [Image.open(p).convert('RGB') for p in batch_paths]
            
                # ðŸ§  Use a prompt that makes BLIP2 reason about the label
                if label == 0:
                    prompt = (
                        "Question: Explain why this histopathological image is classified as normal oral tissue. "
                        "Answer:"
                    )
                else:
                    prompt = (
                        "Question: Explain why this histopathological image is classified as oral squamous cell carcinoma. "
                        "Answer:"
                    )

            
                inputs = processor(images=images, text=[prompt]*len(images), return_tensors="pt").to(device)
            
                with torch.no_grad():
                    outputs = model.generate(**inputs, max_new_tokens=80)
            
                decoded = processor.batch_decode(outputs, skip_special_tokens=True)
            
                for path, caption in zip(batch_paths, decoded):
                    captions.append({
                        'image_path': path,
                        'label': label,
                        'generated_caption': caption
                    })
            
                if len(captions) % 100 == 0:
                    pd.DataFrame(captions).to_csv("captions_progress.csv", index=False)


In [None]:
import pandas as pd

# After the loop finishes
df = pd.DataFrame(captions)
df.to_csv("final_captions.csv", index=False)
print(f"Saved {len(df)} captions to final_captions.csv")
