In [3]:
import torch
from torchvision import transforms
from PIL import Image
import pandas as pd
from utils.vocabulary import build_vocabulary, load_vocabulary
from models.model import ImageCaptioningModel
import json

class CaptionInference:
    def __init__(self, model_path, vocabulary_path, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device

        # Load and create vocabulary
        self.vocab = load_vocabulary(vocabulary_path)
        
        # Initialize model
        self.embed_size = 256  # Make sure these match your training parameters
        self.hidden_size = 512
        self.num_layers = 2
        self.vocab_size = len(self.vocab)
        
        # Create and load model
        self.model = ImageCaptioningModel(
            embed_size=self.embed_size,
            hidden_size=self.hidden_size,
            vocab_size=self.vocab_size,
            num_layers=self.num_layers 
        ).to(self.device)
        
        # Load model weights
        checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()
        
        # Image preprocessing
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
    
    def idx_to_word(self, idx):
        for word, index in self.vocab.items():
            if index == idx:
                return word
        return self.vocab['<UNK>']
    
    def generate_caption(self, image_path, max_length=50):
        # Load and preprocess image
        image = Image.open(image_path).convert('RGB')
        image = self.transform(image).unsqueeze(0).to(self.device)
        
        # Generate caption
        with torch.no_grad():
            caption_indices = self.model.generate_caption(
                image,
                self.vocab,
                max_length=max_length
            )
        
        # Convert indices to words
        caption_words = []
        for idx in caption_indices:
            word = self.idx_to_word(idx)
            if word == '<END>':
                break
            if word not in ['<START>', '<PAD>', '<UNK>']:
                caption_words.append(word)
        
        return ' '.join(caption_words)

def test_inference():
    # Initialize inference
    inference = CaptionInference(
        model_path=r'checkpoints\model_epoch_10.pth',
        vocabulary_path=r'data\vocabulary'
    )
    
    # Print some debug information
    print(f"Vocabulary size: {len(inference.vocab)}")
    print(f"Device being used: {inference.device}")
    print(f"Special tokens in vocabulary:")
    for token in ['<START>', '<END>', '<PAD>', '<UNK>']:
        print(f"{token}: {inference.vocab.get(token, 'Not found')}")
    
    # Try generating a caption
    try:
        image_path = r'data\Flicker8k_Dataset\667626_18933d713e.jpg'
        caption = inference.generate_caption(image_path)
        print(f"\nSuccessfully generated caption: {caption}")
    except Exception as e:
        print(f"\nError during caption generation: {str(e)}")
        raise

if __name__ == '__main__':
    test_inference()

Vocabulary size: 3106
Device being used: cuda
Special tokens in vocabulary:
<START>: 1
<END>: 2
<PAD>: 0
<UNK>: 3

Successfully generated caption: four be in of
