In [8]:
import os
import torch
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
from utils import caption_image
from model import CNN_to_LSTM
from preprocess import get_loader

In [None]:
import torch
from huggingface_hub import hf_hub_download

# Download model checkpoint from Hugging Face
checkpoint_path = hf_hub_download(
    repo_id="sohumgautam/captioning-cnn-lstm",
    filename="pytorch_model.bin"
)

# Initialize the model
model = CNN_to_LSTM(embed_size=256, hidden_size=512, num_layers=2, vocab_size=5240)
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
model.eval()


In [None]:

def test_specific_images(image_folder, dataset_path, captions_file, device):
    # Define image preprocessing (same as training)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # Get the dataset (just to access the vocabulary)
    _, dataset = get_loader(
        root_dir=dataset_path,
        captions_file=captions_file,
        transform=transform,
        batch_size=1  # Doesn't matter here
    )
    
    # Access the vocabulary from the dataset
    vocab = dataset.vocab
    
    # Process each image in the folder
    for filename in ['boy.png', 'boat.png', 'dog.jpg', 'horse.png', 'biker.jpg', 'man_bench.jpg']:
            # Load and process image
            img_path = os.path.join(image_folder, filename)
            image = Image.open(img_path).convert("RGB")
            image_tensor = transform(image).unsqueeze(0).to(device)
            
            # Generate caption
            generated_caption = caption_image(model, image_tensor, vocab)
            caption_text = " ".join(generated_caption)
            
            # Display results
            print(f"Image: {filename}")
            print(f"Caption: {caption_text}")
            print("-" * 50)
            
            # Optionally save the captioned image
            plt.figure(figsize=(8, 8))
            plt.imshow(image)
            plt.title(caption_text)
            plt.axis('off')
            plt.savefig(os.path.join(image_folder, f"captioned_{filename}"))
            plt.close()



In [None]:
# Parameters
test_image_folder = "test_images"  # Folder with your 4 test images
dataset_path = "data/images/"  # Original training images path
captions_file = "data/text.csv"  # Original captions file
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Run test
test_specific_images(test_image_folder, dataset_path, captions_file, device)