In [10]:
# inference4.py
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import json

# Model Definitions (ensure these definitions are exactly as they were during training)
# Attention Module
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(encoder_dim + decoder_dim, attention_dim)
        self.v = nn.Linear(attention_dim, 1)

    def forward(self, encoder_out, hidden):
        hidden = hidden.unsqueeze(1).repeat(1, encoder_out.size(1), 1)
        attn_input = torch.cat((encoder_out, hidden), dim=2)
        energy = torch.tanh(self.attn(attn_input))
        attention = self.v(energy).squeeze(2)
        alpha = torch.softmax(attention, dim=1)
        context = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
        return context, alpha

class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        for name, param in resnet.named_parameters():
            if 'layer4' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])  # output shape: (B, 2048, 1, 1)
        self.linear = nn.Linear(2048, embed_size)

    def forward(self, images):
        features = self.resnet(images)             # (B, 2048, 1, 1)
        features = features.view(features.size(0), -1)  # (B, 2048) — safe reshape
        features = self.linear(features)                # (B, embed_size)
        return features.unsqueeze(1)                    # (B, 1, embed_size)
           # (B, 1, embed_size)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :].to(x.device)


class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_heads=8, num_layers=3, ff_dim=512, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.pos_encoding = PositionalEncoding(embed_size)
        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_size, nhead=num_heads, dim_feedforward=ff_dim, dropout=dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.embed_size = embed_size

    def generate_square_subsequent_mask(self, sz):
        return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

    def forward(self, features, captions):
        tgt = self.embedding(captions)  # (B, T, E)
        tgt = self.pos_encoding(tgt)

        tgt = tgt.permute(1, 0, 2)      # (T, B, E)
        memory = features.permute(1, 0, 2)  # (1, B, E)

        tgt_mask = self.generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)

        out = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask)
        out = self.fc_out(out)  # (T, B, vocab_size)
        return out.permute(1, 0, 2)  # (B, T, vocab_size)
    

# Load vocabulary
def load_vocabulary(path):
    with open(path, 'r') as file:
        idx2word = json.load(file)
    word2idx = {v: int(k) for k, v in idx2word.items()}
    return idx2word, word2idx

idx2word_path = '/home/vitoupro/code/image_captioning/data/processed/idx2word.json'
idx2word, word2idx = load_vocabulary(idx2word_path)

# Initialize device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize models
encoder = EncoderCNN(embed_size=256).to(device)
decoder = TransformerDecoder(    vocab_size=len(word2idx),
    embed_size=256,
    num_heads=8,
    num_layers=3).to(device)

# Load the trained model weights
encoder.load_state_dict(torch.load('/home/vitoupro/code/image_captioning/notebook/encodertransf.pth'))
decoder.load_state_dict(torch.load('/home/vitoupro/code/image_captioning/notebook/decodertransf.pth'))

# Define the transformation
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])

# Prediction function
def predict_caption(image_path, encoder, decoder, transform, device, idx2word, word2idx, max_length=20):
    encoder.eval()
    decoder.eval()

    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)

    memory = encoder(image)  # (1, 1, embed_size)
    input_indices = [word2idx['<START>']]

    for _ in range(max_length):
        tgt = torch.tensor([input_indices], dtype=torch.long).to(device)
        tgt_embed = decoder.embedding(tgt)
        tgt_embed = decoder.pos_encoding(tgt_embed)
        tgt_embed = tgt_embed.permute(1, 0, 2)  # (T, 1, E)
        memory = memory.permute(1, 0, 2)        # (1, 1, E)

        tgt_mask = decoder.generate_square_subsequent_mask(tgt_embed.size(0)).to(device)
        output = decoder.transformer_decoder(tgt_embed, memory, tgt_mask=tgt_mask)
        output = decoder.fc_out(output[-1])  # (1, vocab_size)

        predicted_index = output.argmax(-1).item()
        if predicted_index == word2idx['<END>']:
            break

        input_indices.append(predicted_index)

    predicted_tokens = [idx2word[str(idx)] for idx in input_indices[1:]]  # skip <START>
    return ''.join(predicted_tokens)

# Example usage
image_path = '/home/vitoupro/code/image_captioning/data/2.png'
predicted_caption = predict_caption(image_path, encoder, decoder, transform, device, idx2word, word2idx)
print("Predicted Caption:", predicted_caption.replace(" ", ""))

Predicted Caption: សត្វតោ
