# Importing Libraries 

In [None]:
import os
import torch
import torch.nn as nn
from torchvision import transforms
from transformers import T5Tokenizer, T5ForConditionalGeneration
import timm
from torch.utils.data import Dataset, DataLoader

In [None]:
import cv2
import torch
import numpy as np
import random 
random.seed(1337)

import warnings 
warnings.filterwarnings('ignore')

import torchvision.transforms.functional as TF

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

import torch.nn as nn
import torch.optim as optim

from PIL import Image

import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity



from transformers import Blip2Processor, Blip2ForConditionalGeneration

# Defining paths and variables

In [None]:
image_folder = "image_folder"  
caption_file = "captions.txt"      
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Resizing image

In [None]:
# resize so longer side = 224, then pad shorter side 
def resize_with_pad(image):
    # Resize so the longer side == size
    size = 224
    w, h = image.size
    scale = size / max(w, h)
    new_w, new_h = int(w * scale), int(h * scale)
    
    resized = TF.resize(image, (new_h, new_w), interpolation=transforms.InterpolationMode.BICUBIC)
    
    # Pad to (size, size)
    pad_w = size - new_w
    pad_h = size - new_h
    padding = (pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2)  # (left, top, right, bottom)
    padded = TF.pad(resized, padding, fill=0)  # Fill with black
    return padded

# Loading SmolLM2 Model (Language Model)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("Intel/smollm2")
language_model = AutoModelForSeq2SeqLM.from_pretrained("Intel/smollm2").to(device)
for p in language_model.parameters():
    p.requires_grad = False

# Loading ViT Vision Transformer (Encoder)

In [None]:
vision_encoder = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=0)
vision_encoder.eval()
for p in vision_encoder.parameters():
    p.requires_grad = False
vision_encoder = vision_encoder.to(device)

# Setting up Q-Former (Visual Encoder)

In [None]:
query_tokens = nn.Parameter(torch.randn(1, 32, 768)).to(device)
qformer_blocks = nn.ModuleList([
    nn.TransformerEncoderLayer(d_model=768, nhead=8, batch_first=True)
    for _ in range(6)
]).to(device)
qformer_proj = nn.Linear(768, language_model.config.d_model).to(device)

optimizer = optim.AdamW(list(qformer_blocks.parameters()) + list(qformer_proj.parameters()) + [query_tokens], lr=1e-4)

# Loading Training Data

In [None]:
def parse_qa_file(path):
    samples = []
    with open(path, "r") as f:
        lines = f.readlines()
    img_id = None
    questions = []
    answers = []
    for line in lines:
        line = line.strip()
        if line.endswith("."):
            if img_id is not None:
                for i in range(len(questions)):
                    samples.append((img_id, questions[i], answers[i]))
            img_id = line.replace(".", "").strip()
            questions = []
            answers = []
        elif line.startswith("Q:"):
            questions.append(line[2:].strip())
        elif line.startswith("A:"):
            answers.append(line[2:].strip())
    if img_id is not None and questions and answers:
        for i in range(len(questions)):
            samples.append((img_id, questions[i], answers[i]))
    return samples

samples = parse_qa_file(caption_file)

In [None]:
def preprocess_sample(img_id, question, answer):
    img_path = os.path.join(image_folder, f"{img_id}.jpg")
    image = Image.open(img_path).convert("RGB")
    image = resize_with_pad(image)

    q_enc = tokenizer(question, return_tensors="pt", padding="max_length", truncation=True, max_length=64)
    a_enc = tokenizer(answer, return_tensors="pt", padding="max_length", truncation=True, max_length=64)

    return image, q_enc.input_ids[0], q_enc.attention_mask[0], a_enc.input_ids[0]