In [1]:
import torch
from PIL import Image
import torchvision.transforms as transforms
import pickle
import sys
import os
from model import EncoderCNN, DecoderRNN

In [2]:
def load_models(checkpoint_dir, vocab_path, embed_size=256, hidden_size=512, device="cuda"):
    """Load trained encoder, decoder, and vocabulary"""
    encoder_path = os.path.join(checkpoint_dir, "best_encoder.pth")
    decoder_path = os.path.join(checkpoint_dir, "best_decoder.pth")


    if not os.path.exists(vocab_path):
        raise FileNotFoundError(f"❌ Vocabulary file not found: {vocab_path}")
    if not os.path.exists(encoder_path):
        raise FileNotFoundError(f"❌ Encoder checkpoint not found: {encoder_path}")
    if not os.path.exists(decoder_path):
        raise FileNotFoundError(f"❌ Decoder checkpoint not found: {decoder_path}")


    print(f"📚 Loading vocabulary from: {vocab_path}")
    with open(vocab_path, "rb") as f:
        vocab = pickle.load(f)


    print(f"🧠 Vocabulary size: {len(vocab)}")
    print(f"🏗️ Initializing models...")


    encoder = EncoderCNN(embed_size).to(device)
    decoder = DecoderRNN(embed_size, hidden_size, len(vocab)).to(device)


    print(f"⚡ Loading model weights...")
    encoder.load_state_dict(torch.load(encoder_path, map_location=device))
    # Load decoder state dict, ignoring unexpected keys
    state_dict = torch.load(decoder_path, map_location=device)
    state_dict = {k: v for k, v in state_dict.items() if k in decoder.state_dict()}
    decoder.load_state_dict(state_dict)

    encoder.eval()
    decoder.eval()


    print(f"✅ Models loaded successfully!")
    return encoder, decoder, vocab

In [3]:
def preprocess_image(image_path):
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"❌ Image file not found: {image_path}")


    transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225])
    ])


    try:
        image = Image.open(image_path).convert("RGB")
        processed = transform(image).unsqueeze(0) # add batch dimension
        print(f"🖼️ Image preprocessed: {image.size} -> {processed.shape}")
        return processed
    except Exception as e:
        raise RuntimeError(f"❌ Error processing image {image_path}: {e}")

In [4]:
def generate_caption(image_path, encoder, decoder, vocab, device="cuda", max_length=50, use_beam=True):
    image = preprocess_image(image_path).to(device)


    with torch.no_grad():
        features = encoder(image)
        print(f"🔍 Image features shape: {features.shape}")


    if use_beam and hasattr(decoder, 'sample'):
        caption = decoder.sample(features, vocab, max_len=max_length)[0] # take first in batch
    else:
        caption = decoder.greedy_sample(features, vocab, max_len=max_length)[0]


    return caption

In [None]:
def main():
    print("🚀 Image Caption Generator")
    print("=" * 50)


    image_path = "/home/sahil_duwal/Projects/ImageCap/flickr8k/images/3744832122_2f4febdff6.jpg"
    checkpoint_dir = "checkpoints"
    vocab_path = os.path.join(checkpoint_dir, "vocab.pkl")


    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"🖥️ Using device: {device}")


    try:
        encoder, decoder, vocab = load_models(checkpoint_dir, vocab_path, device=device)
        print(f"\n🔄 Generating caption for: {image_path}")
        caption = generate_caption(image_path, encoder, decoder, vocab, device=device, max_length=50)


        print("\n" + "="*50)
        print(f"🖼️ Image: {os.path.basename(image_path)}")
        print(f"📝 Generated Caption: {caption}")
        print("="*50)


    except FileNotFoundError as e:
        print(e)
        print("\n💡 Make sure you have:")
        print(" 1. Trained your model (run train.py)")
        print(" 2. Correct checkpoint directory path")
        print(" 3. Valid image file path")
        sys.exit(1)
    except Exception as e:
        print(f"❌ Error during inference: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)

In [6]:
if __name__ == "__main__":
    main()

🚀 Image Caption Generator
🖥️ Using device: cuda
📚 Loading vocabulary from: checkpoints/vocab.pkl
🧠 Vocabulary size: 2994
🏗️ Initializing models...




⚡ Loading model weights...
✅ Models loaded successfully!

🔄 Generating caption for: /home/sahil_duwal/Projects/ImageCap/flickr8k/images/10815824_2997e03d76.jpg
🖼️ Image preprocessed: (500, 333) -> torch.Size([1, 3, 224, 224])
🔍 Image features shape: torch.Size([1, 256])

🖼️ Image: 10815824_2997e03d76.jpg
📝 Generated Caption: many people sit on a bench .
