## A. Environment Setup and Imports

In [None]:
import os, random, torch, re
import numpy as np
from collections import defaultdict
from torch.utils.data import Dataset
from datasets import load_dataset, Dataset as HFDataset
from PIL import Image
from IPython.display import display
import spacy


from transformers import (
    BlipProcessor, BlipForConditionalGeneration,
    T5Tokenizer, T5ForConditionalGeneration,
    CLIPProcessor, CLIPModel,
    TrainingArguments, Trainer, TrainerCallback, EvalPrediction,
    default_data_collator, pipeline
)
from sklearn.metrics.pairwise import cosine_similarity
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.translate.meteor_score import meteor_score
from rouge import Rouge

os.environ["HF_HOME"] = "D:/NLP_Cache"
os.environ["HF_DATASETS_CACHE"] = "D:/NLP_Cache/datasets"
os.environ["TRANSFORMERS_CACHE"] = "D:/NLP_Cache/models"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
rouge_evaluator = Rouge()
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
print("\ud83d\ude80 Using device:", device)


## B. Dataset Loading

In [None]:
dataset = load_dataset("Supermaxman/esa-hubble", cache_dir="D:/NLP_Cache/datasets")["train"]
print("Total records - ", len(dataset))

## C. Model and Tokenizer Initialization

In [None]:
from transformers import BlipForConditionalGeneration
blip_model = BlipForConditionalGeneration.from_pretrained(
    "Salesforce/blip-image-captioning-base"
).to(device)

blip_model     = BlipForConditionalGeneration.from_pretrained(
    "Salesforce/blip-image-captioning-base"
).to(device)

blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

# T5 Refinement Model
t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-base").to(device)
t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-base")

# CLIP for similarity
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)

## SECTION 4: DATASET WRAPPER 

In [None]:
class HubbleBLIPDataset(Dataset):
    def __init__(self, dataset, processor, max_samples=None):
        self.dataset = dataset.select(range(max_samples)) if max_samples else dataset
        self.processor = processor

    def __len__(self): return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        image = sample["image"]
        if isinstance(image, list): image = image[0]
        image = image.convert("RGB")
        caption = str(sample.get("description", ""))
        prompt = "Describe the astronomical image."

        inputs = self.processor(images=image, text=prompt, return_tensors="pt",
                                padding="max_length", truncation=True, max_length=64)
        labels = self.processor.tokenizer(caption, return_tensors="pt", padding="max_length",
                                          truncation=True, max_length=64)["input_ids"].squeeze(0)
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        inputs["labels"] = labels
        return inputs


from torch.utils.data import Dataset
import torch

class HubbleBLIPDataset(Dataset):
    def __init__(self, dataset, processor, max_samples=None):
        self.dataset = dataset.select(range(max_samples)) if max_samples else dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]

        image = sample["image"]
        if isinstance(image, list):
            image = image[0]
        image = image.convert("RGB")

        caption = str(sample.get("description", ""))
        prompt = "Describe the astronomical image."

        inputs = self.processor(
            images=image,
            text=prompt,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=64,
        )

        labels = self.processor.tokenizer(
            caption,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=64,
        )["input_ids"].squeeze(0)

        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        inputs["labels"] = labels

        return inputs


## E. Data Splitting

In [None]:
from sklearn.model_selection import train_test_split
from datasets import Dataset

# Split Hugging Face dataset into train and eval sets (e.g., 90/10 split)
train_data = dataset.select(range(int(len(dataset) * 0.9)))
eval_data  = dataset.select(range(int(len(dataset) * 0.9), len(dataset)))

# Prepare datasets
train_dataset = HubbleBLIPDataset(train_data, blip_processor, max_samples=450)
eval_dataset  = HubbleBLIPDataset(eval_data, blip_processor, max_samples=50)


## F. Named Entity Recognition and Object ID Utilities

In [None]:
import spacy
spacy_model = spacy.load("en_core_web_sm")

def extract_named_entities(text):
    doc = spacy_model(text)
    entities = defaultdict(list)
    for ent in doc.ents:
        entities[ent.label_].append(ent.text)
    return dict(entities)

def extract_astronomical_entities(text):
    pattern = r"\b(?:NGC|UGC|IC|M|Messier)\s?\d{1,5}\b"
    return re.findall(pattern, text.upper())

def extract_object_ids(text):
    return re.findall(r'\b(NGC|UGC|IC|M)\s?\d{2,5}\b', text.upper())

def link_object_ids(text, known_ids):
    return [obj for obj in known_ids if obj.lower() in text.lower()]

def summarize_text(text, max_length=60, min_length=20):
    try:
        summary = summarizer(text, max_length=max_length, min_length=min_length, do_sample=False)
        return summary[0]["summary_text"]
    except Exception as e:
        print(f"⚠️ Summarization failed: {e}")
        return "[Summary unavailable]"

## G. Caption Generation and Refinement

In [None]:
def generate_blip_caption(image):
    image = image.convert("RGB").resize((224, 224))
    inputs = blip_processor(images=image, return_tensors="pt").to(device)
    outputs = blip_model.generate(**inputs, max_length=50, repetition_penalty=1.2, no_repeat_ngram_size=3)
    return blip_processor.decode(outputs[0], skip_special_tokens=True)

def refine_caption(blip_caption: str, beams: int = 4):
    if not blip_caption.strip(): return "No caption generated."
    prompt = f"Rewrite the following astronomy image caption so it is concise, scientifically correct, and fluent:\n{blip_caption.strip()}"
    inputs = t5_tokenizer(prompt, return_tensors="pt").to(device)
    outputs = t5_model.generate(**inputs, num_beams=beams, early_stopping=True,
                                max_length=40, length_penalty=0.8,
                                no_repeat_ngram_size=3, repetition_penalty=1.2)
    return t5_tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

def generate_caption_with_metadata(image, metadata):
    title = metadata.get('title', '').strip()
    constellation = metadata.get('Constellation', '').strip()
    category = metadata.get('Category', '').strip()
    distance = metadata.get('Distance', '').strip()
    parts = []
    if title: parts.append(f"of '{title}'")
    if constellation: parts.append(f"located in the constellation {constellation}")
    if category: parts.append(f"which belongs to the category {category}")
    if distance: parts.append(f"and is {distance} away")
    prompt = "Describe the astronomical image " + ", ".join(parts) + "."
    inputs = blip_processor(images=image, text=prompt, return_tensors="pt").to(device)
    outputs = blip_model.generate(**inputs, max_length=50)
    return blip_processor.tokenizer.decode(outputs[0], skip_special_tokens=True)

def refine_caption_with_metadata(caption, metadata):
    prompt = (f"Refine this caption using astronomy terms. Context: title: {metadata.get('title', '')}, "
              f"constellation: {metadata.get('Constellation', '')}, "
              f"category: {metadata.get('Category', '')}. Caption: {caption}")
    inputs = t5_tokenizer(prompt, return_tensors="pt").to(device)
    outputs = t5_model.generate(**inputs, max_length=60)
    return t5_tokenizer.decode(outputs[0], skip_special_tokens=True)


## H. CLIP-Based Caption Reranking

In [None]:
def get_clip_score(image, caption):
    inputs = clip_processor(text=[caption], images=image, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        outputs = clip_model(**inputs)
    return cosine_similarity(outputs.image_embeds.cpu().numpy(), outputs.text_embeds.cpu().numpy())[0][0]

def clip_rerank(image, captions):
    best_score = -1
    best_caption = ""
    for caption in captions:
        inputs = clip_processor(text=[caption], images=image, return_tensors="pt", padding=True, truncation=True).to(device)
        with torch.no_grad():
            outputs = clip_model(**inputs)
        score = cosine_similarity(outputs.image_embeds.cpu().numpy(), outputs.text_embeds.cpu().numpy())[0][0]
        if score > best_score:
            best_score = score
            best_caption = caption
    return best_caption, best_score


def compute_clip_similarity(image_emb, text_emb):
    return cosine_similarity(image_emb.cpu().detach().numpy(), text_emb.cpu().detach().numpy())[0][0]

def prepare_clip_guided_rerank_options(captions, image, clip_processor, clip_model, device="cuda"):
    similarities = []
    for caption in captions:
        inputs = clip_processor(text=[caption], images=image, return_tensors="pt", padding=True, truncation=True).to(device)
        with torch.no_grad():
            outputs = clip_model(**inputs)
        sim = compute_clip_similarity(outputs.image_embeds, outputs.text_embeds)
        similarities.append(sim)
    best_index = int(torch.tensor(similarities).argmax())
    return captions[best_index], similarities[best_index]


## I.Caption Evaluation Metrics

In [None]:
def compute_caption_metrics(eval_pred: EvalPrediction):
    predictions, labels = eval_pred
    if isinstance(predictions, tuple): predictions = predictions[0]
    pred_ids = np.argmax(predictions, axis=-1)
    decoded_preds = blip_processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    decoded_labels = blip_processor.tokenizer.batch_decode(labels, skip_special_tokens=True)
    smoothie = SmoothingFunction().method4
    scores = {
        "BLEU": np.mean([sentence_bleu([ref.split()], pred.split(), smoothing_function=smoothie)
                         for pred, ref in zip(decoded_preds, decoded_labels)]),
        "ROUGE-L": np.mean([rouge_evaluator.get_scores(pred, ref)[0]["rouge-l"]["f"]
                            for pred, ref in zip(decoded_preds, decoded_labels)]),
        "METEOR": np.mean([meteor_score([ref.split()], pred.split())
                           for pred, ref in zip(decoded_preds, decoded_labels)])
    }
    return scores

def evaluate_captions(reference, generated):
    smoothie = SmoothingFunction().method4
    bleu = sentence_bleu([reference.split()], generated.split(), smoothing_function=smoothie)
    rouge = rouge_evaluator.get_scores(generated, reference)[0]["rouge-l"]["f"]
    meteor = meteor_score([reference.split()], generated.split())
    return {"BLEU": bleu, "ROUGE-L": rouge, "METEOR": meteor}


## J. Training Loop and Callbacks

In [None]:
class PrintStepCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs: print(f"Step {state.global_step} - Logs: {logs}")

class PrintEpochCallback(TrainerCallback):
    def on_epoch_end(self, args, state, control, **kwargs):
        print(f"📢 Finished Epoch {int(state.epoch)}")

training_args = TrainingArguments(
    output_dir="./blip_finetuned_light",
    per_device_train_batch_size=2,
    num_train_epochs=4,
    gradient_accumulation_steps=4,
    learning_rate=5e-5,
    warmup_steps=200,
    lr_scheduler_type="cosine",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=10,
    fp16=torch.cuda.is_available(),
    save_total_limit=2,
    remove_unused_columns=False,
    report_to="none",
    metric_for_best_model="eval_METEOR",
    greater_is_better=True
)

trainer = Trainer(
    model=blip_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=blip_processor.tokenizer,
    data_collator=default_data_collator,
    compute_metrics=compute_caption_metrics,
    callbacks=[PrintEpochCallback(), PrintStepCallback()]
)

trainer.train()

In [None]:
indices = random.sample(range(len(dataset)), 5)

for idx in indices:
    image = dataset[idx]["image"]
    if image is None:
        continue

    reference = dataset[idx]["description"]
    blip_caption = generate_blip_caption(image)
    refined = refine_caption(blip_caption)
    reranked, _ = clip_rerank(image, [blip_caption, refined])
    clip_blip = get_clip_score(image, refined)
    clip_ref = get_clip_score(image, reference)
    metrics = evaluate_captions(reference, refined)
    entities = extract_named_entities(reference)
    summary = summarize_text(reference)

    print(f"\n Image Index - {idx}")
    display(image)
    print(f"BLIP Caption - {blip_caption}")
    print(f"Refined Caption - {refined}")
    print(f"CLIP-Reranked Caption - {reranked}")
    print(f"Reference - {reference}")
    print(f"Summary (T5) - {summary}")
    print(f"Named Entities - {entities}")
    print(f"CLIP Refined - {clip_blip:.4f} | CLIP (Reference): {clip_ref:.4f}")
    print("NLP Metrics -", metrics)
