In [1]:
import torch
from torch import nn
import random
import os
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFont
import json



# Import the model definition 
from model import Captioner


# Load vocabulary files
vocab = json.load(open('vocab.json', 'r')) # Ensure 'vocab.json' is in the same directory or provide correct path
rev_vocab = json.load(open('vocab_rev.json', 'r')) # Ensure 'vocab_rev.json' is in the same directory or provide correct path
VSIZE = len(vocab)


# 1. Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2. Model Hyperparameters (Keep consistent with training)
BATCH_SIZE = 1
W_ESIZE = 124
RNN_HIDDEN = 512
MAX_LENGTH = 80
model_path = "captioner_epoch_10.pth" # Path to your trained model weights

# 3. Model Initialization
model = Captioner(VSIZE, RNN_HIDDEN, BATCH_SIZE, W_ESIZE, 512, vocab, MAX_LENGTH, rev_vocab, device).to(device)

# 4. Load Trained Model Weights
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval() # Set model to evaluation mode

# 5. Image Transformation for Inference
transform_infer = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

# 6. Load Image (choose a random image for demonstration)
image_folder = "D:/ankit/caption_data/train2017" # Replace with your image folder path

Using cache found in C:\Users\Ankit Kumar/.cache\torch\hub\pytorch_vision_v0.10.0


In [2]:

image_file = os.path.join(image_folder, random.choice(os.listdir(image_folder)))
original_image = Image.open(image_file).convert('RGB')
image_tensor = transform_infer(original_image).unsqueeze(0).to(device)

# 7. Generate Caption
with torch.no_grad(): # Disable gradient calculation during inference
    generated_caption_tokens = model.generate_caption(image_tensor)

# 8. Post-process Caption
generated_caption_chars = generated_caption_tokens[1:-1] # Remove <START> and <END> tokens
generated_caption = "".join(generated_caption_chars) # Join characters to form a string

# 9. Display Results (Image and Caption)
print("Generated Caption:", generated_caption)



Generated Caption: A large jetliner flying through a blue sky with a large clouds.


In [3]:
# --- Optional: Annotate and Display Image (PIL drawing) ---
annotated_image = original_image.copy()
draw = ImageDraw.Draw(annotated_image)
font = ImageFont.load_default() # You can specify a font file if needed
text_position = (10, 10) # Adjust position as needed
draw.text(text_position, "Caption: " + generated_caption, font=font, fill=(255, 255, 255))
annotated_image.show(title="Image Captioning")