In [8]:
# inference4.py
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import json

# Model Definitions (ensure these definitions are exactly as they were during training)
# Model Definitions (EncoderCNN and DecoderRNN)
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        # for param in resnet.parameters():
        #     param.requires_grad = False
        for name, param in resnet.named_parameters():
            if 'layer4' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.embed = nn.Linear(resnet.fc.in_features, embed_size)

    def forward(self, images):
        features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.embed(features)
        return features

class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.embed = nn.Embedding(vocab_size, embed_size)
        # self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=0.3)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.init_h = nn.Linear(hidden_size, hidden_size)  # Initialize LSTM hidden state
        self.init_c = nn.Linear(hidden_size, hidden_size)  # Initialize LSTM cell state

    def forward(self, features, captions, sampling_probability=1.0):
        batch_size, seq_len = captions.size()
        embeddings = self.embed(captions)
    
        h = self.init_h(features).unsqueeze(0).repeat(self.num_layers, 1, 1)
        c = self.init_c(features).unsqueeze(0).repeat(self.num_layers, 1, 1)
    
        inputs = embeddings[:, 0].unsqueeze(1)  # Embed <START>
        outputs = []

        for t in range(1, seq_len):
            lstm_out, (h, c) = self.lstm(inputs, (h, c))
            output = self.linear(lstm_out.squeeze(1))
            outputs.append(output)

        # Decide whether to use teacher forcing or model prediction
            teacher_force = torch.rand(1).item() > sampling_probability
            top1 = output.argmax(1)

            next_input = captions[:, t] if teacher_force else top1
            inputs = self.embed(next_input).unsqueeze(1)

        return torch.stack(outputs, dim=1)
    

# Load vocabulary
def load_vocabulary(path):
    with open(path, 'r') as file:
        idx2word = json.load(file)
    word2idx = {v: int(k) for k, v in idx2word.items()}
    return idx2word, word2idx

idx2word_path = '/home/vitoupro/code/image_captioning/data/processed/idx2word.json'
idx2word, word2idx = load_vocabulary(idx2word_path)

# Initialize device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize models
encoder = EncoderCNN(embed_size=512).to(device)
decoder = DecoderRNN(embed_size=256, hidden_size=512, vocab_size=len(word2idx), num_layers=1).to(device)

# Load the trained model weights
encoder.load_state_dict(torch.load('/home/vitoupro/code/image_captioning/notebook/encoderssp.pth'))
decoder.load_state_dict(torch.load('/home/vitoupro/code/image_captioning/notebook/decoderssp.pth'))

# Define the transformation
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])

# Prediction function
def predict_caption(image_path, encoder, decoder, transform, device, idx2word, word2idx):
    encoder.eval()
    decoder.eval()
    
    # Load and transform the image
    image = Image.open(image_path).convert('RGB')
    if transform:
        image = transform(image)
    image = image.unsqueeze(0).to(device)  # Add batch dimension and transfer to device
    
    # Generate features from the image using the encoder
    features = encoder(image)
    
    # Start the sequence with the <START> token
    predicted_indices = [word2idx['<START>']]
    predictions = []
    
    # Initial input to the LSTM is the <START> token
    input_idx = torch.tensor([predicted_indices[-1]], dtype=torch.long).to(device)
    
    # Initialize the LSTM state
    h, c = None, None
    
    # Generate words until the <END> token is predicted or the max length is reached
    for _ in range(20):  # Assuming max length of 20 for safety
        input_idx = input_idx.unsqueeze(0)  # Add batch dimension for single time-step prediction
        if h is None and c is None:
            # Generate initial hidden states from features
            h = decoder.init_h(features).unsqueeze(0).repeat(decoder.num_layers, 1, 1)
            c = decoder.init_c(features).unsqueeze(0).repeat(decoder.num_layers, 1, 1)
        
        outputs, (h, c) = decoder.lstm(decoder.embed(input_idx), (h, c))
        outputs = decoder.linear(outputs.squeeze(1))
        
        # Get the predicted word index
        predicted_index = outputs.argmax(-1).item()
        predicted_indices.append(predicted_index)
        predictions.append(idx2word[str(predicted_index)])  # Decode to word
        
        # Prepare the next input
        input_idx = torch.tensor([predicted_index], dtype=torch.long).to(device)
        
        # Stop if the <END> token is predicted
        if predicted_index == word2idx['<END>']:
            break
    
    predicted_caption = ' '.join(predictions)  # Join the predicted words
    
    return predicted_caption

# Example usage
image_path = '/home/vitoupro/code/image_captioning/data/2.png'
predicted_caption = predict_caption(image_path, encoder, decoder, transform, device, idx2word, word2idx)
print("Predicted Caption:", predicted_caption.replace(" ", ""))

Predicted Caption: ដំរី<END>
