In [7]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import AutoModel
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
from tqdm import tqdm
import torch
import os
from multiprocessing import Process, Queue, set_start_method
device = "cuda" if torch.cuda.is_available() else "cpu"
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
set_start_method('spawn', force=True)
logger.info("Multiprocessing start method set to 'spawn'.")

2025-06-17 00:54:40,712 - INFO - Multiprocessing start method set to 'spawn'.


# 1. Load models trained with DPO and SimPO

In [2]:
model_configs = [
    {"name": "princeton-nlp/Mistral-7B-Instruct-DPO", "gpu_id": 0, "alias": "Mistral-DPO"},
    {"name": "princeton-nlp/Mistral-7B-Instruct-SimPO", "gpu_id": 1, "alias": "Mistral-SimPO"},
]
torch_dtype = torch.bfloat16
loaded_models = {}
loaded_tokenizers = {}
for model_config in model_configs:
    model_name = model_config["name"]
    gpu_id = model_config["gpu_id"]
    alias = model_config["alias"]
    current_device = torch.device(f"cuda:{gpu_id}")

    logger.info(f"Loading model {model_name} on GPU {gpu_id}")
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="flash_attention_2", torch_dtype=torch_dtype).to(current_device)

    loaded_models[alias] = model
    loaded_tokenizers[alias] = tokenizer
    logger.info(f"Model {alias} loaded successfully")

2025-06-17 00:41:33,301 - INFO - Loading model princeton-nlp/Mistral-7B-Instruct-DPO on GPU 0
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 74.54it/s]
2025-06-17 00:41:38,704 - INFO - Model Mistral-DPO loaded successfully
2025-06-17 00:41:38,705 - INFO - Loading model princeton-nlp/Mistral-7B-Instruct-SimPO on GPU 1
Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 44.89it/s]
2025-06-17 00:41:42,013 - INFO - Model Mistral-SimPO loaded successfully


# 2. Experiments
## 2.1 Relation between Log probability and |y|

In [8]:
def run_model_inference(gpu_id, model_name, alias, prompt, num_generations, request_queue, response_queue):
    try:
        device = torch.device(f"cuda:{gpu_id}")
        logger.info(f"Process {os.getpid()}: Loading model '{alias}' ({model_name}) on {device}...")

        tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            attn_implementation="flash_attention_2",
            torch_dtype=torch_dtype,
            device_map=device, 
            trust_remote_code=True,
        )

        logger.info(f"Process {os.getpid()}: Model '{alias}' loaded successfully and is on device: {model.device}")

        log_likelihoods = []
        response_lengths = []

        prompt_inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(model.device)
        prompt_length = prompt_inputs.input_ids.shape[1]

        with torch.no_grad():
            for _ in tqdm(range(num_generations), desc=f"Generating with {alias}"):
                max_len = np.random.randint(20, 200)
                generated_ids = model.generate(
                    **prompt_inputs,
                    max_new_tokens=max_len,
                    do_sample=True,
                    top_k=50,
                    top_p=0.95,
                    temperature=0.7,
                    pad_token_id=tokenizer.eos_token_id
                )

                if generated_ids.shape[1] <= prompt_length:
                    continue

                response_ids = generated_ids[:, prompt_length:]
                response_length = response_ids.shape[1]

                if response_length == 0: 
                    continue

                full_ids = generated_ids 
                outputs = model(full_ids, labels=full_ids) 
                logits = outputs.logits 

                logits_for_response = logits[:, prompt_length - 1 : -1, :] 
                
                if logits_for_response.shape[1] != response_ids.shape[1]:
            
                    log_probs_full = torch.nn.functional.log_softmax(logits, dim=-1)
                    true_token_log_probs_full = log_probs_full.gather(
                        dim=-1, index=generated_ids.unsqueeze(-1)
                    ).squeeze(-1) 

                    avg_log_prob = true_token_log_probs_full[:, prompt_length:].mean().item()

                else: 
                    log_probs = torch.nn.functional.log_softmax(logits_for_response, dim=-1)
                    true_token_log_probs = torch.gather(log_probs, 2, response_ids.unsqueeze(-1)).squeeze(-1)
                    avg_log_prob = true_token_log_probs.mean().item()

                response_lengths.append(response_length)
                log_likelihoods.append(avg_log_prob)

        response_queue.put((alias, log_likelihoods, response_lengths))

    except Exception as e:
        logger.error(f"Process {os.getpid()}: Error in model '{alias}': {e}")
        response_queue.put((alias, None, str(e))) 

In [None]:
prompt_text = "Once upon a time in a land far, far away, there lived a dragon who"
num_generations = 100
results = {}
processes = []
response_queue = Queue() 

for config in model_configs:
    p = Process(
        target=run_model_inference,
        args=(
            config["gpu_id"],
            config["name"],
            config["alias"],
            prompt_text,
            num_generations,
            Queue(),
            response_queue 
        )
    )
    processes.append(p)
    p.start()

logger.info("Main process: All model processes started. Waiting for results...")

for _ in range(len(model_configs)):
    alias, log_likelihoods, response_lengths = response_queue.get()
    if log_likelihoods is not None:
        results[alias] = {
            "log_likelihoods": log_likelihoods,
            "response_lengths": response_lengths
        }
        logger.info(f"Main process: Received results for model '{alias}'.")
    else:
        logger.error(f"Main process: Model '{alias}' encountered an error: {response_lengths}")

for p in processes:
    p.join()
logger.info("Main process: All model processes finished.")

print("\n--- Final Results ---")
for alias, data in results.items():
    print(f"Model: {alias}")
    print(f"  Avg Log Likelihood: {np.mean(data['log_likelihoods']):.4f}")
    print(f"  Avg Response Length: {np.mean(data['response_lengths']):.2f}")
    print(f"  Total Generations: {len(data['log_likelihoods'])}")

2025-06-17 00:54:48,653 - INFO - Main process: All model processes started. Waiting for results...


Traceback (most recent call last):
Traceback (most recent call last):
  File [35m"<string>"[0m, line [35m1[0m, in [35m<module>[0m
    from multiprocessing.spawn import spawn_main; [31mspawn_main[0m[1;31m(tracker_fd=155, pipe_handle=158)[0m
                                                  [31m~~~~~~~~~~[0m[1;31m^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^[0m
  File [35m"/data/satori_hdd1/mutyuu/miniconda3/envs/CS336/lib/python3.13/multiprocessing/spawn.py"[0m, line [35m122[0m, in [35mspawn_main[0m
    exitcode = _main(fd, parent_sentinel)
  File [35m"/data/satori_hdd1/mutyuu/miniconda3/envs/CS336/lib/python3.13/multiprocessing/spawn.py"[0m, line [35m132[0m, in [35m_main[0m
    self = reduction.pickle.load(from_parent)
[1;35mAttributeError[0m: [35mCan't get attribute 'run_model_inference' on <module '__main__' (<class '_frozen_importlib.BuiltinImporter'>)>[0m
  File [35m"<string>"[0m, line [35m1[0m, in [35m<module>[0m
    from multiprocessing.spawn import spawn_

KeyboardInterrupt: 

: 

In [None]:
print("Plotting the results...")
correlation, _ = pearsonr(response_lengths, avg_log_probs)
rho_text = f"$\\rho = {correlation:.2f}$"

plt.figure(figsize=(10, 6))
plt.scatter(response_lengths, avg_log_probs, s=5, alpha=0.6, c='steelblue')
plt.title("Model's Average Log Probability vs. Response Length", fontsize=16)
plt.xlabel("Response length $|y|$", fontsize=14)
plt.ylabel("Avg. log prob. $p_{\\theta}(y|x)$", fontsize=14)
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.minorticks_on()
plt.grid(which='minor', linestyle=':', linewidth=0.5)


plt.text(0.95, 0.15, rho_text, transform=plt.gca().transAxes,
         fontsize=15, verticalalignment='top', horizontalalignment='right',
         bbox=dict(boxstyle='round,pad=0.5', fc='white', ec='black', lw=2))

plt.show()
plt.savefig("../imgs/mistral_7b_DPO_avg_log_prob_vs_response_length.png", bbox_inches='tight', dpi=300)

Generating responses...


 16%|█▌        | 155/1000 [10:11<55:32,  3.94s/it]  


KeyboardInterrupt: 

4 GPUs detected.
