In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from textwrap import wrap
import re
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch
from transformers import AutoProcessor, Blip2ForConditionalGeneration
from peft import get_peft_model, LoraConfig
from PIL import Image
from tqdm import tqdm
import os

In [None]:
image_path = '/kaggle/input/flickr8k/Images'
df = pd.read_csv('/kaggle/input/flickr8k/captions.txt')

In [None]:
df.head(10)

In [None]:
def readImage(path,img_size=224):
    img = load_img(path,color_mode='rgb',target_size=(img_size,img_size))
    img = img_to_array(img)
    img = img/255.
    
    return img
    
def display_images(temp_df):
    temp_df = temp_df.reset_index(drop=True)
    plt.figure(figsize = (20 , 20))
    n = 0
    for i in range(temp_df.shape[0]):
        n+=1
        plt.subplot(5 , 5, n)
        plt.subplots_adjust(hspace = 0.7, wspace = 0.3)
        image = readImage(f"{image_path}/{temp_df.image[i]}")
        plt.imshow(image)
        plt.title("\n".join(wrap(temp_df.caption[i], 20)))
        plt.axis("off")

In [None]:
display_images(df.sample(15))

In [None]:
def preprocess_caption(caption):
    # Convert to lowercase
    caption = caption.lower()
    
    # Remove punctuation and special characters except basic ones
    caption = re.sub(r"[^a-z0-9\s]", "", caption)
    
    # Remove extra spaces
    caption = re.sub(r"\s+", " ", caption).strip()
    
    # Add start and end tokens
    caption = "<start> " + caption + " <end>"
    
    return caption

In [None]:
df['caption'] = df['caption'].apply(preprocess_caption)

In [None]:
# Image preprocessing
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [None]:
# Dataset class
class Flickr8kDataset(Dataset):
    def __init__(self, dataframe, image_folder, processor):
        self.df = dataframe
        self.image_folder = image_folder
        self.processor = processor

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(os.path.join(self.image_folder, row["image"])).convert("RGB")
        caption = row["caption"]
        proc_out = self.processor(
            images=image,
            return_tensors="pt",
            padding="max_length"
        )

        proc_out = {
            k: v.squeeze() for k, v in proc_out.items()
        }
        proc_out["text"] = caption
        return proc_out

In [None]:
def collate_fn(batch):
    # pad the input_ids and attention_mask
    processed_batch = {}
    for key in batch[0].keys():
        if key != "text":
            processed_batch[key] = torch.stack([example[key] for example in batch])
        else:
            text_inputs = processor.tokenizer(
                [example["text"] for example in batch], padding=True, return_tensors="pt"
            )
            processed_batch["input_ids"] = text_inputs["input_ids"]
            processed_batch["attention_mask"] = text_inputs["attention_mask"]
    return processed_batch

In [None]:
# quant_config = BitsAndBytesConfig(load_in_8bit=True)

processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
    "ybelkada/blip2-opt-2.7b-fp16-sharded", 
    device_map="auto", 
    # quantization_config=quant_config
)

In [None]:
# Let's define the LoraConfig
config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=["q_proj", "k_proj"]
)

In [None]:
model = get_peft_model(model, config)
model.print_trainable_parameters()

In [None]:
dataset = Flickr8kDataset(df, "/kaggle/input/flickr8k/Images", processor)
train_size = int(0.8 * len(dataset))
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])

In [None]:
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=5, collate_fn=collate_fn)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

device = "cuda" if torch.cuda.is_available() else "cpu"

model.train()

In [None]:
EPOCHS = 1

In [None]:
for epoch in range(EPOCHS):
    print(f"Epoch {epoch + 1}/{EPOCHS}")
    
    progress_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc="Training", leave=False)

    for idx, batch in progress_bar:
        input_ids = batch.pop("input_ids").to(device)
        pixel_values = batch.pop("pixel_values").to(device, torch.float16)

        outputs = model(input_ids=input_ids,
                        pixel_values=pixel_values,
                        labels=input_ids)
        
        loss = outputs.loss

        # Update progress bar description with current loss
        progress_bar.set_postfix({"loss": loss.item()})

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

In [None]:
MODEL_DIR = "blip2-opt2.7b-finetuned-lora"

In [None]:
model.save_pretrained(MODEL_DIR)
processor.save_pretrained(MODEL_DIR)

In [None]:
N_QUAL = 20
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Decoding strategies as per assignment
DECODING_STRATEGIES = [
    ("beam", {"num_beams": 5, "max_new_tokens": 30}),
    ("top_k", {"do_sample": True, "top_k": 50, "max_new_tokens": 30}),
    ("top_p", {"do_sample": True, "top_p": 0.9, "max_new_tokens": 30}),
    ("temperature", {"do_sample": True, "temperature": 0.7, "max_new_tokens": 30}),
]

In [None]:
# Load model and processor
def load_model_and_processor():
    try:
        processor = AutoProcessor.from_pretrained(MODEL_DIR)
        model = Blip2ForConditionalGeneration.from_pretrained(MODEL_DIR, torch_dtype=torch.float16)
        model = model.to(DEVICE)
        model.eval()
        return processor, model
    except Exception as e:
        raise RuntimeError(f"Failed to load model or processor: {e}")

processor, model = load_model_and_processor()

In [None]:
# Generative decoding function
def generate_caption(image, strategy, params):
    """Generate a caption with the given decoding strategy."""
    try:
        inputs = processor(images=image, return_tensors="pt").to(DEVICE, torch.float16)
        outputs = model.generate(**inputs, **params)
        caption = processor.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return caption.strip()
    except Exception as e:
        print(f"Error generating caption for strategy {strategy}: {e}")
        return ""

In [None]:
# Automatic metrics setup
bleu_metric = evaluate.load("bleu")
meteor_metric = evaluate.load("meteor")
rouge_metric = evaluate.load("rouge")

In [None]:
def compute_self_bleu(captions):
    """Compute Self-BLEU to measure diversity (lower is more diverse)."""
    scores = []
    for i, hyp in enumerate(captions):
        refs = [captions[j] for j in range(len(captions)) if j != i]
        score = bleu_metric.compute(predictions=[hyp], references=[refs], max_order=4)["bleu"]
        scores.append(score)
    return np.mean(scores) if scores else 0.0

In [None]:
def compute_distinct_n(captions, n=2):
    """Compute Distinct-n to measure diversity (higher is more diverse)."""
    ngrams = set()
    total_ngrams = 0
    for caption in captions:
        tokens = word_tokenize(caption.lower())
        for i in range(len(tokens) - n + 1):
            ngram = tuple(tokens[i:i+n])
            ngrams.add(ngram)
            total_ngrams += 1
    return len(ngrams) / total_ngrams if total_ngrams > 0 else 0.0

In [None]:
def compute_metrics(refs, hyps):
    """Compute BLEU-4, METEOR, ROUGE-L, Self-BLEU, Distinct-n (SPICE placeholder)."""
    try:
        # Handle multiple references per image (Flickr8k has 5 captions per image)
        bleu_score = bleu_metric.compute(predictions=hyps, references=refs, max_order=4)["bleu"]
        meteor_score = meteor_metric.compute(predictions=hyps, references=[r[0] for r in refs])["meteor"]
        rouge_score = rouge_metric.compute(predictions=hyps, references=[r[0] for r in refs])["rougeL"]
        self_bleu = compute_self_bleu(hyps)
        distinct_2 = compute_distinct_n(hyps, n=2)
        # SPICE requires external setup; use placeholder (implement if pycocoevalcap is available)
        spice_score = 0.0  # Placeholder
        return {
            "BLEU-4": bleu_score,
            "METEOR": meteor_score,
            "ROUGE-L": rouge_score,
            "Self-BLEU": self_bleu,
            "Distinct-2": distinct_2,
            "SPICE": spice_score
        }
    except Exception as e:
        print(f"Error computing metrics: {e}")
        return {}

In [None]:
# Run automatic evaluation on validation set
def evaluate_dataset(df, output_file="metrics_results.csv"):
    """Evaluate all decoding strategies on the validation set and save results."""
    results = []
    for strat, params in DECODING_STRATEGIES:
        all_refs, all_hyps = [], []
        for _, row in tqdm(df.iterrows(), total=len(df), desc=f"Evaluating {strat}"):
            img_path = os.path.join(IMAGE_DIR, row["image"])
            if not os.path.exists(img_path):
                print(f"Image not found: {img_path}")
                continue
            img = Image.open(img_path).convert("RGB")
            # Flickr8k has 5 captions; use all for BLEU, first for others
            refs = df[df["image"] == row["image"]]["caption"].tolist()
            hyp = generate_caption(img, strat, params)
            all_refs.append(refs)
            all_hyps.append(hyp)
        metrics = compute_metrics(all_refs, all_hyps)
        metrics["strategy"] = strat
        results.append(metrics)
        print(f"Metrics for {strat}: {metrics}")
    
    # Save metrics to CSV
    with open(output_file, "w", newline="") as f:
        fieldnames = ["strategy", "BLEU-4", "METEOR", "ROUGE-L", "Self-BLEU", "Distinct-2", "SPICE"]
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(results)
    print(f"Metrics saved to {output_file}")
    return results

In [None]:
# Qualitative and error analysis
def qualitative_analysis(df, output_csv="qualitative_analysis.csv"):
    """Select N_QUAL samples, generate captions, detect errors, and save to CSV."""
    df_sample = df.sample(N_QUAL, random_state=42).reset_index(drop=True)
    rows = []
    key_elements_count = 0
    hallucination_free_count = 0
    
    for _, row in df_sample.iterrows():
        img_path = os.path.join(IMAGE_DIR, row["image"])
        if not os.path.exists(img_path):
            print(f"Image not found: {img_path}")
            continue
        img = Image.open(img_path).convert("RGB")
        refs = df[df["image"] == row["image"]]["caption"].tolist()
        ref_tokens = set(word_tokenize(" ".join(refs).lower()))
        
        for strat, params in DECODING_STRATEGIES:
            gen = generate_caption(img, strat, params)
            gen_tokens = word_tokenize(gen.lower())
            issue = []
            
            # Hallucination: tokens not in any reference (excluding common words)
            stop_words = set(nltk.corpus.stopwords.words('english'))
            if any(tok not in ref_tokens and tok not in stop_words for tok in gen_tokens):
                issue.append("hallucination")
            else:
                hallucination_free_count += 1
            
            # Repetition: same word appearing multiple times
            word_counts = {tok: gen_tokens.count(tok) for tok in set(gen_tokens)}
            if any(count > 2 for count in word_counts.values()):
                issue.append("repetition")
            
            # Omission: significantly shorter than average reference length
            avg_ref_len = np.mean([len(word_tokenize(r)) for r in refs])
            if len(gen_tokens) < avg_ref_len / 2:
                issue.append("omission")
            
            # Key elements: count unique nouns as proxy (requires NLTK pos_tag)
            pos_tags = nltk.pos_tag(gen_tokens)
            nouns = len([t for t, pos in pos_tags if pos.startswith('NN')])
            if nouns >= 3:
                key_elements_count += 1
            
            rows.append({
                "image": row["image"],
                "strategy": strat,
                "reference": refs[0],
                "generated": gen,
                "issue": "; ".join(issue) if issue else "OK",
                "noun_count": nouns
            })
    
    # Save to CSV
    with open(output_csv, "w", newline="") as f:
        fieldnames = ["image", "strategy", "reference", "generated", "issue", "noun_count"]
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)
    
    # Success criteria check
    total_samples = len(df_sample) * len(DECODING_STRATEGIES)
    print(f"Qualitative analysis saved to {output_csv}")
    print(f"Descriptions with ≥3 key elements: {key_elements_count/total_samples*100:.2f}%")
    print(f"Hallucination-free descriptions: {hallucination_free_count/total_samples*100:.2f}%")
    return rows

In [None]:
# Example usage (assumes val_df is defined as a DataFrame with 'image' and 'caption' columns)
# metrics = evaluate_dataset(val_df)
# qualitative_rows = qualitative_analysis(val_df)