In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader 
from PIL import Image
import pandas as pd
import os
import re
import json
from sklearn.model_selection import train_test_split
import jiwer
import matplotlib.pyplot as plt
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from timm import create_model

# Function to load idx2word and convert it to word2idx
def load_vocabulary(path):
    with open(path, 'r') as file:
        idx2word = json.load(file)
    word2idx = {v: int(k) for k, v in idx2word.items()}
    return idx2word, word2idx

# Load vocabulary
idx2word_path = '/home/vitoupro/code/image_captioning/data/processed/idx2word_level_khmercut_420.json'
idx2word, word2idx = load_vocabulary(idx2word_path)

# Encoding a list of words (word-level)
def encode_khmer_sentence(sentence, word2idx):
    words = sentence.strip().split()
    indices = []
    for word in words:
        index = word2idx.get(word)
        if index is None:
            return None, f"Word '{word}' not found in vocabulary!"
        indices.append(index)
    return indices, None

# Decoding a list of indices (word-level)
def decode_indices(indices, idx2word):
    words = []
    for index in indices:
        word = idx2word.get(str(index))
        if word is None:
            return None, f"Index '{index}' not found in idx2word!"
        words.append(word)
    return ' '.join(words), None

class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(encoder_dim + decoder_dim, attention_dim)
        self.v = nn.Linear(attention_dim, 1)

    def forward(self, encoder_out, hidden):
        hidden = hidden.unsqueeze(1).repeat(1, encoder_out.size(1), 1)
        attn_input = torch.cat((encoder_out, hidden), dim=2)
        energy = torch.tanh(self.attn(attn_input))
        attention = self.v(energy).squeeze(2)
        alpha = torch.softmax(attention, dim=1)
        context = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
        return context, alpha


# Data-efficient Image Transformer (DeiT) Encoder
class DeiTEncoder(nn.Module):
    def __init__(self, embed_size):
        super(DeiTEncoder, self).__init__()
        import timm
        self.deit = timm.create_model('deit_tiny_patch16_224', pretrained=True)
        self.deit.head = nn.Identity()  # Remove classification head
        self.linear = nn.Linear(self.deit.num_features, embed_size)  # Project to embed_size

    def forward(self, images):
        features = self.deit.forward_features(images)  # (B, seq_len, 192)
        cls_token = features[:, 0, :]  # Take only [CLS] token: (B, 192)
        out = self.linear(cls_token)  # Project to (B, embed_size)
        return out.unsqueeze(1)  # (B, 1, embed_size)
       


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :].to(x.device)

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_heads=8, num_layers=3, ff_dim=512, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.pos_encoding = PositionalEncoding(embed_size)
        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_size, nhead=num_heads, dim_feedforward=ff_dim, dropout=dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(embed_size, vocab_size)

    def generate_square_subsequent_mask(self, sz):
        return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

    def forward(self, features, captions):
        tgt = self.embedding(captions)
        tgt = self.pos_encoding(tgt)
        tgt = tgt.permute(1, 0, 2)
        memory = features.permute(1, 0, 2)
        tgt_mask = self.generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)
        out = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask)
        out = self.fc_out(out)
        return out.permute(1, 0, 2)

class ImageCaptionDataset(torch.utils.data.Dataset):
    def __init__(self, img_labels, img_dir, vocab, transform=None, max_length=50):
        self.img_labels = img_labels
        self.img_dir = img_dir
        self.vocab = vocab
        self.transform = transform
        self.max_length = max_length

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        caption = self.img_labels.iloc[idx, 1]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        indices, error = encode_khmer_sentence(caption, self.vocab)
        if error:
            print(f"Error encoding caption: {error}")
            indices = [self.vocab['<UNK>']] * self.max_length
        tokens = [self.vocab['<START>']] + indices + [self.vocab['<END>']]
        tokens += [self.vocab['<PAD>']] * (self.max_length - len(tokens))
        return image, torch.tensor(tokens[:self.max_length])

transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])

annotations_file = '/home/vitoupro/code/image_captioning/data/processed/word_segmented_imglabel_khmercut.txt'
img_dir = '/home/vitoupro/code/image_captioning/data/raw/img'
image_names, captions = [], []

with open(annotations_file, 'r', encoding='utf-8') as f:
    for line in f:
        parts = line.strip().split(' ', 1)
        if len(parts) == 2:
            image_name, caption = parts
            image_names.append(image_name)
            captions.append(caption)

all_images = pd.DataFrame({'image': image_names, 'caption': captions})
train_images, eval_images, train_captions, eval_captions = train_test_split(
    all_images['image'].tolist(), all_images['caption'].tolist(), test_size=0.2, random_state=42
)

train_dataset = ImageCaptionDataset(pd.DataFrame({'image': train_images, 'caption': train_captions}), img_dir, word2idx, transform, 20)
eval_dataset = ImageCaptionDataset(pd.DataFrame({'image': eval_images, 'caption': eval_captions}), img_dir, word2idx, transform, 20)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=32, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
embed_size = 256
encoder = DeiTEncoder(embed_size=embed_size).to(device)
decoder = TransformerDecoder(len(word2idx), embed_size, num_heads=8, num_layers=3).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=word2idx['<PAD>'])
params = list(decoder.parameters()) + list(encoder.parameters())
optimizer = torch.optim.Adam(params, lr=0.001, weight_decay=1e-5)

smoothing = SmoothingFunction().method1
rouge = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=False)

def evaluate_model(encoder, decoder, dataloader, device, epoch):
    encoder.eval()
    decoder.eval()
    total_wer, total_cer, total_bleu1, total_bleu2, total_bleu4, total_rougeL = 0, 0, 0, 0, 0, 0
    num_samples = 0

    with torch.no_grad():
        for images, captions in dataloader:
            images, captions = images.to(device), captions.to(device)
            features = encoder(images)
            outputs = decoder(features, captions[:, :-1])
            predicted_captions = outputs.argmax(-1)

            for i in range(len(captions)):
                gt_caption = decode_indices(captions[i].tolist(), idx2word)[0]
                pred_caption = decode_indices(predicted_captions[i].tolist(), idx2word)[0]

                ref_match = re.search(r"<START>(.*?)<END>", gt_caption)
                pred_match = re.search(r"^(.*?)<END>", pred_caption)
                reference = ref_match.group(1).strip() if ref_match else ""
                prediction = pred_match.group(1).strip() if pred_match else ""
                if not reference or not prediction:
                    continue

                ref_tokens = reference.split()
                pred_tokens = prediction.split()

                total_wer += jiwer.wer(reference, prediction)
                total_cer += jiwer.cer(reference, prediction)
                total_bleu1 += sentence_bleu([ref_tokens], pred_tokens, weights=(1, 0, 0, 0), smoothing_function=smoothing)
                total_bleu2 += sentence_bleu([ref_tokens], pred_tokens, weights=(0.5, 0.5, 0, 0), smoothing_function=smoothing)
                total_bleu4 += sentence_bleu([ref_tokens], pred_tokens, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothing)
                total_rougeL += rouge.score(reference, prediction)['rougeL'].fmeasure
                num_samples += 1

    avg_wer = total_wer / num_samples if num_samples > 0 else 0
    avg_cer = total_cer / num_samples if num_samples > 0 else 0
    avg_bleu1 = total_bleu1 / num_samples if num_samples > 0 else 0
    avg_bleu2 = total_bleu2 / num_samples if num_samples > 0 else 0
    avg_bleu4 = total_bleu4 / num_samples if num_samples > 0 else 0
    avg_rougeL = total_rougeL / num_samples if num_samples > 0 else 0

    print(f"WER: {avg_wer:.2f}, CER: {avg_cer:.2f}, BLEU-1: {avg_bleu1:.2f}, BLEU-2: {avg_bleu2:.2f}, BLEU-4: {avg_bleu4:.2f}, ROUGE-L: {avg_rougeL:.2f}")
    return avg_wer, avg_cer, avg_bleu1, avg_bleu2, avg_bleu4, avg_rougeL


In [12]:
num_epochs = 15
best_wer = float('inf')
teacher_forcing_ratio = 0.9  # Start with 90% teacher forcing

train_losses, wer_scores, cer_scores = [], [], []
bleu1_scores, bleu2_scores, bleu4_scores, rougeL_scores = [], [], [], []

for epoch in range(num_epochs):
    encoder.train()
    decoder.train()
    total_loss = 0

    for images, captions in train_loader:
        images, captions = images.to(device), captions.to(device)

        features = encoder(images)
        
        input_tokens = captions[:, :-1]
        targets = captions[:, 1:]

        if torch.rand(1).item() > teacher_forcing_ratio:
            # Scheduled sampling: use own prediction as input
            outputs = []
            batch_size = images.size(0)
            inputs = torch.full((batch_size, 1), word2idx['<START>'], dtype=torch.long, device=device)

            for _ in range(input_tokens.size(1)):
                output = decoder(features, inputs)
                last_output = output[:, -1, :]  # last timestep
                predicted = last_output.argmax(-1).unsqueeze(1)
                inputs = torch.cat((inputs, predicted), dim=1)
                outputs.append(last_output)

            outputs = torch.stack(outputs, dim=1).squeeze(2)  # (B, T, vocab)
        else:
            # Teacher forcing: normal
            outputs = decoder(features, input_tokens)

        batch_size, seq_len, vocab_size = outputs.size()
        loss = criterion(
            outputs.reshape(batch_size * seq_len, vocab_size),
            targets.reshape(-1)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    train_losses.append(avg_loss)

    print(f"Epoch {epoch+1}: Train Loss: {avg_loss:.4f}")

    avg_wer, avg_cer, avg_bleu1, avg_bleu2, avg_bleu4, avg_rougeL = evaluate_model(encoder, decoder, eval_loader, device, epoch)

    wer_scores.append(avg_wer)
    cer_scores.append(avg_cer)

    # Reduce teacher forcing slowly
    teacher_forcing_ratio = max(0.3, teacher_forcing_ratio * 0.9)


Epoch 1: Train Loss: 4.8700
WER: 0.91, CER: 0.78, BLEU-1: 0.11, BLEU-2: 0.03, BLEU-4: 0.02, ROUGE-L: 0.00
Epoch 2: Train Loss: 3.9889
WER: 0.87, CER: 0.73, BLEU-1: 0.24, BLEU-2: 0.12, BLEU-4: 0.06, ROUGE-L: 0.00
Epoch 3: Train Loss: 4.3184
WER: 0.82, CER: 0.72, BLEU-1: 0.20, BLEU-2: 0.08, BLEU-4: 0.04, ROUGE-L: 0.00
Epoch 4: Train Loss: 3.8762
WER: 0.82, CER: 0.70, BLEU-1: 0.25, BLEU-2: 0.12, BLEU-4: 0.05, ROUGE-L: 0.00
Epoch 5: Train Loss: 3.6982
WER: 0.80, CER: 0.67, BLEU-1: 0.25, BLEU-2: 0.13, BLEU-4: 0.06, ROUGE-L: 0.00
Epoch 6: Train Loss: 3.6491
WER: 0.90, CER: 0.73, BLEU-1: 0.28, BLEU-2: 0.15, BLEU-4: 0.06, ROUGE-L: 0.00
Epoch 7: Train Loss: 3.6855
WER: 0.89, CER: 0.72, BLEU-1: 0.27, BLEU-2: 0.13, BLEU-4: 0.06, ROUGE-L: 0.00
Epoch 8: Train Loss: 3.6568
WER: 0.75, CER: 0.61, BLEU-1: 0.33, BLEU-2: 0.18, BLEU-4: 0.08, ROUGE-L: 0.00
Epoch 9: Train Loss: 3.5825
WER: 0.78, CER: 0.65, BLEU-1: 0.30, BLEU-2: 0.16, BLEU-4: 0.07, ROUGE-L: 0.00
Epoch 10: Train Loss: 3.5989
WER: 0.94, CER: 0

In [14]:
@torch.no_grad()
def predict_caption(image_path, encoder, decoder, transform, device, idx2word, word2idx, max_length=20):
    encoder.eval()
    decoder.eval()

    # Load and transform image
    image = Image.open(image_path).convert('RGB')
    if transform:
        image = transform(image)
    image = image.unsqueeze(0).to(device)  # shape (1, 3, H, W)

    # Encode image
    encoder_out = encoder(image)  # (1, 1, embed_size)

    # Start generating
    generated_indices = [word2idx['<START>']]
    
    for _ in range(max_length):
        input_tensor = torch.tensor([generated_indices], dtype=torch.long).to(device)  # (1, T)
        
        # Predict next token
        output = decoder(encoder_out, input_tensor)  # (1, T, vocab_size)
        next_token_logits = output[:, -1, :]          # (1, vocab_size)
        predicted_index = next_token_logits.argmax(dim=-1).item()

        if predicted_index == word2idx['<END>']:
            break
        
        generated_indices.append(predicted_index)

    # Decode indices to words
    predicted_tokens = [idx2word[str(idx)] for idx in generated_indices[1:]]  # Skip <START>

    predicted_caption = ' '.join(predicted_tokens)  # 🔥 join words with space
    return predicted_caption

print("\n=== Testing a prediction after training ===\n")

test_image_path = '/home/vitoupro/code/image_captioning/data/9.png'  # your image

predicted_caption = predict_caption(
    image_path=test_image_path,
    encoder=encoder,
    decoder=decoder,
    transform=transform,
    device=device,
    idx2word=idx2word,
    word2idx=word2idx,
    max_length=20
)

print("Predicted Caption:", predicted_caption)


=== Testing a prediction after training ===

Predicted Caption: មាន នេះ មាន មួយ មួយ


In [16]:
@torch.no_grad()
def predict_caption_beam_search(image_path, encoder, decoder, transform, device, idx2word, word2idx, max_length=20, beam_size=3):
    encoder.eval()
    decoder.eval()

    # Load and transform image
    image = Image.open(image_path).convert('RGB')
    if transform:
        image = transform(image)
    image = image.unsqueeze(0).to(device)

    # Encode
    encoder_out = encoder(image)  # (1, 1, embed)

    sequences = [[list([word2idx['<START>']]), 0.0]]  # (sequence, score)

    for _ in range(max_length):
        all_candidates = []
        for seq, score in sequences:
            tgt = torch.tensor([seq], dtype=torch.long).to(device)
            outputs = decoder(encoder_out, tgt)
            outputs = outputs[:, -1, :]  # last step
            probs = torch.softmax(outputs, dim=-1)
            topk_probs, topk_indices = probs.topk(beam_size)

            for i in range(beam_size):
                candidate = seq + [topk_indices[0][i].item()]
                candidate_score = score - torch.log(topk_probs[0][i]).item()
                all_candidates.append((candidate, candidate_score))

        ordered = sorted(all_candidates, key=lambda tup: tup[1])
        sequences = ordered[:beam_size]

    best_seq = sequences[0][0]
    predicted_tokens = [idx2word[str(idx)] for idx in best_seq[1:] if idx != word2idx['<END>']]  # skip <START>
    predicted_caption = ''.join(predicted_tokens)
    return predicted_caption
  
  
test_image_path = '/home/vitoupro/code/image_captioning/data/9.png'  # your image

predicted_caption = predict_caption_beam_search(
    image_path=test_image_path,
    encoder=encoder,
    decoder=decoder,
    transform=transform,
    device=device,
    idx2word=idx2word,
    word2idx=word2idx,
    max_length=20,
    beam_size=3  # you can test beam_size=3 or 5
)

print("Predicted Caption:", predicted_caption)


Predicted Caption: រូបនេះមានក្រឡថ្មមួយនិងស្លាបព្រាមួយ


In [1]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.utils.data import DataLoader
from PIL import Image
import pandas as pd
import os
import re
import json
from sklearn.model_selection import train_test_split
import jiwer
import matplotlib.pyplot as plt
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from timm import create_model

# Function to load idx2word and convert it to word2idx
def load_vocabulary(path):
    with open(path, 'r') as file:
        idx2word = json.load(file)
    word2idx = {v: int(k) for k, v in idx2word.items()}
    return idx2word, word2idx

# Load vocabulary
idx2word_path = '/home/vitoupro/code/image_captioning/data/processed/idx2word_level_khmercut_420.json'
idx2word, word2idx = load_vocabulary(idx2word_path)

# Encoding a list of words (word-level)
def encode_khmer_sentence(sentence, word2idx):
    words = sentence.strip().split()
    indices = []
    for word in words:
        index = word2idx.get(word)
        if index is None:
            return None, f"Word '{word}' not found in vocabulary!"
        indices.append(index)
    return indices, None

# Decoding a list of indices (word-level)
def decode_indices(indices, idx2word):
    words = []
    for index in indices:
        word = idx2word.get(str(index))
        if word is None:
            return None, f"Index '{index}' not found in idx2word!"
        words.append(word)
    return ' '.join(words), None

class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(encoder_dim + decoder_dim, attention_dim)
        self.v = nn.Linear(attention_dim, 1)

    def forward(self, encoder_out, hidden):
        hidden = hidden.unsqueeze(1).repeat(1, encoder_out.size(1), 1)
        attn_input = torch.cat((encoder_out, hidden), dim=2)
        energy = torch.tanh(self.attn(attn_input))
        attention = self.v(energy).squeeze(2)
        alpha = torch.softmax(attention, dim=1)
        context = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
        return context, alpha


# Data-efficient Image Transformer (DeiT) Encoder
class DeiTEncoder(nn.Module):
    def __init__(self, embed_size):
        super(DeiTEncoder, self).__init__()
        import timm
        self.deit = timm.create_model('deit_small_patch16_224', pretrained=True)
        self.deit.head = nn.Identity()  # Remove classification head
        self.linear = nn.Linear(self.deit.num_features, embed_size)  # Project to embed_size

    def forward(self, images):
        features = self.deit.forward_features(images)  # (B, seq_len, 192)
        cls_token = features[:, 0, :]  # Take only [CLS] token: (B, 192)
        out = self.linear(cls_token)  # Project to (B, embed_size)
        return out.unsqueeze(1)  # (B, 1, embed_size)
       


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :].to(x.device)

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, embed_size, num_heads=8, num_layers=3, ff_dim=512, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.pos_encoding = PositionalEncoding(embed_size)
        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_size, nhead=num_heads, dim_feedforward=ff_dim, dropout=dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(embed_size, vocab_size)

    def generate_square_subsequent_mask(self, sz):
        return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

    def forward(self, features, captions):
        tgt = self.embedding(captions)
        tgt = self.pos_encoding(tgt)
        tgt = tgt.permute(1, 0, 2)
        memory = features.permute(1, 0, 2)
        tgt_mask = self.generate_square_subsequent_mask(tgt.size(0)).to(tgt.device)
        out = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask)
        out = self.fc_out(out)
        return out.permute(1, 0, 2)

class ImageCaptionDataset(torch.utils.data.Dataset):
    def __init__(self, img_labels, img_dir, vocab, transform=None, max_length=50):
        self.img_labels = img_labels
        self.img_dir = img_dir
        self.vocab = vocab
        self.transform = transform
        self.max_length = max_length

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        caption = self.img_labels.iloc[idx, 1]
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        indices, error = encode_khmer_sentence(caption, self.vocab)
        if error:
            print(f"Error encoding caption: {error}")
            indices = [self.vocab['<UNK>']] * self.max_length
        tokens = [self.vocab['<START>']] + indices + [self.vocab['<END>']]
        tokens += [self.vocab['<PAD>']] * (self.max_length - len(tokens))
        return image, torch.tensor(tokens[:self.max_length])

transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])

annotations_file = '/home/vitoupro/code/image_captioning/data/processed/word_segmented_imglabel_khmercut.txt'
img_dir = '/home/vitoupro/code/image_captioning/data/raw/img'
image_names, captions = [], []

with open(annotations_file, 'r', encoding='utf-8') as f:
    for line in f:
        parts = line.strip().split(' ', 1)
        if len(parts) == 2:
            image_name, caption = parts
            image_names.append(image_name)
            captions.append(caption)

all_images = pd.DataFrame({'image': image_names, 'caption': captions})
train_images, eval_images, train_captions, eval_captions = train_test_split(
    all_images['image'].tolist(), all_images['caption'].tolist(), test_size=0.2, random_state=42
)

train_dataset = ImageCaptionDataset(pd.DataFrame({'image': train_images, 'caption': train_captions}), img_dir, word2idx, transform, 20)
eval_dataset = ImageCaptionDataset(pd.DataFrame({'image': eval_images, 'caption': eval_captions}), img_dir, word2idx, transform, 20)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=32, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
embed_size = 256
encoder = DeiTEncoder(embed_size=embed_size).to(device)
decoder = TransformerDecoder(len(word2idx), embed_size, num_heads=8, num_layers=3).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=word2idx['<PAD>'])
params = list(decoder.parameters()) + list(encoder.parameters())
optimizer = torch.optim.Adam(params, lr=0.001, weight_decay=1e-5)

smoothing = SmoothingFunction().method1
rouge = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=False)

def evaluate_model(encoder, decoder, dataloader, device, epoch):
    encoder.eval()
    decoder.eval()
    total_wer, total_cer, total_bleu1, total_bleu2, total_bleu4, total_rougeL = 0, 0, 0, 0, 0, 0
    num_samples = 0

    with torch.no_grad():
        for images, captions in dataloader:
            images, captions = images.to(device), captions.to(device)
            features = encoder(images)
            outputs = decoder(features, captions[:, :-1])
            predicted_captions = outputs.argmax(-1)

            for i in range(len(captions)):
                gt_caption = decode_indices(captions[i].tolist(), idx2word)[0]
                pred_caption = decode_indices(predicted_captions[i].tolist(), idx2word)[0]

                ref_match = re.search(r"<START>(.*?)<END>", gt_caption)
                pred_match = re.search(r"^(.*?)<END>", pred_caption)
                reference = ref_match.group(1).strip() if ref_match else ""
                prediction = pred_match.group(1).strip() if pred_match else ""
                if not reference or not prediction:
                    continue

                ref_tokens = reference.split()
                pred_tokens = prediction.split()

                total_wer += jiwer.wer(reference, prediction)
                total_cer += jiwer.cer(reference, prediction)
                total_bleu1 += sentence_bleu([ref_tokens], pred_tokens, weights=(1, 0, 0, 0), smoothing_function=smoothing)
                total_bleu2 += sentence_bleu([ref_tokens], pred_tokens, weights=(0.5, 0.5, 0, 0), smoothing_function=smoothing)
                total_bleu4 += sentence_bleu([ref_tokens], pred_tokens, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothing)
                total_rougeL += rouge.score(reference, prediction)['rougeL'].fmeasure
                num_samples += 1

    avg_wer = total_wer / num_samples if num_samples > 0 else 0
    avg_cer = total_cer / num_samples if num_samples > 0 else 0
    avg_bleu1 = total_bleu1 / num_samples if num_samples > 0 else 0
    avg_bleu2 = total_bleu2 / num_samples if num_samples > 0 else 0
    avg_bleu4 = total_bleu4 / num_samples if num_samples > 0 else 0
    avg_rougeL = total_rougeL / num_samples if num_samples > 0 else 0

    print(f"WER: {avg_wer:.2f}, CER: {avg_cer:.2f}, BLEU-1: {avg_bleu1:.2f}, BLEU-2: {avg_bleu2:.2f}, BLEU-4: {avg_bleu4:.2f}, ROUGE-L: {avg_rougeL:.2f}")
    return avg_wer, avg_cer, avg_bleu1, avg_bleu2, avg_bleu4, avg_rougeL


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
num_epochs = 15
best_wer = float('inf')
teacher_forcing_ratio = 0.9  # Start with 90% teacher forcing

train_losses, wer_scores, cer_scores = [], [], []
bleu1_scores, bleu2_scores, bleu4_scores, rougeL_scores = [], [], [], []

for epoch in range(num_epochs):
    encoder.train()
    decoder.train()
    total_loss = 0

    for images, captions in train_loader:
        images, captions = images.to(device), captions.to(device)

        features = encoder(images)
        
        input_tokens = captions[:, :-1]
        targets = captions[:, 1:]

        if torch.rand(1).item() > teacher_forcing_ratio:
            # Scheduled sampling: use own prediction as input
            outputs = []
            batch_size = images.size(0)
            inputs = torch.full((batch_size, 1), word2idx['<START>'], dtype=torch.long, device=device)

            for _ in range(input_tokens.size(1)):
                output = decoder(features, inputs)
                last_output = output[:, -1, :]  # last timestep
                predicted = last_output.argmax(-1).unsqueeze(1)
                inputs = torch.cat((inputs, predicted), dim=1)
                outputs.append(last_output)

            outputs = torch.stack(outputs, dim=1).squeeze(2)  # (B, T, vocab)
        else:
            # Teacher forcing: normal
            outputs = decoder(features, input_tokens)

        batch_size, seq_len, vocab_size = outputs.size()
        loss = criterion(
            outputs.reshape(batch_size * seq_len, vocab_size),
            targets.reshape(-1)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    train_losses.append(avg_loss)

    print(f"Epoch {epoch+1}: Train Loss: {avg_loss:.4f}")

    avg_wer, avg_cer, avg_bleu1, avg_bleu2, avg_bleu4, avg_rougeL = evaluate_model(encoder, decoder, eval_loader, device, epoch)

    wer_scores.append(avg_wer)
    cer_scores.append(avg_cer)

    # Reduce teacher forcing slowly
    teacher_forcing_ratio = max(0.3, teacher_forcing_ratio * 0.9)


Epoch 1: Train Loss: 4.7892
WER: 0.91, CER: 0.73, BLEU-1: 0.16, BLEU-2: 0.07, BLEU-4: 0.03, ROUGE-L: 0.00
Epoch 2: Train Loss: 4.2321
WER: 0.87, CER: 0.73, BLEU-1: 0.19, BLEU-2: 0.10, BLEU-4: 0.05, ROUGE-L: 0.00
Epoch 3: Train Loss: 4.0388
WER: 0.85, CER: 0.71, BLEU-1: 0.24, BLEU-2: 0.11, BLEU-4: 0.05, ROUGE-L: 0.00
Epoch 4: Train Loss: 3.8215
WER: 0.78, CER: 0.67, BLEU-1: 0.27, BLEU-2: 0.14, BLEU-4: 0.06, ROUGE-L: 0.00
Epoch 5: Train Loss: 3.9683
WER: 0.78, CER: 0.67, BLEU-1: 0.26, BLEU-2: 0.12, BLEU-4: 0.05, ROUGE-L: 0.00
Epoch 6: Train Loss: 4.0067
WER: 0.80, CER: 0.70, BLEU-1: 0.21, BLEU-2: 0.12, BLEU-4: 0.05, ROUGE-L: 0.00
Epoch 7: Train Loss: 3.5999
WER: 0.75, CER: 0.65, BLEU-1: 0.28, BLEU-2: 0.16, BLEU-4: 0.07, ROUGE-L: 0.00
Epoch 8: Train Loss: 3.5465
WER: 0.75, CER: 0.65, BLEU-1: 0.31, BLEU-2: 0.18, BLEU-4: 0.08, ROUGE-L: 0.00
Epoch 9: Train Loss: 3.5066
WER: 0.76, CER: 0.64, BLEU-1: 0.28, BLEU-2: 0.15, BLEU-4: 0.06, ROUGE-L: 0.00
Epoch 10: Train Loss: 4.0272
WER: 0.76, CER: 0