In [1]:
# Image Captioning with Attention-Enhanced LSTM

import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader
from PIL import Image
import pandas as pd
import os
import re
import json
from sklearn.model_selection import train_test_split
import jiwer
import matplotlib.pyplot as plt

# Function to load idx2word and convert it to word2idx
def load_vocabulary(path):
    with open(path, 'r') as file:
        idx2word = json.load(file)
    word2idx = {v: int(k) for k, v in idx2word.items()}
    return idx2word, word2idx

# Load vocabulary
idx2word_path = '/home/vitoupro/code/image_captioning/data/processed/idx2word.json'
idx2word, word2idx = load_vocabulary(idx2word_path)

# Encoding and decoding functions
def encode_khmer_word(word, word2idx):
    indices = []
    for character in word:
        index = word2idx.get(character)
        if index is None:
            return None, f"Character '{character}' not found in vocabulary!"
        indices.append(index)
    return indices, None

def decode_indices(indices, idx2word):
    characters = []
    for index in indices:
        character = idx2word.get(str(index))
        if character is None:
            return None, f"Index '{index}' not found in idx2word!"
        characters.append(character)
    return ''.join(characters), None

# Attention Module
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(encoder_dim + decoder_dim, attention_dim)
        self.v = nn.Linear(attention_dim, 1)

    def forward(self, encoder_out, hidden):
        hidden = hidden.unsqueeze(1).repeat(1, encoder_out.size(1), 1)
        attn_input = torch.cat((encoder_out, hidden), dim=2)
        energy = torch.tanh(self.attn(attn_input))
        attention = self.v(energy).squeeze(2)
        alpha = torch.softmax(attention, dim=1)
        context = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
        return context, alpha

# EncoderCNN with spatial features
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        for name, param in resnet.named_parameters():
            if 'layer4' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
        self.resnet = nn.Sequential(*list(resnet.children())[:-2])
        self.avgpool = nn.AdaptiveAvgPool2d((14, 14))
        self.embed = nn.Linear(2048, embed_size)

    def forward(self, images):
        features = self.resnet(images)
        features = self.avgpool(features)
        features = features.view(features.size(0), 2048, -1).permute(0, 2, 1)
        features = self.embed(features)
        return features

# Decoder with Attention
class DecoderWithAttention(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, attention_dim=256, num_layers=1):
        super(DecoderWithAttention, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.attention = Attention(embed_size, hidden_size, attention_dim)
        self.lstm = nn.LSTM(embed_size + embed_size, hidden_size, num_layers, batch_first=True, dropout=0.3)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.init_h = nn.Linear(embed_size, hidden_size)
        self.init_c = nn.Linear(embed_size, hidden_size)

    def forward(self, encoder_out, captions, sampling_probability=1.0):
        batch_size, seq_len = captions.size()
        embedded = self.embedding(captions)
        h = self.init_h(encoder_out.mean(1)).unsqueeze(0)
        c = self.init_c(encoder_out.mean(1)).unsqueeze(0)

        inputs = embedded[:, 0, :].unsqueeze(1)
        outputs = []

        for t in range(1, seq_len):
            context, _ = self.attention(encoder_out, h[-1])
            lstm_input = torch.cat((inputs.squeeze(1), context), dim=1).unsqueeze(1)
            output, (h, c) = self.lstm(lstm_input, (h, c))
            output = self.linear(output.squeeze(1))
            outputs.append(output)

            teacher_force = torch.rand(1).item() > sampling_probability
            top1 = output.argmax(1)
            inputs = embedded[:, t, :].unsqueeze(1) if teacher_force else self.embedding(top1).unsqueeze(1)

        return torch.stack(outputs, dim=1)

# The rest of the code remains the same...
# You can now instantiate DecoderRNN with attention in your training setup.
# Image Captioning Dataset
class ImageCaptionDataset(torch.utils.data.Dataset):
    def __init__(self, img_labels, img_dir, vocab, transform=None, max_length=50):
        self.img_labels = img_labels
        self.img_dir = img_dir
        self.vocab = vocab
        self.transform = transform
        self.max_length = max_length

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        caption = self.img_labels.iloc[idx, 1]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        indices, error = encode_khmer_word(caption, self.vocab)
        if error:
            print(f"Error encoding caption: {error}")
            indices = [self.vocab['<UNK>']] * self.max_length
        tokens = [self.vocab['<START>']] + indices + [self.vocab['<END>']]
        tokens += [self.vocab['<PAD>']] * (self.max_length - len(tokens))
        return image, torch.tensor(tokens[:self.max_length])

# # Define transformations
# transform = transforms.Compose([
#     transforms.Resize(256),
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
# ]) 

transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])


# Load dataset
annotations_file = '/home/vitoupro/code/image_captioning/data/raw/annotation.txt'
img_dir = '/home/vitoupro/code/image_captioning/data/raw/animals'
all_images = pd.read_csv(annotations_file, delimiter=' ', names=['image', 'caption'])

# Split dataset
train_images, eval_images, train_captions, eval_captions = train_test_split(
    all_images['image'].tolist(), all_images['caption'].tolist(), test_size=0.2, random_state=42
)

train_dataset = ImageCaptionDataset(
    img_labels=pd.DataFrame({'image': train_images, 'caption': train_captions}),
    img_dir=img_dir,
    vocab=word2idx,
    transform=transform,
    max_length=20
)

eval_dataset = ImageCaptionDataset(
    img_labels=pd.DataFrame({'image': eval_images, 'caption': eval_captions}),
    img_dir=img_dir,
    vocab=word2idx,
    transform=transform,
    max_length=20
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=32, shuffle=False)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Unified embed size
embed_size = 256

encoder = EncoderCNN(embed_size=embed_size).to(device)
decoder = DecoderWithAttention(
    embed_size=embed_size,
    hidden_size=512,
    vocab_size=len(word2idx),
    num_layers=1
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=word2idx['<PAD>'])
params = list(decoder.parameters()) + list(encoder.embed.parameters())
optimizer = torch.optim.Adam(params, lr=0.001, weight_decay=1e-5)



def custom_transform(text):
    # Lowercase the text
    text = text.lower()
    # Remove punctuation
    text = re.sub(r'[^\w\s]', '', text)
    # Remove multiple spaces
    text = re.sub(r'\s+', ' ', text).strip()
    # Return as list of words
    return text.split()
    

def calculate_wer(gt, pred, epoch, file_path='metric.txt'):
    # Default empty strings to avoid UnboundLocalError
    content_pred = ""
    content_ground_true = ""

    # Extract prediction (remove after <END>)
    match_pred = re.search(r"^(.*?)<END>", pred)
    if match_pred:
        content_pred = match_pred.group(1)

    # Extract ground truth (between <START> and <END>)
    match_ground_true = re.search(r"<START>(.*?)<END>", gt)
    if match_ground_true:
        content_ground_true = match_ground_true.group(1)

    # Write to log file
    with open(file_path, 'a') as file:
        file.write(f"Epoch {epoch}\n")
        file.write("===========================\n")
        file.write(f"pred: {content_pred}\n")
        file.write(f"true: {content_ground_true}\n")
        file.write("===========================\n")

    # Ensure non-empty values for WER
    content_pred = content_pred.strip() or ""
    content_ground_true = content_ground_true.strip() or ""

    wer_score = jiwer.wer(content_ground_true, content_pred)
    return wer_score



def calculate_cer(gt, pred):
    content_pred = ""
    content_ground_true = ""

    match_pred = re.search(r"^(.*?)<END>", pred)
    if match_pred:
        content_pred = match_pred.group(1)

    match_ground_true = re.search(r"<START>(.*?)<END>", gt)
    if match_ground_true:
        content_ground_true = match_ground_true.group(1)

    content_pred = content_pred.strip() or ""
    content_ground_true = content_ground_true.strip() or ""

    return jiwer.cer(content_ground_true, content_pred)


def evaluate_model(encoder, decoder, dataloader, device, epoch):
    encoder.eval()
    decoder.eval()
    total_wer, total_cer, num_samples = 0, 0, 0
    with torch.no_grad():
        for images, captions in dataloader:
            images, captions = images.to(device), captions.to(device)
            features = encoder(images)
            outputs = decoder(features, captions[:, :-1])
            predicted_captions = outputs.argmax(-1)
            
            for i in range(len(captions)):
                gt_caption = decode_indices(captions[i].tolist(), idx2word)[0]
                pred_caption = decode_indices(predicted_captions[i].tolist(), idx2word)[0]
                
                wer = calculate_wer(gt_caption, pred_caption, epoch)
                cer = calculate_cer(gt_caption, pred_caption)
                total_wer += wer
                total_cer += cer
                num_samples += 1

    avg_wer = total_wer / num_samples if num_samples > 0 else 0
    avg_cer = total_cer / num_samples if num_samples > 0 else 0
    print(f"Average WER: {avg_wer:.2f}, Average CER: {avg_cer:.2f}")
    return avg_wer, avg_cer



In [2]:
# Training Loop with attention
num_epochs = 15
best_wer = float('inf')

train_losses, wer_scores, cer_scores = [], [], []

for epoch in range(num_epochs):
    encoder.train()
    decoder.train()
    total_loss = 0
    sampling_prob = max(0.1, 1.0 - epoch * 0.05)

    for images, captions in train_loader:
        images, captions = images.to(device), captions.to(device)

        features = encoder(images)
        outputs = decoder(features, captions, sampling_probability=sampling_prob)

        loss = criterion(outputs.view(-1, len(word2idx)), captions[:, 1:].reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    train_losses.append(avg_loss)

    print(f'Epoch {epoch+1}: Train Loss: {avg_loss:.4f}, Sampling Prob: {sampling_prob:.2f}')

    avg_wer, avg_cer = evaluate_model(encoder, decoder, eval_loader, device, epoch)
    wer_scores.append(avg_wer)
    cer_scores.append(avg_cer)

Epoch 1: Train Loss: 2.9380, Sampling Prob: 1.00
Average WER: 1.00, Average CER: 0.87
Epoch 2: Train Loss: 2.4834, Sampling Prob: 0.95
Average WER: 1.00, Average CER: 0.77
Epoch 3: Train Loss: 1.9386, Sampling Prob: 0.90
Average WER: 0.86, Average CER: 0.51
Epoch 4: Train Loss: 1.1467, Sampling Prob: 0.85
Average WER: 0.45, Average CER: 0.38
Epoch 5: Train Loss: 0.7784, Sampling Prob: 0.80
Average WER: 0.29, Average CER: 0.31
Epoch 6: Train Loss: 0.5544, Sampling Prob: 0.75
Average WER: 0.22, Average CER: 0.26
Epoch 7: Train Loss: 0.4697, Sampling Prob: 0.70
Average WER: 0.26, Average CER: 0.32
Epoch 8: Train Loss: 0.3930, Sampling Prob: 0.65
Average WER: 0.27, Average CER: 0.28
Epoch 9: Train Loss: 0.3468, Sampling Prob: 0.60
Average WER: 0.24, Average CER: 0.28
Epoch 10: Train Loss: 0.3358, Sampling Prob: 0.55
Average WER: 0.27, Average CER: 0.30
Epoch 11: Train Loss: 0.3185, Sampling Prob: 0.50
Average WER: 0.21, Average CER: 0.25
Epoch 12: Train Loss: 0.2318, Sampling Prob: 0.45
Av

In [4]:
def predict_caption(image_path, encoder, decoder, transform, device, idx2word, word2idx, max_length=20):
    encoder.eval()
    decoder.eval()

    # Load and transform the image
    image = Image.open(image_path).convert('RGB')
    if transform:
        image = transform(image)
    image = image.unsqueeze(0).to(device)  # (1, 3, 224, 224)

    # Extract spatial features from encoder: (1, 196, embed_size)
    encoder_out = encoder(image)

    # Initialize LSTM hidden and cell state from mean-pooled encoder output
    h = decoder.init_h(encoder_out.mean(1)).unsqueeze(0)
    c = decoder.init_c(encoder_out.mean(1)).unsqueeze(0)

    input_idx = torch.tensor([word2idx['<START>']], dtype=torch.long).to(device)
    predictions = []

    for _ in range(max_length):
        embedded = decoder.embedding(input_idx).unsqueeze(1)  # (1, 1, embed_size)

        # Apply attention to get context vector
        context, _ = decoder.attention(encoder_out, h[-1])  # context: (1, embed_size)

        # Concatenate embedded word and context
        lstm_input = torch.cat((embedded.squeeze(1), context), dim=1).unsqueeze(1)  # (1, 1, 2*embed_size)

        # LSTM forward
        output, (h, c) = decoder.lstm(lstm_input, (h, c))
        output = decoder.linear(output.squeeze(1))  # (1, vocab_size)

        # Predict next token
        predicted_index = output.argmax(-1).item()
        if predicted_index == word2idx['<END>']:
            break

        predictions.append(idx2word[str(predicted_index)])
        input_idx = torch.tensor([predicted_index], dtype=torch.long).to(device)

    predicted_caption = ''.join(predictions)  # Khmer: character-based
    return predicted_caption

image_path = '/home/vitoupro/code/image_captioning/data/00000001_020.jpg'
caption = predict_caption(image_path, encoder, decoder, transform, device, idx2word, word2idx)
print("Predicted Caption:", caption)


Predicted Caption: ហាមស្ទ័រ


In [10]:
if avg_wer < best_wer:
    best_wer = avg_wer
    torch.save(encoder.state_dict(), "encoderattdec.pth")
    torch.save(decoder.state_dict(), "decoderattdec.pth")
    print("✅ Saved best model!")

In [11]:
# Save the model state dictionary
torch.save({
    'encoder_state_dict': encoder.state_dict(),
    'decoder_state_dict': decoder.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, 'captioning_model_attdec.pth')
