In [1]:
import torch
import json
from torchvision import transforms, models
from PIL import Image
import os
import torch.nn as nn

# Load vocabulary
def load_vocab(vocab_path):
    with open(vocab_path, 'r') as file:
        idx2word = json.load(file)
    word2idx = {v: int(k) for k, v in idx2word.items()}
    return idx2word, word2idx

idx2word_path = '/home/vitoupro/code/image_captioning/data/processed/idx2word.json'
idx2word, word2idx = load_vocab(idx2word_path)

# Image transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Define the Encoder CNN
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        for param in resnet.parameters():
            param.requires_grad = False
        modules = list(resnet.children())[:-1]  # Remove the last layer
        self.resnet = nn.Sequential(*modules)
        self.embed = nn.Linear(resnet.fc.in_features, embed_size)

    def forward(self, images):
        features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.embed(features)
        return features

# Define the Decoder RNN
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

    def forward(self, features, captions):
        embeddings = self.embed(captions)
        h0 = features.unsqueeze(0).repeat(self.num_layers, 1, 1)
        c0 = torch.zeros_like(h0)
        lstm_out, _ = self.lstm(embeddings, (h0, c0))
        outputs = self.linear(lstm_out)
        return outputs

# Load the models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = EncoderCNN(embed_size=512).to(device)
decoder = DecoderRNN(embed_size=256, hidden_size=512, vocab_size=len(word2idx), num_layers=1).to(device)

# Load saved model weights
checkpoint = torch.load('/home/vitoupro/code/image_captioning/notebook/captioning_model_2.pth')
encoder.load_state_dict(checkpoint['encoder_state_dict'])
decoder.load_state_dict(checkpoint['decoder_state_dict'])

# Function to generate caption
def generate_caption(image_path, encoder, decoder, word2idx, idx2word, max_length=50):
    encoder.eval()
    decoder.eval()
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        features = encoder(image)
        captions = torch.tensor([word2idx['<START>']]).unsqueeze(0).to(device)
        result_caption = []
        
        for _ in range(max_length):
            outputs = decoder(features, captions)
            outputs = outputs[:, -1, :]
            _, predicted = outputs.max(1)
            predicted_word = idx2word[str(predicted.item())]
            result_caption.append(predicted_word)
            
            if predicted_word == '<END>':
                break
            
            captions = torch.cat((captions, predicted.unsqueeze(0)), dim=1)
    
    return ' '.join(result_caption)

# Example usage
image_path = '/home/vitoupro/code/image_captioning/data/raw/animals/fly/43daa62bc9.jpg'
caption = generate_caption(image_path, encoder, decoder, word2idx, idx2word)
print("Generated Caption:", caption.replace(" ", ""))




Generated Caption: ត្រីមាស<END>
