In [9]:
# inference4.py
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import json

# Model Definitions (ensure these definitions are exactly as they were during training)
# Attention Module
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(encoder_dim + decoder_dim, attention_dim)
        self.v = nn.Linear(attention_dim, 1)

    def forward(self, encoder_out, hidden):
        hidden = hidden.unsqueeze(1).repeat(1, encoder_out.size(1), 1)
        attn_input = torch.cat((encoder_out, hidden), dim=2)
        energy = torch.tanh(self.attn(attn_input))
        attention = self.v(energy).squeeze(2)
        alpha = torch.softmax(attention, dim=1)
        context = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
        return context, alpha

# EncoderCNN with spatial features
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        for name, param in resnet.named_parameters():
            if 'layer4' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
        self.resnet = nn.Sequential(*list(resnet.children())[:-2])
        self.avgpool = nn.AdaptiveAvgPool2d((14, 14))
        self.embed = nn.Linear(2048, embed_size)

    def forward(self, images):
        features = self.resnet(images)
        features = self.avgpool(features)
        features = features.view(features.size(0), 2048, -1).permute(0, 2, 1)
        features = self.embed(features)
        return features

# Decoder with Attention
class DecoderWithAttention(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, attention_dim=256, num_layers=1):
        super(DecoderWithAttention, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.attention = Attention(embed_size, hidden_size, attention_dim)
        self.lstm = nn.LSTM(embed_size + embed_size, hidden_size, num_layers, batch_first=True, dropout=0.3)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.init_h = nn.Linear(embed_size, hidden_size)
        self.init_c = nn.Linear(embed_size, hidden_size)

    def forward(self, encoder_out, captions, sampling_probability=1.0):
        batch_size, seq_len = captions.size()
        embedded = self.embedding(captions)
        h = self.init_h(encoder_out.mean(1)).unsqueeze(0)
        c = self.init_c(encoder_out.mean(1)).unsqueeze(0)

        inputs = embedded[:, 0, :].unsqueeze(1)
        outputs = []

        for t in range(1, seq_len):
            context, _ = self.attention(encoder_out, h[-1])
            lstm_input = torch.cat((inputs.squeeze(1), context), dim=1).unsqueeze(1)
            output, (h, c) = self.lstm(lstm_input, (h, c))
            output = self.linear(output.squeeze(1))
            outputs.append(output)

            teacher_force = torch.rand(1).item() > sampling_probability
            top1 = output.argmax(1)
            inputs = embedded[:, t, :].unsqueeze(1) if teacher_force else self.embedding(top1).unsqueeze(1)

        return torch.stack(outputs, dim=1)
    

# Load vocabulary
def load_vocabulary(path):
    with open(path, 'r') as file:
        idx2word = json.load(file)
    word2idx = {v: int(k) for k, v in idx2word.items()}
    return idx2word, word2idx

idx2word_path = '/home/vitoupro/code/image_captioning/data/processed/idx2word.json'
idx2word, word2idx = load_vocabulary(idx2word_path)

# Initialize device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize models
encoder = EncoderCNN(embed_size=256).to(device)
decoder = DecoderWithAttention(embed_size=256, hidden_size=512, vocab_size=len(word2idx), num_layers=1).to(device)

# Load the trained model weights
encoder.load_state_dict(torch.load('/home/vitoupro/code/image_captioning/notebook/encoderattdec.pth'))
decoder.load_state_dict(torch.load('/home/vitoupro/code/image_captioning/notebook/decoderattdec.pth'))

# Define the transformation
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])

# Prediction function
def predict_caption(image_path, encoder, decoder, transform, device, idx2word, word2idx, max_length=20):
    encoder.eval()
    decoder.eval()

    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    encoder_out = encoder(image)
    h = decoder.init_h(encoder_out.mean(1)).unsqueeze(0)
    c = decoder.init_c(encoder_out.mean(1)).unsqueeze(0)

    input_idx = torch.tensor([word2idx['<START>']], dtype=torch.long).to(device)
    predictions = []

    for _ in range(max_length):
        embedded = decoder.embedding(input_idx).unsqueeze(1)
        context, _ = decoder.attention(encoder_out, h[-1])
        lstm_input = torch.cat((embedded.squeeze(1), context), dim=1).unsqueeze(1)

        output, (h, c) = decoder.lstm(lstm_input, (h, c))
        output = decoder.linear(output.squeeze(1))
        predicted_index = output.argmax(-1).item()

        if predicted_index == word2idx['<END>']:
            break

        predictions.append(idx2word[str(predicted_index)])
        input_idx = torch.tensor([predicted_index], dtype=torch.long).to(device)

    return ''.join(predictions)


# Example usage
image_path = '/home/vitoupro/code/image_captioning/data/2.png'
predicted_caption = predict_caption(image_path, encoder, decoder, transform, device, idx2word, word2idx)
print("Predicted Caption:", predicted_caption.replace(" ", ""))

Predicted Caption: ឆ្កែ
