## Overview
This notebook trains a Transformer-based image captioning model with a multi-phase training approach using CLIP score optimization. This notebook contains the same code as the seperately uploaded project, with alterations to allow it to run on Google Colab. The model learns to generate descriptive captions for images through a combination of cross-entropy loss and CLIP-based policy gradient loss. I reduced the amount of images provided in the folder to be placed in the Google Drive to 10K images instead of the full 140K images I used. This is to reduce the size of the files to be handed in, while still allowing this notebook to be tested. The results of training on this small subset of the larger dataset is not representative of the results achieved on the full 140K image-caption pairs. The provided dataset serves merely to showcase the functionality of the code.

## Setup Instructions

### Prerequisites
- Google Colab with GPU runtime enabled
- Google Drive for storing datasets and models

### First-Time Setup

1. **Upload the provided zip file** to your Google Drive.
   - The zip file contains necessary images and caption data
   - Place it in the path specified in the notebook (default: `/content/drive/MyDrive/ImageCaptionColab/)
   - Unzip so that your drive contains the unzipped folder called ImageCaptionColab

2. **Run the empty/broken image finder** (third codeblock of the notebook) before running the main pipeline
   - This will identify and remove any corrupt images that might cause training issues
   - Adjust the image directory paths in this cell if needed

3. **Check configuration variables** at the beginning of the notebook:
   - `BASE_PATH`: Base directory for all files
   - `METRICS_PATH`: Path to metric calculation files
   - `ZIP_PATH`: Path to the zipped images
   - `IMAGE_FOLDER`: Where extracted images will be stored
   - `CAPTIONS_FILE`: Path to captions TSV file
   - `MODEL_PATH`: Path for loading pre-trained model (if any)
   - `OUTPUT_PATH`: Where to save the fine-tuned model

### Running the Training Pipeline

After placing the correct folder in the Google Drive you can run the cells in order from top to bottom:

1. **Copy metric files** and set up the environment
   - The notebook will copy necessary metric calculation files
   - Required Python packages will be installed

2. **Extract images from zip**
   - Images will be extracted to the specified folder
   - By default, it extracts up to 10,000 images for training

3. **Feature extraction and data preparation**
   - Image features are extracted using ResNet18
   - Features are cached to disk to speed up future runs
   - Vocabulary is built from captions

4. **Multi-phase training**
   - **Phase 1**: Training with cross-entropy loss only
   - **Phase 2**: Combined cross-entropy and CLIP-based loss
   - **Phase 3**: Fine-tuning with high CLIP loss weight

5. **Evaluation**
   - The model is evaluated using CLIP score
   - Training curves and sample captions are provided

## Training Parameters

The training process can be customized by modifying the `base_config` dictionary:

- `max_images`: Maximum number of images to use (default: 10,000)
- `batch_size`: Batch size for training (default: 64)
- `embed_size`, `hidden_size`, `num_layers`: Model architecture parameters
- `learning_rate`: Initial learning rate
- `num_epochs`: Number of epochs for each training phase
- `clip_loss_weight`: Weight for the CLIP-based loss component

## Output

After training completes, the following files are generated:

- Saved models for each training phase
- The best model based on validation loss and CLIP score
- Training curves showing loss and CLIP score progression
- Evaluation metrics and sample generated captions

In [None]:
# Part 0: Copy metric files to ensure they can be imported
import os
import shutil
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive', force_remount=True)

# Paths
BASE_PATH = "/content/drive/MyDrive"  # Update this path if needed
METRICS_PATH = f"{BASE_PATH}/ImageCaptionColab/TrainMetric"  # folder containing the metric py files
SOURCE_DIR = f"{BASE_PATH}/ImageCaptionColab/TrainMetric"
DEST_DIR = "/content"  # current working directory

# list of metric files to copy (if they're not in a metrics folder)
metric_files = [
    "base.py",
    "clip_scorer.py",
    "grammar_checker.py",
    "clip.py",
    "metric.py"
]

# copy each file to current directory
print("copying metric files to current directory...")
for filename in metric_files:
    source_path = os.path.join(SOURCE_DIR, filename)
    dest_path = os.path.join(DEST_DIR, filename)

    if os.path.exists(source_path):
        shutil.copy2(source_path, dest_path)
        print(f"  Copied {filename} ✓")
    else:
        print(f"  {filename} not found in source directory ✗")

!pip install language_tool_python transformers Pillow torch torchvision matplotlib pandas tqdm

In [None]:
# part 1: setup and import metrics
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from collections import Counter
import math
import sys
import zipfile
import os
import shutil

# change the path to the metrics folder,
sys.path.append(METRICS_PATH)  # if they're in a folder in my google drive

# check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"using device: {device}")

# paths for the project
CAPTIONS_FILE = f"{BASE_PATH}/ImageCaptionColab/filtered_captions.tsv" # all captions for images that could be extracted (removed those for broken urls)
MODEL_PATH = f"{BASE_PATH}/best_model.pt"  # path to pre-trained model
OUTPUT_PATH = f"{BASE_PATH}/finetuned_model.pt"  # where to save fine-tuned model
FILTERED_CAPTIONS_FILE = f"{BASE_PATH}/filtered_captions_matched.tsv"


# Define your paths
BASE_PATH = "/content/drive/MyDrive"  # Update to your actual base path
ZIP_PATH = f"{BASE_PATH}/ImageCaptionColab/loadedimages.zip"
IMAGE_FOLDER = f"{BASE_PATH}/ImageCaptionColab/images/images"  # final destination folder

# Create directories
os.makedirs(IMAGE_FOLDER, exist_ok=True)

# Function to extract and flatten the directory structure
def extract_images(zip_path, dest_folder, max_images=10094):
    """
    extract images from ZIP file and place them directly in the destination folder,
    ignoring any folder structure in the ZIP
    """
    print(f"Extracting up to {max_images} images from {zip_path} to {dest_folder}...")

    # make temporary extraction folder
    temp_folder = os.path.join(os.path.dirname(dest_folder), "temp_extract")
    os.makedirs(temp_folder, exist_ok=True)

    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        # Get list of image files in the zip
        all_files = [f for f in zip_ref.namelist() if f.lower().endswith(('.jpg', '.jpeg', '.png'))]

        # Limit to max_images
        if len(all_files) > max_images:
            print(f"ZIP contains {len(all_files)} images, limiting extraction to {max_images}")
            # Shuffle the list to get a random sample
            random.seed(42)  # For reproducibility
            random.shuffle(all_files)
            file_list = all_files[:max_images]
        else:
            file_list = all_files

        # Extract the selected images to temp folder
        for file in tqdm(file_list, desc="extracting to temp folder"):
            zip_ref.extract(file, temp_folder)

        # move files from temp folder to destination, flattening directory structure
        # quick fix for an issue I ran into with images being placed in a folder in the folder
        count = 0
        skipped = 0
        for root, _, files in os.walk(temp_folder):
            for file in files:
                if file.lower().endswith(('.jpg', '.jpeg', '.png')):
                    # Get original filename without path
                    src_path = os.path.join(root, file)
                    # Use just the base filename for destination
                    dest_path = os.path.join(dest_folder, file)

                    # If file already exists, skip it
                    if os.path.exists(dest_path):
                        skipped += 1
                        continue

                    shutil.copy2(src_path, dest_path)
                    count += 1

        # remove temporary folder
        shutil.rmtree(temp_folder)

        print(f"extracted and flattened {count} images to {dest_folder}")
        return count

# Check if extraction is needed
if len([f for f in os.listdir(IMAGE_FOLDER) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]) < 10000:
    # clear the destination folder first to avoid mixed content
    for file in os.listdir(IMAGE_FOLDER):
        file_path = os.path.join(IMAGE_FOLDER, file)
        if os.path.isfile(file_path):
            os.remove(file_path)


    extract_images(ZIP_PATH, IMAGE_FOLDER, MAX_EXTRACT_IMAGES)
else:
    print(f"images already extracted in {IMAGE_FOLDER}")
    print(f"found {len([f for f in os.listdir(IMAGE_FOLDER) if f.lower().endswith(('.jpg', '.jpeg', '.png'))])} image files")


def filter_captions_file(captions_file, image_folder, output_file=None):
    """
    make sure the captions file only has entries for images that actually exist
    """
    import os
    import pandas as pd
    from tqdm.notebook import tqdm

    print(f"getting captions from: {captions_file}")
    # read the captions file
    df = pd.read_csv(captions_file, sep='\t', header=None, names=['image_id', 'caption'])
    original_count = len(df)
    print(f"original captions count: {original_count}")

    # list of available images (directly in the image folder, no nested structure this time grr)
    print(f"checking available images in: {image_folder}")
    available_images = set()
    for file in os.listdir(image_folder):
        if file.lower().endswith(('.jpg', '.jpeg', '.png')):
            # both with and without extension
            base_name = os.path.splitext(file)[0]
            available_images.add(base_name)
            available_images.add(file)

    print(f"found {len(available_images)} available images")

    # filter captions to only include those for available images
    print("filtering captions...")
    filtered_df = df[df['image_id'].astype(str).isin(available_images)]
    filtered_count = len(filtered_df)
    print(f"nr of filtered captions: {filtered_count}")
    print(f"removed {original_count - filtered_count} captions for missing images")

    if output_file:
        print(f"saving filtered captions to: {output_file}...")
        filtered_df.to_csv(output_file, sep='\t', header=False, index=False)
        print(f"saved filtered captions file with {filtered_count} entries!")

    return filtered_df

# alternative version that combines train and val datasets
def filter_to_all_training_images(reference_captions, generated_captions, image_paths,
                                train_loader, val_loader):
    """
    filter evaluation images to match exactly those used in both training and validation
    """
    # extract image IDs from both datasets
    train_image_ids = set(train_loader.dataset.image_names)
    val_image_ids = set(val_loader.dataset.image_names)
    all_used_image_ids = train_image_ids.union(val_image_ids)

    print(f"found {len(train_image_ids)} training and {len(val_image_ids)} validation images")
    print(f"total of {len(all_used_image_ids)} unique images used in training/validation")

    # filter dictionaries to only include these IDs
    filtered_reference = {img_id: captions for img_id, captions in reference_captions.items()
                         if img_id in all_used_image_ids}
    filtered_generated = {img_id: caption for img_id, caption in generated_captions.items()
                         if img_id in all_used_image_ids}
    filtered_paths = {img_id: path for img_id, path in image_paths.items()
                     if img_id in all_used_image_ids}

    print(f"filtered to {len(filtered_reference)} reference images")
    print(f"filtered to {len(filtered_generated)} generated images")
    print(f"filtered to {len(filtered_paths)} image paths")

    return filtered_reference, filtered_generated, filtered_paths

In [None]:
#CAN SKIP THIS CELL WHEN NO NEW IMAGES WERE EXTRACTED FROM
# loadedimages.zip


# empty/broken image finder
# minimal checks, focuses on file size
# use once at the beginning (after extracting the desired amount of images from the zip file)
# to ensure all images actually exist

import os
import shutil
import time
import concurrent.futures
from tqdm.notebook import tqdm
from google.colab import drive

# try to mount drive
try:
    drive.mount('/content/drive')
    print("Google Drive mounted successfully!")
except:
    print("drive already mounted or could not be mounted")

# Configuration - ADJUST THESE
IMAGE_DIR = f"{BASE_PATH}/ImageCaptionColab/images/images"  # image directory
OUTPUT_DIR = f"{BASE_PATH}/ImageCaptionColab/broken_images"  # where to move broken images
EXTENSIONS = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp'] # really only have one of them, but you never know
MIN_SIZE_BYTES = 1000  # any file smaller than this is considered empty/broken
MAX_WORKERS = 4  # can be higher since we're doing very simple checks
BATCH_SIZE = 100  # can be much larger since we're not loading images
DELETE_BROKEN = False  # set to True to delete broken images instead of moving them

# function to check if a file is empty or too small
def is_broken_file(file_path, min_size=MIN_SIZE_BYTES):
    """
    check if a file is too small to be a valid image
    """
    try:
        # get file size
        size = os.path.getsize(file_path)
        if size < min_size:
            return True, f"file too small: {size} bytes"

        # optional quick header check for slightly larger files
        # I don't think this code was ever actually needed
        if size < 5000:  # only check headers for suspicious small files
            try:
                with open(file_path, 'rb') as f:
                    header = f.read(16)

                # common image file headers
                valid_headers = [
                    b'\xff\xd8\xff',  # JPEG
                    b'\x89PNG\r\n\x1a\n',  # PNG
                    b'GIF8',  # GIF
                    b'BM',  # BMP
                    b'RIFF'  # WEBP
                ]

                # check if header matches any valid image header
                if not any(header.startswith(h) for h in valid_headers):
                    return True, "invalid image header"
            except:
                # if we can't read the file, consider it broken
                return True, "can't read file"

        return False, "OK"
    except Exception as e:
        return True, str(e)

# function to process a batch of files
def process_batch(files):
    """
    process a batch of files, return list of broken files
    """
    broken = []
    for file_path in files:
        is_broken, reason = is_broken_file(file_path)
        if is_broken:
            broken.append((file_path, reason))
    return broken

# main function to find and remove empty/broken images
def find_and_remove_empty_files():
    """
    find and remove empty or broken image files
    """
    start_time = time.time()

    # make sure output directory exists if we're moving files
    if not DELETE_BROKEN and not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)

    # get all image files
    print(f"scanning directory: {IMAGE_DIR}")
    image_files = []

    for root, _, files in os.walk(IMAGE_DIR):
        for file in files:
            _, ext = os.path.splitext(file)
            if ext.lower() in EXTENSIONS:
                image_files.append(os.path.join(root, file))

    print(f"found {len(image_files)} image files to check")

    # split into batches for processing
    batches = [image_files[i:i+BATCH_SIZE] for i in range(0, len(image_files), BATCH_SIZE)]

    # process all batches
    all_broken_files = []
    files_processed = 0

    print(f"processing in {len(batches)} batches with {MAX_WORKERS} workers")

    for batch_idx, batch in enumerate(batches):
        batch_start = time.time()

        # process the batch
        with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
            # split into smaller sub-batches for each worker
            sub_batch_size = max(1, len(batch) // MAX_WORKERS)
            sub_batches = [batch[i:i+sub_batch_size] for i in range(0, len(batch), sub_batch_size)]

            # process sub-batches in parallel
            futures = [executor.submit(process_batch, sub_batch) for sub_batch in sub_batches]
            broken_batches = [future.result() for future in concurrent.futures.as_completed(futures)]

            # flatten results
            batch_broken = [item for sublist in broken_batches for item in sublist]

        # handle broken files
        for file_path, reason in batch_broken:
            all_broken_files.append((file_path, reason))

            # remove or move the file
            if DELETE_BROKEN:
                try:
                    os.remove(file_path)
                except Exception as e:
                    print(f"Error deleting {file_path}: {e}")
            else:
                try:
                    filename = os.path.basename(file_path)
                    dest_path = os.path.join(OUTPUT_DIR, filename)

                    # handle filename problems
                    if os.path.exists(dest_path):
                        base, ext = os.path.splitext(filename)
                        dest_path = os.path.join(OUTPUT_DIR, f"{base}_{int(time.time()*1000) % 10000}{ext}")

                    shutil.move(file_path, dest_path)
                except Exception as e:
                    print(f"Error moving {file_path}: {e}")

        # update progress
        files_processed += len(batch)
        progress_percent = files_processed / len(image_files) * 100

        # calculate timing metrics
        batch_time = time.time() - batch_start
        files_per_second = len(batch) / max(0.001, batch_time)
        elapsed_time = time.time() - start_time

        # estimated time remaining
        remaining_files = len(image_files) - files_processed
        if files_per_second > 0:
            estimated_time_remaining = remaining_files / files_per_second
        else:
            estimated_time_remaining = 0

        # print progress
        print(f"batch {batch_idx+1}/{len(batches)} | "
              f"progress: {progress_percent:.1f}% | "
              f"found {len(batch_broken)} broken files in this batch | "
              f"speed: {files_per_second:.1f} files/sec")

    # make a little report
    report_path = os.path.join(OUTPUT_DIR if not DELETE_BROKEN else os.path.dirname(IMAGE_DIR),
                               "empty_files_report.txt")

    with open(report_path, 'w') as f:
        f.write("# empty/broken files report\n\n")
        f.write(f"total files processed: {len(image_files)}\n")
        f.write(f"empty/broken files found: {len(all_broken_files)}\n\n")
        f.write("## list of empty/broken files\n\n")

        for file_path, reason in all_broken_files:
            f.write(f"{file_path}\t{reason}\n")

    # Final stats
    total_time = time.time() - start_time
    print("\n" + "="*50)
    print("PROCESSING COMPLETE")
    print("="*50)
    print(f"processed {len(image_files)} files in {total_time:.1f} seconds ({total_time/60:.1f} minutes)")
    print(f"found {len(all_broken_files)} empty or broken files")
    print(f"average speed: {len(image_files)/total_time:.1f} files per second")
    print(f"report saved to: {report_path}")

    return all_broken_files

# execute the function
if __name__ == "__main__":
    print("starting empty/broken file finder...")
    broken_files = find_and_remove_empty_files()
    print("done! :)")

In [None]:
# part 2: model definition and datasets

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import random
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from collections import Counter
from transformers import CLIPProcessor, CLIPModel
import math
import language_tool_python

torch.serialization.add_safe_globals([
    np.core.multiarray.scalar,
    np.dtype,
    np.ndarray,
    np._globals
])

# check if GPU is available (again i guess)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"using: {device}")

class CaptionFeatureDataset(Dataset):
    def __init__(self, features_dict, captions_dict, word2idx, max_len=22, split='train'):
        """
        dataset for training/validation (now with extra safety checks).

        Args:
            features_dict: dictionary mapping image_id to image features
            captions_dict: dictionary mapping image_id to list of captions
            word2idx: Word to index mapping
            max_len: max caption length
            split: 'train' or 'val'
        """
        # first, ensure data consistency - only use images with both features and captions (i'm finally learning from my mistakes)
        common_ids = set(features_dict.keys()).intersection(set(captions_dict.keys()))
        print(f"Dataset '{split}': {len(common_ids)} valid images (from {len(features_dict)} features, {len(captions_dict)} caption sets)")

        if len(common_ids) == 0:
            raise ValueError("no valid images found with both features and captions! :(")

        # get all valid image IDs
        all_ids = list(common_ids)
        random.seed(42)  # for reproducibility
        random.shuffle(all_ids)

        # split into train/val (90%/10%), probably should have made this an argument as well
        if split == 'train':
            self.image_names = all_ids[:int(0.9 * len(all_ids))]
        else:
            self.image_names = all_ids[int(0.9 * len(all_ids)):]

        self.features_dict = features_dict
        self.captions_dict = captions_dict
        self.word2idx = word2idx
        self.max_len = max_len

        print(f"created {split} dataset with {len(self.image_names)} images")

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_name = self.image_names[idx]

        # double safety check (should never fail due to the initialization, but i'm paranoid now)
        if img_name not in self.features_dict:
            raise KeyError(f"image {img_name} not found in features dictionary!")
        if img_name not in self.captions_dict:
            raise KeyError(f"image {img_name} not found in captions dictionary!")

        feature = self.features_dict[img_name]

        # for training, randomly choose one caption
        caption = random.choice(self.captions_dict[img_name]).split()

        # add < SOS > and <EOS> tokens
        caption = ["< SOS >"] + caption + ["<EOS>"]
        tokens = [self.word2idx.get(w, self.word2idx["<UNK>"]) for w in caption]
        tokens = tokens[:self.max_len]
        tokens += [self.word2idx["<PAD>"]] * (self.max_len - len(tokens))

        return feature, torch.tensor(tokens, dtype=torch.long)


class CaptionEvaluationDataset(Dataset):
    """
    dataset for evaluation with safety checks
    """
    def __init__(self, features_dict):
        self.image_names = list(features_dict.keys())
        self.features_dict = features_dict
        print(f"created evaluation dataset with {len(self.image_names)} images")

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_name = self.image_names[idx]
        feature = self.features_dict[img_name]
        return feature, img_name

# using Mirha Sidheek's Transformer Decoder model implementation
class Transformer_Decoder(nn.Module):
    def __init__(self, embed_size, vocab_size, hidden_size, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.pos_encoding = nn.Parameter(torch.zeros(1, 22, embed_size))  # max_len=22
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=embed_size,
            nhead=16,
            dim_feedforward=hidden_size,
            dropout=0.2,
            batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.fc = nn.Linear(embed_size, vocab_size)

    def forward(self, features, captions, mask=None):
        batch_size = captions.size(0)
        seq_len = captions.size(1)

        # Embed captions and add positional encoding
        embedded = self.embedding(captions) + self.pos_encoding[:, :seq_len, :]

        # Create causal mask for autoregressive decoding if not provided
        if mask is None:
            mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(features.device)

        # Features as memory (B, 1, embed_size) for cross-attention
        memory = features.unsqueeze(1)  # Add sequence dimension

        # Decode
        output = self.decoder(tgt=embedded, memory=memory, tgt_mask=mask)
        return self.fc(output)

# Beam search caption generation
def beam_search_caption(image_feature, decoder, project_features, word2idx, idx2word, device,
                        beam_width=5, max_len=22):
    """
    Generate a caption using beam search.

    Args:
        image_feature: Image feature tensor
        decoder: The decoder model
        project_features: The projection layer
        word2idx: Word to index mapping
        idx2word: Index to word mapping
        device: Device to run on
        beam_width: Beam width
        max_len: Maximum caption length

    Returns:
        The generated caption
    """
    decoder.eval()
    project_features.eval()

    # Project the image features
    with torch.no_grad():
        projected = project_features(image_feature)

    # Beam search
    sequences = [[[word2idx["< SOS >"]], 0.0]]  # (sequence, log_prob)
    completed = []

    for _ in range(max_len):
        candidates = []
        for seq, score in sequences:
            if seq[-1] == word2idx["<EOS>"]:
                completed.append((seq, score))
                continue

            input_seq = torch.tensor([seq], dtype=torch.long).to(device)
            with torch.no_grad():
                output = decoder(projected, input_seq)
                probs = torch.softmax(output[:, -1, :], dim=-1)
                topk = torch.topk(probs, beam_width)

            for i in range(beam_width):
                token = topk.indices[0, i].item()
                token_prob = topk.values[0, i].item()
                new_seq = seq + [token]
                new_score = score + torch.log(torch.tensor(token_prob + 1e-10)).item()
                candidates.append((new_seq, new_score))

        # Select top beam_width candidates
        sequences = sorted(candidates, key=lambda x: x[1], reverse=True)[:beam_width]
        if len(sequences) == 0:
            break

    # Include completed sequences
    all_sequences = sequences + completed
    if not all_sequences:
        return ""

    best_seq = sorted(all_sequences, key=lambda x: x[1], reverse=True)[0][0]

    # Convert to words and return caption
    caption_tokens = [idx2word[t] for t in best_seq if t not in [word2idx["<PAD>"], word2idx["< SOS >"], word2idx["<EOS>"]]]
    return " ".join(caption_tokens)

def sample_caption(image_feature, decoder, project_features, word2idx, idx2word, device,
                   max_len=22, temperature=1.0):
    """
    sample a caption using the current policy (the decoder), keeping track of log probabilities
    for policy gradient training

    Args:
        image_feature: image feature tensor
        decoder: decoder model
        project_features: projection layer
        word2idx: word to index mapping
        idx2word: index to word mapping
        device: device to run on
        max_len: max caption length
        temperature: temperature for sampling (higher = more diverse)

    Returns:
        tokens: list of token ids for the sampled caption
        log_probs: sum of log probabilities for the sampled tokens
        caption: caption
    """
    decoder.eval()  # set to eval mode initially
    project_features.eval()

    # project image features
    with torch.no_grad():
        projected = project_features(image_feature.unsqueeze(0))

    # switch to train mode for sampling
    decoder.train()

    # start with SOS token
    tokens = [word2idx["< SOS >"]]
    log_probs_list = []

    # generate tokens one by one
    for i in range(max_len - 1):  # -1 because we've already added SOS
        # prepare current sequence
        curr_seq = torch.tensor([tokens], dtype=torch.long).to(device)

        # get predictions
        with torch.enable_grad():  # make sure we're tracking gradients
            outputs = decoder(projected, curr_seq)
            logits = outputs[:, -1, :] / temperature
            probs = F.softmax(logits, dim=-1)

            # sample from the distribution
            dist = torch.distributions.Categorical(probs)
            next_token = dist.sample()
            log_prob = dist.log_prob(next_token)

            # add to tracking variables
            log_probs_list.append(log_prob)
            tokens.append(next_token.item())

            # stop if we encounter the EOS token
            if next_token.item() == word2idx["<EOS>"]:
                break

    # convert token IDs to words (excluding SOS, EOS, and PAD)
    caption_words = [idx2word[token] for token in tokens
                    if token not in [word2idx["<PAD>"], word2idx["< SOS >"], word2idx["<EOS>"]]]
    caption = " ".join(caption_words)

    # return the tokens, sum of log probabilities, and the caption text
    return tokens, torch.stack(log_probs_list).sum(), caption


In [None]:
# part 3: save and load model

def save_model(decoder, project_features, metrics, output_path, model_type="checkpoint"):
    """
    save model with comprehensive metadata

    Args:
        decoder: decoder model
        project_features: features projection layer
        metrics: dictionary of metrics to save with the model
        output_path: base path for saving
        model_type: type of model being saved (checkpoint, best_val, best_clip)
    """
    import os
    import torch
    import time
    from datetime import datetime

    # create a timestamped filename so i finally keep track properly
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    if model_type == "checkpoint":
        filename = f"{output_path}_{model_type}_{timestamp}.pt"
    else:
        filename = f"{output_path}_{model_type}.pt"

    # prepare model save dictionary with metadata
    save_dict = {
        'decoder': decoder.state_dict(),
        'project_features': project_features.state_dict(),
        'metrics': metrics,
        'timestamp': timestamp,
        'type': model_type
    }

    # save model
    print(f"Saving {model_type} model to {filename}...")
    torch.save(save_dict, filename)

    # if this is a best model, also save a copy to the standard path, kinda wonky, but works for now
    if model_type in ["best_val", "best_clip"]:
        standard_path = f"{output_path}.pt"
        print(f"Also saving to standard path: {standard_path}")
        torch.save(save_dict, standard_path)

    return filename

def load_model(model_path, embed_size, vocab_size, hidden_size, num_layers, device):
    """
    load a saved model (now with proper error handling)
    """
    import os
    import torch

    print(f"Loading model from {model_path}...")

    try:
        # try loading the model with weights_only=False
        checkpoint = torch.load(model_path, map_location=device, weights_only=False)

        # initialize the models
        decoder = Transformer_Decoder(embed_size, vocab_size, hidden_size, num_layers).to(device)
        project_features = nn.Linear(512, embed_size).to(device)

        # load the weights
        decoder.load_state_dict(checkpoint['decoder'])
        project_features.load_state_dict(checkpoint['project_features'])

        # extract metadata if available
        metadata = {}
        for key in ['metrics', 'timestamp', 'type']:
            if key in checkpoint:
                metadata[key] = checkpoint[key]

        print(f"model loaded successfully!")
        if 'type' in metadata:
            print(f"Model type: {metadata['type']}")
        if 'timestamp' in metadata:
            print(f"Saved on: {metadata['timestamp']}")

        return decoder, project_features, metadata

    except Exception as e:
        print(f"Error loading model: {e}")
        print("Initializing new model instead...")

        # initialize new models if there was none found
        decoder = Transformer_Decoder(embed_size, vocab_size, hidden_size, num_layers).to(device)
        project_features = nn.Linear(2048, embed_size).to(device)

        return decoder, project_features, {}

def find_best_model(base_path, model_type="best_clip"):
    """
    find the best saved model of a given type.

    Args:
        base_path: Base path where models are saved
        model_type: Type of model to find (best_val, best_clip)

    Returns:
        Path to the best model, or None if no model found
    """
    import os
    import glob

    # look for exact match first
    exact_path = f"{base_path}_{model_type}.pt"
    if os.path.exists(exact_path):
        return exact_path

    # look for timestamped versions
    pattern = f"{base_path}_{model_type}_*.pt"
    matches = glob.glob(pattern)

    if matches:
        # sort by modification time (most recent first)
        matches.sort(key=os.path.getmtime, reverse=True)
        return matches[0]

    # default to standard path
    standard_path = f"{base_path}.pt"
    if os.path.exists(standard_path):
        return standard_path

    return None

In [None]:
# part 4: learning rate schedule

def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.1):
    """
    create a schedule with a learning rate that decreases following the values of the
    cosine function between the initial lr and 0, with a warmup period at the beginning

    Args:
        optimizer: optimizer for which to schedule the learning rate
        num_warmup_steps: nr of steps for the warmup phase
        num_training_steps: total nr of training steps
        min_lr_ratio: min learning rate ratio compared to the initial LR

    Returns:
        learning rate scheduler
    """
    def lr_lambda(current_step):
        # warmup phase
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))

        # cosine decay phase
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        cosine_decay = 0.5 * (1.0 + math.cos(math.pi * progress))
        decayed = (1 - min_lr_ratio) * cosine_decay + min_lr_ratio

        return decayed

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

def create_optimizer_and_scheduler(model_params, learning_rate, num_training_steps,
                                  warmup_ratio=0.1, weight_decay=0.01):
    """
    create an optimizer and learning rate scheduler with warmup.

    Args:
        model_params: parameters of the model to optimize
        learning_rate: max learning rate
        num_training_steps: total nr of training steps
        warmup_ratio: portion of training to use for warmup
        weight_decay: weight decay coefficient

    Returns:
        optimizer: AdamW optimizer
        scheduler: learning rate scheduler
    """
    # create optimizer with weight decay
    optimizer = torch.optim.AdamW(
        model_params,
        lr=learning_rate,
        weight_decay=weight_decay
    )

    # create scheduler with warmup
    num_warmup_steps = int(num_training_steps * warmup_ratio)
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps
    )

    return optimizer, scheduler

In [None]:
# step 5: CLIP

def compute_clip_reward_loss(features, img_ids, decoder, project_features,
                         word2idx, idx2word, clip_calculator,
                         reference_captions, image_paths, device):
    """
    compute policy gradient loss using CLIP scores as rewards

    Args:
        features: batch of image features
        img_ids: image IDs corresponding to the features
        decoder: decoder model
        project_features: projection layer
        word2idx: word to index mapping
        idx2word: index to word mapping
        clip_calculator: CLIP calculator instance
        reference_captions: dictionary of reference captions
        image_paths: dictionary of image paths
        device: device to run on

    Returns:
        loss: policy gradient loss
        mean_clip_score: average CLIP score for the batch
    """
    # sample captions and collect log probabilities
    sampled_captions = {}
    log_probs = []
    valid_indices = []

    # generate captions with sampling to allow gradient flow
    for i, (feature, img_id) in enumerate(zip(features, img_ids)):
        img_id = str(img_id)

        # skip images without reference captions or paths
        if img_id not in reference_captions or img_id not in image_paths:
            continue

        # sample a caption and get log probability
        try:
            _, log_prob, caption = sample_caption(
                feature, decoder, project_features,
                word2idx, idx2word, device
            )

            sampled_captions[img_id] = caption
            log_probs.append(log_prob)
            valid_indices.append(i)
        except Exception as e:
            print(f"error sampling caption: {e}")
            continue

    # if no valid captions, return zero loss
    if not sampled_captions:
        return torch.tensor(0.0, device=device), 0.0

    # prepare data for CLIP evaluation
    batch_references = {img_id: reference_captions[img_id]
                       for img_id in sampled_captions.keys()}
    batch_paths = {img_id: image_paths[img_id]
                  for img_id in sampled_captions.keys()}

    # calculate CLIP scores
    try:
        clip_scores = clip_calculator.compute(
            batch_references, sampled_captions, batch_paths
        )

        rewards = clip_scores

        # convert to tensor
        rewards = torch.tensor(rewards, device=device)
    except Exception as e:
        print(f"error calculating CLIP score: {e}")
        return torch.tensor(0.0, device=device), 0.0

    # calculate policy gradient loss
    # higher CLIP score ( better CLIP score) = Lower loss (hence the negative sign)
    policy_loss = -rewards * torch.stack(log_probs).mean()

    return policy_loss, clip_scores

def evaluate_model_with_clip_score(decoder, project_features, eval_loader, word2idx, idx2word,
                                  clip_calculator, tsv_path, image_dir,
                                  train_loader, val_loader, max_eval_images=100, cached_refs=None, cached_paths=None):
    """
    evaluate model using CLIP score (now with optimization for larger datasets).

    Args:
        decoder: decoder model
        project_features: projection model
        eval_loader: dataloader for evaluation
        word2idx: word to index mapping
        idx2word: index to word mapping
        clip_calculator: CLIPCalculator instance
        tsv_path: path to TSV file with reference captions
        image_dir: directory with images
        train_loader, val_loader: training/validation data loaders
        max_eval_images: maximum number of images to evaluate (for speed)
        cached_refs, cached_paths: optional cached references and paths

    Returns:
        clip_score: CLIP score
        generated_captions: dict of generated captions
    """
    import torch
    import random
    from tqdm.notebook import tqdm

    decoder.eval()
    project_features.eval()

    # get training and validation image IDs to ensure we only use images from our datasets
    train_image_ids = set(train_loader.dataset.image_names)
    val_image_ids = set(val_loader.dataset.image_names)
    all_used_image_ids = train_image_ids.union(val_image_ids)

    # load only necessary references (either from cache or from file)
    if cached_refs is not None and cached_paths is not None:
        reference_captions = cached_refs
        image_paths = cached_paths
        print(f"using cached references with {len(reference_captions)} images")
    else:
        # load references for all training/validation images
        print(f"loading reference captions for {len(all_used_image_ids)} training/validation images...")
        reference_captions, image_paths = load_references(tsv_path, image_dir, filter_ids=all_used_image_ids)
        print(f"loaded {len(reference_captions)} valid reference images with captions")

    # find images that have both features and reference captions
    valid_image_ids = set(reference_captions.keys()).intersection(all_used_image_ids)
    print(f"found {len(valid_image_ids)} images with both features and reference captions")

    # limit to max_eval_images if needed (for faster evaluation)
    if len(valid_image_ids) > max_eval_images:
        print(f"limiting CLIP evaluation to {max_eval_images} random images (out of {len(valid_image_ids)})")
        random.seed(42)  # for reproducibility
        eval_image_ids = set(random.sample(list(valid_image_ids), max_eval_images))
    else:
        eval_image_ids = valid_image_ids
        print(f"evalauting all {len(eval_image_ids)} valid images with CLIP")

    # generate captions for each image in the evaluation set
    generated_captions = {}

    with torch.no_grad():
        for features, img_ids in tqdm(eval_loader, desc="generating captions"):
            features = features.to(device)

            for i, (feature, img_id) in enumerate(zip(features, img_ids)):
                img_id = str(img_id.item() if hasattr(img_id, 'item') else img_id)

                # skip if not in our evaluation set
                if img_id not in eval_image_ids:
                    continue

                # double-check that we have a reference caption and image path
                if img_id not in reference_captions or img_id not in image_paths:
                    continue

                img_feature = feature.unsqueeze(0)

                # generate caption with beam search
                caption = beam_search_caption(
                    img_feature, decoder, project_features,
                    word2idx, idx2word, device, beam_width=5, max_len=22
                )

                generated_captions[img_id] = caption

    print(f"generated {len(generated_captions)} captions")

    # final filtering to ensure all IDs are present in all dictionaries once and for all!! :(
    common_ids = set(generated_captions.keys()).intersection(
                    set(reference_captions.keys())).intersection(
                    set(image_paths.keys()))

    filtered_generated = {img_id: generated_captions[img_id] for img_id in common_ids}
    filtered_reference = {img_id: reference_captions[img_id] for img_id in common_ids}
    filtered_paths = {img_id: image_paths[img_id] for img_id in common_ids}

    print(f"final evaluation set: {len(common_ids)} images with all required data")

    if len(common_ids) == 0:
        print("ERROR: no images have both generated captions, reference captions, and image paths!")
        return 0.0, {}

    # calculate CLIP scores
    print("computing CLIP scores...")
    generated_score = clip_calculator.compute(filtered_reference, filtered_generated, filtered_paths)
    reference_score = clip_calculator.compute_reference_clip(filtered_reference, filtered_paths)

    print(f"CLIP scores - generated: {generated_score:.4f}, reference: {reference_score:.4f}")

    return generated_score, generated_captions


In [None]:
# part 6: training

def train_in_phases(train_loader, val_loader, word2idx, idx2word, config):
    """
    train a model in multiple phases with different objectives

    Args:
        train_loader: dataloader for training
        val_loader: dataloader for validation
        word2idx, idx2word: vocabulary mappings
        config: base configuration for training

    Returns:
        decoder, project_features: trained models
        histories: list of training histories for each phase
    """
    import os
    import time
    from datetime import datetime

    # make a unique experiment directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    experiment_dir = os.path.join(BASE_PATH, f"training_experiment_{timestamp}")
    os.makedirs(experiment_dir, exist_ok=True)

    print(f"Starting multi-phase training in {experiment_dir}")

    # save configuration
    with open(os.path.join(experiment_dir, "config.txt"), "w") as f:
        for key, value in config.items():
            f.write(f"{key}: {value}\n")

    # record histories for each phase
    histories = []

    # phase 1: initial training with cross-entropy loss only
    print("\n" + "="*50)
    print("PHASE 1: Cross-Entropy Training")
    print("="*50)

    phase1_config = config.copy()
    phase1_config.update({
        'num_epochs': 10,
        'clip_loss_weight': 0.0,  # no CLIP loss
        'learning_rate': 0.0003,
        'output_path': os.path.join(experiment_dir, "phase1_model")
    })

    # train with cross-entropy only
    decoder, project_features, history1 = train_model_enhanced(
        train_loader=train_loader,
        val_loader=val_loader,
        word2idx=word2idx,
        idx2word=idx2word,
        embed_size=phase1_config['embed_size'],
        hidden_size=phase1_config['hidden_size'],
        num_layers=phase1_config['num_layers'],
        learning_rate=phase1_config['learning_rate'],
        num_epochs=phase1_config['num_epochs'],
        early_stopping_patience=phase1_config['early_stopping_patience'],
        checkpoint_frequency=phase1_config['checkpoint_frequency'],
        model_path=None,  # Start from scratch
        output_path=phase1_config['output_path'],
        feature_dim=phase1_config['feature_dim'],
        clip_loss_weight=phase1_config['clip_loss_weight'],
        clip_batch_size=phase1_config['clip_batch_size'],
        clip_eval_frequency=phase1_config['clip_eval_frequency']
    )

    histories.append(history1)

    # save phase 1 training curves
    plot_training_curves(history1)
    plt.savefig(os.path.join(experiment_dir, "phase1_curves.png"))
    plt.close()

    # find best model from phase 1
    phase1_best_model = find_best_model(phase1_config['output_path'], "best_val")

    # phase 2: balanced training with both losses
    print("\n" + "="*50)
    print("PHASE 2: Combined Cross-Entropy and CLIP Training")
    print("="*50)

    phase2_config = config.copy()
    phase2_config.update({
        'num_epochs': 10,
        'clip_loss_weight': 0.3,  # moderate CLIP influence
        'learning_rate': 0.0001,  # lower learning rate for fine-tuning
        'output_path': os.path.join(experiment_dir, "phase2_model")
    })

    # train with combined loss
    decoder, project_features, history2 = train_model_enhanced(
        train_loader=train_loader,
        val_loader=val_loader,
        word2idx=word2idx,
        idx2word=idx2word,
        embed_size=phase2_config['embed_size'],
        hidden_size=phase2_config['hidden_size'],
        num_layers=phase2_config['num_layers'],
        learning_rate=phase2_config['learning_rate'],
        num_epochs=phase2_config['num_epochs'],
        early_stopping_patience=phase2_config['early_stopping_patience'],
        checkpoint_frequency=phase2_config['checkpoint_frequency'],
        model_path=phase1_best_model,  # Continue from phase 1
        output_path=phase2_config['output_path'],
        feature_dim=phase2_config['feature_dim'],
        clip_loss_weight=phase2_config['clip_loss_weight'],
        clip_batch_size=phase2_config['clip_batch_size'],
        clip_eval_frequency=phase2_config['clip_eval_frequency']
    )

    histories.append(history2)

    # save phase 2 training curves
    plot_training_curves(history2)
    plt.savefig(os.path.join(experiment_dir, "phase2_curves.png"))
    plt.close()

    # find best model from phase 2
    phase2_best_model = find_best_model(phase2_config['output_path'], "best_clip")

    # phase 3: CLIP fine-tuning
    print("\n" + "="*50)
    print("PHASE 3: CLIP-Only Fine-Tuning")
    print("="*50)

    phase3_config = config.copy()
    phase3_config.update({
        'num_epochs': 8,
        'clip_loss_weight': 0.9,  # high CLIP influence
        'learning_rate': 5e-5,  # very low learning rate for fine-tuning
        'output_path': os.path.join(experiment_dir, "phase3_model")
    })

    # train with high CLIP weight
    decoder, project_features, history3 = train_model_enhanced(
        train_loader=train_loader,
        val_loader=val_loader,
        word2idx=word2idx,
        idx2word=idx2word,
        embed_size=phase3_config['embed_size'],
        hidden_size=phase3_config['hidden_size'],
        num_layers=phase3_config['num_layers'],
        learning_rate=phase3_config['learning_rate'],
        num_epochs=phase3_config['num_epochs'],
        early_stopping_patience=phase3_config['early_stopping_patience'],
        checkpoint_frequency=phase3_config['checkpoint_frequency'],
        model_path=phase2_best_model,  # Continue from phase 2
        output_path=phase3_config['output_path'],
        feature_dim=phase3_config['feature_dim'],
        clip_loss_weight=phase3_config['clip_loss_weight'],
        clip_batch_size=phase3_config['clip_batch_size'],
        clip_eval_frequency=phase3_config['clip_eval_frequency']
    )

    histories.append(history3)

    # save phase 3 training curves
    plot_training_curves(history3)
    plt.savefig(os.path.join(experiment_dir, "phase3_curves.png"))
    plt.close()

    # find the final best model
    final_best_model = find_best_model(phase3_config['output_path'], "best_clip")

    # if no phase 3 model is better, use phase 2's best
    if not final_best_model:
        final_best_model = phase2_best_model

    # load the best overall model
    decoder, project_features, _ = load_model(
        final_best_model,
        phase3_config['embed_size'],
        len(word2idx),
        phase3_config['hidden_size'],
        phase3_config['num_layers'],
        device
    )

    # final evaluation with CLIP score
    print("\n" + "="*50)
    print("FINAL EVALUATION!")
    print("="*50)

    eval_dataset = CaptionEvaluationDataset(val_loader.dataset.features_dict)
    eval_loader = DataLoader(eval_dataset, batch_size=16, shuffle=False, num_workers=2)

    clip_calculator = CLIPCalculator()
    final_clip_score, _ = evaluate_model_with_clip_score(
        decoder, project_features, eval_loader, word2idx, idx2word,
        clip_calculator, CAPTIONS_FILE, IMAGE_FOLDER,
        train_loader, val_loader, max_eval_images=100
    )

    print(f"final CLIP Score: {final_clip_score:.4f}")

    # create and save a final comparison plot
    plot_training_phases_comparison(histories, experiment_dir)

    return decoder, project_features, histories

def plot_training_phases_comparison(histories, output_dir):
    """
    create a plot comparing metrics across training phases

    Args:
        histories: list of training histories for each phase
        output_dir: directory to save the plot
    """
    import matplotlib.pyplot as plt
    import numpy as np

    plt.figure(figsize=(15, 12))

    # plot 1: training loss across phases
    ax1 = plt.subplot(2, 2, 1)
    colors = ['b', 'g', 'r']

    for i, history in enumerate(histories):
        if 'train_loss' in history and history['train_loss']:
            # create epoch numbers for this phase
            epochs = np.arange(len(history['train_loss']))

            # offset epochs for phases after the first
            if i > 0:
                offset = sum(len(h['train_loss']) for h in histories[:i])
                epochs = epochs + offset

            ax1.plot(epochs, history['train_loss'], f'{colors[i]}-',
                    label=f'Phase {i+1} Training Loss')

    ax1.set_title('Training Loss Across Phases')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)

    # plot 2: validation loss across phases
    ax2 = plt.subplot(2, 2, 2)

    for i, history in enumerate(histories):
        if 'val_loss' in history and history['val_loss']:
            # create epoch numbers for this phase
            epochs = np.arange(len(history['val_loss']))

            # offset epochs for phases after the first
            if i > 0:
                offset = sum(len(h['val_loss']) for h in histories[:i])
                epochs = epochs + offset

            ax2.plot(epochs, history['val_loss'], f'{colors[i]}-',
                    label=f'Phase {i+1} Validation Loss')

    ax2.set_title('Validation Loss Across Phases')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.legend()
    ax2.grid(True)

    # plot 3: CLIP scores across phases
    ax3 = plt.subplot(2, 2, 3)

    for i, history in enumerate(histories):
        if 'clip_scores' in history and history['clip_scores'] and 'eval_epochs' in history:
            # get evaluation epochs for this phase
            eval_epochs = history['eval_epochs']

            # Offset epochs for phases after the first
            if i > 0:
                offset = sum(len(h['train_loss']) for h in histories[:i])
                eval_epochs = [e + offset for e in eval_epochs]

            ax3.plot(eval_epochs, history['clip_scores'], f'{colors[i]}-o',
                    label=f'Phase {i+1} CLIP Score')

    ax3.set_title('CLIP Scores Across Phases')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('CLIP Score')
    ax3.legend()
    ax3.grid(True)

    # plot 4: learning rates across phases
    ax4 = plt.subplot(2, 2, 4)

    for i, history in enumerate(histories):
        if 'learning_rates' in history and history['learning_rates']:
            steps = np.arange(len(history['learning_rates']))


            if i > 0:
                offset = sum(len(h['learning_rates']) for h in histories[:i])
                steps = steps + offset

            ax4.plot(steps, history['learning_rates'], f'{colors[i]}-',
                    label=f'Phase {i+1} Learning Rate')

    ax4.set_title('Learning Rate Schedule Across Phases')
    ax4.set_xlabel('Training Step')
    ax4.set_ylabel('Learning Rate')
    ax4.set_yscale('log')
    ax4.legend()
    ax4.grid(True)

    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'all_phases_comparison.png'), dpi=150)
    plt.close()

In [None]:
# part 7: data preperation

def extract_and_cache_features_threadpool(image_folder, captions_dict, cache_path=None, force_reload=False, num_workers=4, batch_size=64):
    """
    extract image features using ResNet18 with ThreadPoolExecutor and cache them to disk.
    this makes sure we don't have to extract the image features every time we want to start training.
    and also its faster becaue i have to pay for colab aaa.

    tried to optimize it for Google Colab T4 GPUs.
    """
    import os
    import torch
    import torchvision.models as models
    import torchvision.transforms as transforms
    from PIL import Image
    from tqdm.notebook import tqdm
    import time
    import concurrent.futures
    import threading

    # set default cache path if not provided
    if cache_path is None:
        cache_path = os.path.join(BASE_PATH, "cached_features.pt")

    # check if cache exists and load it
    all_features_dict = {}
    if not force_reload and os.path.exists(cache_path):
        print(f"loading cached features from {cache_path}...")
        try:
            all_features_dict = torch.load(cache_path)
            print(f"loaded features for {len(all_features_dict)} images from cache")
        except Exception as e:
            print(f"error loading cache: {e} :( ")
            all_features_dict = {}

    # create a new features_dict that only contains images from captions_dict
    print("synchronizing feature and caption dictionaries...")
    features_dict = {}
    missing_images = []

    # first, add features from cache for images in captions_dict
    for img_id in captions_dict.keys():
        if img_id in all_features_dict:
            features_dict[img_id] = all_features_dict[img_id]
        else:
            missing_images.append(img_id)

    print(f"using {len(features_dict)} cached features")
    print(f"need to extract features for {len(missing_images)} more images")

    # if we already have all features we need, return them
    if not missing_images:
        print("All required features found in cache!")
        return features_dict

    # otherwise, we need to extract features for missing images
    # load resnet18 and remove final classification layer
    print("loading resnet 18 model...")
    resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    resnet = torch.nn.Sequential(*list(resnet.children())[:-1])
    resnet.eval().to(device)

    # image transformation
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
    ])

    def load_image(img_name):
        """load and preprocess a single image"""
        img_path = os.path.join(image_folder, f"{img_name}.jpg")
        if not os.path.exists(img_path):
            return None, None

        try:
            image = Image.open(img_path).convert("RGB")
            img_tensor = transform(image)
            return img_name, img_tensor
        except Exception as e:
            print(f"Error loading {img_name}: {e}")
            return None, None

    # start extraction timer
    print(f"extracting features for {len(missing_images)} images using {num_workers} threads and batch size {batch_size}...")
    start_time = time.time()

    # result dictionary and lock for thread safety
    result_lock = threading.Lock()

    # process images in batches
    extracted_count = 0
    with tqdm(total=len(missing_images)) as pbar:
        # process all images in batches
        for i in range(0, len(missing_images), batch_size):
            batch_names = missing_images[i:i + batch_size]

            # load images in parallel
            loaded_images = []
            with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
                for result in executor.map(load_image, batch_names):
                    if result[0] is not None:  # skip failed images
                        loaded_images.append(result)

            # skip empty batches
            if not loaded_images:
                pbar.update(len(batch_names))
                continue

            # separate names and tensors
            batch_img_names = [img[0] for img in loaded_images]
            batch_tensors = [img[1] for img in loaded_images]

            # process batch on GPU
            try:
                # stack tensors and move to GPU
                stacked_tensors = torch.stack(batch_tensors).to(device)

                # extract features in a single forward pass
                with torch.no_grad():
                    features = resnet(stacked_tensors).squeeze(-1).squeeze(-1).cpu()

                # save features to dictionanry
                with result_lock:
                    for j, img_name in enumerate(batch_img_names):
                        features_dict[img_name] = features[j]
                        all_features_dict[img_name] = features[j]  # update the full cache
                        extracted_count += 1

                    # save checkpoint every 1000 images (you live and you learn...)
                    if extracted_count % 1000 == 0:
                        print(f"\nsaving checkpoint after {extracted_count} images...")
                        torch.save(all_features_dict, cache_path + ".temp")
            except Exception as e:
                print(f"error processing batch: {e}")

            # update progress
            pbar.update(len(batch_names))

    # save the final features dictionary
    print(f"extracted features for {extracted_count} images in {time.time() - start_time:.2f} seconds")
    print(f"ttotal features in cache: {len(all_features_dict)}")
    print(f"features for this run: {len(features_dict)}")
    print(f"saving features to {cache_path}...")
    torch.save(all_features_dict, cache_path)

    # remove temporary file if it exists
    if os.path.exists(cache_path + ".temp"):
        os.remove(cache_path + ".temp")

    return features_dict


def prepare_data_with_cache(captions_file, image_folder, max_images=19000, cache_path=None, force_reload=False, num_workers=8, batch_size=64):
    """
    prepare data for training, with feature caching support this time.
    """
    print("loading captions...")
    # load and clean captions
    df = pd.read_csv(captions_file, sep="\t", names=["image", "caption"])
    df = df.dropna()
    df["image"] = df["image"].astype(str).str.strip()

    # remove .jpg extension if still there
    df["image"] = df["image"].apply(lambda x: x[:-4] if x.endswith('.jpg') else x)

    # build dictionary: image -> [captions]
    captions_dict = {}
    for _, row in df.iterrows():
        img_name = row["image"]
        caption = row["caption"]
        if img_name not in captions_dict:
            captions_dict[img_name] = []
        captions_dict[img_name].append(caption)

    # limit to max_images if needed
    if len(captions_dict) > max_images:
        print(f"limiting dataset from {len(captions_dict)} to {max_images} images")
        random.seed(42)
        subset_ids = random.sample(list(captions_dict.keys()), max_images)
        subset_captions_dict = {img_id: captions_dict[img_id] for img_id in subset_ids}
        captions_dict = subset_captions_dict

    print(f"using {len(captions_dict)} images with captions")

    # make vocabulary
    all_captions = sum(captions_dict.values(), [])
    words = [word for caption in all_captions for word in caption.split()]
    most_common = Counter(words).most_common(4900)

    # special tokens first
    vocab = ["<PAD>", "< SOS >", "<EOS>", "<UNK>"] + [w for w, _ in most_common]
    word2idx = {w: i for i, w in enumerate(vocab)}
    idx2word = {i: w for w, i in word2idx.items()}
    vocab_size = len(vocab)
    print(f"vocab size: {vocab_size}")

    # extract image features with caching
    print("extracting image features (with multithreaded caching)...")
    features_dict = extract_and_cache_features_threadpool(
        image_folder,
        captions_dict,
        cache_path=cache_path,
        force_reload=force_reload,
        num_workers=num_workers,
        batch_size=batch_size
    )


    caption_keys = set(captions_dict.keys())
    feature_keys = set(features_dict.keys())

    print(f"caption keys: {len(caption_keys)}, feature keys: {len(feature_keys)}")

    # find any discrepancies
    missing_captions = feature_keys - caption_keys
    missing_features = caption_keys - feature_keys

    if missing_captions:
        print(f"WARNING: {len(missing_captions)} images have features but no captions")
        # remove them from features_dict
        for img_id in missing_captions:
            if img_id in features_dict:
                del features_dict[img_id]

    if missing_features:
        print(f"WARNING: {len(missing_features)} images have captions but no features")
        # remove them from captions_dict
        for img_id in missing_features:
            if img_id in captions_dict:
                del captions_dict[img_id]

    # final check
    print(f"final dataset: {len(features_dict)} images with both features and captions")

    return features_dict, captions_dict, word2idx, idx2word

In [None]:


def train_model_enhanced(
        train_loader, val_loader, word2idx, idx2word,
        embed_size=256, hidden_size=512, num_layers=4,
        learning_rate=0.0003, num_epochs=20,
        early_stopping_patience=5, checkpoint_frequency=1,
        model_path=None, output_path=None, feature_dim=512,
        clip_loss_weight=0.5, clip_batch_size=16,
        clip_eval_frequency=50, warmup_ratio=0.1,
        weight_decay=0.01):
    """
    enhanced training function that integrates CLIP scores using policy gradients

    Args:
        train_loader: dataloader for training
        val_loader: dataloader for validation
        word2idx, idx2word: vocabulary mappings
        embed_size, hidden_size, num_layers: model architecture parameters
        learning_rate: learning rate
        num_epochs: nr of epochs
        early_stopping_patience: nr of epochs to wait before early stopping
        checkpoint_frequency: how often to save checkpoints (in epochs)
        model_path: path to load pre-trained model (if resuming training)
        output_path: where to save models
        feature_dim: dimension of image features (512 for ResNet18, 2048 for ResNet50)
        clip_loss_weight: weight for the CLIP-based loss component (0-1)
        clip_batch_size: nr of samples for CLIP evaluation in each batch
        clip_eval_frequency: how often to evaluate CLIP loss (in training steps)
        warmup_ratio: portion of training to use for warmup
        weight_decay: weight decay coefficient

    Returns:
        trained decoder and project_features models, training history
    """
    import time
    from datetime import datetime, timedelta
    from tqdm.notebook import tqdm

    # initialize training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'clip_scores': [],
        'clip_batch_scores': [],
        'learning_rates': [],
        'eval_epochs': [],
        'best_val_loss': float('inf'),
        'best_clip_score': 0.0,
        'epochs_without_improvement': 0,
        'start_time': time.time(),
        'total_training_time': 0
    }

    vocab_size = len(word2idx)

    # initialize CLIP calculator
    clip_calculator = CLIPCalculator()

    # cache reference captions and image paths
    print("loading and caching reference captions...")
    train_image_ids = set(train_loader.dataset.image_names)
    val_image_ids = set(val_loader.dataset.image_names)
    all_used_image_ids = train_image_ids.union(val_image_ids)
    cached_references, cached_image_paths = load_references(
        CAPTIONS_FILE, IMAGE_FOLDER, filter_ids=all_used_image_ids)
    print(f"cached {len(cached_references)} reference cations")

    # initialize or load modsels
    if model_path:
        try:
            decoder, project_features, metadata = load_model(
                model_path, embed_size, vocab_size, hidden_size, num_layers, device
            )

            # check if feature dimensions match
            if project_features.weight.size(1) != feature_dim:
                print(f"warning: model expects {project_features.weight.size(1)}-dim features, but we have {feature_dim}-dim features")
                print("re-initializing projection layer...")
                project_features = nn.Linear(feature_dim, embed_size).to(device)

            # Load training history if available
            if 'metrics' in metadata and isinstance(metadata['metrics'], dict):
                for key in history:
                    if key in metadata['metrics']:
                        history[key] = metadata['metrics'][key]
                print(f"continue training from epoch {len(history['train_loss'])+1}")
        except Exception as e:
            print(f"error loading model: {e}")
            print("initializing new model...")
            decoder = Transformer_Decoder(embed_size, vocab_size, hidden_size, num_layers).to(device)
            project_features = nn.Linear(feature_dim, embed_size).to(device)
    else:
        # initialize new models
        decoder = Transformer_Decoder(embed_size, vocab_size, hidden_size, num_layers).to(device)
        project_features = nn.Linear(feature_dim, embed_size).to(device)
        print("initialized new model")

    # create optimizer and scheduler
    total_steps = len(train_loader) * num_epochs
    optimizer, scheduler = create_optimizer_and_scheduler(
        list(decoder.parameters()) + list(project_features.parameters()),
        learning_rate=learning_rate,
        num_training_steps=total_steps,
        warmup_ratio=warmup_ratio,
        weight_decay=weight_decay
    )

    # cross-entropy loss for supervised training
    criterion = nn.CrossEntropyLoss(ignore_index=word2idx["<PAD>"])

    # evaluation dataset for CLIP scoring
    eval_dataset = CaptionEvaluationDataset(val_loader.dataset.features_dict)
    eval_loader = DataLoader(eval_dataset, batch_size=16, shuffle=False, num_workers=2)

    print(f"starting training with {num_epochs} epochs")
    print(f"feature dimension: {feature_dim}")
    print(f"architecture: {num_layers} layers, {hidden_size} hidden dim, {embed_size} embed dim")
    print(f"CLIP loss weight: {clip_loss_weight}")
    print(f"evaluating CLIP every {clip_eval_frequency} steps")

    start_epoch = len(history['train_loss'])
    global_step = 0

    for epoch in range(start_epoch, start_epoch + num_epochs):
        epoch_start_time = time.time()

        # training phase
        decoder.train()
        project_features.train()
        total_ce_loss = 0
        total_clip_loss = 0
        total_combined_loss = 0
        epoch_clip_scores = []
        num_clip_batches = 0

        progress_bar = tqdm(train_loader, desc=f"epoch {epoch+1}/{start_epoch+num_epochs} [Train]")
        for batch_idx, (features, captions) in enumerate(progress_bar):
            features, captions = features.to(device), captions.to(device)

            # standard cross-entropy loss
            optimizer.zero_grad()
            projected = project_features(features)
            output = decoder(projected, captions[:, :-1])
            ce_loss = criterion(output.reshape(-1, vocab_size), captions[:, 1:].reshape(-1))

            # determine whether to compute CLIP reward loss
            compute_clip = (global_step % clip_eval_frequency == 0) and (clip_loss_weight > 0)
            clip_loss = torch.tensor(0.0, device=device)
            clip_score = 0.0

            # compute policy gradient loss with CLIP rewards (occasionally)
            if compute_clip:
                # select a subset of images for CLIP evaluation
                clip_eval_indices = random.sample(
                    range(len(features)), min(clip_batch_size, len(features)))

                clip_features = features[clip_eval_indices]
                img_ids = [train_loader.dataset.image_names[batch_idx * train_loader.batch_size + i]
                          for i in clip_eval_indices]

                # calculate policy gradient loss with CLIP as reward
                clip_loss, clip_score = compute_clip_reward_loss(
                    clip_features, img_ids, decoder, project_features,
                    word2idx, idx2word, clip_calculator,
                    cached_references, cached_image_paths, device
                )

                if clip_score > 0:
                    epoch_clip_scores.append(clip_score)
                    num_clip_batches += 1

            # combine losses
            if compute_clip and clip_score > 0:
                combined_loss = (1 - clip_loss_weight) * ce_loss + clip_loss_weight * clip_loss
                total_clip_loss += clip_loss.item()
            else:
                combined_loss = ce_loss

            # backprop and optimization
            combined_loss.backward()

            # gradient clipping
            torch.nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(project_features.parameters(), max_norm=1.0)

            # update weights
            optimizer.step()
            scheduler.step()

            # track current learning rate
            current_lr = scheduler.get_last_lr()[0]

            # update progress bar and tracking variables
            total_ce_loss += ce_loss.item()
            total_combined_loss += combined_loss.item()

            if compute_clip and clip_score > 0:
                progress_bar.set_postfix(
                    CE_Loss=f"{ce_loss.item():.4f}",
                    CLIP_Loss=f"{clip_loss.item():.4f}",
                    CLIP_Score=f"{clip_score:.4f}",
                    LR=f"{current_lr:.6f}"
                )
            else:
                progress_bar.set_postfix(
                    CE_Loss=f"{ce_loss.item():.4f}",
                    LR=f"{current_lr:.6f}"
                )

            global_step += 1

        # calculate average losses for the epoch
        avg_ce_loss = total_ce_loss / len(train_loader)
        avg_combined_loss = total_combined_loss / len(train_loader)

        # calculate average CLIP score if available
        avg_clip_score = sum(epoch_clip_scores) / num_clip_batches if num_clip_batches > 0 else 0.0

        # update the history
        history['train_loss'].append(avg_ce_loss)
        history['clip_batch_scores'].append(avg_clip_score)
        history['learning_rates'].append(current_lr)

        # validation phase
        decoder.eval()
        project_features.eval()
        total_val_loss = 0

        with torch.no_grad():
            for features, captions in tqdm(val_loader, desc=f"Epoch {epoch+1}/{start_epoch+num_epochs} [Val]"):
                features, captions = features.to(device), captions.to(device)
                projected = project_features(features)
                output = decoder(projected, captions[:, :-1])
                loss = criterion(output.reshape(-1, vocab_size), captions[:, 1:].reshape(-1))
                total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_loader)
        history['val_loss'].append(avg_val_loss)

        # calculate epoch time
        epoch_time = time.time() - epoch_start_time
        total_time = time.time() - history['start_time']
        history['total_training_time'] = total_time

        # print progress summary
        print(f"epoch {epoch+1}/{start_epoch+num_epochs} - "
              f"train CE Loss: {avg_ce_loss:.4f}, "
              f"val Loss: {avg_val_loss:.4f}, "
              f"CLIP Score: {avg_clip_score:.4f}, "
              f"time: {epoch_time:.1f}s, "
              f"total: {timedelta(seconds=int(total_time))}")

        # save checkpoint if needed
        if (epoch + 1) % checkpoint_frequency == 0:
            save_model(
                decoder, project_features, history, output_path,
                model_type=f"checkpoint"
            )

        # track best validation loss
        if avg_val_loss < history['best_val_loss']:
            history['best_val_loss'] = avg_val_loss
            history['epochs_without_improvement'] = 0
            save_model(
                decoder, project_features, history, output_path,
                model_type="best_val"
            )
            print(f"new best model saved based on validation loss: {avg_val_loss:.4f}")
        else:
            history['epochs_without_improvement'] += 1

        # Evaluate with CLIP score every few epochs or at the end
        evaluate_clip = ((epoch + 1) % 2 == 0) or (epoch == start_epoch + num_epochs - 1)

        if evaluate_clip:
            print("evaluating with CLIP score...")
            clip_score, generated_captions = evaluate_model_with_clip_score(
                decoder, project_features, eval_loader, word2idx, idx2word,
                clip_calculator, CAPTIONS_FILE, IMAGE_FOLDER,
                train_loader, val_loader, max_eval_images=100,
                cached_refs=cached_references, cached_paths=cached_image_paths
            )

            history['clip_scores'].append(clip_score)
            history['eval_epochs'].append(epoch)

            # Save model if CLIP score improved
            if clip_score > history['best_clip_score']:
                history['best_clip_score'] = clip_score
                save_model(
                    decoder, project_features, history, output_path,
                    model_type="best_clip"
                )
                print(f"new best model saved based on CLIP score: {clip_score:.4f}")

        # Sample some captions every few epochs
        if (epoch + 1) % 2 == 0:
            print("\nsample captions:")
            sample_features, _ = next(iter(val_loader))
            for i in range(min(3, len(sample_features))):
                feature = sample_features[i].unsqueeze(0).to(device)
                caption = beam_search_caption(
                    feature, decoder, project_features,
                    word2idx, idx2word, device, beam_width=3
                )
                print(f"sampel {i+1}: {caption}")
            print()

        # Early stopping check
        if history['epochs_without_improvement'] >= early_stopping_patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

    # Plot training curves
    plot_training_curves(history)

    # Load the best model (using CLIP score)
    best_model_path = find_best_model(output_path, "best_clip")
    if best_model_path:
        print(f"loading best model from {best_model_path}")
        decoder, project_features, _ = load_model(
            best_model_path, embed_size, vocab_size, hidden_size, num_layers, device
        )

    return decoder, project_features, history

def plot_training_curves(history):
    """
    plot enhanced training curves with more metrics

    Args:
        history: training history dictionary
    """
    import matplotlib.pyplot as plt
    import numpy as np

    # multi-panel figure
    fig = plt.figure(figsize=(18, 12))

    # 1: loss curves
    ax1 = fig.add_subplot(2, 2, 1)
    ax1.plot(history['train_loss'], 'b-', label='Training CE Loss')
    ax1.plot(history['val_loss'], 'r-', label='Validation Loss')

    if 'best_val_loss' in history:
        ax1.axhline(y=history['best_val_loss'], color='r', linestyle='--',
                   label=f'Best Val Loss: {history["best_val_loss"]:.4f}')

    ax1.set_title('Loss Curves')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)

    # 2: CLIP scores
    ax2 = fig.add_subplot(2, 2, 2)

    # plot CLIP batch scores if available
    if 'clip_batch_scores' in history and any(history['clip_batch_scores']):
        clip_batch_x = list(range(len(history['clip_batch_scores'])))
        ax2.plot(clip_batch_x, history['clip_batch_scores'], 'g-', alpha=0.5,
                label='Training CLIP Scores')

    # lot evaluation CLIP scores if available
    if 'clip_scores' in history and 'eval_epochs' in history and history['clip_scores']:
        ax2.plot(history['eval_epochs'], history['clip_scores'], 'g-o',
                label='Evaluation CLIP Scores')

        if 'best_clip_score' in history:
            ax2.axhline(y=history['best_clip_score'], color='r', linestyle='--',
                       label=f'Best CLIP Score: {history["best_clip_score"]:.4f}')

    ax2.set_title('CLIP Score Progression')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('CLIP Score')
    ax2.legend()
    ax2.grid(True)

    # 3: lr
    ax3 = fig.add_subplot(2, 2, 3)
    if 'learning_rates' in history and history['learning_rates']:
        ax3.plot(history['learning_rates'], 'c-')
        ax3.set_title('Learning Rate Schedule')
        ax3.set_xlabel('Step')
        ax3.set_ylabel('Learning Rate')
        ax3.set_yscale('log')  # Log scale for better visualization
        ax3.grid(True)

    # 4: combined metrics (optional)
    ax4 = fig.add_subplot(2, 2, 4)

    # secondary axis for CLIP score
    if ('clip_scores' in history and history['clip_scores'] and
        'eval_epochs' in history and 'val_loss' in history):

        # plot validation loss on primary axis
        epochs = list(range(len(history['val_loss'])))
        line1 = ax4.plot(epochs, history['val_loss'], 'r-', label='Validation Loss')
        ax4.set_xlabel('Epoch')
        ax4.set_ylabel('Validation Loss', color='r')
        ax4.tick_params(axis='y', labelcolor='r')

        # secondary axis for CLIP score
        ax4_twin = ax4.twinx()
        line2 = ax4_twin.plot(history['eval_epochs'], history['clip_scores'], 'g-o',
                             label='CLIP Score')
        ax4_twin.set_ylabel('CLIP Score', color='g')
        ax4_twin.tick_params(axis='y', labelcolor='g')

        # combine legends
        lines = line1 + line2
        labels = [l.get_label() for l in lines]
        ax4.legend(lines, labels, loc='upper right')

        ax4.set_title('Validation Loss vs CLIP Score')
        ax4.grid(True)

    plt.tight_layout()
    plt.savefig(f"{BASE_PATH}/enhanced_training_curves.png", dpi=150)
    plt.show()

In [None]:
def main_pipeline():
    """
    main execution pipeline for the improved image captioning fine-tuning
    """
    import os
    import torch
    import random
    import numpy as np
    from torch.utils.data import DataLoader
    from datetime import datetime

    print("\n" + "="*60)
    print(f"starting image captioning fine-tuning pipeline at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"using: {device}")
    print("="*60)

    # set random seeds for reproducibility
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)

    # base configuration
    base_config = {
        'captions_file': CAPTIONS_FILE,
        'image_folder': f"{BASE_PATH}/images/loadedimages",
        'max_images': 9000,
        'batch_size': 64,
        'embed_size': 256,
        'hidden_size': 768,
        'num_layers': 6,
        'learning_rate': 0.0003,
        'num_epochs': 20,
        'early_stopping_patience': 5,
        'checkpoint_frequency': 2,
        'num_workers': 4,
        'feature_extraction_workers': 96,
        'feature_extraction_batch_size': 128,
        'clip_loss_weight': 0.3,
        'clip_batch_size': 16,
        'clip_eval_frequency': 100,  # evaluate less frequently to speed up training
        'seed': 42,
        'feature_dim': 512,  # specific for resnet18!!
        'feature_cache_path': os.path.join(BASE_PATH, "cached_features.pt"),
        'force_reload_features': False,
    }

    print("\nConfiguration:")
    for key, value in base_config.items():
        print(f"  {key}: {value}")

    # prepare data
    print("\n preparing data...")
    features_dict, captions_dict, word2idx, idx2word = prepare_data_with_cache(
        base_config['captions_file'],
        base_config['image_folder'],
        max_images=base_config['max_images'],
        cache_path=base_config['feature_cache_path'],
        force_reload=base_config['force_reload_features'],
        num_workers=base_config['feature_extraction_workers'],
        batch_size=base_config['feature_extraction_batch_size']
    )

    # verify feature dimension
    sample_feature = next(iter(features_dict.values()))
    actual_feature_dim = sample_feature.size(0)
    print(f"detected feature dimension: {actual_feature_dim}")
    if actual_feature_dim != base_config['feature_dim']:
        print(f"updating feature dimension from {base_config['feature_dim']} to {actual_feature_dim}")
        base_config['feature_dim'] = actual_feature_dim

    # create datasets
    print("\n creating datasets...")
    train_dataset = CaptionFeatureDataset(
        features_dict, captions_dict, word2idx, max_len=22, split='train'
    )
    val_dataset = CaptionFeatureDataset(
        features_dict, captions_dict, word2idx, max_len=22, split='val'
    )

    # create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=base_config['batch_size'],
        shuffle=True,
        num_workers=base_config['num_workers'],
        pin_memory=True if device == "cuda" else False
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=base_config['batch_size'],
        shuffle=False,
        num_workers=base_config['num_workers'],
        pin_memory=True if device == "cuda" else False
    )

    print(f"training set: {len(train_dataset)} samples")
    print(f"validation set: {len(val_dataset)} samples")

    # train in phases
    print("\n running multi-phase training...")
    decoder, project_features, histories = train_in_phases(
        train_loader, val_loader, word2idx, idx2word, base_config
    )

    # final evaluation
    print("\n performing final evaluation...")
    # make evaluation dataset
    eval_dataset = CaptionEvaluationDataset(val_loader.dataset.features_dict)
    eval_loader = DataLoader(
        eval_dataset, batch_size=16, shuffle=False, num_workers=2
    )

    # evaluate with CLIP score
    clip_calculator = CLIPCalculator()
    final_clip_score, generated_captions = evaluate_model_with_clip_score(
        decoder, project_features, eval_loader, word2idx, idx2word,
        clip_calculator, CAPTIONS_FILE, IMAGE_FOLDER,
        train_loader, val_loader, max_eval_images=100
    )

    print(f"\n final CLIP Score: {final_clip_score:.4f}")

    print("\n" + "="*60)
    print("image captioning fine-tuning pipeline complete!")
    print("="*60)

    return decoder, project_features, histories


if __name__ == "__main__":
    main_pipeline()