In [1]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader
from PIL import Image
import pandas as pd
import os
import re
import json
from sklearn.model_selection import train_test_split
import jiwer
import matplotlib.pyplot as plt

# Function to load idx2word and convert it to word2idx
def load_vocabulary(path):
    with open(path, 'r') as file:
        idx2word = json.load(file)
    word2idx = {v: int(k) for k, v in idx2word.items()}
    return idx2word, word2idx

# Load vocabulary
idx2word_path = '/home/vitoupro/code/image_captioning/data/processed/idx2word.json'
idx2word, word2idx = load_vocabulary(idx2word_path)

# Encoding and decoding functions
def encode_khmer_word(word, word2idx):
    indices = []
    for character in word:
        index = word2idx.get(character)
        if index is None:
            return None, f"Character '{character}' not found in vocabulary!"
        indices.append(index)
    return indices, None

def decode_indices(indices, idx2word):
    characters = []
    for index in indices:
        character = idx2word.get(str(index))
        if character is None:
            return None, f"Index '{index}' not found in idx2word!"
        characters.append(character)
    return ''.join(characters), None

# Attention Module
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(encoder_dim + decoder_dim, attention_dim)
        self.v = nn.Linear(attention_dim, 1)

    def forward(self, encoder_out, hidden):
        hidden = hidden.unsqueeze(1).repeat(1, encoder_out.size(1), 1)
        attn_input = torch.cat((encoder_out, hidden), dim=2)
        energy = torch.tanh(self.attn(attn_input))
        attention = self.v(energy).squeeze(2)
        alpha = torch.softmax(attention, dim=1)
        context = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
        return context, alpha

class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        for name, param in resnet.named_parameters():
            if 'layer4' in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])  # output shape: (B, 2048, 1, 1)
        self.linear = nn.Linear(2048, embed_size)

    def forward(self, images):
        features = self.resnet(images)             # (B, 2048, 1, 1)
        features = features.view(features.size(0), -1)  # (B, 2048) — safe reshape
        features = self.linear(features)                # (B, embed_size)
        return features.unsqueeze(1)                    # (B, 1, embed_size)
           # (B, 1, embed_size)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :].to(x.device)


class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_heads=8, num_layers=3, ff_dim=512, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.pos_encoding = PositionalEncoding(embed_size)
        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_size, nhead=num_heads, dim_feedforward=ff_dim, dropout=dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.embed_size = embed_size

    def generate_square_subsequent_mask(self, sz):
        return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

    def forward(self, features, captions):
        tgt = self.embedding(captions)  # (B, T, E)
        tgt = self.pos_encoding(tgt)

        tgt = tgt.permute(1, 0, 2)      # (T, B, E)
        memory = features.permute(1, 0, 2)  # (1, B, E)

        tgt_mask = self.generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)

        out = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask)
        out = self.fc_out(out)  # (T, B, vocab_size)
        return out.permute(1, 0, 2)  # (B, T, vocab_size)


# The rest of the code remains the same...
# You can now instantiate DecoderRNN with attention in your training setup.
# Image Captioning Dataset
class ImageCaptionDataset(torch.utils.data.Dataset):
    def __init__(self, img_labels, img_dir, vocab, transform=None, max_length=50):
        self.img_labels = img_labels
        self.img_dir = img_dir
        self.vocab = vocab
        self.transform = transform
        self.max_length = max_length

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        caption = self.img_labels.iloc[idx, 1]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        indices, error = encode_khmer_word(caption, self.vocab)
        if error:
            print(f"Error encoding caption: {error}")
            indices = [self.vocab['<UNK>']] * self.max_length
        tokens = [self.vocab['<START>']] + indices + [self.vocab['<END>']]
        tokens += [self.vocab['<PAD>']] * (self.max_length - len(tokens))
        return image, torch.tensor(tokens[:self.max_length])

# # Define transformations
# transform = transforms.Compose([
#     transforms.Resize(256),
#     transforms.CenterCrop(224),
#     transforms.ToTensor(),
# ]) 

transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])


# Load dataset
annotations_file = '/home/vitoupro/code/image_captioning/data/raw/annotation.txt'
img_dir = '/home/vitoupro/code/image_captioning/data/raw/animals'
all_images = pd.read_csv(annotations_file, delimiter=' ', names=['image', 'caption'])

# Split dataset
train_images, eval_images, train_captions, eval_captions = train_test_split(
    all_images['image'].tolist(), all_images['caption'].tolist(), test_size=0.2, random_state=42
)

train_dataset = ImageCaptionDataset(
    img_labels=pd.DataFrame({'image': train_images, 'caption': train_captions}),
    img_dir=img_dir,
    vocab=word2idx,
    transform=transform,
    max_length=20
)

eval_dataset = ImageCaptionDataset(
    img_labels=pd.DataFrame({'image': eval_images, 'caption': eval_captions}),
    img_dir=img_dir,
    vocab=word2idx,
    transform=transform,
    max_length=20
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=32, shuffle=False)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Unified embed size
embed_size = 256

encoder = EncoderCNN(embed_size=embed_size).to(device)
decoder = TransformerDecoder(
    vocab_size=len(word2idx),
    embed_size=embed_size,
    num_heads=8,
    num_layers=3
).to(device)


criterion = nn.CrossEntropyLoss(ignore_index=word2idx['<PAD>'])
params = list(decoder.parameters()) + list(encoder.parameters())  # include all trainable encoder parts
optimizer = torch.optim.Adam(params, lr=0.001, weight_decay=1e-5)



def custom_transform(text):
    # Lowercase the text
    text = text.lower()
    # Remove punctuation
    text = re.sub(r'[^\w\s]', '', text)
    # Remove multiple spaces
    text = re.sub(r'\s+', ' ', text).strip()
    # Return as list of words
    return text.split()
    

def calculate_wer(gt, pred, epoch, file_path='metric.txt'):
    # Default empty strings to avoid UnboundLocalError
    content_pred = ""
    content_ground_true = ""

    # Extract prediction (remove after <END>)
    match_pred = re.search(r"^(.*?)<END>", pred)
    if match_pred:
        content_pred = match_pred.group(1)

    # Extract ground truth (between <START> and <END>)
    match_ground_true = re.search(r"<START>(.*?)<END>", gt)
    if match_ground_true:
        content_ground_true = match_ground_true.group(1)

    # Write to log file
    with open(file_path, 'a') as file:
        file.write(f"Epoch {epoch}\n")
        file.write("===========================\n")
        file.write(f"pred: {content_pred}\n")
        file.write(f"true: {content_ground_true}\n")
        file.write("===========================\n")

    # Ensure non-empty values for WER
    content_pred = content_pred.strip() or ""
    content_ground_true = content_ground_true.strip() or ""

    wer_score = jiwer.wer(content_ground_true, content_pred)
    return wer_score



def calculate_cer(gt, pred):
    content_pred = ""
    content_ground_true = ""

    match_pred = re.search(r"^(.*?)<END>", pred)
    if match_pred:
        content_pred = match_pred.group(1)

    match_ground_true = re.search(r"<START>(.*?)<END>", gt)
    if match_ground_true:
        content_ground_true = match_ground_true.group(1)

    content_pred = content_pred.strip() or ""
    content_ground_true = content_ground_true.strip() or ""

    return jiwer.cer(content_ground_true, content_pred)


def evaluate_model(encoder, decoder, dataloader, device, epoch):
    encoder.eval()
    decoder.eval()
    total_wer, total_cer, num_samples = 0, 0, 0
    with torch.no_grad():
        for images, captions in dataloader:
            images, captions = images.to(device), captions.to(device)
            features = encoder(images)
            outputs = decoder(features, captions[:, :-1])
            predicted_captions = outputs.argmax(-1)
            
            for i in range(len(captions)):
                gt_caption = decode_indices(captions[i].tolist(), idx2word)[0]
                pred_caption = decode_indices(predicted_captions[i].tolist(), idx2word)[0]
                
                wer = calculate_wer(gt_caption, pred_caption, epoch)
                cer = calculate_cer(gt_caption, pred_caption)
                total_wer += wer
                total_cer += cer
                num_samples += 1

    avg_wer = total_wer / num_samples if num_samples > 0 else 0
    avg_cer = total_cer / num_samples if num_samples > 0 else 0
    print(f"Average WER: {avg_wer:.2f}, Average CER: {avg_cer:.2f}")
    return avg_wer, avg_cer



In [None]:
num_epochs = 15
best_wer = float('inf')

train_losses, wer_scores, cer_scores = [], [], []

for epoch in range(num_epochs):
    encoder.train()
    decoder.train()
    total_loss = 0

    for images, captions in train_loader:
        images, captions = images.to(device), captions.to(device)

        features = encoder(images)                           # (B, 1, embed_size)
        outputs = decoder(features, captions[:, :-1])  # ✅ just 2 arguments: (image_features, input_captions)
       # (B, T, vocab_size)
        loss = criterion(outputs.reshape(-1, len(word2idx)), captions[:, 1:].reshape(-1))


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    train_losses.append(avg_loss)

    print(f"Epoch {epoch+1}: Train Loss: {avg_loss:.4f}")

    avg_wer, avg_cer = evaluate_model(encoder, decoder, eval_loader, device, epoch)
    wer_scores.append(avg_wer)
    cer_scores.append(avg_cer)


Epoch 1: Train Loss: 1.7655
Average WER: 0.88, Average CER: 0.32
Epoch 2: Train Loss: 0.4657
Average WER: 0.66, Average CER: 0.18
Epoch 3: Train Loss: 0.3439
Average WER: 0.50, Average CER: 0.13
Epoch 4: Train Loss: 0.2834
Average WER: 0.50, Average CER: 0.13
Epoch 5: Train Loss: 0.2044
Average WER: 0.43, Average CER: 0.13
Epoch 6: Train Loss: 0.2094
Average WER: 0.38, Average CER: 0.10
Epoch 7: Train Loss: 0.1578
Average WER: 0.36, Average CER: 0.09
Epoch 8: Train Loss: 0.1738
Average WER: 0.43, Average CER: 0.12
Epoch 9: Train Loss: 0.1470
Average WER: 0.31, Average CER: 0.09
Epoch 10: Train Loss: 0.1461
Average WER: 0.32, Average CER: 0.09
Epoch 11: Train Loss: 0.1327
Average WER: 0.31, Average CER: 0.08
Epoch 12: Train Loss: 0.1176
Average WER: 0.26, Average CER: 0.07
Epoch 13: Train Loss: 0.1074
Average WER: 0.30, Average CER: 0.07
Epoch 14: Train Loss: 0.1060
Average WER: 0.40, Average CER: 0.12
Epoch 15: Train Loss: 0.1088
Average WER: 0.29, Average CER: 0.08


In [3]:

def predict_caption(image_path, encoder, decoder, transform, device, idx2word, word2idx, max_length=20):
    encoder.eval()
    decoder.eval()

    # Load and transform image
    image = Image.open(image_path).convert('RGB')
    if transform:
        image = transform(image)
    image = image.unsqueeze(0).to(device)

    # Encode image: (1, 1, embed_size)
    encoder_out = encoder(image)

    input_indices = [word2idx['<START>']]
    for _ in range(max_length):
        tgt = torch.tensor([input_indices], dtype=torch.long).to(device)  # shape: (1, T)
        tgt_embed = decoder.embedding(tgt)
        tgt_embed = decoder.pos_encoding(tgt_embed)  # (1, T, E)
        tgt_embed = tgt_embed.permute(1, 0, 2)        # (T, 1, E)
        memory = encoder_out.permute(1, 0, 2)         # (1, 1, E)
        tgt_mask = decoder.generate_square_subsequent_mask(tgt_embed.size(0)).to(device)

        output = decoder.transformer_decoder(tgt_embed, memory, tgt_mask=tgt_mask)
        output = decoder.fc_out(output[-1])  # (1, vocab_size) → last time step

        predicted_index = output.argmax(-1).item()
        if predicted_index == word2idx['<END>']:
            break

        input_indices.append(predicted_index)

    predicted_tokens = [idx2word[str(idx)] for idx in input_indices[1:]]  # skip <START>
    predicted_caption = ''.join(predicted_tokens)
    return predicted_caption


image_path = '/home/vitoupro/code/image_captioning/data/image.png'
caption = predict_caption(image_path, encoder, decoder, transform, device, idx2word, word2idx)
print("Predicted Caption:", caption)


Predicted Caption: សេះ


In [4]:
if avg_wer < best_wer:
    best_wer = avg_wer
    torch.save(encoder.state_dict(), "encodertransf.pth")
    torch.save(decoder.state_dict(), "decodertransf.pth")
    print("✅ Saved best model!")

✅ Saved best model!


In [None]:
# Save the model state dictionary
torch.save({
    'encoder_state_dict': encoder.state_dict(),
    'decoder_state_dict': decoder.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, 'captioning_model_transf.pth')
