In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from tqdm import tqdm
from decoder import TransformerDecoder
import transformers
from transformers import CLIPModel, CLIPTokenizer, CLIPProcessor
from datasets import load_dataset, load_from_disk
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import random
from sklearn.model_selection import train_test_split


# === Set up environment ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Load CLIP & Tokenizer ===
CLIP = transformers.CLIPModel.from_pretrained('openai/clip-vit-base-patch32').to(device)
tokenizer = transformers.CLIPTokenizer.from_pretrained('openai/clip-vit-base-patch32')

CLIP.eval()
for param in CLIP.parameters():
    param.requires_grad = False

vocab = tokenizer.get_vocab()

token_embedding = CLIP.text_model.embeddings.token_embedding.to(device)

# === Load token embedding ===
token_embedding.weight.requires_grad = False

flickr = load_dataset("flickr30k_dataset/")


In [None]:

torch.manual_seed(42)
random.seed(42)


# === Instantiate model ===
model = TransformerDecoder().to(device)

# === Load weights ===
checkpoint = torch.load("decoder_9.pth", map_location=device)
model.load_state_dict(checkpoint)

def get_test_image():
    test_image = flickr['test'][5]['image']
    # test_cap = flickr['test'][0]['caption'][0]
    plt.imshow(test_image)

    # test_cap = tokenizer(test_cap,return_tensors="pt", padding="max_length", truncation=True)['input_ids']

    # test_cap = token_embedding(test_cap)

    # print (test_cap.shape)


    preprocess = T.Compose([
        T.Resize((224, 224)),
        T.ToTensor(),
        T.Normalize(
            mean=[0.4815, 0.4578, 0.4082],
            std=[0.2686, 0.2613, 0.2758]
        )
    ])


    CLIP.eval()

    with torch.no_grad():
        vision_outputs = CLIP.vision_model(preprocess(test_image).unsqueeze(0).to(device))
        patch_embeddings = vision_outputs.last_hidden_state[:, 1:, :]

    # print (patch_embeddings.shape)

    img = patch_embeddings

    return (img)

# === Inference Function ===

def inference(model, start_token=49406, end_token=49407, max_len=77, device='cuda'):
    model.eval()
     # Shape: (1, 4, 196)
    
    with torch.no_grad():

        # Start with <start> token
        generated = [start_token]
        img = get_test_image()
        for _ in range(max_len):
            y = torch.tensor(generated, dtype=torch.long, device=device).unsqueeze(0)  # (1, seq_len)

            y_emb = token_embedding(y).to(device)
            # print (y_emb.shape)
            # print (img.shape)
            
            logits = model(img, y_emb)# (1, seq_len, vocab_size)
            next_token_logits = logits[0, -1]  # (vocab_size,)
            next_token = torch.argmax(next_token_logits).item()
            #print (next_token)
            
            generated.append(next_token)
            
            if next_token == end_token:
                break

    return generated[1:]


# === Run Inference ===
if __name__ == "__main__":
    image_path = "example.jpg"  # replace with your test image
    image = Image.open(image_path).convert("RGB")
    caption = inference(model)

    print("\n📷 Generated caption:", tokenizer.decode(caption))