In [None]:

# Import required libraries
import math
import json
import os
import torch
import torch.nn as nn
import torch.utils.data as data
import numpy as np
from collections import defaultdict
from torchvision import transforms
from tqdm.notebook import tqdm

# Import custom modules
from encoder import EncoderCNN
from decoder import DecoderRNN
from dataset import get_loader
from vocab import Vocabulary


In [None]:
# NLP utilities
def clean_sentence(output, idx2word):
    """Convert word indices to clean sentence."""
    sentence = ""
    for i in output:
        if i == 0:  # <pad> token
            continue
        if i == 1:  # <end> token
            break
        word = idx2word[i]
        if i == 18:  # Handle punctuation
            sentence = sentence + word
        else:
            sentence = sentence + " " + word
    return sentence.strip()

from nltk.translate.bleu_score import corpus_bleu

def bleu_score(true_sentences, predicted_sentences):
    """Calculate BLEU score."""
    hypotheses = []
    references = []
    for img_id in set(true_sentences.keys()).intersection(set(predicted_sentences.keys())):
        img_refs = [cap.split() for cap in true_sentences[img_id]]
        references.append(img_refs)
        hypotheses.append(predicted_sentences[img_id][0].strip().split())
    return corpus_bleu(references, hypotheses)


In [None]:
# Training parameters
batch_size = 128
vocab_threshold = 5
vocab_from_file = True
embed_size = 256
hidden_size = 512
num_epochs = 3
save_every = 1
print_every = 20
log_file = "training_log.txt"
learning_rate = 0.001

# Dataset location
cocoapi_dir = "./coco2017/"

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


In [None]:
# Training transformations
transform_train = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# Validation/test transformations
transform_test = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])


In [None]:
# Create data loaders
train_loader = get_loader(
    transform=transform_train,
    mode='train',
    batch_size=batch_size,
    vocab_threshold=vocab_threshold,
    vocab_from_file=vocab_from_file,
    cocoapi_loc=cocoapi_dir
)

val_loader = get_loader(
    transform=transform_test,
    mode='valid',
    cocoapi_loc=cocoapi_dir
)


In [None]:
# Get vocabulary size
vocab_size = len(train_loader.dataset.vocab)
print(f"Vocabulary size: {vocab_size}")

# Initialize encoder and decoder
encoder = EncoderCNN(embed_size)
decoder = DecoderRNN(embed_size, hidden_size, vocab_size)

# Move models to device
encoder.to(device)
decoder.to(device)

# Loss function
criterion = nn.CrossEntropyLoss().to(device)

# Optimizer (only train decoder and encoder embedding layer)
params = list(decoder.parameters()) + list(encoder.embed.parameters())
optimizer = torch.optim.Adam(params, lr=learning_rate)

# Training steps per epoch
total_step = math.ceil(len(train_loader.dataset) / train_loader.batch_sampler.batch_size)
print(f"Total training steps per epoch: {total_step}")


In [None]:
# Create models directory
os.makedirs('./models', exist_ok=True)

# Training loop
with open(log_file, 'w') as f:
    for epoch in range(1, num_epochs + 1):
        for i_step in range(1, total_step + 1):
            
            # Get batch with same caption length
            indices = train_loader.dataset.get_train_indices()
            new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
            train_loader.batch_sampler.sampler = new_sampler
            
            # Get batch
            images, captions = next(iter(train_loader))
            images = images.to(device)
            captions = captions.to(device)
            
            # Zero gradients
            decoder.zero_grad()
            encoder.zero_grad()
            
            # Forward pass
            features = encoder(images)
            outputs = decoder(features, captions)
            
            # Calculate loss
            loss = criterion(outputs.view(-1, vocab_size), captions.view(-1))
            
            # Backward pass
            loss.backward()
            optimizer.step()
            
            # Log statistics
            stats = (f"Epoch [{epoch}/{num_epochs}], Step [{i_step}/{total_step}], "
                    f"Loss: {loss.item():.4f}, Perplexity: {np.exp(loss.item()):.4f}")
            
            f.write(stats + "\n")
            f.flush()
            
            if i_step % print_every == 0:
                print(stats)
        
        # Save model weights
        if epoch % save_every == 0:
            torch.save(decoder.state_dict(), f'./models/decoder-{epoch}.pkl')
            torch.save(encoder.state_dict(), f'./models/encoder-{epoch}.pkl')
            print(f"Model saved at epoch {epoch}")

print("Training completed!")


In [None]:
# Load trained model
encoder_file = f"encoder-{num_epochs}.pkl"
decoder_file = f"decoder-{num_epochs}.pkl"

encoder.load_state_dict(torch.load(f'./models/{encoder_file}'))
decoder.load_state_dict(torch.load(f'./models/{decoder_file}'))

encoder.eval()
decoder.eval()

print("Model loaded and set to evaluation mode.")


In [None]:
# Generate predictions for validation set
pred_result = defaultdict(list)

with torch.no_grad():
    for img_id, img in tqdm(val_loader, desc="Generating captions"):
        img = img.to(device)
        features = encoder(img).unsqueeze(1)
        output = decoder.sample(features)
        sentence = clean_sentence(output, val_loader.dataset.vocab.idx2word)
        pred_result[img_id.item()].append(sentence)

print(f"Generated captions for {len(pred_result)} images.")


In [None]:
# Load ground truth captions
with open(os.path.join(cocoapi_dir, "annotations/captions_val2017.json"), "r") as f:
    caption_data = json.load(f)

valid_annot = caption_data["annotations"]
valid_result = defaultdict(list)
for annotation in valid_annot:
    valid_result[annotation["image_id"]].append(annotation["caption"].lower())

print(f"Loaded ground truth captions for {len(valid_result)} images.")


In [None]:
# Calculate BLEU score
bleu = bleu_score(true_sentences=valid_result, predicted_sentences=pred_result)
print(f"BLEU Score: {bleu:.4f}")

# Show some examples
print("\nExample predictions:")
for i, (img_id, preds) in enumerate(list(pred_result.items())[:3]):
    print(f"\nImage {img_id}:")
    print(f"Predicted: {preds[0]}")
    if img_id in valid_result:
        print(f"Ground truth: {valid_result[img_id][:2]}")


In [None]:
# Save final trained model
torch.save({
    'encoder_state_dict': encoder.state_dict(),
    'decoder_state_dict': decoder.state_dict(),
    'embed_size': embed_size,
    'hidden_size': hidden_size,
    'vocab_size': vocab_size,
    'vocab_threshold': vocab_threshold
}, './models/final_model.pth')

print("Final model saved as 'final_model.pth'")
print(f"Model achieved BLEU score of {bleu:.4f} on validation set")
