# Exploration of Sampling Techniques

This notebook explores sampling techniques for the problem of image captioning. The notebooks tests beam search, argmax, and top-p sampling with the trained VLM models. 

Note: the experiments must be run to create saved model weights before running this notebook. 

In [None]:
import os
import sys
import torch
import torch.nn as nn
from transformers import AutoTokenizer
import datetime
import logging
logging.getLogger("transformers").setLevel(logging.ERROR)

"""
Add support for either running in collab by uploading this notebook and
mounting the directory or locally from the room or experiments folder"
"""
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')

    # NOTE: change the drive path if running with a mounted google drive in collab
    project_root = "/content/drive/Othercomputers/My MacBook Pro/image-captioning"
else:
    cwd = os.getcwd()

    if cwd.endswith("experiments"):
        project_root = os.path.abspath(os.path.join(cwd, '..'))
    else:
        project_root = cwd

if project_root not in sys.path:
    sys.path.append(project_root)

print("Project root:", project_root)

if IN_COLAB:
    !pip install evaluate > /dev/null 2>&1
    !pip install pycocoevalcap > /dev/null 2>&1


from vision_language_model import VisionLanguageModel
import train as train
import data_processing as dp
import download_data as get_data
import evaluation as eval


device = torch.device("cuda" if torch.cuda.is_available() else
                      "mps" if torch.backends.mps.is_available() else "cpu")

experiment = "sampling_exploration"
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

In [None]:
# download data, keep it outside of the mounted directory if running in collab to avoid data transfer overhead
if IN_COLAB:
    data_dir = "/content/flickr30k_data"
else:
    data_dir = os.path.join(project_root, "flickr30k_data")

# only download data if it does not already exist
if not os.path.exists(data_dir) or not os.listdir(data_dir):
    os.makedirs(data_dir, exist_ok=True)
    get_data.download_and_partition(data_dir)
else:
    print(f"Data already exists in {data_dir}, skipping download.")

# setup saving directories
model_weights_dir = os.path.join(project_root, "model_weights")
evaluations_dir = os.path.join(project_root, "evaluations")
os.makedirs(model_weights_dir, exist_ok=True)
os.makedirs(evaluations_dir, exist_ok=True)

## ViT/GPT2 Sampling Exploration

In [None]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# model config for stage one
model = VisionLanguageModel(
    image_encoder_type="vit",
    llava_projections=True,
    cross_attention=False,
    debug=False,
    decoder_type="gpt2",
    d_model=768,
    tokenizer=tokenizer
    ).to(device)

model_path=os.path.join(model_weights_dir, "experiment_1_finetune_weights_2025-04-13_23-08-39.pt")
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)

### Beam Search: 2 beams

In [None]:
test_loader = dp.batch_stream("captions.txt", os.path.join(data_dir, "test"), batch_size=4, eval_mode=True, seed=32)

bleu, cider = eval.evaluate_bleu_cider(
    model=model,
    data_loader=test_loader,
    display_captions=True,
    save_captions_path=os.path.join(evaluations_dir, f"{experiment}_captions_{timestamp}.jpg"),
    max_new_tokens=15,
    max_batches=250,
    num_beams=2
)

print(f"BLEU: {bleu:.4f}, CIDEr: {cider:.4f}")

### Top P Sampling: p=0.8, temperature=0.7

In [None]:
test_loader = dp.batch_stream("captions.txt", os.path.join(data_dir, "test"), batch_size=4, eval_mode=True, seed=32)

bleu, cider = eval.evaluate_bleu_cider(
    model=model,
    data_loader=test_loader,
    display_captions=True,
    save_captions_path=os.path.join(evaluations_dir, f"{experiment}_captions_{timestamp}.jpg"),
    max_new_tokens=15,
    max_batches=250,
    do_sample=True,
    top_p=0.8,
    temperature=0.7
)

print(f"BLEU: {bleu:.4f}, CIDEr: {cider:.4f}")

### Argmax: 1 beam

In [None]:
test_loader = dp.batch_stream("captions.txt", os.path.join(data_dir, "test"), batch_size=4, eval_mode=True, seed=32)

bleu, cider = eval.evaluate_bleu_cider(
    model=model,
    data_loader=test_loader,
    display_captions=True,
    save_captions_path=os.path.join(evaluations_dir, f"{experiment}_captions_{timestamp}.jpg"),
    max_new_tokens=15,
    max_batches=250,
    num_beams=1
)

print(f"BLEU: {bleu:.4f}, CIDEr: {cider:.4f}")

## Clip/Vicuna 7b Sampling Exploration

In [None]:
tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
tokenizer.pad_token = tokenizer.eos_token

# model config for stage one
model = VisionLanguageModel(
    image_encoder_type="clip",
    llava_projections=True,
    cross_attention=False,
    debug=False,
    decoder_type="vicuna",
    d_model=4096,
    tokenizer=tokenizer
    ).to(device)

loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

model_path=os.path.join(model_weights_dir, "experiment_4_finetune_weights_2025-04-14_14-56-03.pt")
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)

### Beam Search: 2 beams

In [None]:
test_loader = dp.batch_stream("captions.txt", os.path.join(data_dir, "test"), batch_size=2, eval_mode=True, seed=32)

bleu, cider = eval.evaluate_bleu_cider(
    model=model,
    data_loader=test_loader,
    display_captions=True,
    save_captions_path=os.path.join(evaluations_dir, f"{experiment}_captions_{timestamp}.jpg"),
    max_new_tokens=15,
    max_batches=500,
    num_beams=2
)

print(f"BLEU: {bleu:.4f}, CIDEr: {cider:.4f}")

### Top P Sampling: p=0.8, temperature=0.7

In [None]:
test_loader = dp.batch_stream("captions.txt", os.path.join(data_dir, "test"), batch_size=2, eval_mode=True, seed=32)

bleu, cider = eval.evaluate_bleu_cider(
    model=model,
    data_loader=test_loader,
    display_captions=True,
    save_captions_path=os.path.join(evaluations_dir, f"{experiment}_captions_{timestamp}.jpg"),
    max_new_tokens=15,
    max_batches=500,
    do_sample=True,
    top_p=0.8,
    temperature=0.7
)

print(f"BLEU: {bleu:.4f}, CIDEr: {cider:.4f}")

### Argmax: 1 beam

In [None]:
test_loader = dp.batch_stream("captions.txt", os.path.join(data_dir, "test"), batch_size=2, eval_mode=True, seed=32)

bleu, cider = eval.evaluate_bleu_cider(
    model=model,
    data_loader=test_loader,
    display_captions=True,
    save_captions_path=os.path.join(evaluations_dir, f"{experiment}_captions_{timestamp}.jpg"),
    max_new_tokens=15,
    max_batches=500,
    num_beams=1
)

print(f"BLEU: {bleu:.4f}, CIDEr: {cider:.4f}")