In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader
from PIL import Image
import pandas as pd
import os
import re
import json
from sklearn.model_selection import train_test_split
import jiwer
import matplotlib.pyplot as plt

# Function to load idx2word and convert it to word2idx
def load_vocabulary(path):
    with open(path, 'r') as file:
        idx2word = json.load(file)
    word2idx = {v: int(k) for k, v in idx2word.items()}
    return idx2word, word2idx

# Load vocabulary
idx2word_path = '/home/vitoupro/code/image_captioning/data/processed/idx2word.json'
idx2word, word2idx = load_vocabulary(idx2word_path)

# Encoding and decoding functions
def encode_khmer_word(word, word2idx):
    indices = []
    for character in word:
        index = word2idx.get(character)
        if index is None:
            return None, f"Character '{character}' not found in vocabulary!"
        indices.append(index)
    return indices, None

def decode_indices(indices, idx2word):
    characters = []
    for index in indices:
        character = idx2word.get(str(index))
        if character is None:
            return None, f"Index '{index}' not found in idx2word!"
        characters.append(character)
    return ''.join(characters), None

# Model Definitions (EncoderCNN and DecoderRNN)
class EncoderCNN(nn.Module):
    def __init__(self, embed_size):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        for param in resnet.parameters():
            param.requires_grad = False
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.embed = nn.Linear(resnet.fc.in_features, embed_size)

    def forward(self, images):
        features = self.resnet(images)
        features = features.reshape(features.size(0), -1)
        features = self.embed(features)
        return features

class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.init_h = nn.Linear(hidden_size, hidden_size)  # Initialize LSTM hidden state
        self.init_c = nn.Linear(hidden_size, hidden_size)  # Initialize LSTM cell state

    def forward(self, features, captions):
        embeddings = self.embed(captions)
        h0 = self.init_h(features).unsqueeze(0).repeat(self.num_layers, 1, 1)
        c0 = self.init_c(features).unsqueeze(0).repeat(self.num_layers, 1, 1)
        lstm_out, _ = self.lstm(embeddings, (h0, c0))
        outputs = self.linear(lstm_out)
        return outputs

# Image Captioning Dataset
class ImageCaptionDataset(torch.utils.data.Dataset):
    def __init__(self, img_labels, img_dir, vocab, transform=None, max_length=50):
        self.img_labels = img_labels
        self.img_dir = img_dir
        self.vocab = vocab
        self.transform = transform
        self.max_length = max_length

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        caption = self.img_labels.iloc[idx, 1]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        indices, error = encode_khmer_word(caption, self.vocab)
        if error:
            print(f"Error encoding caption: {error}")
            indices = [self.vocab['<UNK>']] * self.max_length
        tokens = [self.vocab['<START>']] + indices + [self.vocab['<END>']]
        tokens += [self.vocab['<PAD>']] * (self.max_length - len(tokens))
        return image, torch.tensor(tokens[:self.max_length])

# Define transformations
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
]) 

# Load dataset
annotations_file = '/home/vitoupro/code/image_captioning/data/raw/annotation.txt'
img_dir = '/home/vitoupro/code/image_captioning/data/raw/animals'
all_images = pd.read_csv(annotations_file, delimiter=' ', names=['image', 'caption'])

# Split dataset
train_images, eval_images, train_captions, eval_captions = train_test_split(
    all_images['image'].tolist(), all_images['caption'].tolist(), test_size=0.2, random_state=42
)

train_dataset = ImageCaptionDataset(
    img_labels=pd.DataFrame({'image': train_images, 'caption': train_captions}),
    img_dir=img_dir,
    vocab=word2idx,
    transform=transform,
    max_length=20
)

eval_dataset = ImageCaptionDataset(
    img_labels=pd.DataFrame({'image': eval_images, 'caption': eval_captions}),
    img_dir=img_dir,
    vocab=word2idx,
    transform=transform,
    max_length=20
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=32, shuffle=False)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize models
encoder = EncoderCNN(embed_size=512).to(device)
decoder = DecoderRNN(embed_size=256, hidden_size=512, vocab_size=len(word2idx), num_layers=1).to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss(ignore_index=word2idx['<PAD>'])
params = list(decoder.parameters()) + list(encoder.embed.parameters())
optimizer = torch.optim.Adam(params, lr=0.001)

def custom_transform(text):
    # Lowercase the text
    text = text.lower()
    # Remove punctuation
    text = re.sub(r'[^\w\s]', '', text)
    # Remove multiple spaces
    text = re.sub(r'\s+', ' ', text).strip()
    # Return as list of words
    return text.split()
    

def calculate_wer(gt, pred, epoch, file_path='metric.txt'):
    
    with open(file_path, 'a') as file:  # Open file in append mode
        file.write(f"Epoch {epoch}\n")
        
        file.write("===========================\n")
        match_pred = re.search(r"^(.*?)<END>", pred)
        if match_pred:
            content_pred = match_pred.group(1)
        file.write(f"pred: {content_pred}\n")
        match_ground_true = re.search(r"<START>(.*?)<END>", gt)
        if match_ground_true:
            content_ground_true = match_ground_true.group(1)
        file.write(f"true: {content_ground_true}\n")
        file.write("===========================\n")
    
    if not content_ground_true:  # Ensure non-empty
        content_ground_true = ['']
    if not content_pred:  # Ensure non-empty
        content_pred = ['']
    wer_score = jiwer.wer(content_ground_true, content_pred)
    
    return wer_score

def calculate_cer(gt, pred):
    match_pred = re.search(r"^(.*?)<END>", pred)
    if match_pred:
        content_pred = match_pred.group(1)
    
    match_ground_true = re.search(r"<START>(.*?)<END>", gt)
    if match_ground_true:
        content_ground_true = match_ground_true.group(1)

    return jiwer.cer(content_ground_true, content_pred)

def evaluate_model(encoder, decoder, dataloader, device, epoch):
    encoder.eval()
    decoder.eval()
    total_wer, total_cer, num_samples = 0, 0, 0
    with torch.no_grad():
        for images, captions in dataloader:
            images, captions = images.to(device), captions.to(device)
            features = encoder(images)
            outputs = decoder(features, captions[:, :-1])
            predicted_captions = outputs.argmax(-1)
            
            for i in range(len(captions)):
                gt_caption = decode_indices(captions[i].tolist(), idx2word)[0]
                pred_caption = decode_indices(predicted_captions[i].tolist(), idx2word)[0]
                
                wer = calculate_wer(gt_caption, pred_caption, epoch)
                cer = calculate_cer(gt_caption, pred_caption)
                total_wer += wer
                total_cer += cer
                num_samples += 1

    avg_wer = total_wer / num_samples if num_samples > 0 else 0
    avg_cer = total_cer / num_samples if num_samples > 0 else 0
    print(f"Average WER: {avg_wer:.2f}, Average CER: {avg_cer:.2f}")
    return avg_wer, avg_cer




In [15]:
# Training Loop
num_epochs = 15
best_wer = float('inf')

for epoch in range(num_epochs):
    encoder.train()
    decoder.train()
    total_loss = 0
    for images, captions in train_loader:
        images, captions = images.to(device), captions.to(device)
        features = encoder(images)
        outputs = decoder(features, captions[:, :-1])
        loss = criterion(outputs.view(-1, len(word2idx)), captions[:, 1:].reshape(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f'Epoch {epoch+1}: Train Loss: {total_loss/len(train_loader)}')
    _, wer = evaluate_model(encoder, decoder, eval_loader, device, epoch)

    if wer < best_wer:
        best_wer = wer


Epoch 1: Train Loss: 1.7804002950831157
Average WER: 0.97, Average CER: 0.25
Epoch 2: Train Loss: 0.5050743563873011
Average WER: 0.83, Average CER: 0.20
Epoch 3: Train Loss: 0.38889224136747963
Average WER: 0.57, Average CER: 0.15
Epoch 4: Train Loss: 0.26699337290554515
Average WER: 0.32, Average CER: 0.09
Epoch 5: Train Loss: 0.1642110865654015
Average WER: 0.22, Average CER: 0.06
Epoch 6: Train Loss: 0.08937570180107908
Average WER: 0.16, Average CER: 0.04
Epoch 7: Train Loss: 0.055989182258887986
Average WER: 0.12, Average CER: 0.03
Epoch 8: Train Loss: 0.048371379013832025
Average WER: 0.11, Average CER: 0.03
Epoch 9: Train Loss: 0.03615249984148072
Average WER: 0.10, Average CER: 0.03
Epoch 10: Train Loss: 0.03486498088644045
Average WER: 0.10, Average CER: 0.03
Epoch 11: Train Loss: 0.02760678636500748
Average WER: 0.12, Average CER: 0.03
Epoch 12: Train Loss: 0.024037751562257365
Average WER: 0.07, Average CER: 0.02
Epoch 13: Train Loss: 0.018246573611821342
Average WER: 0.10,

In [20]:
def predict_caption(image_path, encoder, decoder, transform, device, idx2word, word2idx):
    encoder.eval()
    decoder.eval()
    
    # Load and transform the image
    image = Image.open(image_path).convert('RGB')
    if transform:
        image = transform(image)
    image = image.unsqueeze(0).to(device)  # Add batch dimension and transfer to device
    
    # Generate features from the image using the encoder
    features = encoder(image)
    
    # Start the sequence with the <START> token
    predicted_indices = [word2idx['<START>']]
    predictions = []
    
    # Initial input to the LSTM is the <START> token
    input_idx = torch.tensor([predicted_indices[-1]], dtype=torch.long).to(device)
    
    # Initialize the LSTM state
    h, c = None, None
    
    # Generate words until the <END> token is predicted or the max length is reached
    for _ in range(20):  # Assuming max length of 20 for safety
        input_idx = input_idx.unsqueeze(0)  # Add batch dimension for single time-step prediction
        if h is None and c is None:
            # Generate initial hidden states from features
            h = decoder.init_h(features).unsqueeze(0).repeat(decoder.num_layers, 1, 1)
            c = decoder.init_c(features).unsqueeze(0).repeat(decoder.num_layers, 1, 1)
        
        outputs, (h, c) = decoder.lstm(decoder.embed(input_idx), (h, c))
        outputs = decoder.linear(outputs.squeeze(1))
        
        # Get the predicted word index
        predicted_index = outputs.argmax(-1).item()
        predicted_indices.append(predicted_index)
        predictions.append(idx2word[str(predicted_index)])  # Decode to word
        
        # Prepare the next input
        input_idx = torch.tensor([predicted_index], dtype=torch.long).to(device)
        
        # Stop if the <END> token is predicted
        if predicted_index == word2idx['<END>']:
            break
    
    predicted_caption = ' '.join(predictions)  # Join the predicted words
    
    return predicted_caption

# Example usage
image_path = '/home/vitoupro/code/image_captioning/data/raw/animals/dog/0be3797d3d.jpg'
predicted_caption = predict_caption(image_path, encoder, decoder, transform, device, idx2word, word2idx)
print("Predicted Caption:", predicted_caption.replace(" ", ""))


Predicted Caption: ឆ្កែ<END>


In [22]:
# Save the encoder and decoder models
torch.save(encoder.state_dict(), 'encoder_v4.pth')
torch.save(decoder.state_dict(), 'decoder_v4.pth')

In [9]:
# Save the model state dictionary
torch.save({
    'encoder_state_dict': encoder.state_dict(),
    'decoder_state_dict': decoder.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, 'captioning_model_3z.pth')
