In [1]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader
from PIL import Image
import pandas as pd
import os
import re
import json
from sklearn.model_selection import train_test_split
import jiwer
import matplotlib.pyplot as plt

# Function to load idx2word and convert it to word2idx
def load_vocabulary(path):
    with open(path, 'r') as file:
        idx2word = json.load(file)
    word2idx = {v: int(k) for k, v in idx2word.items()}
    return idx2word, word2idx

# Load vocabulary
idx2word_path = '/home/vitoupro/code/image_captioning/data/processed/idx2word.json'
idx2word, word2idx = load_vocabulary(idx2word_path)

# Encoding and decoding functions
def encode_khmer_word(word, word2idx):
    indices = []
    for character in word:
        index = word2idx.get(character)
        if index is None:
            return None, f"Character '{character}' not found in vocabulary!"
        indices.append(index)
    return indices, None

def decode_indices(indices, idx2word):
    characters = []
    for index in indices:
        character = idx2word.get(str(index))
        if character is None:
            return None, f"Index '{index}' not found in idx2word!"
        characters.append(character)
    return ''.join(characters), None

# Model Definitions (EncoderCNN and DecoderRNN)
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        # for param in resnet.parameters():
        #     param.requires_grad = False
        for name, param in resnet.named_parameters():
            if 'layer4' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.embed = nn.Linear(resnet.fc.in_features, embed_size)

    def forward(self, images):
        features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.embed(features)
        return features

class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.embed = nn.Embedding(vocab_size, embed_size)
        # self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True, dropout=0.3)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.init_h = nn.Linear(hidden_size, hidden_size)  # Initialize LSTM hidden state
        self.init_c = nn.Linear(hidden_size, hidden_size)  # Initialize LSTM cell state

    def forward(self, features, captions, sampling_probability=1.0):
        batch_size, seq_len = captions.size()
        embeddings = self.embed(captions)
    
        h = self.init_h(features).unsqueeze(0).repeat(self.num_layers, 1, 1)
        c = self.init_c(features).unsqueeze(0).repeat(self.num_layers, 1, 1)
    
        inputs = embeddings[:, 0].unsqueeze(1)  # Embed <START>
        outputs = []

        for t in range(1, seq_len):
            lstm_out, (h, c) = self.lstm(inputs, (h, c))
            output = self.linear(lstm_out.squeeze(1))
            outputs.append(output)

        # Decide whether to use teacher forcing or model prediction
            teacher_force = torch.rand(1).item() > sampling_probability
            top1 = output.argmax(1)

            next_input = captions[:, t] if teacher_force else top1
            inputs = self.embed(next_input).unsqueeze(1)

        return torch.stack(outputs, dim=1)


# Image Captioning Dataset
class ImageCaptionDataset(torch.utils.data.Dataset):
    def __init__(self, img_labels, img_dir, vocab, transform=None, max_length=50):
        self.img_labels = img_labels
        self.img_dir = img_dir
        self.vocab = vocab
        self.transform = transform
        self.max_length = max_length

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        caption = self.img_labels.iloc[idx, 1]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        indices, error = encode_khmer_word(caption, self.vocab)
        if error:
            print(f"Error encoding caption: {error}")
            indices = [self.vocab['<UNK>']] * self.max_length
        tokens = [self.vocab['<START>']] + indices + [self.vocab['<END>']]
        tokens += [self.vocab['<PAD>']] * (self.max_length - len(tokens))
        return image, torch.tensor(tokens[:self.max_length])

# # Define transformations
# transform = transforms.Compose([
#     transforms.Resize(256),
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
# ]) 

transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])


# Load dataset
annotations_file = '/home/vitoupro/code/image_captioning/data/raw/annotation.txt'
img_dir = '/home/vitoupro/code/image_captioning/data/raw/animals'
all_images = pd.read_csv(annotations_file, delimiter=' ', names=['image', 'caption'])

# Split dataset
train_images, eval_images, train_captions, eval_captions = train_test_split(
    all_images['image'].tolist(), all_images['caption'].tolist(), test_size=0.2, random_state=42
)

train_dataset = ImageCaptionDataset(
    img_labels=pd.DataFrame({'image': train_images, 'caption': train_captions}),
    img_dir=img_dir,
    vocab=word2idx,
    transform=transform,
    max_length=20
)

eval_dataset = ImageCaptionDataset(
    img_labels=pd.DataFrame({'image': eval_images, 'caption': eval_captions}),
    img_dir=img_dir,
    vocab=word2idx,
    transform=transform,
    max_length=20
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=32, shuffle=False)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize models
encoder = EncoderCNN(embed_size=512).to(device)
decoder = DecoderRNN(embed_size=256, hidden_size=512, vocab_size=len(word2idx), num_layers=1).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=word2idx['<PAD>'])
params = list(decoder.parameters()) + list(encoder.embed.parameters())
# optimizer = torch.optim.Adam(params, lr=0.001)
optimizer = torch.optim.Adam(params, lr=0.001, weight_decay=1e-5)


def custom_transform(text):
    # Lowercase the text
    text = text.lower()
    # Remove punctuation
    text = re.sub(r'[^\w\s]', '', text)
    # Remove multiple spaces
    text = re.sub(r'\s+', ' ', text).strip()
    # Return as list of words
    return text.split()
    

def calculate_wer(gt, pred, epoch, file_path='metric.txt'):
    
    with open(file_path, 'a') as file:  # Open file in append mode
        file.write(f"Epoch {epoch}\n")
        
        file.write("===========================\n")
        match_pred = re.search(r"^(.*?)<END>", pred)
        if match_pred:
            content_pred = match_pred.group(1)
        file.write(f"pred: {content_pred}\n")
        match_ground_true = re.search(r"<START>(.*?)<END>", gt)
        if match_ground_true:
            content_ground_true = match_ground_true.group(1)
        file.write(f"true: {content_ground_true}\n")
        file.write("===========================\n")
    
    if not content_ground_true:  # Ensure non-empty
        content_ground_true = ['']
    if not content_pred:  # Ensure non-empty
        content_pred = ['']
    wer_score = jiwer.wer(content_ground_true, content_pred)
    
    return wer_score

def calculate_cer(gt, pred):
    match_pred = re.search(r"^(.*?)<END>", pred)
    if match_pred:
        content_pred = match_pred.group(1)
    
    match_ground_true = re.search(r"<START>(.*?)<END>", gt)
    if match_ground_true:
        content_ground_true = match_ground_true.group(1)

    return jiwer.cer(content_ground_true, content_pred)

def evaluate_model(encoder, decoder, dataloader, device, epoch):
    encoder.eval()
    decoder.eval()
    total_wer, total_cer, num_samples = 0, 0, 0
    with torch.no_grad():
        for images, captions in dataloader:
            images, captions = images.to(device), captions.to(device)
            features = encoder(images)
            outputs = decoder(features, captions[:, :-1])
            predicted_captions = outputs.argmax(-1)
            
            for i in range(len(captions)):
                gt_caption = decode_indices(captions[i].tolist(), idx2word)[0]
                pred_caption = decode_indices(predicted_captions[i].tolist(), idx2word)[0]
                
                wer = calculate_wer(gt_caption, pred_caption, epoch)
                cer = calculate_cer(gt_caption, pred_caption)
                total_wer += wer
                total_cer += cer
                num_samples += 1

    avg_wer = total_wer / num_samples if num_samples > 0 else 0
    avg_cer = total_cer / num_samples if num_samples > 0 else 0
    print(f"Average WER: {avg_wer:.2f}, Average CER: {avg_cer:.2f}")
    return avg_wer, avg_cer




In [2]:
# Training Loop
num_epochs = 15
best_wer = float('inf')

for epoch in range(num_epochs):
    encoder.train()
    decoder.train()
    total_loss = 0
    sampling_prob = max(0.1, 1.0 - epoch * 0.05)

    for images, captions in train_loader:
        images, captions = images.to(device), captions.to(device)
        features = encoder(images)

        outputs = decoder(features, captions, sampling_probability=sampling_prob)

        loss = criterion(outputs.view(-1, len(word2idx)), captions[:, 1:].reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f'Epoch {epoch+1}: Train Loss: {total_loss/len(train_loader):.4f}, Sampling Prob: {sampling_prob:.2f}')
    _, wer = evaluate_model(encoder, decoder, eval_loader, device, epoch)

    if wer < best_wer:
        best_wer = wer



Epoch 1: Train Loss: 2.8530, Sampling Prob: 1.00
Average WER: 1.00, Average CER: 1.00
Epoch 2: Train Loss: 2.4235, Sampling Prob: 0.95
Average WER: 1.00, Average CER: 0.87
Epoch 3: Train Loss: 2.2942, Sampling Prob: 0.90
Average WER: 1.00, Average CER: 1.06
Epoch 4: Train Loss: 2.0832, Sampling Prob: 0.85
Average WER: 1.00, Average CER: 1.00
Epoch 5: Train Loss: 1.8994, Sampling Prob: 0.80
Average WER: 0.86, Average CER: 0.88
Epoch 6: Train Loss: 1.2342, Sampling Prob: 0.75
Average WER: 0.56, Average CER: 0.77
Epoch 7: Train Loss: 0.8641, Sampling Prob: 0.70
Average WER: 0.40, Average CER: 0.53
Epoch 8: Train Loss: 0.6688, Sampling Prob: 0.65
Average WER: 0.37, Average CER: 0.50
Epoch 9: Train Loss: 0.5361, Sampling Prob: 0.60
Average WER: 0.30, Average CER: 0.34
Epoch 10: Train Loss: 0.5329, Sampling Prob: 0.55
Average WER: 0.32, Average CER: 0.40
Epoch 11: Train Loss: 0.3871, Sampling Prob: 0.50
Average WER: 0.30, Average CER: 0.42
Epoch 12: Train Loss: 0.3462, Sampling Prob: 0.45
Av

In [14]:
if wer < best_wer:
    best_wer = wer
    torch.save(encoder.state_dict(), "encoder_best.pth")
    torch.save(decoder.state_dict(), "decoder_best.pth")
    print("✅ Saved best model!")


In [36]:
def predict_caption(image_path, encoder, decoder, transform, device, idx2word, word2idx, max_length=20):
    encoder.eval()
    decoder.eval()

    # Load and transform the image
    image = Image.open(image_path).convert('RGB')
    if transform:
        image = transform(image)
    image = image.unsqueeze(0).to(device)  # Add batch dimension

    # Encode the image
    features = encoder(image)

    # Start generation with <START> token
    input_idx = torch.tensor([word2idx['<START>']], dtype=torch.long).to(device)
    predictions = []
    h, c = None, None

    for _ in range(max_length):
        input_idx = input_idx.unsqueeze(0)  # (1, 1)

        # Initialize hidden state on first step
        if h is None and c is None:
            h = decoder.init_h(features).unsqueeze(0).repeat(decoder.num_layers, 1, 1)
            c = decoder.init_c(features).unsqueeze(0).repeat(decoder.num_layers, 1, 1)

        embedded = decoder.embed(input_idx)
        output, (h, c) = decoder.lstm(embedded, (h, c))
        output = decoder.linear(output.squeeze(1))

        predicted_index = output.argmax(-1).item()
        if predicted_index == word2idx['<END>']:
            break

        predictions.append(idx2word[str(predicted_index)])
        input_idx = torch.tensor([predicted_index], dtype=torch.long).to(device)

    predicted_caption = ''.join(predictions)  # Khmer: no need for space
    return predicted_caption

image_path = '/home/vitoupro/code/image_captioning/data/processed/image.png'
caption = predict_caption(image_path, encoder, decoder, transform, device, idx2word, word2idx)
print("Predicted Caption:", caption)



Predicted Caption: ខ្លាឃ្មុំ
