In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
import json
import os
import time
import pickle
from tqdm import tqdm
import logging
from datetime import datetime
from datasets import load_dataset, load_from_disk
from peft import AutoPeftModelForCausalLM, LoraConfig,PeftModel
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    set_seed,
)
# some experiments for the Llama model on safety
import os
import sys
import socket
import re
import random
import numpy as np
import torch
import pickle
from collections import defaultdict
from itertools import combinations
from typing import Dict, List, Optional
from dataclasses import dataclass, asdict

[2025-05-25 15:34:42,079] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)


2025-05-25 15:34:46.980238: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-05-25 15:34:46.980264: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-25 15:34:46.981173: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Parse the prompts to check what the fields represents and whether running friction intervention models would even make sense. 

In [2]:
import json 
 
test_path = f"diplomacy_prompts/sample_1K_denis_testdata_with_prompts.json"
with open(test_path, "r") as file:
    test_data = json.load(file)

In [3]:
test_data[0]

{'sender': 'ITALY',
 'recipient': 'AUSTRIA',
 'message': 'Also, seems like there is plenty of risk that Russia could try for Rumania. Why not use Bulg to support Rum, while having Rum and BLS Support Arm to Sev?',
 'sender_labels': False,
 'phase': 'F1907M',
 'units': {'AUSTRIA': ['A BUD', 'A SER', 'A ARM', 'F BLA', 'A BUL', 'A RUM'],
  'ENGLAND': ['F ENG', 'F NWY', 'A PIC', 'F NWG'],
  'FRANCE': ['F BRE'],
  'GERMANY': ['F BEL', 'A PAR', 'A HOL', 'F NTH', 'A MUN', 'A BUR', 'F HEL'],
  'ITALY': ['A SMY',
   'F AEG',
   'A VIE',
   'A MAR',
   'F MAO',
   'F TYS',
   'A TYR',
   'A GAS',
   'F SPA',
   'F NAO'],
  'RUSSIA': ['A SEV', 'A BOH', 'A UKR', 'A MOS', 'A GAL'],
  'TURKEY': []},
 'predicted_orders': {'AUSTRIA': {'ENGLAND': ['F NWG - NWY',
    'F ENG S A PIC - BEL',
    'F NWY - SWE',
    'A PIC - BEL'],
   'FRANCE': ['F BRE H'],
   'GERMANY': ['F NTH - LON',
    'F BEL S F HEL - NTH',
    'F HEL - NTH',
    'A HOL S F BEL',
    'A BUR S A PAR',
    'A PAR S F BRE',
    'A MUN S 

In [5]:
from datasets import load_dataset, load_from_disk
selected_validation_dataset = load_from_disk("selected_validation_dataset_friction_mixed_with_original") # Test set for WTD

In [6]:
selected_validation_dataset

Dataset({
    features: ['friction_statement', 'rationale', 'task_state', 'belief_state', 'context', 'dialog_id', 'friction_participant', 'type', 'P1_personality_type', 'P1_facet', 'P2_personality_type', 'P2_facet', 'P3_personality_type', 'P3_facet', 'unique_id', 'message_id', 'original_wtd_friction_data', '__index_level_0__'],
    num_rows: 500
})

In [14]:
len(test_data)

1000

In [4]:

def process_diplomacy_for_friction(test_data):

    processed_messages = []  # Create a list to store all processed messages
    count = 0
    for message in test_data:
        count = count+1
        sender = message['sender']
        recipient = message['recipient']

        if message['recipient'] == 'ALL' or message['recipient'] == message['sender'] or message["phase"].endswith('A') or message["phase"].endswith('R'):
            continue
        else:
            #board states format
            formatted_board = ""
            for country, units in message['units'].items():
                formatted_board += f"{country}: {', '.join(units)}\n"
            formatted_board = formatted_board.strip()
            #predicted orders for opponent
            opponent_orders = ""
            for country, units in message['predicted_orders'][message['recipient']].items():
                opponent_orders += f"{country}: {', '.join(units)}\n"
            formatted_opponent_orders = opponent_orders.strip()
            #message history
            message_history = ""
            start_index = max(0, len(message['prev_5_message']) - 5)
            for msg_info in message['prev_5_message'][start_index:]:
                sender = msg_info['sender']
                text = msg_info['message']
                message_history += f"Message from {sender}: {text}\n"
            message_history += f"Message from {message['sender']}: {message['message']}"
            #recommended orders for myself.
            recommended_orders = message['predicted_orders'][message['sender']][message['recipient']]
            formatted_recommended_orders = ",\n".join(recommended_orders)

                # Modify the prompts for friction interventions
            Prompt_level1 = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
        You are an expert assistant specializing in the Diplomacy board game and collaborative reasoning. Your role is to analyze:
        1. The current board state.
        2. The recommended orders for the player.

        Your goal is to identify potential friction points in the player's strategy and provide a constructive intervention:

        <belief_state>
        Identify contradictions, oversights, or strategic misalignments in the recommended orders compared to the board state.
        </belief_state>

        <rationale>
        Explain why an intervention is needed—what's misaligned, its potential impact, and how alternative thinking could benefit the player.
        </rationale>

        <friction>
        Provide a concise friction intervention (max 2 sentences) that prompts self-reflection about their strategy. The first sentence should directly address the strategic concern, and the second should offer a reflective question or alternative perspective.
        </friction>
        Your response MUST include all three required XML tags (<belief_state>, <rationale>, and <friction>) with complete content for each.
        <|eot_id|><|start_header_id|>user<|end_header_id|>
        **Board State:**  
        {formatted_board}

        **Recommended Orders:**  
        {formatted_recommended_orders}

        **Advice Request:**  
        You are advising the player controlling {message['recipient']}. Identify potential friction points and provide constructive intervention.
        <|eot_id|><|start_header_id|>assistant<|end_header_id|>
        """

            Prompt_level2 = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
        You are an expert assistant specializing in the Diplomacy board game and collaborative reasoning. Your role is to analyze:
        1. The current board state.
        2. The recommended orders for the player.
        3. The potential orders for every power.
        

        Your goal is to identify potential friction points in the player's strategy and provide a constructive intervention:

        <belief_state>
        Identify contradictions, oversights, or strategic misalignments in the recommended orders compared to the board state and other powers' likely moves.
        </belief_state>

        <rationale>
        Explain why an intervention is needed—what's misaligned, its potential impact, and how alternative thinking could benefit the player.
        </rationale>

        <friction>
        Provide a concise friction intervention (max 2 sentences) that prompts self-reflection about their strategy. The first sentence should directly address the strategic concern, and the second should offer a reflective question or alternative perspective.
        </friction>
        Your response MUST include all three required XML tags (<belief_state>, <rationale>, and <friction>) with complete content for each.
        <|eot_id|><|start_header_id|>user<|end_header_id|>
        **Board State:**  
        {formatted_board}

        **Recommended Orders for {message['recipient']}:**  
        {formatted_recommended_orders}

        **Potential Orders for other powers:**
        {formatted_opponent_orders}

        **Advice Request:**  
        You are advising the player controlling {message['recipient']}. Identify potential friction points and provide constructive intervention.
        <|eot_id|><|start_header_id|>assistant<|end_header_id|>
        """

            Prompt_level3 = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
        You are an expert assistant specializing in the Diplomacy board game and collaborative reasoning. Your role is to analyze:
        1. The current board state.
        2. The recommended orders for the player.
        3. The potential orders for every power.
        4. The message history between the player and other players.

        Your goal is to identify potential friction points in the player's strategy and diplomatic approach, then provide a constructive intervention:

        <belief_state>
        Identify contradictions, oversights, or misalignments in the recommended orders compared to the board state, other powers' likely moves, and diplomatic communications.
        </belief_state>

        <rationale>
        Explain why an intervention is needed—what's misaligned in strategy or diplomacy, its potential impact, and how alternative thinking could benefit the player.
        </rationale>

        <friction>
        Provide a concise friction intervention (max 2 sentences) that prompts self-reflection about their strategy or diplomatic approach. The first sentence should directly address the concern, and the second should offer a reflective question or alternative perspective.
        </friction>
        Your response MUST include all three required XML tags (<belief_state>, <rationale>, and <friction>) with complete content for each.
        <|eot_id|><|start_header_id|>user<|end_header_id|>
        **Board State:**  
        {formatted_board}

        **Recommended Orders for {message['recipient']}:**  
        {formatted_recommended_orders}

        **Potential Orders for other powers:**
        {formatted_opponent_orders}

        **Message History:**
        {message_history}

        **Advice Request:**  
        You are advising the player controlling {message['recipient']}. Identify potential friction points and provide constructive intervention.
        <|eot_id|><|start_header_id|>assistant<|end_header_id|>
        """
            
            processed_message = message.copy()
         # Add the prompts to the copied message
            processed_message['Prompt_level1'] = Prompt_level1
            processed_message['Prompt_level2'] = Prompt_level2
            processed_message['Prompt_level3'] = Prompt_level3
             # Add this processed message to our list
            processed_messages.append(processed_message)

    return processed_messages

In [5]:
test_prompts = process_diplomacy_for_friction(test_data[0:2])

In [8]:
diplomacy_detailed_explanation = """
Based on the board state and recommended orders, your strategy appears focused on attacking Sevastopol while also moving into Galicia. However, this leaves you vulnerable to a critical threat: Russia plans to attack Vienna (A BOH - VIE with support from A GAL). 

While Italy currently controls Vienna and their orders show support for your A BUD - GAL move, this leaves Vienna exposed. Losing Vienna, a home center, would be devastating to your position and ability to build new units.

Your eastern offensive against Russia in Sevastopol diverts resources from this more urgent defensive need. Consider redirecting forces to secure Vienna - perhaps keeping A BUD in position or moving A SER to Trieste to provide support. Remember that in Diplomacy, protecting home centers takes priority over gaining new territories when your core position is under threat.

The recommended shift is to "Modify your orders to focus on securing Vienna instead of attacking Sevastopol" because Russia's coordinated attack with support from Galicia represents an immediate threat to your power base, while Sevastopol can be targeted later from a position of greater security.
"""
diplomacy_detailed_explanation

'\nBased on the board state and recommended orders, your strategy appears focused on attacking Sevastopol while also moving into Galicia. However, this leaves you vulnerable to a critical threat: Russia plans to attack Vienna (A BOH - VIE with support from A GAL). \n\nWhile Italy currently controls Vienna and their orders show support for your A BUD - GAL move, this leaves Vienna exposed. Losing Vienna, a home center, would be devastating to your position and ability to build new units.\n\nYour eastern offensive against Russia in Sevastopol diverts resources from this more urgent defensive need. Consider redirecting forces to secure Vienna - perhaps keeping A BUD in position or moving A SER to Trieste to provide support. Remember that in Diplomacy, protecting home centers takes priority over gaining new territories when your core position is under threat.\n\nThe recommended shift is to "Modify your orders to focus on securing Vienna instead of attacking Sevastopol" because Russia\'s co

In [6]:
test_prompts[0]

{'sender': 'ITALY',
 'recipient': 'AUSTRIA',
 'message': 'Also, seems like there is plenty of risk that Russia could try for Rumania. Why not use Bulg to support Rum, while having Rum and BLS Support Arm to Sev?',
 'sender_labels': False,
 'phase': 'F1907M',
 'units': {'AUSTRIA': ['A BUD', 'A SER', 'A ARM', 'F BLA', 'A BUL', 'A RUM'],
  'ENGLAND': ['F ENG', 'F NWY', 'A PIC', 'F NWG'],
  'FRANCE': ['F BRE'],
  'GERMANY': ['F BEL', 'A PAR', 'A HOL', 'F NTH', 'A MUN', 'A BUR', 'F HEL'],
  'ITALY': ['A SMY',
   'F AEG',
   'A VIE',
   'A MAR',
   'F MAO',
   'F TYS',
   'A TYR',
   'A GAS',
   'F SPA',
   'F NAO'],
  'RUSSIA': ['A SEV', 'A BOH', 'A UKR', 'A MOS', 'A GAL'],
  'TURKEY': []},
 'predicted_orders': {'AUSTRIA': {'ENGLAND': ['F NWG - NWY',
    'F ENG S A PIC - BEL',
    'F NWY - SWE',
    'A PIC - BEL'],
   'FRANCE': ['F BRE H'],
   'GERMANY': ['F NTH - LON',
    'F BEL S F HEL - NTH',
    'F HEL - NTH',
    'A HOL S F BEL',
    'A BUR S A PAR',
    'A PAR S F BRE',
    'A MUN S 

In [40]:
test_prompts['Prompt_level2']

"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n        You are an expert assistant specializing in the Diplomacy board game and collaborative reasoning. Your role is to analyze:\n        1. The current board state.\n        2. The recommended orders for the player.\n        3. The potential orders for every power.\n        \n\n        Your goal is to identify potential friction points in the player's strategy and provide a constructive intervention:\n\n        <belief_state>\n        Identify contradictions, oversights, or strategic misalignments in the recommended orders compared to the board state and other powers' likely moves.\n        </belief_state>\n\n        <rationale>\n        Explain why an intervention is needed—what's misaligned, its potential impact, and how alternative thinking could benefit the player.\n        </rationale>\n\n        <friction>\n        Provide a concise friction intervention (max 2 sentences) that prompts self-reflection about their str

In [41]:
test_prompts['Prompt_level3']

"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n        You are an expert assistant specializing in the Diplomacy board game and collaborative reasoning. Your role is to analyze:\n        1. The current board state.\n        2. The recommended orders for the player.\n        3. The potential orders for every power.\n        4. The message history between the player and other players.\n\n        Your goal is to identify potential friction points in the player's strategy and diplomatic approach, then provide a constructive intervention:\n\n        <belief_state>\n        Identify contradictions, oversights, or misalignments in the recommended orders compared to the board state, other powers' likely moves, and diplomatic communications.\n        </belief_state>\n\n        <rationale>\n        Explain why an intervention is needed—what's misaligned in strategy or diplomacy, its potential impact, and how alternative thinking could benefit the player.\n        </rationale>\n\n 

In [42]:
models_list_deli_weights =[
        
    #     'FAAF_friction_run_with_69ksamples_good_vs_rogue_weights/checkpoint-500',
    # 'FAAF_friction_run_with_69ksamples_good_vs_rogue_weights/checkpoint-2000',
    # 'WTD_faaf_0.01', 'WTD_faaf_1', 'WTD_faaf_5', 
    
    # 'WTD_faaf_10',
#     'sft_deli_multiturn_rogue_cleaned/checkpoint-1000',
#     'sft_deli_multiturn_rogue_cleaned/checkpoint-2000',
        'sft_deli_multiturn_rogue_cleaned/checkpoint-3000',
    'DELI_all_weights/DELI_faaf_weights/checkpoint-2000',
#     # 'DELI_all_weights/DELI_faaf_unconditioned_weights/checkpoint-2000',
#     # 'DELI_all_weights/DELI_faaf_first_part_weights/checkpoint-2000',

#     # 'dpo_friction_run_with_69ksamples/checkpoint-500',
#     # 'dpo_friction_run_with_69ksamples/checkpoint-2000',
#     'DELI_all_weights/DELI_dpo_weights/checkpoint-2000',
#      'DELI_all_weights/DELI_dpo_weights/checkpoint-3500',
#     # 'dpo_friction_run_with_69ksamples/checkpoint-3000',
#     # 'WTD_IPO_weights/checkpoint-2000',
#     'DELI_all_weights/DELI_ipo_weights/checkpoint-4500',
#         "DELI_all_weights/DELI_sft_weights/checkpoint-4000",
#         "DELI_all_weights/DELI_sft_weights/checkpoint-small-1500",
#          "DELI_all_weights/DELI_ppo_weights/ppo_checkpoint_epoch_1_batch_800",
 
    ]

In [48]:
models_list_deli_weights =[
        

    'WTD_faaf_10',
 
        'sft_deli_multiturn_rogue_cleaned/checkpoint-3000',
    'DELI_all_weights/DELI_faaf_weights/checkpoint-2000',
 
     'DELI_all_weights/DELI_dpo_weights/checkpoint-3500',
 
    'DELI_all_weights/DELI_ipo_weights/checkpoint-4500',

        "DELI_all_weights/DELI_sft_weights/checkpoint-small-1500",
         "DELI_all_weights/DELI_ppo_weights/ppo_checkpoint_epoch_1_batch_800",
 
    ]
models_list_deli_weights

['WTD_faaf_10',
 'sft_deli_multiturn_rogue_cleaned/checkpoint-3000',
 'DELI_all_weights/DELI_faaf_weights/checkpoint-2000',
 'DELI_all_weights/DELI_dpo_weights/checkpoint-3500',
 'DELI_all_weights/DELI_ipo_weights/checkpoint-4500',
 'DELI_all_weights/DELI_sft_weights/checkpoint-small-1500',
 'DELI_all_weights/DELI_ppo_weights/ppo_checkpoint_epoch_1_batch_800']

In [22]:
generation_args = {
    "max_new_tokens": 356,
    "temperature": 0.9,
    "do_sample": True,
    "top_k": 50,
    "top_p": 0.9,
    "num_beams": 5,
    "min_length": 100,
    'num_return_sequences': 1,
    "sampling_strategy": "top_p_sampling"
    }
    

In [49]:
import torch
import gc
import pickle
from tqdm import tqdm
from transformers import AutoTokenizer
 

def process_diplomacy_prompts_with_models(test_data, models_list, generation_args, output_dir = None):
    """
    Process Diplomacy data with multiple models and different prompt levels.
    
    Args:
        test_data: List of dictionaries containing the message data with prompts
        models_list: List of model paths to process
        generation_args: Arguments for text generation
    """
    
    results_by_model = {}
    os.makedirs(output_dir, exist_ok=True)
    for agent_model_name in tqdm(models_list, desc="Processing Models"):
        # Format model name for results
        loading_model_name = agent_model_name
        if "/" in agent_model_name:
            parts = agent_model_name.split("/")
            if len(parts) >= 3:
                # Combine the first two parts
                agent_model_name = parts[1] + "_" + parts[2]
            elif len(parts) == 2:
                agent_model_name = parts[0] + "_" + parts[1]
        
        print(f"\n===== Processing Model: {agent_model_name} =====\n")
        
        # Load model and tokenizer
        try:
            # Load base model
            lora_model = AutoModelForCausalLM.from_pretrained(
                "llama3_8b_instruct",
                device_map="auto",
                low_cpu_mem_usage=True,
                torch_dtype=torch.float16,
                trust_remote_code=True,
            )
            
            # Apply LoRA adapter
            lora_model = PeftModel.from_pretrained(
                lora_model,
                loading_model_name,
                torch_dtype=torch.float16,
                device_map="auto",
                low_cpu_mem_usage=True,
                trust_remote_code=True,
            )
            
            # Merge the model
            print("Merging LoRA adapter...")
            merged_model = lora_model.merge_and_unload()
            
            # Load tokenizer
            tokenizer = AutoTokenizer.from_pretrained(loading_model_name)
            tokenizer.pad_token = "<|reserved_special_token_0|>"
            tokenizer.padding_side = "right"
            
            # Initialize results for this model
            model_results = {
                "model_name": agent_model_name,
                "samples": []
            }
            
            # Process each message/sample
            for message_idx, message in enumerate(tqdm(test_data, desc=f"Processing samples for {agent_model_name}")):
                # Skip if we don't have the prompts
                if not all(key in message for key in ['Prompt_level1', 'Prompt_level2', 'Prompt_level3']):
                    continue
                
                sample_results = {
                    "message_id": message_idx,
                    "sender": message.get('sender', ''),
                    "recipient": message.get('recipient', ''),
                    "prompt_results": []
                }
                
                # Process each prompt level
                for level_name in ['Prompt_level1', 'Prompt_level2', 'Prompt_level3']:
                    print(f"Processing {level_name} for message {message_idx}")
                    
                    # Generate responses
                    generated_texts, all_generated_texts = generate_multiple_sequences_with_intrinsic_metrics(
                        merged_model, 
                        tokenizer, 
                        message[level_name], 
                        generation_args, 
                        None,
                        strategy="top_p_sampling", 
                        batched=True, 
                    )
                    
                    # Process the generated text
                    if generated_texts and isinstance(generated_texts, list):
                        text_to_parse = generated_texts[0][0] if (generated_texts[0] and isinstance(generated_texts[0], list)) else generated_texts[0]
                        
                        # Parse components
                        tags_for_parsing = ["friction", "rationale", "belief_state"]
                        parsed_components = parse_tags_robust(text_to_parse, tags_for_parsing)
                        
                        # Extract components
                        friction_intervention = ' '.join(parsed_components.get('friction', []))
                        if not friction_intervention:
                            friction_intervention = handle_friction_logic(text_to_parse)
                        
                        belief_state = ' '.join(parsed_components.get('belief_state', []))
                        rationale = ' '.join(parsed_components.get('rationale', []))
                        
                        # Save results for this prompt level
                        prompt_result = {
                            "prompt_level": level_name,
                            "raw_prompt": message[level_name],
                            "generated_text": text_to_parse,
                            "parsed_components": {
                                "friction": friction_intervention,
                                "rationale": rationale,
                                "belief_state": belief_state
                            }
                        }
                        
                        sample_results["prompt_results"].append(prompt_result)
                
                # Add this sample's results to the model results
                model_results["samples"].append(sample_results)
            
            # Save results for this model to a pickle file
            safe_model_name = agent_model_name.replace('/', '_')
            
            with open(f"{output_dir}/results_{safe_model_name}.pkl", 'wb') as f:
                pickle.dump(model_results, f)
            
            # Also store in our overall results dictionary
            results_by_model[agent_model_name] = model_results
            
            # Clean up the model to free memory
            del merged_model, tokenizer, lora_model
            torch.cuda.empty_cache()
            gc.collect()
            
        except Exception as e:
            print(f"Error processing model {agent_model_name}: {str(e)}")
            continue
    
    return results_by_model

output_dir = "diplomacy_prompts/init_generations"
results = process_diplomacy_prompts_with_models(test_prompts, models_list_deli_weights, generation_args, output_dir= output_dir)

Processing Models:   0%|                                                                                                                                                                                                 | 0/7 [00:00<?, ?it/s]


===== Processing Model: WTD_faaf_10 =====



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Merging LoRA adapter...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

Processing samples for WTD_faaf_10: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:00<00:00, 295863.77it/s][A
Processing Models:  14%|██████████████████████████▍                                                                                                                                                              | 1/7 [00:06<00:38,  6.48s/it]


===== Processing Model: sft_deli_multiturn_rogue_cleaned_checkpoint-3000 =====



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Merging LoRA adapter...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

Processing samples for sft_deli_multiturn_rogue_cleaned_checkpoint-3000: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:00<00:00, 166207.85it/s][A
Processing Models:  29%|████████████████████████████████████████████████████▊                                                                                                                                    | 2/7 [00:12<00:31,  6.28s/it]


===== Processing Model: DELI_faaf_weights_checkpoint-2000 =====



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Merging LoRA adapter...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

Processing samples for DELI_faaf_weights_checkpoint-2000: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:00<00:00, 273192.21it/s][A
Processing Models:  43%|███████████████████████████████████████████████████████████████████████████████▎                                                                                                         | 3/7 [00:18<00:24,  6.25s/it]


===== Processing Model: DELI_dpo_weights_checkpoint-3500 =====



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Merging LoRA adapter...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

Processing samples for DELI_dpo_weights_checkpoint-3500: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:00<00:00, 284076.37it/s][A
Processing Models:  57%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                               | 4/7 [00:25<00:18,  6.22s/it]


===== Processing Model: DELI_ipo_weights_checkpoint-4500 =====



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Merging LoRA adapter...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

Processing samples for DELI_ipo_weights_checkpoint-4500: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:00<00:00, 278528.00it/s][A
Processing Models:  71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                    | 5/7 [00:31<00:12,  6.20s/it]


===== Processing Model: DELI_sft_weights_checkpoint-small-1500 =====



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Merging LoRA adapter...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

Processing samples for DELI_sft_weights_checkpoint-small-1500: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:00<00:00, 289850.28it/s][A
Processing Models:  86%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                          | 6/7 [00:37<00:06,  6.18s/it]


===== Processing Model: DELI_ppo_weights_ppo_checkpoint_epoch_1_batch_800 =====



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Merging LoRA adapter...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.

Processing samples for DELI_ppo_weights_ppo_checkpoint_epoch_1_batch_800: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:00<00:00, 281830.70it/s][A
Processing Models: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:43<00:00,  6.21s/it]


In [70]:
import gc
from tqdm import tqdm
from collections import defaultdict
import re

def process_models_with_formatted_prompts(models_list, test_data, generation_args, output_dir = None):
    """
    Process formatted Diplomacy prompts with multiple models.
    
    Args:
        models_list: List of model paths to process
        generation_args: Arguments for text generation
    """
    os.makedirs(output_dir, exist_ok=True)
    # First, create the formatted prompts
    processed_data = process_diplomacy_for_friction(test_data)
    
    results_by_model = {}
    
    for agent_model_name in tqdm(models_list, desc="Processing Models"):
        # Format model name for results
        loading_model_name = agent_model_name
        if "/" in agent_model_name:
            parts = agent_model_name.split("/")
            if len(parts) >= 3:
                agent_model_name = parts[1] + "_" + parts[2]
            elif len(parts) == 2:
                agent_model_name = parts[0] + "_" + parts[1]
        
        print(f"\n===== Processing Model: {agent_model_name} =====\n")
        
        # Load model and tokenizer
        try:
            # Load base model
            lora_model = AutoModelForCausalLM.from_pretrained(
                "llama3_8b_instruct",
                device_map="auto",
                low_cpu_mem_usage=True,
                torch_dtype=torch.float16,
                trust_remote_code=True,
            )
            
            # Apply LoRA adapter
            lora_model = PeftModel.from_pretrained(
                lora_model,
                loading_model_name,
                torch_dtype=torch.float16,
                device_map="auto",
                low_cpu_mem_usage=True,
                trust_remote_code=True,
            )
            
            # Merge the model
            print("Merging LoRA adapter...")
            merged_model = lora_model.merge_and_unload()
            
            # Load tokenizer
            tokenizer = AutoTokenizer.from_pretrained(loading_model_name)
            tokenizer.pad_token = "<|reserved_special_token_0|>"
            tokenizer.padding_side = "right"
            
            # Initialize results for this model
            model_results = {
                "model_name": agent_model_name,
                "samples": []
            }
            
             # Process each sample
            for sample_idx, sample in enumerate(processed_data):
                sample_results = {
                    "sample_id": sample_idx,
                    "message_id": sample.get('message_id', sample_idx),
                    "sender": sample.get('sender', ''),
                    "recipient": sample.get('recipient', ''),
                    "prompt_results": []
                }
                # Process each prompt level
                
                for level_key in ['Prompt_level1', 'Prompt_level2', 'Prompt_level3']:
                    if level_key in sample:
                        print(f"Processing {level_key} for sample {sample_idx}")
                        
                        # Generate responses
                        generated_texts, all_generated_texts = generate_multiple_sequences_with_intrinsic_metrics(
                            merged_model, 
                            tokenizer, 
                            sample[level_key], 
                            generation_args, 
                            None,
                            strategy="top_p_sampling", 
                            batched=True, 
                        )
                        
                        # Process the generated text
                        if generated_texts and isinstance(generated_texts, list):
                            text_to_parse = generated_texts[0][0] if (generated_texts[0] and isinstance(generated_texts[0], list)) else generated_texts[0]
                            
                            # Parse components
                            tags_for_parsing = ["friction", "rationale", "belief_state"]
                            parsed_components = parse_tags_robust(text_to_parse, tags_for_parsing)
                            
                            # Extract components
                            friction_intervention = ' '.join(parsed_components.get('friction', []))
                            if not friction_intervention:
                                friction_intervention = handle_friction_logic(text_to_parse)
                            
                            belief_state = ' '.join(parsed_components.get('belief_state', []))
                            rationale = ' '.join(parsed_components.get('rationale', []))
                            
                            # Save results for this prompt level
                            prompt_result = {
                                "prompt_level": level_key,
                                "raw_prompt": sample[level_key],
                                "generated_text": text_to_parse,
                                "parsed_components": {
                                    "friction": friction_intervention,
                                    "rationale": rationale,
                                    "belief_state": belief_state
                                }
                            }
                            
                            sample_results["prompt_results"].append(prompt_result)
                
                # Add this sample's results to the model results
                model_results["samples"].append(sample_results)
      
            
            # Save results for this model to a pickle file
            safe_model_name = agent_model_name.replace('/', '_')
            with open(f"{output_dir}/results_{safe_model_name}.pkl", 'wb') as f:
                pickle.dump(model_results, f)
            
            # Also store in our overall results dictionary
            results_by_model[agent_model_name] = model_results
            
            # Clean up the model to free memory
            del merged_model, tokenizer, lora_model
            torch.cuda.empty_cache()
            gc.collect()
            
        except Exception as e:
            print(f"Error processing model {agent_model_name}: {str(e)}")
            continue
    de
    return results_by_model
output_dir = "diplomacy_prompts/init_generations"
results_by_model = process_models_with_formatted_prompts(models_list_deli_weights,test_data[0:1], generation_args, output_dir = output_dir)

Processing Models:   0%|                                                                                                                                                                                                 | 0/7 [00:00<?, ?it/s]


===== Processing Model: WTD_faaf_10 =====



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Merging LoRA adapter...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Processing Prompt_level1 for sample 0
Processing Prompt_level2 for sample 0
Processing Prompt_level3 for sample 0


Processing Models:  14%|██████████████████████████▍                                                                                                                                                              | 1/7 [00:20<02:04, 20.82s/it]


===== Processing Model: sft_deli_multiturn_rogue_cleaned_checkpoint-3000 =====



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Merging LoRA adapter...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Processing Prompt_level1 for sample 0
Processing Prompt_level2 for sample 0
Processing Prompt_level3 for sample 0


Processing Models:  29%|████████████████████████████████████████████████████▊                                                                                                                                    | 2/7 [00:40<01:39, 19.91s/it]


===== Processing Model: DELI_faaf_weights_checkpoint-2000 =====



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Merging LoRA adapter...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Processing Prompt_level1 for sample 0
Processing Prompt_level2 for sample 0
Processing Prompt_level3 for sample 0


Processing Models:  43%|███████████████████████████████████████████████████████████████████████████████▎                                                                                                         | 3/7 [01:13<01:43, 25.98s/it]


===== Processing Model: DELI_dpo_weights_checkpoint-3500 =====



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Merging LoRA adapter...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Processing Prompt_level1 for sample 0
Processing Prompt_level2 for sample 0
Processing Prompt_level3 for sample 0


Processing Models:  57%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                               | 4/7 [01:45<01:25, 28.44s/it]


===== Processing Model: DELI_ipo_weights_checkpoint-4500 =====



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Merging LoRA adapter...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Processing Prompt_level1 for sample 0
Processing Prompt_level2 for sample 0
Processing Prompt_level3 for sample 0


Processing Models:  71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                    | 5/7 [02:10<00:54, 27.13s/it]


===== Processing Model: DELI_sft_weights_checkpoint-small-1500 =====



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Merging LoRA adapter...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Processing Prompt_level1 for sample 0
Processing Prompt_level2 for sample 0
Processing Prompt_level3 for sample 0


Processing Models:  86%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                          | 6/7 [02:29<00:24, 24.31s/it]


===== Processing Model: DELI_ppo_weights_ppo_checkpoint_epoch_1_batch_800 =====



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Merging LoRA adapter...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Processing Prompt_level1 for sample 0
Processing Prompt_level2 for sample 0
Processing Prompt_level3 for sample 0


Processing Models: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [02:46<00:00, 23.82s/it]


In [54]:
results_by_model

{}

In [29]:
results

{'sft_deli_multiturn_rogue_cleaned_checkpoint-1000': {'model_name': 'sft_deli_multiturn_rogue_cleaned_checkpoint-1000',
  'samples': [{'message_id': 0,
    'sender': 'ITALY',
    'recipient': 'AUSTRIA',
    'prompt_results': [{'prompt_level': 'Prompt_level1',
      'raw_prompt': "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n        You are an expert assistant specializing in the Diplomacy board game and collaborative reasoning. Your role is to analyze:\n        1. The current board state.\n        2. The recommended orders for the player.\n\n        Your goal is to identify potential friction points in the player's strategy and provide a constructive intervention:\n\n        <belief_state>\n        Identify contradictions, oversights, or strategic misalignments in the recommended orders compared to the board state.\n        </belief_state>\n\n        <rationale>\n        Explain why an intervention is needed—what's misaligned, its potential impact, and how alternative t

In [43]:
import re

def parse_tags_robust(text, tags=None):
    """
    Enhanced parsing function that handles XML-style tags with improved robustness 
    for Diplomacy-specific content.
    
    Args:
        text (str): The text to parse.
        tags (list): List of tags to parse. If None, defaults to standard tags.
        
    Returns:
        dict: Dictionary with tag names as keys and lists of extracted content as values.
    """
    if tags is None:
        tags = ["friction", "rationale", "belief_state"]
    
    # Initialize result
    result = {tag: [] for tag in tags}
    
    # Tag mapping for variations
    tag_variations = {
        "belief_state": ["belief_state", "belief state", "beliefstate", "belief-state", "b"],
        "rationale": ["rationale", "rational", "reasoning", "reason", "r"],
        "friction": ["friction", "intervention", "advice", "f"]
    }
    
    # Reverse mapping for normalization
    tag_normalization = {}
    for normalized, variations in tag_variations.items():
        for variation in variations:
            tag_normalization[variation.lower()] = normalized
    
    # First attempt: Try exact XML-style tags
    for tag in tags:
        # Try variations of the tag name
        possible_tags = [tag]
        if tag in tag_variations:
            possible_tags.extend(tag_variations[tag])
        
        found = False
        for tag_var in possible_tags:
            # Try with and without underscores/hyphens
            tag_patterns = [
                tag_var,
                tag_var.replace("_", " "),
                tag_var.replace("-", " "),
                tag_var.replace(" ", "_"),
                tag_var.replace(" ", "-")
            ]
            
            for t_pattern in tag_patterns:
                # Match both <tag> and <tag></tag> patterns
                open_tag_pattern = f"<{t_pattern}>"
                close_tag_pattern = f"</{t_pattern}>"
                open_only_pattern = re.escape(open_tag_pattern) + r"(.*?)(?=" + re.escape(close_tag_pattern) + r"|\Z)"
                
                matches = re.findall(open_only_pattern, text, re.DOTALL | re.IGNORECASE)
                if matches:
                    result[tag].extend([m.strip() for m in matches])
                    found = True
                    break
            
            if found:
                break
    
    # Second attempt: Try markdown-style headings and formatting
    for tag in tags:
        if not result[tag]:  # Only if we didn't find anything in first pass
            possible_tags = [tag]
            if tag in tag_variations:
                possible_tags.extend(tag_variations[tag])
            
            for tag_var in possible_tags:
                # Various markdown and formatting patterns
                patterns = [
                    # Markdown headings
                    r'#{1,3}\s*' + re.escape(tag_var) + r'[:\s]*\n(.*?)(?=\n#|\n\n|\Z)',
                    # Bold with asterisks
                    r'\*\*\s*' + re.escape(tag_var) + r'[:\s]*\*\*\s*\n(.*?)(?=\n\*\*|\n\n|\Z)',
                    # Bold with underscores
                    r'__\s*' + re.escape(tag_var) + r'[:\s]*__\s*\n(.*?)(?=\n__|\n\n|\Z)',
                    # Plain text with colon
                    r'\b' + re.escape(tag_var) + r'[:\s]+(.*?)(?=\n\n|\Z)'
                ]
                
                for pattern in patterns:
                    matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE)
                    if matches:
                        result[tag].extend([m.strip() for m in matches])
                        break
    
    # Third attempt: Content-specific extraction based on sentence patterns
    for tag in tags:
        if not result[tag]:  # Only if we still didn't find anything
            if tag == "friction" or tag in tag_variations.get("friction", []):
                # Look for sentences that suggest a friction intervention
                question_patterns = [
                    # Questions about strategy
                    r'(?:Have you considered|Did you think about|What if|Could you|Why not).*?\?',
                    # Suggestions phrased as questions
                    r'(?:Wouldn\'t it be|Might it be|Isn\'t there).*?\?',
                ]
                
                # Extract any matching sentences
                for pattern in question_patterns:
                    matches = re.findall(pattern, text, re.IGNORECASE)
                    if matches:
                        # Take the first 2 matching questions as friction
                        combined = " ".join(matches[:2]).strip()
                        if combined:
                            result[tag].append(combined)
                            break
    
    # Fourth attempt: Fall back to content structure analysis
    for tag in tags:
        if not result[tag]:  # If still nothing found
            # Split by double newlines to get paragraphs
            paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
            
            if tag == "friction" and paragraphs:
                if len(paragraphs) >= 3:
                    # Friction is often at the end
                    result[tag].append(paragraphs[-1])
                else:
                    # Just use the last paragraph
                    result[tag].append(paragraphs[-1])
            
            elif tag == "rationale" and paragraphs:
                if len(paragraphs) >= 3:
                    # Rationale is often in the middle
                    result[tag].append(paragraphs[len(paragraphs)//2])
                elif len(paragraphs) == 2:
                    # Use first paragraph
                    result[tag].append(paragraphs[0])
            
            elif tag == "belief_state" and paragraphs:
                # Belief state is often at the beginning
                result[tag].append(paragraphs[0])
    
    # Final cleanup - remove quotes or special formatting from content
    for tag in result:
        result[tag] = [
            re.sub(r'^["\']\s*|\s*["\']$', '', item)  # Remove surrounding quotes
            for item in result[tag]
        ]
    
    return result


def handle_friction_logic(text):
    """
    Extract a meaningful friction intervention from text even when proper tags are missing.
    Improved to handle Diplomacy-specific content patterns.
    
    Args:
        text (str): The input text
        
    Returns:
        str: The extracted friction intervention
    """
    # First check if there's already a well-formatted friction section
    friction_match = re.search(r'<friction>(.*?)(?:</friction>|\Z)', text, re.DOTALL | re.IGNORECASE)
    if friction_match:
        return friction_match.group(1).strip()
    
    # Look for quotes that might be friction interventions
    quote_match = re.search(r'["\'](.*?)["\']', text, re.DOTALL)
    if quote_match and ("?" in quote_match.group(1) or "consider" in quote_match.group(1).lower()):
        return quote_match.group(1).strip()
    
    # Look for questions that might be friction interventions
    questions = re.findall(r'[^.!?]*\?', text)
    strategic_questions = [q for q in questions if any(term in q.lower() for term in 
                                                   ["strategy", "consider", "think", "reflect", 
                                                    "alternative", "option", "might", "could", 
                                                    "why not", "what if", "balance"])]
    if strategic_questions and len(strategic_questions) >= 1:
        return " ".join(strategic_questions[:2]).strip()
    
    # Look for sentences that have Diplomacy-specific strategy terms
    strategy_terms = ["support", "attack", "move", "defend", "ally", "alliance", "secure", 
                      "position", "flank", "threat", "opportunity", "risk", "coordinate"]
    
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text.strip())
    strategic_sentences = [s for s in sentences if any(term in s.lower() for term in strategy_terms)]
    
    if strategic_sentences:
        if len(strategic_sentences) >= 2:
            return f"{strategic_sentences[0]} {strategic_sentences[-1]}"
        else:
            return strategic_sentences[0]
    
    # Fall back to the original logic if nothing better is found
    if sentences:
        if len(sentences) >= 3:
            return f"{sentences[0]} {sentences[-2]} {sentences[-1]}"
        else:
            return " ".join(sentences)
    
    return ""

In [26]:
def generate_multiple_sequences_with_intrinsic_metrics(model, tokenizer, prompts, generation_args, device, 
                                                       strategy="beam_search", batched=False, 
                                                       reward_model=None, best_of_n=None, top_k_candidates=1, rm_tokenizer = None
                                                      , rm_max_length = None):
    """
    Generate multiple sequences using various strategies including best-of-N sampling.
    
    Args:
        model: Language model for generation
        tokenizer: Tokenizer for the model
        prompts: Input prompts
        generation_args: Arguments for generation
        device: Device to place tensors on
        strategy: Generation strategy ("beam_search", "top_k_sampling", "top_p_sampling", or "best_of_n")
        batched: Whether inputs are batched
        reward_model: Reward model for scoring in best-of-N sampling (AutoModelForSequenceClassification)
        best_of_n: Number of samples to generate for best-of-N sampling (default: None)
        top_k_candidates: Number of top candidates to return from best-of-N sampling (default: 1)
        
    Returns:
        generated_texts: List of generated texts
        all_generated_texts: List of all generated texts
    """
    if batched:
        tokenizer.pad_token = "<|reserved_special_token_0|>"  # new pad token for this run
        tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
        tokenizer.padding_side = 'right'

        cleaned_prompts = prompts.replace("\n", " ")  
        inputs = tokenizer(cleaned_prompts, return_tensors="pt", padding=True, truncation=True).to(model.device)
    else:
        tokenizer.pad_token = "<|reserved_special_token_0|>"  # new pad token for this run
        tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
        tokenizer.padding_side = 'right'
        inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(device)

    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']
    
    # Handle best-of-N sampling strategy
    if strategy == "best_of_n":
        if reward_model is None:
            raise ValueError("Reward model must be provided for best-of-N sampling")
        
        if best_of_n is None or best_of_n <= 0:
            best_of_n = 4  # Default sample size
        
        with torch.no_grad():
            # Generate multiple candidates for each prompt
            all_candidates = []
            all_prompt_candidates = []
            
            # Use top_p or top_k sampling to generate diverse candidates
            sampling_strategy = generation_args.get("sampling_strategy", "top_p_sampling")
            # print("BON sampling strategy", sampling_strategy)
            for _ in range(best_of_n):
                if sampling_strategy == "top_p_sampling": 
                    print("RUNNING TopP sampling for BON")
                    outputs = model.generate(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        max_length=input_ids.size(1) + generation_args["max_new_tokens"],
                        temperature=generation_args.get("temperature", 0.7),
                        top_p=generation_args.get("top_p", 0.9),
                        do_sample=True,
                        num_return_sequences=1,
                        eos_token_id=tokenizer.eos_token_id,
                        pad_token_id=tokenizer.pad_token_id,
                        return_dict_in_generate=True,
                    )
                else:  # Default to top_k_sampling
                    outputs = model.generate(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        max_length=input_ids.size(1) + generation_args["max_new_tokens"],
                        temperature=generation_args.get("temperature", 0.7),
                        top_k=generation_args.get("top_k", 50),
                        do_sample=True,
                        num_return_sequences=1,
                        eos_token_id=tokenizer.eos_token_id,
                        pad_token_id=tokenizer.pad_token_id,
                        return_dict_in_generate=True,
                    )
                
                # Process the generated sequence
                for i in range(len(outputs.sequences)):
                    sequence = outputs.sequences[i]
                    prompt_length = input_ids.shape[-1]
                    new_tokens = sequence[prompt_length:]
                    generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
                    
                    if len(all_candidates) <= i:
                        all_candidates.append([])
                    
                    all_candidates[i].append(generated_text)
            
            # Score candidates with the reward model
            best_candidates = []
            all_prompt_candidates = []
            parsed_candidates = []
            tags_for_parsing = ["friction", "rationale", "t", "b"]  
            for candidate_index, candidates in enumerate(all_candidates):
                print(f"\n=== Processing candidates for prompt {candidate_index} ===")
                print(f"Number of candidates: {len(candidates)}")

                #parse the generated model outputs to get the friction + rationale

                        
                for candidate in candidates:
                    parsed_frictive_states_and_friction = parse_tags_robust(candidate, tags_for_parsing)
                    friction_intervention = ' '.join(parsed_frictive_states_and_friction.get('friction', []))
                    if not friction_intervention:
                        friction_intervention = handle_friction_logic(candidate)

                    rationale = ' '.join(parsed_frictive_states_and_friction.get('rationale', []))
                    friction_and_rationale = rationale + friction_intervention
                    parsed_candidates.append(friction_and_rationale)
                    # print("PARSED friction + rationale",friction_and_rationale )
                # For each candidate, prepare input for reward model
                candidate_inputs = [prompts + " " + f"</s> {candidate} </s>" for candidate in parsed_candidates]
                tokenized_inputs = rm_tokenizer(candidate_inputs, return_tensors="pt", padding=True, truncation=True, max_length=rm_max_length).to(device)
                
                # Get scores from reward model
                reward_outputs = reward_model(**tokenized_inputs) 
                scores = reward_outputs.logits.squeeze(-1)
                
                ## Print all candidates with their scores
                print("\nAll candidates with scores:")
                for i, (candidate, score) in enumerate(zip(candidates, scores)):
                    print(f"Candidate {i}: Score = {score:.4f}")
                    print(f"Text snippet: {candidate[:50]}...")
                
                # Get top-k indices
                if top_k_candidates > len(candidates):
                    top_k_candidates = len(candidates)
                
                # Fix the error by converting bfloat16 to float32 before calling numpy()
                top_result = torch.topk(scores, top_k_candidates)
                top_indices = top_result.indices.cpu().numpy()
                top_values = top_result.values.cpu().float().numpy()  # Convert to float32 first
                
                print(f"\nTop {top_k_candidates} candidates:")
                # Print only the top-k candidates
                for rank, (idx, score) in enumerate(zip(top_indices, top_values)):
                    print(f"Rank {rank+1}: Candidate {idx} with score {score:.4f}")
                    print(f"Text: {candidates[idx][:300]}...")
                
                # Store the chosen candidates for verification
                prompt_best_candidates = [candidates[idx] for idx in top_indices]
                
                # Verification check
                max_score_idx = scores.argmax().item()
                if max_score_idx != top_indices[0]:
                    print(f"WARNING: Discrepancy detected! argmax={max_score_idx} but topk.indices[0]={top_indices[0]}")
                else:
                    print(f"VERIFIED: Top candidate is correctly selected (index {top_indices[0]})")



                
                best_candidates.append(prompt_best_candidates)
                all_prompt_candidates.extend(candidates)
            
            return best_candidates, all_prompt_candidates
    
    # Original strategies
    with torch.no_grad():
        if strategy == "beam_search":
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=input_ids.size(1) + generation_args["max_new_tokens"],
                num_beams=generation_args["num_beams"],
                num_return_sequences=generation_args["num_return_sequences"],
                early_stopping=True,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                return_dict_in_generate=True,
                output_scores=True
            )
        elif strategy == "top_k_sampling":
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=input_ids.size(1) + generation_args["max_new_tokens"],
                temperature=generation_args["temperature"],
                top_k=generation_args["top_k"],
                do_sample=True,
                num_return_sequences=generation_args["num_return_sequences"],
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                min_length=generation_args.get("min_length", 0),
                return_dict_in_generate=True,
                output_scores=True
            )
        elif strategy == "top_p_sampling":
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=input_ids.size(1) + generation_args["max_new_tokens"],
                temperature=generation_args["temperature"],
                top_p=generation_args["top_p"],
                do_sample=True,
                num_return_sequences=generation_args["num_return_sequences"],
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                return_dict_in_generate=True,
            )
        else:
            raise ValueError("Unsupported strategy. Use 'beam_search', 'top_k_sampling', 'top_p_sampling', or 'best_of_n'.")

    # Decode the generated tokens for each prompt in the batch
    generated_texts = []
    all_generated_texts = []

    for i in range(0, len(outputs.sequences), generation_args["num_return_sequences"]):
        prompt_texts = []
        prompt_only = []
        for j in range(generation_args["num_return_sequences"]):
            sequence_index = i + j  # Global index for the current sequence
            output = outputs.sequences[sequence_index]
            prompt_length = input_ids.shape[-1]  # Length of the input prompt
            new_tokens = output[prompt_length:]  # Get only the generated tokens
            generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
            prompt_tokens = output[:prompt_length]
            prompt_text = tokenizer.decode(prompt_tokens, skip_special_tokens=True)
    
            prompt_texts.append(generated_text)
            prompt_only.append(prompt_text)

        generated_texts.append(prompt_texts)
        all_generated_texts.extend(prompt_only)
    
    return generated_texts, all_generated_texts


In [2]:
import pickle 
 

In [3]:
output_folder = "diplomacy_prompts/init_generations"

# Load existing results
results = pickle.load(open(f"{output_folder}/combined_results.pkl", "rb"))

In [4]:
def analyze_data_structure(all_data):
    print("==== OVERALL STRUCTURE ====")
    print(f"Type of all_data: {type(all_data)}")
    print(f"Number of top-level items: {len(all_data)}")
    
    # Check the first few items
    print("\n==== SAMPLE OF FIRST FEW ITEMS ====")
    for i, item in enumerate(all_data[:3]):  # Just check first 3 items
        print(f"\nItem {i+1}:")
        print(f"  Type: {type(item)}")
        
        if isinstance(item, dict):
            print("  Dictionary keys:")
            for key in list(item.keys())[:10]:  # Show up to 10 keys
                value = item[key]
                print(f"    {key}: {type(value)}")
                
                # If the value is complex, show a bit more detail
                if isinstance(value, dict):
                    print(f"      Sample keys: {list(value.keys())[:3]}...")
                elif isinstance(value, list):
                    print(f"      Length: {len(value)}")
                    if len(value) > 0:
                        print(f"      First element type: {type(value[0])}")
        
        elif isinstance(item, tuple):
            print(f"  Tuple length: {len(item)}")
            for j, element in enumerate(item):
                print(f"    Element {j+1} type: {type(element)}")
                if isinstance(element, dict):
                    print(f"      Sample keys: {list(element.keys())[:5]}...")
        
        elif isinstance(item, list):
            print(f"  List length: {len(item)}")
            if len(item) > 0:
                print(f"  First element type: {type(item[0])}")
                if isinstance(item[0], dict):
                    print(f"    Sample keys: {list(item[0].keys())[:5]}...")
    
    # Check if the data structure is consistent
    print("\n==== CONSISTENCY CHECK ====")
    types_count = {}
    for item in all_data:
        item_type = type(item).__name__
        if item_type not in types_count:
            types_count[item_type] = 0
        types_count[item_type] += 1
    
    print("Distribution of item types:")
    for item_type, count in types_count.items():
        print(f"  {item_type}: {count} items ({count/len(all_data)*100:.1f}%)")
    
    # If there are dictionaries, check key consistency
    if 'dict' in types_count and types_count['dict'] > 0:
        print("\nChecking dictionary key consistency:")
        all_keys = set()
        key_counts = {}
        
        for item in all_data:
            if isinstance(item, dict):
                for key in item.keys():
                    all_keys.add(key)
                    if key not in key_counts:
                        key_counts[key] = 0
                    key_counts[key] += 1
        
        print(f"  Total unique keys across all dictionaries: {len(all_keys)}")
        print("  Most common keys:")
        for key, count in sorted(key_counts.items(), key=lambda x: x[1], reverse=True)[:10]:
            print(f"    {key}: {count} occurrences ({count/types_count['dict']*100:.1f}% of dicts)")

# Usage
analyze_data_structure(results)

==== OVERALL STRUCTURE ====
Type of all_data: <class 'list'>
Number of top-level items: 6

==== SAMPLE OF FIRST FEW ITEMS ====

Item 1:
  Type: <class 'tuple'>
  Tuple length: 2
    Element 1 type: <class 'dict'>
      Sample keys: ['model_name', 'samples']...
    Element 2 type: <class 'dict'>
      Sample keys: ['file_name', 'file_size', 'file_modified', 'model_type']...

Item 2:
  Type: <class 'tuple'>
  Tuple length: 2
    Element 1 type: <class 'dict'>
      Sample keys: ['model_name', 'samples']...
    Element 2 type: <class 'dict'>
      Sample keys: ['file_name', 'file_size', 'file_modified', 'model_type']...

Item 3:
  Type: <class 'tuple'>
  Tuple length: 2
    Element 1 type: <class 'dict'>
      Sample keys: ['model_name', 'samples']...
    Element 2 type: <class 'dict'>
      Sample keys: ['file_name', 'file_size', 'file_modified', 'model_type']...

==== CONSISTENCY CHECK ====
Distribution of item types:
  tuple: 6 items (100.0%)


In [5]:
def examine_model_samples(all_data):
    print("==== EXAMINING MODEL SAMPLES ====")
    
    for file_idx, file_data in enumerate(all_data[:3]):  # Look at first 3 files
        if isinstance(file_data, tuple) and len(file_data) >= 2:
            # Extract data and metadata
            data, metadata = file_data[0], file_data[1]
            model_type = metadata.get('model_type', 'Unknown')
            file_name = metadata.get('file_name', 'Unknown')
        else:
            # Handle case where the structure is different
            data = file_data
            model_type = file_data.get('_model_type', 'Unknown') if isinstance(file_data, dict) else 'Unknown'
            file_name = file_data.get('_file_name', 'Unknown') if isinstance(file_data, dict) else 'Unknown'
        
        print(f"\n=== FILE {file_idx+1}: {file_name} ({model_type}) ===")
        
        # Try to access the samples based on different possible structures
        samples = None
        if isinstance(data, dict):
            if 'samples' in data:
                samples = data['samples']
            elif 'results' in data and isinstance(data['results'], list):
                samples = data['results']
        
        if not samples and isinstance(data, list):
            # If the data itself is a list of samples
            samples = data
        
        if not samples:
            print("Could not find samples in this data structure")
            continue
        
        print(f"Found {len(samples)} samples")
        
        # Look at the first sample
        if len(samples) > 0:
            sample = samples[0]
            print("\nExamining first sample:")
            
            # Try to access prompt_results
            prompt_results = None
            if isinstance(sample, dict):
                if 'prompt_results' in sample:
                    prompt_results = sample['prompt_results']
                elif 'results' in sample:
                    prompt_results = sample['results']
            
            if not prompt_results and isinstance(sample, list):
                prompt_results = sample
            
            if not prompt_results:
                print("Could not find prompt_results in this sample")
                continue
            
            print(f"Found {len(prompt_results)} prompt results")
            
            # Look at the first prompt result
            if len(prompt_results) > 0:
                result = prompt_results[0]
                print("\nExamining first prompt result:")
                
                # Try to access parsed_components
                parsed_components = None
                if isinstance(result, dict):
                    if 'parsed_components' in result:
                        parsed_components = result['parsed_components']
                    elif 'components' in result:
                        parsed_components = result['components']
                
                if not parsed_components and isinstance(result, dict):
                    # Try to check if friction/rationale/belief are directly in the result
                    if any(k in result for k in ['friction', 'rationale', 'belief_state']):
                        parsed_components = result
                
                if not parsed_components:
                    print("Could not find parsed_components in this result")
                    # Show the available keys in the result
                    if isinstance(result, dict):
                        print(f"Available keys: {list(result.keys())}")
                    continue
                
                # Now examine the components we're interested in
                for component in ['friction', 'rationale', 'belief_state']:
                    if component in parsed_components:
                        value = parsed_components[component]
                        print(f"\n{component.upper()}:")
                        if isinstance(value, str):
                            print(f"First 100 chars: {value[:100]}...")
                        else:
                            print(f"Type: {type(value)}")
                
                # Also show the generated text if available
                if isinstance(result, dict) and 'generated_text' in result:
                    print("\nGENERATED_TEXT:")
                    print(f"First 100 chars: {result['generated_text'][:100]}...")

# Usage
examine_model_samples(results)

==== EXAMINING MODEL SAMPLES ====

=== FILE 1: results_sft_deli_multiturn_rogue_cleaned_checkpoint-3000.pkl (SFT_multiturn) ===
Found 500 samples

Examining first sample:
Found 3 prompt results

Examining first prompt result:

FRICTION:
First 100 chars: Friction Intervention: "Consider the potential risks of building up in the Balkans while neglecting ...

RATIONALE:
First 100 chars: The player's recommended orders show a focus on supporting AUSTRIA's military build-up in the Balkan...

BELIEF_STATE:
First 100 chars: The player's strategy appears to focus on building up their forces in the Balkans without considerin...

GENERATED_TEXT:
First 100 chars:  <rationale> The player's recommended orders show a focus on supporting AUSTRIA's military build-up ...

=== FILE 2: results_DELI_faaf_weights_checkpoint-2000.pkl (FAAF) ===
Found 500 samples

Examining first sample:
Found 3 prompt results

Examining first prompt result:

FRICTION:
First 100 chars: Your strategy seems to be heavily focus

In [6]:
results[0][0]['samples'][0]['prompt_results'][0].keys()
dict_keys(['prompt_level', 'raw_prompt', 'generated_text', 'parsed_components'])
results[0][0]['samples'][0]['prompt_results'][0]['parsed_components']

{'friction': 'Friction Intervention: "Consider the potential risks of building up in the Balkans while neglecting other parts of the board. How can you balance your strategic focus to ensure a well-rounded approach that also addresses other crucial areas of the board?',
 'rationale': "The player's recommended orders show a focus on supporting AUSTRIA's military build-up in the Balkans, which is a crucial step in their overall strategy. However, there might be an oversight in considering how AUSTRIA's moves affect the overall board dynamics and potential opportunities or risks for other players.",
 'belief_state': "The player's strategy appears to focus on building up their forces in the Balkans without considering potential threats or opportunities on other fronts. This might lead to an imbalance in their strategy."}


NameError: name 'dict_keys' is not defined

In [22]:
results[0][0]['samples'][0]['prompt_results'][0]['parsed_components']

{'friction': 'Friction Intervention: "Consider the potential risks of building up in the Balkans while neglecting other parts of the board. How can you balance your strategic focus to ensure a well-rounded approach that also addresses other crucial areas of the board?',
 'rationale': "The player's recommended orders show a focus on supporting AUSTRIA's military build-up in the Balkans, which is a crucial step in their overall strategy. However, there might be an oversight in considering how AUSTRIA's moves affect the overall board dynamics and potential opportunities or risks for other players.",
 'belief_state': "The player's strategy appears to focus on building up their forces in the Balkans without considering potential threats or opportunities on other fronts. This might lead to an imbalance in their strategy."}

{'friction': 'Friction Intervention: "Consider the potential risks of building up in the Balkans while neglecting other parts of the board. How can you balance your strategic focus to ensure a well-rounded approach that also addresses other crucial areas of the board?',
 'rationale': "The player's recommended orders show a focus on supporting AUSTRIA's military build-up in the Balkans, which is a crucial step in their overall strategy. However, there might be an oversight in considering how AUSTRIA's moves affect the overall board dynamics and potential opportunities or risks for other players.",
 'belief_state': "The player's strategy appears to focus on building up their forces in the Balkans without considering potential threats or opportunities on other fronts. This might lead to an imbalance in their strategy."}

In [8]:
import pickle
import os
import pandas as pd

# List of files to load
files = [
    "/s/chopin/d/proj/ramfis-aida/XE/diplomacy_prompts/init_generations/results_sft_deli_multiturn_rogue_cleaned_checkpoint-3000.pkl",
    "/s/chopin/d/proj/ramfis-aida/XE/diplomacy_prompts/init_generations/results_DELI_faaf_weights_checkpoint-2000.pkl",
    "/s/chopin/d/proj/ramfis-aida/XE/diplomacy_prompts/init_generations/results_DELI_dpo_weights_checkpoint-3500.pkl",
    "/s/chopin/d/proj/ramfis-aida/XE/diplomacy_prompts/init_generations/results_DELI_ipo_weights_checkpoint-4500.pkl",
    "/s/chopin/d/proj/ramfis-aida/XE/diplomacy_prompts/init_generations/results_DELI_sft_weights_checkpoint-small-1500.pkl",
    "/s/chopin/d/proj/ramfis-aida/XE/diplomacy_prompts/init_generations/results_DELI_ppo_weights_ppo_checkpoint_epoch_1_batch_800.pkl"
]

# Dictionary to store the data from each file along with metadata
all_data = []

# Load each file and store the data
for file_path in files:
    try:
        # Get file metadata
        file_name = os.path.basename(file_path)
        file_size = os.path.getsize(file_path)
        file_modified = os.path.getmtime(file_path)
        
        # Load the pickle file
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
        
        # Extract the model type from the filename
        if "sft_deli_multiturn" in file_name:
            model_type = "SFT_multiturn"
        elif "faaf_weights" in file_name:
            model_type = "FAAF"
        elif "dpo_weights" in file_name:
            model_type = "DPO"
        elif "ipo_weights" in file_name:
            model_type = "IPO"
        elif "sft_weights" in file_name:
            model_type = "SFT"
        elif "ppo_weights" in file_name:
            model_type = "PPO"
        else:
            model_type = "Unknown"
        
        # Add metadata to the data
        if isinstance(data, dict):
            # If data is already a dictionary, add metadata
            metadata = {
                'file_name': file_name,
                'file_size': file_size,
                'file_modified': file_modified,
                'model_type': model_type
            }
            # Store as tuple (data, metadata)
            all_data.append((data, metadata))
        elif isinstance(data, list):
            # If data is a list, convert to list of dictionaries with metadata
            for item in data:
                if isinstance(item, dict):
                    item_with_meta = item.copy()
                    item_with_meta['_file_name'] = file_name
                    item_with_meta['_file_size'] = file_size
                    item_with_meta['_file_modified'] = file_modified
                    item_with_meta['_model_type'] = model_type
                    all_data.append(item_with_meta)
                else:
                    # If list items aren't dictionaries, store with metadata separately
                    all_data.append({
                        'data': item,
                        '_file_name': file_name,
                        '_file_size': file_size,
                        '_file_modified': file_modified,
                        '_model_type': model_type
                    })
        else:
            # For other data types, create a dictionary with the data and metadata
            all_data.append({
                'data': data,
                '_file_name': file_name,
                '_file_size': file_size,
                '_file_modified': file_modified,
                '_model_type': model_type
            })
            
        print(f"Successfully loaded: {file_name} ({model_type})")
        
    except Exception as e:
        print(f"Error loading {file_path}: {e}")

# Save the combined data
output_path = "/s/chopin/d/proj/ramfis-aida/XE/diplomacy_prompts/init_generations/combined_results.pkl"
with open(output_path, 'wb') as f:
    pickle.dump(all_data, f)

print(f"\nCombined data saved to: {output_path}")
print(f"Total items in combined file: {len(all_data)}")

# Create a summary of the data
print("\nSummary of combined data:")
model_types = {}
for item in all_data:
    if isinstance(item, dict) and '_model_type' in item:
        model_type = item['_model_type']
    elif isinstance(item, tuple) and len(item) > 1 and 'model_type' in item[1]:
        model_type = item[1]['model_type']
    else:
        model_type = "Unknown"
        
    if model_type not in model_types:
        model_types[model_type] = 0
    model_types[model_type] += 1

for model_type, count in model_types.items():
    print(f"  {model_type}: {count} items")

Successfully loaded: results_sft_deli_multiturn_rogue_cleaned_checkpoint-3000.pkl (SFT_multiturn)
Successfully loaded: results_DELI_faaf_weights_checkpoint-2000.pkl (FAAF)
Successfully loaded: results_DELI_dpo_weights_checkpoint-3500.pkl (DPO)
Successfully loaded: results_DELI_ipo_weights_checkpoint-4500.pkl (IPO)
Successfully loaded: results_DELI_sft_weights_checkpoint-small-1500.pkl (SFT)
Successfully loaded: results_DELI_ppo_weights_ppo_checkpoint_epoch_1_batch_800.pkl (PPO)

Combined data saved to: /s/chopin/d/proj/ramfis-aida/XE/diplomacy_prompts/init_generations/combined_results.pkl
Total items in combined file: 6

Summary of combined data:
  SFT_multiturn: 1 items
  FAAF: 1 items
  DPO: 1 items
  IPO: 1 items
  SFT: 1 items
  PPO: 1 items


In [27]:
print(results[0][0]['samples'][0]['prompt_results'][0])

{'prompt_level': 'Prompt_level1', 'raw_prompt': "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n        You are an expert assistant specializing in the Diplomacy board game and collaborative reasoning. Your role is to analyze:\n        1. The current board state.\n        2. The recommended orders for the player.\n\n        Your goal is to identify potential friction points in the player's strategy and provide a constructive intervention:\n\n        <belief_state>\n        Identify contradictions, oversights, or strategic misalignments in the recommended orders compared to the board state.\n        </belief_state>\n\n        <rationale>\n        Explain why an intervention is needed—what's misaligned, its potential impact, and how alternative thinking could benefit the player.\n        </rationale>\n\n        <friction>\n        Provide a concise friction intervention (max 2 sentences) that prompts self-reflection about their strategy. The first sentence should directly a

In [7]:
import random

def print_random_samples(all_data, num_samples=5):
    print("==== RANDOM SAMPLE RESULTS ====\n")
    
    # Collect all samples from all files
    all_samples = []
    
    for file_data in all_data:
        # Extract the data depending on structure
        if isinstance(file_data, tuple) and len(file_data) >= 2:
            data, metadata = file_data[0], file_data[1]
            model_type = metadata.get('model_type', 'Unknown')
        else:
            data = file_data
            model_type = file_data.get('_model_type', 'Unknown') if isinstance(file_data, dict) else 'Unknown'
        
        # Try to access the samples based on different possible structures
        samples = None
        if isinstance(data, dict):
            if 'samples' in data:
                samples = data['samples']
            elif 'results' in data and isinstance(data['results'], list):
                samples = data['results']
        
        if not samples and isinstance(data, list):
            # If the data itself is a list of samples
            samples = data
        
        if samples:
            # Add metadata to each sample
            for sample in samples:
                if isinstance(sample, dict):
                    sample_with_meta = {'sample': sample, 'model_type': model_type}
                    all_samples.append(sample_with_meta)
    
    print(f"Total samples collected: {len(all_samples)}")
    
    # Select random samples
    if len(all_samples) > 0:
        random_samples = random.sample(all_samples, min(num_samples, len(all_samples)))
        
        for i, sample_data in enumerate(random_samples):
            print(f"\n{'='*80}")
            print(f"SAMPLE {i+1} (Model: {sample_data['model_type']})")
            print(f"{'='*80}\n")
            
            sample = sample_data['sample']
            
            # Get prompt results
            prompt_results = None
            if isinstance(sample, dict):
                if 'prompt_results' in sample:
                    prompt_results = sample['prompt_results']
                elif 'results' in sample:
                    prompt_results = sample['results']
            
            if not prompt_results and isinstance(sample, list):
                prompt_results = sample
            
            if not prompt_results:
                print("Could not find prompt results in this sample")
                continue
            
            # Select a random prompt result
            if len(prompt_results) > 0:
                result = random.choice(prompt_results)
                
                # Print prompt level if available
                if isinstance(result, dict) and 'prompt_level' in result:
                    print(f"PROMPT LEVEL: {result['prompt_level']}")
                    print("-" * 40)
                
                # Print raw prompt if available
                if isinstance(result, dict) and 'raw_prompt' in result:
                    print("RAW PROMPT:")
                    print("-" * 40)
                    # Split by newlines to make it more readable
                    prompt_lines = result['raw_prompt'].split('\n')
                    for line in prompt_lines:
                        print(line)
                    print("-" * 40)
                
                # Get parsed components
                parsed_components = None
                if isinstance(result, dict):
                    if 'parsed_components' in result:
                        parsed_components = result['parsed_components']
                    elif 'components' in result:
                        parsed_components = result['components']
                
                if not parsed_components and isinstance(result, dict):
                    # Try to check if friction/rationale/belief are directly in the result
                    if any(k in result for k in ['friction', 'rationale', 'belief_state']):
                        parsed_components = result
                
                # Print components
                if parsed_components:
                    for component in ['friction', 'rationale', 'belief_state']:
                        if component in parsed_components:
                            print(f"\n{component.upper()}:")
                            print("-" * 40)
                            value = parsed_components[component]
                            if isinstance(value, str):
                                # Split by newlines to make it more readable
                                component_lines = value.split('\n')
                                for line in component_lines:
                                    print(line)
                            else:
                                print(f"(Not a string: {type(value)})")
                            print("-" * 40)
                
                # Print generated text
                if isinstance(result, dict) and 'generated_text' in result:
                    print("\nGENERATED TEXT:")
                    print("-" * 40)
                    # Split by newlines to make it more readable
                    text_lines = result['generated_text'].split('\n')
                    for line in text_lines:
                        print(line)
                    print("-" * 40)
    else:
        print("No samples found in the data")

# Usage
print_random_samples(results, num_samples=5)

==== RANDOM SAMPLE RESULTS ====

Total samples collected: 3000

SAMPLE 1 (Model: PPO)

PROMPT LEVEL: Prompt_level2
----------------------------------------
RAW PROMPT:
----------------------------------------
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
        You are an expert assistant specializing in the Diplomacy board game and collaborative reasoning. Your role is to analyze:
        1. The current board state.
        2. The recommended orders for the player.
        3. The potential orders for every power.
        

        Your goal is to identify potential friction points in the player's strategy and provide a constructive intervention:

        <belief_state>
        Identify contradictions, oversights, or strategic misalignments in the recommended orders compared to the board state and other powers' likely moves.
        </belief_state>

        <rationale>
        Explain why an intervention is needed—what's misaligned, its potential impact, and how alternat

## RM Experiments for CVPR friction


In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import pandas as pd
import json
import os
import time
import pickle
 
import gc
 
from tqdm import tqdm
import logging
from datetime import datetime
from datasets import load_dataset, load_from_disk
from peft import AutoPeftModelForCausalLM, LoraConfig,PeftModel
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    set_seed,
)
# some experiments for the Llama model on safety
import os
import sys
import socket
import re
import random
import numpy as np
import torch
import pickle
from collections import defaultdict
from itertools import combinations
from typing import Dict, List, Optional
from dataclasses import dataclass, asdict

import re

def parse_tags_robust(text, tags=None):
    """
    Enhanced parsing function that handles XML-style tags with improved robustness 
    for Diplomacy-specific content.
    
    Args:
        text (str): The text to parse.
        tags (list): List of tags to parse. If None, defaults to standard tags.
        
    Returns:
        dict: Dictionary with tag names as keys and lists of extracted content as values.
    """
    if tags is None:
        tags = ["friction", "rationale", "belief_state"]
    
    # Initialize result
    result = {tag: [] for tag in tags}
    
    # Tag mapping for variations
    tag_variations = {
        "belief_state": ["belief_state", "belief state", "beliefstate", "belief-state", "b"],
        "rationale": ["rationale", "rational", "reasoning", "reason", "r"],
        "friction": ["friction", "intervention", "advice", "f"]
    }
    
    # Reverse mapping for normalization
    tag_normalization = {}
    for normalized, variations in tag_variations.items():
        for variation in variations:
            tag_normalization[variation.lower()] = normalized
    
    # First attempt: Try exact XML-style tags
    for tag in tags:
        # Try variations of the tag name
        possible_tags = [tag]
        if tag in tag_variations:
            possible_tags.extend(tag_variations[tag])
        
        found = False
        for tag_var in possible_tags:
            # Try with and without underscores/hyphens
            tag_patterns = [
                tag_var,
                tag_var.replace("_", " "),
                tag_var.replace("-", " "),
                tag_var.replace(" ", "_"),
                tag_var.replace(" ", "-")
            ]
            
            for t_pattern in tag_patterns:
                # Match both <tag> and <tag></tag> patterns
                open_tag_pattern = f"<{t_pattern}>"
                close_tag_pattern = f"</{t_pattern}>"
                open_only_pattern = re.escape(open_tag_pattern) + r"(.*?)(?=" + re.escape(close_tag_pattern) + r"|\Z)"
                
                matches = re.findall(open_only_pattern, text, re.DOTALL | re.IGNORECASE)
                if matches:
                    result[tag].extend([m.strip() for m in matches])
                    found = True
                    break
            
            if found:
                break
    
    # Second attempt: Try markdown-style headings and formatting
    for tag in tags:
        if not result[tag]:  # Only if we didn't find anything in first pass
            possible_tags = [tag]
            if tag in tag_variations:
                possible_tags.extend(tag_variations[tag])
            
            for tag_var in possible_tags:
                # Various markdown and formatting patterns
                patterns = [
                    # Markdown headings
                    r'#{1,3}\s*' + re.escape(tag_var) + r'[:\s]*\n(.*?)(?=\n#|\n\n|\Z)',
                    # Bold with asterisks
                    r'\*\*\s*' + re.escape(tag_var) + r'[:\s]*\*\*\s*\n(.*?)(?=\n\*\*|\n\n|\Z)',
                    # Bold with underscores
                    r'__\s*' + re.escape(tag_var) + r'[:\s]*__\s*\n(.*?)(?=\n__|\n\n|\Z)',
                    # Plain text with colon
                    r'\b' + re.escape(tag_var) + r'[:\s]+(.*?)(?=\n\n|\Z)'
                ]
                
                for pattern in patterns:
                    matches = re.findall(pattern, text, re.DOTALL | re.IGNORECASE)
                    if matches:
                        result[tag].extend([m.strip() for m in matches])
                        break
    
    # Third attempt: Content-specific extraction based on sentence patterns
    for tag in tags:
        if not result[tag]:  # Only if we still didn't find anything
            if tag == "friction" or tag in tag_variations.get("friction", []):
                # Look for sentences that suggest a friction intervention
                question_patterns = [
                    # Questions about strategy
                    r'(?:Have you considered|Did you think about|What if|Could you|Why not).*?\?',
                    # Suggestions phrased as questions
                    r'(?:Wouldn\'t it be|Might it be|Isn\'t there).*?\?',
                ]
                
                # Extract any matching sentences
                for pattern in question_patterns:
                    matches = re.findall(pattern, text, re.IGNORECASE)
                    if matches:
                        # Take the first 2 matching questions as friction
                        combined = " ".join(matches[:2]).strip()
                        if combined:
                            result[tag].append(combined)
                            break
    
    # Fourth attempt: Fall back to content structure analysis
    for tag in tags:
        if not result[tag]:  # If still nothing found
            # Split by double newlines to get paragraphs
            paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
            
            if tag == "friction" and paragraphs:
                if len(paragraphs) >= 3:
                    # Friction is often at the end
                    result[tag].append(paragraphs[-1])
                else:
                    # Just use the last paragraph
                    result[tag].append(paragraphs[-1])
            
            elif tag == "rationale" and paragraphs:
                if len(paragraphs) >= 3:
                    # Rationale is often in the middle
                    result[tag].append(paragraphs[len(paragraphs)//2])
                elif len(paragraphs) == 2:
                    # Use first paragraph
                    result[tag].append(paragraphs[0])
            
            elif tag == "belief_state" and paragraphs:
                # Belief state is often at the beginning
                result[tag].append(paragraphs[0])
    
    # Final cleanup - remove quotes or special formatting from content
    for tag in result:
        result[tag] = [
            re.sub(r'^["\']\s*|\s*["\']$', '', item)  # Remove surrounding quotes
            for item in result[tag]
        ]
    
    return result


def handle_friction_logic(text):
    """
    Extract a meaningful friction intervention from text even when proper tags are missing.
    Improved to handle Diplomacy-specific content patterns.
    
    Args:
        text (str): The input text
        
    Returns:
        str: The extracted friction intervention
    """
    # First check if there's already a well-formatted friction section
    friction_match = re.search(r'<friction>(.*?)(?:</friction>|\Z)', text, re.DOTALL | re.IGNORECASE)
    if friction_match:
        return friction_match.group(1).strip()
    
    # Look for quotes that might be friction interventions
    quote_match = re.search(r'["\'](.*?)["\']', text, re.DOTALL)
    if quote_match and ("?" in quote_match.group(1) or "consider" in quote_match.group(1).lower()):
        return quote_match.group(1).strip()
    
    # Look for questions that might be friction interventions
    questions = re.findall(r'[^.!?]*\?', text)
    strategic_questions = [q for q in questions if any(term in q.lower() for term in 
                                                   ["strategy", "consider", "think", "reflect", 
                                                    "alternative", "option", "might", "could", 
                                                    "why not", "what if", "balance"])]
    if strategic_questions and len(strategic_questions) >= 1:
        return " ".join(strategic_questions[:2]).strip()
    
    # Look for sentences that have Diplomacy-specific strategy terms
    strategy_terms = ["support", "attack", "move", "defend", "ally", "alliance", "secure", 
                      "position", "flank", "threat", "opportunity", "risk", "coordinate"]
    
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text.strip())
    strategic_sentences = [s for s in sentences if any(term in s.lower() for term in strategy_terms)]
    
    if strategic_sentences:
        if len(strategic_sentences) >= 2:
            return f"{strategic_sentences[0]} {strategic_sentences[-1]}"
        else:
            return strategic_sentences[0]
    
    # Fall back to the original logic if nothing better is found
    if sentences:
        if len(sentences) >= 3:
            return f"{sentences[0]} {sentences[-2]} {sentences[-1]}"
        else:
            return " ".join(sentences)
    
    return ""


def process_diplomacy_for_friction(test_data):

    processed_messages = []  # Create a list to store all processed messages
    count = 0
    for message in test_data:
        count = count+1
        sender = message['sender']
        recipient = message['recipient']

        if message['recipient'] == 'ALL' or message['recipient'] == message['sender'] or message["phase"].endswith('A') or message["phase"].endswith('R'):
            continue
        else:
            #board states format
            formatted_board = ""
            for country, units in message['units'].items():
                formatted_board += f"{country}: {', '.join(units)}\n"
            formatted_board = formatted_board.strip()
            #predicted orders for opponent
            opponent_orders = ""
            for country, units in message['predicted_orders'][message['recipient']].items():
                opponent_orders += f"{country}: {', '.join(units)}\n"
            formatted_opponent_orders = opponent_orders.strip()
            #message history
            message_history = ""
            start_index = max(0, len(message['prev_5_message']) - 5)
            for msg_info in message['prev_5_message'][start_index:]:
                sender = msg_info['sender']
                text = msg_info['message']
                message_history += f"Message from {sender}: {text}\n"
            message_history += f"Message from {message['sender']}: {message['message']}"
            #recommended orders for myself.
            recommended_orders = message['predicted_orders'][message['sender']][message['recipient']]
            formatted_recommended_orders = ",\n".join(recommended_orders)

                # Modify the prompts for friction interventions
            Prompt_level1 = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
        You are an expert assistant specializing in the Diplomacy board game and collaborative reasoning. Your role is to analyze:
        1. The current board state.
        2. The recommended orders for the player.

        Your goal is to identify potential friction points in the player's strategy and provide a constructive intervention:

        <belief_state>
        Identify contradictions, oversights, or strategic misalignments in the recommended orders compared to the board state.
        </belief_state>

        <rationale>
        Explain why an intervention is needed—what's misaligned, its potential impact, and how alternative thinking could benefit the player.
        </rationale>

        <friction>
        Provide a concise friction intervention (max 2 sentences) that prompts self-reflection about their strategy. The first sentence should directly address the strategic concern, and the second should offer a reflective question or alternative perspective.
        </friction>
        Your response MUST include all three required XML tags (<belief_state>, <rationale>, and <friction>) with complete content for each.
        <|eot_id|><|start_header_id|>user<|end_header_id|>
        **Board State:**  
        {formatted_board}

        **Recommended Orders:**  
        {formatted_recommended_orders}

        **Advice Request:**  
        You are advising the player controlling {message['recipient']}. Identify potential friction points and provide constructive intervention.
        <|eot_id|><|start_header_id|>assistant<|end_header_id|>
        """

            Prompt_level2 = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
        You are an expert assistant specializing in the Diplomacy board game and collaborative reasoning. Your role is to analyze:
        1. The current board state.
        2. The recommended orders for the player.
        3. The potential orders for every power.
        

        Your goal is to identify potential friction points in the player's strategy and provide a constructive intervention:

        <belief_state>
        Identify contradictions, oversights, or strategic misalignments in the recommended orders compared to the board state and other powers' likely moves.
        </belief_state>

        <rationale>
        Explain why an intervention is needed—what's misaligned, its potential impact, and how alternative thinking could benefit the player.
        </rationale>

        <friction>
        Provide a concise friction intervention (max 2 sentences) that prompts self-reflection about their strategy. The first sentence should directly address the strategic concern, and the second should offer a reflective question or alternative perspective.
        </friction>
        Your response MUST include all three required XML tags (<belief_state>, <rationale>, and <friction>) with complete content for each.
        <|eot_id|><|start_header_id|>user<|end_header_id|>
        **Board State:**  
        {formatted_board}

        **Recommended Orders for {message['recipient']}:**  
        {formatted_recommended_orders}

        **Potential Orders for other powers:**
        {formatted_opponent_orders}

        **Advice Request:**  
        You are advising the player controlling {message['recipient']}. Identify potential friction points and provide constructive intervention.
        <|eot_id|><|start_header_id|>assistant<|end_header_id|>
        """

            Prompt_level3 = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
        You are an expert assistant specializing in the Diplomacy board game and collaborative reasoning. Your role is to analyze:
        1. The current board state.
        2. The recommended orders for the player.
        3. The potential orders for every power.
        4. The message history between the player and other players.

        Your goal is to identify potential friction points in the player's strategy and diplomatic approach, then provide a constructive intervention:

        <belief_state>
        Identify contradictions, oversights, or misalignments in the recommended orders compared to the board state, other powers' likely moves, and diplomatic communications.
        </belief_state>

        <rationale>
        Explain why an intervention is needed—what's misaligned in strategy or diplomacy, its potential impact, and how alternative thinking could benefit the player.
        </rationale>

        <friction>
        Provide a concise friction intervention (max 2 sentences) that prompts self-reflection about their strategy or diplomatic approach. The first sentence should directly address the concern, and the second should offer a reflective question or alternative perspective.
        </friction>
        Your response MUST include all three required XML tags (<belief_state>, <rationale>, and <friction>) with complete content for each.
        <|eot_id|><|start_header_id|>user<|end_header_id|>
        **Board State:**  
        {formatted_board}

        **Recommended Orders for {message['recipient']}:**  
        {formatted_recommended_orders}

        **Potential Orders for other powers:**
        {formatted_opponent_orders}

        **Message History:**
        {message_history}

        **Advice Request:**  
        You are advising the player controlling {message['recipient']}. Identify potential friction points and provide constructive intervention.
        <|eot_id|><|start_header_id|>assistant<|end_header_id|>
        """
            
            processed_message = message.copy()
         # Add the prompts to the copied message
            processed_message['Prompt_level1'] = Prompt_level1
            processed_message['Prompt_level2'] = Prompt_level2
            processed_message['Prompt_level3'] = Prompt_level3
             # Add this processed message to our list
            processed_messages.append(processed_message)

    return processed_messages

def generate_multiple_sequences_with_intrinsic_metrics(model, tokenizer, prompts, generation_args, device, 
                                                       strategy="beam_search", batched=False, 
                                                       reward_model=None, best_of_n=None, top_k_candidates=1, rm_tokenizer = None
                                                      , rm_max_length = None):
    """
    Generate multiple sequences using various strategies including best-of-N sampling.
    
    Args:
        model: Language model for generation
        tokenizer: Tokenizer for the model
        prompts: Input prompts
        generation_args: Arguments for generation
        device: Device to place tensors on
        strategy: Generation strategy ("beam_search", "top_k_sampling", "top_p_sampling", or "best_of_n")
        batched: Whether inputs are batched
        reward_model: Reward model for scoring in best-of-N sampling (AutoModelForSequenceClassification)
        best_of_n: Number of samples to generate for best-of-N sampling (default: None)
        top_k_candidates: Number of top candidates to return from best-of-N sampling (default: 1)
        
    Returns:
        generated_texts: List of generated texts
        all_generated_texts: List of all generated texts
    """
    if batched:
        tokenizer.pad_token = "<|reserved_special_token_0|>"  # new pad token for this run
        tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
        tokenizer.padding_side = 'right'

        cleaned_prompts = prompts.replace("\n", " ")  
        inputs = tokenizer(cleaned_prompts, return_tensors="pt", padding=True, truncation=True).to(model.device)
    else:
        tokenizer.pad_token = "<|reserved_special_token_0|>"  # new pad token for this run
        tokenizer.pad_token_id = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
        tokenizer.padding_side = 'right'
        inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(device)

    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']
    
    # Handle best-of-N sampling strategy
    if strategy == "best_of_n":
        if reward_model is None:
            raise ValueError("Reward model must be provided for best-of-N sampling")
        
        if best_of_n is None or best_of_n <= 0:
            best_of_n = 4  # Default sample size
        
        with torch.no_grad():
            # Generate multiple candidates for each prompt
            all_candidates = []
            all_prompt_candidates = []
            
            # Use top_p or top_k sampling to generate diverse candidates
            sampling_strategy = generation_args.get("sampling_strategy", "top_p_sampling")
            # print("BON sampling strategy", sampling_strategy)
            for _ in range(best_of_n):
                if sampling_strategy == "top_p_sampling": 
                    print("RUNNING TopP sampling for BON")
                    outputs = model.generate(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        max_length=input_ids.size(1) + generation_args["max_new_tokens"],
                        temperature=generation_args.get("temperature", 0.7),
                        top_p=generation_args.get("top_p", 0.9),
                        do_sample=True,
                        num_return_sequences=1,
                        eos_token_id=tokenizer.eos_token_id,
                        pad_token_id=tokenizer.pad_token_id,
                        return_dict_in_generate=True,
                    )
                else:  # Default to top_k_sampling
                    outputs = model.generate(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        max_length=input_ids.size(1) + generation_args["max_new_tokens"],
                        temperature=generation_args.get("temperature", 0.7),
                        top_k=generation_args.get("top_k", 50),
                        do_sample=True,
                        num_return_sequences=1,
                        eos_token_id=tokenizer.eos_token_id,
                        pad_token_id=tokenizer.pad_token_id,
                        return_dict_in_generate=True,
                    )
                
                # Process the generated sequence
                for i in range(len(outputs.sequences)):
                    sequence = outputs.sequences[i]
                    prompt_length = input_ids.shape[-1]
                    new_tokens = sequence[prompt_length:]
                    generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
                    
                    if len(all_candidates) <= i:
                        all_candidates.append([])
                    
                    all_candidates[i].append(generated_text)
            
            # Score candidates with the reward model
            best_candidates = []
            all_prompt_candidates = []
            parsed_candidates = []
            tags_for_parsing = ["friction", "rationale", "t", "b"]  
            for candidate_index, candidates in enumerate(all_candidates):
                print(f"\n=== Processing candidates for prompt {candidate_index} ===")
                print(f"Number of candidates: {len(candidates)}")

                #parse the generated model outputs to get the friction + rationale

                        
                for candidate in candidates:
                    parsed_frictive_states_and_friction = parse_tags_robust(candidate, tags_for_parsing)
                    friction_intervention = ' '.join(parsed_frictive_states_and_friction.get('friction', []))
                    if not friction_intervention:
                        friction_intervention = handle_friction_logic(candidate)

                    rationale = ' '.join(parsed_frictive_states_and_friction.get('rationale', []))
                    friction_and_rationale = rationale + friction_intervention
                    parsed_candidates.append(friction_and_rationale)
                    # print("PARSED friction + rationale",friction_and_rationale )
                # For each candidate, prepare input for reward model
                candidate_inputs = [prompts + " " + f"</s> {candidate} </s>" for candidate in parsed_candidates]
                tokenized_inputs = rm_tokenizer(candidate_inputs, return_tensors="pt", padding=True, truncation=True, max_length=rm_max_length).to(device)
                
                # Get scores from reward model
                reward_outputs = reward_model(**tokenized_inputs) 
                scores = reward_outputs.logits.squeeze(-1)
                
                ## Print all candidates with their scores
                print("\nAll candidates with scores:")
                for i, (candidate, score) in enumerate(zip(candidates, scores)):
                    print(f"Candidate {i}: Score = {score:.4f}")
                    print(f"Text snippet: {candidate[:50]}...")
                
                # Get top-k indices
                if top_k_candidates > len(candidates):
                    top_k_candidates = len(candidates)
                
                # Fix the error by converting bfloat16 to float32 before calling numpy()
                top_result = torch.topk(scores, top_k_candidates)
                top_indices = top_result.indices.cpu().numpy()
                top_values = top_result.values.cpu().float().numpy()  # Convert to float32 first
                
                print(f"\nTop {top_k_candidates} candidates:")
                # Print only the top-k candidates
                for rank, (idx, score) in enumerate(zip(top_indices, top_values)):
                    print(f"Rank {rank+1}: Candidate {idx} with score {score:.4f}")
                    print(f"Text: {candidates[idx][:300]}...")
                
                # Store the chosen candidates for verification
                prompt_best_candidates = [candidates[idx] for idx in top_indices]
                
                # Verification check
                max_score_idx = scores.argmax().item()
                if max_score_idx != top_indices[0]:
                    print(f"WARNING: Discrepancy detected! argmax={max_score_idx} but topk.indices[0]={top_indices[0]}")
                else:
                    print(f"VERIFIED: Top candidate is correctly selected (index {top_indices[0]})")



                
                best_candidates.append(prompt_best_candidates)
                all_prompt_candidates.extend(candidates)
            
            return best_candidates, all_prompt_candidates
    
    # Original strategies
    with torch.no_grad():
        if strategy == "beam_search":
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=input_ids.size(1) + generation_args["max_new_tokens"],
                num_beams=generation_args["num_beams"],
                num_return_sequences=generation_args["num_return_sequences"],
                early_stopping=True,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                return_dict_in_generate=True,
                output_scores=True
            )
        elif strategy == "top_k_sampling":
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=input_ids.size(1) + generation_args["max_new_tokens"],
                temperature=generation_args["temperature"],
                top_k=generation_args["top_k"],
                do_sample=True,
                num_return_sequences=generation_args["num_return_sequences"],
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                min_length=generation_args.get("min_length", 0),
                return_dict_in_generate=True,
                output_scores=True
            )
        elif strategy == "top_p_sampling":
            outputs = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=input_ids.size(1) + generation_args["max_new_tokens"],
                temperature=generation_args["temperature"],
                top_p=generation_args["top_p"],
                do_sample=True,
                num_return_sequences=generation_args["num_return_sequences"],
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
                return_dict_in_generate=True,
            )
        else:
            raise ValueError("Unsupported strategy. Use 'beam_search', 'top_k_sampling', 'top_p_sampling', or 'best_of_n'.")

    # Decode the generated tokens for each prompt in the batch
    generated_texts = []
    all_generated_texts = []

    for i in range(0, len(outputs.sequences), generation_args["num_return_sequences"]):
        prompt_texts = []
        prompt_only = []
        for j in range(generation_args["num_return_sequences"]):
            sequence_index = i + j  # Global index for the current sequence
            output = outputs.sequences[sequence_index]
            prompt_length = input_ids.shape[-1]  # Length of the input prompt
            new_tokens = output[prompt_length:]  # Get only the generated tokens
            generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
            prompt_tokens = output[:prompt_length]
            prompt_text = tokenizer.decode(prompt_tokens, skip_special_tokens=True)
    
            prompt_texts.append(generated_text)
            prompt_only.append(prompt_text)

        generated_texts.append(prompt_texts)
        all_generated_texts.extend(prompt_only)
    
    return generated_texts, all_generated_texts




def process_models_with_formatted_prompts(models_list, test_data, generation_args, output_dir = None):
    """
    Process formatted Diplomacy prompts with multiple models.
    
    Args:
        models_list: List of model paths to process
        generation_args: Arguments for text generation
    """
    os.makedirs(output_dir, exist_ok=True)
    # First, create the formatted prompts
    processed_data = process_diplomacy_for_friction(test_data)
    print("length of processed prompts", len(processed_data))
    
    results_by_model = {}
    
    for agent_model_name in tqdm(models_list, desc="Processing Models"):
        # Format model name for results
        loading_model_name = agent_model_name
        if "/" in agent_model_name:
            parts = agent_model_name.split("/")
            if len(parts) >= 3:
                agent_model_name = parts[1] + "_" + parts[2]
            elif len(parts) == 2:
                agent_model_name = parts[0] + "_" + parts[1]
        
        print(f"\n===== Processing Model: {agent_model_name} =====\n")
        
        # Load model and tokenizer
        try:
            # Load base model
            lora_model = AutoModelForCausalLM.from_pretrained(
                "llama3_8b_instruct",
                device_map="auto",
                low_cpu_mem_usage=True,
                torch_dtype=torch.float16,
                trust_remote_code=True,
            )
            
            # Apply LoRA adapter
            lora_model = PeftModel.from_pretrained(
                lora_model,
                loading_model_name,
                torch_dtype=torch.float16,
                device_map="auto",
                low_cpu_mem_usage=True,
                trust_remote_code=True,
            )
            
            # Merge the model
            print("Merging LoRA adapter...")
            merged_model = lora_model.merge_and_unload()
            
            # Load tokenizer
            tokenizer = AutoTokenizer.from_pretrained(loading_model_name)
            tokenizer.pad_token = "<|reserved_special_token_0|>"
            tokenizer.padding_side = "right"
            
            # Initialize results for this model
            model_results = {
                "model_name": agent_model_name,
                "samples": []
            }
            
             # Process each sample
            for sample_idx, sample in enumerate(processed_data):
                sample_results = {
                    "sample_id": sample_idx,
                    "message_id": sample.get('message_id', sample_idx),
                    "sender": sample.get('sender', ''),
                    "recipient": sample.get('recipient', ''),
                    "prompt_results": []
                }
                # Process each prompt level
                
                for level_key in ['Prompt_level1', 'Prompt_level2', 'Prompt_level3']:
                    if level_key in sample:
                        print(f"Processing {level_key} for sample {sample_idx}")
                        
                        # Generate responses
                        generated_texts, all_generated_texts = generate_multiple_sequences_with_intrinsic_metrics(
                            merged_model, 
                            tokenizer, 
                            sample[level_key], 
                            generation_args, 
                            None,
                            strategy="top_p_sampling", 
                            batched=True, 
                        )
                        
                        # Process the generated text
                        if generated_texts and isinstance(generated_texts, list):
                            text_to_parse = generated_texts[0][0] if (generated_texts[0] and isinstance(generated_texts[0], list)) else generated_texts[0]
                            
                            # Parse components
                            tags_for_parsing = ["friction", "rationale", "belief_state"]
                            parsed_components = parse_tags_robust(text_to_parse, tags_for_parsing)
                            
                            # Extract components
                            friction_intervention = ' '.join(parsed_components.get('friction', []))
                            if not friction_intervention:
                                friction_intervention = handle_friction_logic(text_to_parse)
                            
                            belief_state = ' '.join(parsed_components.get('belief_state', []))
                            rationale = ' '.join(parsed_components.get('rationale', []))
                            
                            # Save results for this prompt level
                            prompt_result = {
                                "prompt_level": level_key,
                                "raw_prompt": sample[level_key],
                                "generated_text": text_to_parse,
                                "parsed_components": {
                                    "friction": friction_intervention,
                                    "rationale": rationale,
                                    "belief_state": belief_state
                                }
                            }
                            
                            sample_results["prompt_results"].append(prompt_result)
                
                # Add this sample's results to the model results
                model_results["samples"].append(sample_results)
      
            
            # Save results for this model to a pickle file
            safe_model_name = agent_model_name.replace('/', '_')
            with open(f"{output_dir}/results_{safe_model_name}.pkl", 'wb') as f:
                pickle.dump(model_results, f)
            
            # Also store in our overall results dictionary
            results_by_model[agent_model_name] = model_results
            
            # Clean up the model to free memory
            del merged_model, tokenizer, lora_model
            torch.cuda.empty_cache()
            gc.collect()
            
        except Exception as e:
            print(f"Error processing model {agent_model_name}: {str(e)}")
            continue
    
    return results_by_model


if __name__ == "__main__":
    #load json raw data for diplomacy prompt construction
    test_path = f"diplomacy_prompts/sample_1K_denis_testdata_with_prompts.json"
    with open(test_path, "r") as file:
        test_data = json.load(file)

    # create the test prompts in friction generation format with xml tags for specific fields
    
    # test_prompts = process_diplomacy_for_friction(test_data[0:500]) # select the first 500 prompts to get more model samples at first. 
    models_list_deli_weights =[
        

#     'WTD_faaf_10',
 
        'sft_deli_multiturn_rogue_cleaned/checkpoint-3000',
    'DELI_all_weights/DELI_faaf_weights/checkpoint-2000',
 
     'DELI_all_weights/DELI_dpo_weights/checkpoint-3500',
 
    'DELI_all_weights/DELI_ipo_weights/checkpoint-4500',

        "DELI_all_weights/DELI_sft_weights/checkpoint-small-1500",
         "DELI_all_weights/DELI_ppo_weights/ppo_checkpoint_epoch_1_batch_800",
 
    ]
    #sample generation arguments -- for sampling, tokens to generate etc
    generation_args = {
    "max_new_tokens": 356,
    "temperature": 0.9,
    "do_sample": True,
    "top_k": 50,
    "top_p": 0.9,
    "num_beams": 5,
    "min_length": 100,
    'num_return_sequences': 1,
    "sampling_strategy": "top_p_sampling"
    }

    output_dir = "diplomacy_prompts/init_generations"
    results_by_model = process_models_with_formatted_prompts(models_list_deli_weights,test_data[0:500], generation_args, output_dir = output_dir)
    

[2025-04-15 01:36:44,385] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)


2025-04-15 01:36:49.105683: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-15 01:36:49.105716: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-15 01:36:49.106763: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


length of processed prompts 500


Processing Models:   0%|                                                                                                                                                                                                 | 0/6 [00:00<?, ?it/s]


===== Processing Model: sft_deli_multiturn_rogue_cleaned_checkpoint-3000 =====



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Merging LoRA adapter...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Processing Prompt_level1 for sample 0
Processing Prompt_level2 for sample 0


Processing Models:   0%|                                                                                                                                                                                                 | 0/6 [00:18<?, ?it/s]


KeyboardInterrupt: 

## Changes asked: instead of friction suggestion, we need to get the friction models to explain why the recommended orders were given. 


In [23]:

def process_diplomacy_for_friction_explanations_only(test_data):
    processed_message = {}
    processed_messages = []  # Create a list to store all processed messages
    count = 0
    for message in test_data:
        count = count+1
        sender = message['sender']
        recipient = message['recipient']

        if message['recipient'] == 'ALL' or message['recipient'] == message['sender'] or message["phase"].endswith('A') or message["phase"].endswith('R'):
            continue
        else:
            #board states format
            formatted_board = ""
            for country, units in message['units'].items():
                formatted_board += f"{country}: {', '.join(units)}\n"
            formatted_board = formatted_board.strip()
            #predicted orders for opponent
            opponent_orders = ""
            for country, units in message['predicted_orders'][message['recipient']].items():
                opponent_orders += f"{country}: {', '.join(units)}\n"
            formatted_opponent_orders = opponent_orders.strip()
            #message history
            message_history = ""
            start_index = max(0, len(message['prev_5_message']) - 5)
            for msg_info in message['prev_5_message'][start_index:]:
                sender = msg_info['sender']
                text = msg_info['message']
                message_history += f"Message from {sender}: {text}\n"
            message_history += f"Message from {message['sender']}: {message['message']}"
            #recommended orders for myself.
            recommended_orders = message['predicted_orders'][message['sender']][message['recipient']]
            formatted_recommended_orders = ",\n".join(recommended_orders)

                # Modify the prompts for friction interventions


            Prompt_level2 = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
        You are an expert assistant specializing in the Diplomacy board game and collaborative reasoning. Your role is to analyze:
        1. The current board state.
        2. The recommended orders for the player.
        3. The potential orders for every power.
        

        Your goal is to identify potential friction points in the player's strategy and provide a constructive intervention:

        <belief_state>
        Identify contradictions, oversights, or strategic misalignments in the recommended orders compared to the board state and other powers' likely moves.
        </belief_state>

        <rationale>
        Explain why an intervention is needed—what's misaligned, its potential impact, and how alternative thinking could benefit the player.
        </rationale>

        <friction>
        Provide a concise friction intervention (max 2 sentences) that prompts self-reflection about their strategy. The first sentence should directly address the strategic concern, and the second should offer a reflective question or alternative perspective.
        </friction>
        Your response MUST include all three required XML tags (<belief_state>, <rationale>, and <friction>) with complete content for each.
        <|eot_id|><|start_header_id|>user<|end_header_id|>
        **Board State:**  
        {formatted_board}

        **Recommended Orders for {message['recipient']}:**  
        {formatted_recommended_orders}

        **Potential Orders for other powers:**
        {formatted_opponent_orders}

        **Advice Request:**  
        You are advising the player controlling {message['recipient']}. Identify potential friction points and provide constructive intervention.
        <|eot_id|><|start_header_id|>assistant<|end_header_id|>
        """

        
            processed_message = message.copy()
         # Add the prompts to the copied message
      
            processed_message['Prompt_level2'] = Prompt_level2
  
             # Add this processed message to our list
            processed_messages.append(processed_message)

    return processed_messages


test_prompts_new = process_diplomacy_for_friction_explanations_only(test_data[0:1])

print(test_prompts_new)
{'Prompt_level2': "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n        You are an expert assistant specializing in the Diplomacy board game and collaborative reasoning. Your role is to analyze:\n        1. The current board state.\n        2. The recommended orders for the player.\n        3. The potential orders for every power.\n        \n\n        Your goal is to identify potential friction points in the player's strategy and provide a constructive intervention:\n\n        <belief_state>\n        Identify contradictions, oversights, or strategic misalignments in the recommended orders compared to the board state and other powers' likely moves.\n        </belief_state>\n\n        <rationale>\n        Explain why an intervention is needed—what's misaligned, its potential impact, and how alternative thinking could benefit the player.\n        </rationale>\n\n        <friction>\n        Provide a concise friction intervention (max 2 sentences) that prompts self-reflection about their strategy. The first sentence should directly address the strategic concern, and the second should offer a reflective question or alternative perspective.\n        </friction>\n        Your response MUST include all three required XML tags (<belief_state>, <rationale>, and <friction>) with complete content for each.\n        <|eot_id|><|start_header_id|>user<|end_header_id|>\n        **Board State:**  \n        AUSTRIA: A BUD, A TRI, F ALB, A SER, A BUL\nENGLAND: A MOS, F MAO, A LVN, F ENG, A MAR, A GAS, F EDI, A BRE\nFRANCE: A POR\nGERMANY: A PAR, A GAL, A VIE, A BOH, F BEL, F HOL, A UKR, A WAR, A TYR, A MUN\nITALY: F WES, A NAF, A VEN, F SPA, A PIE\nRUSSIA: \nTURKEY: F BLA, A RUM, A CON, F AEG\n\n        **Recommended Orders for GERMANY:**  \n        F BEL H,\nF HOL H,\nA PAR H,\nA MUN S A TYR,\nA BOH S A VIE,\nA TYR - TRI,\nA WAR - GAL,\nA UKR S A RUM,\nA GAL - BUD,\nA VIE S A GAL - BUD\n\n        **Potential Orders for other powers:**\n        AUSTRIA: A TRI S A BUD, F ALB S A TRI, A BUD S A TRI, A SER S A BUD, A BUL - GRE\nENGLAND: F EDI H, F ENG - MAO, A BRE H, F MAO - SPA/SC, A GAS S F MAO - SPA, A MAR S F MAO - SPA, A LVN S A MOS, A MOS S A LVN\nFRANCE: A POR H\nITALY: A NAF H, F SPA/SC H, F WES S F SPA/SC, A PIE - MAR, A VEN H\nRUSSIA: \nTURKEY: F BLA S A CON - BUL, A RUM S A GAL - BUD, F AEG - GRE, A CON - BUL\n\n        **Advice Request:**  \n        You are advising the player controlling GERMANY. Identify potential friction points and provide constructive intervention.\n        <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n        "}

In [24]:
diplomacy_detailed_explanation = """
Based on the board state and recommended orders, your strategy appears focused on attacking Sevastopol while also moving into Galicia. However, this leaves you vulnerable to a critical threat: Russia plans to attack Vienna (A BOH - VIE with support from A GAL). 

While Italy currently controls Vienna and their orders show support for your A BUD - GAL move, this leaves Vienna exposed. Losing Vienna, a home center, would be devastating to your position and ability to build new units.

Your eastern offensive against Russia in Sevastopol diverts resources from this more urgent defensive need. Consider redirecting forces to secure Vienna - perhaps keeping A BUD in position or moving A SER to Trieste to provide support. Remember that in Diplomacy, protecting home centers takes priority over gaining new territories when your core position is under threat.

The recommended shift is to "Modify your orders to focus on securing Vienna instead of attacking Sevastopol" because Russia's coordinated attack with support from Galicia represents an immediate threat to your power base, while Sevastopol can be targeted later from a position of greater security.
"""
diplomacy_detailed_explanation

'\nBased on the board state and recommended orders, your strategy appears focused on attacking Sevastopol while also moving into Galicia. However, this leaves you vulnerable to a critical threat: Russia plans to attack Vienna (A BOH - VIE with support from A GAL). \n\nWhile Italy currently controls Vienna and their orders show support for your A BUD - GAL move, this leaves Vienna exposed. Losing Vienna, a home center, would be devastating to your position and ability to build new units.\n\nYour eastern offensive against Russia in Sevastopol diverts resources from this more urgent defensive need. Consider redirecting forces to secure Vienna - perhaps keeping A BUD in position or moving A SER to Trieste to provide support. Remember that in Diplomacy, protecting home centers takes priority over gaining new territories when your core position is under threat.\n\nThe recommended shift is to "Modify your orders to focus on securing Vienna instead of attacking Sevastopol" because Russia\'s co

In [27]:
test_prompts_new = process_diplomacy_for_friction_explanations_only(test_data[0:1])

In [32]:
test_prompts_new[0]

{'Prompt_level2': "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n        You are an expert assistant specializing in the Diplomacy board game and collaborative reasoning. Your role is to analyze:\n        1. The current board state.\n        2. The recommended orders for the player.\n        3. The potential orders for every power.\n        \n\n        Your goal is to identify potential friction points in the player's strategy and provide a constructive intervention:\n\n        <belief_state>\n        Identify contradictions, oversights, or strategic misalignments in the recommended orders compared to the board state and other powers' likely moves.\n        </belief_state>\n\n        <rationale>\n        Explain why an intervention is needed—what's misaligned, its potential impact, and how alternative thinking could benefit the player.\n        </rationale>\n\n        <friction>\n        Provide a concise friction intervention (max 2 sentences) that prompts self-reflecti

In [72]:
def process_diplomacy_for_explanations(test_data, prompt_type="full"):
    """
    Process diplomacy data to create explanation prompts.
    
    Args:
        test_data: List of diplomacy game data
        prompt_type: Either "single" for one order per prompt or "full" for all orders
        
    Returns:
        Dictionary of processed prompts with appropriate keys
    """
    processed_prompts = {}
    overall_index = 0
    
    for message_idx, message in enumerate(test_data):
        # Skip messages that don't need processing
        if (message['recipient'] == 'ALL' or 
            message['recipient'] == message['sender'] or 
            message["phase"].endswith('A') or 
            message["phase"].endswith('R')):
            continue
        
        # Format board state
        formatted_board = ""
        for country, units in message['units'].items():
            formatted_board += f"{country}: {', '.join(units)}\n"
        formatted_board = formatted_board.strip()
        
        # Format opponent orders
        opponent_orders = ""
        for country, units in message['predicted_orders'][message['recipient']].items():
            opponent_orders += f"{country}: {', '.join(units)}\n"
        formatted_opponent_orders = opponent_orders.strip()
        
        # Format message history (keeping this part from original function)
        message_history = ""
        start_index = max(0, len(message['prev_5_message']) - 5)
        for msg_info in message['prev_5_message'][start_index:]:
            sender = msg_info['sender']
            text = msg_info['message']
            message_history += f"Message from {sender}: {text}\n"
        message_history += f"Message from {message['sender']}: {message['message']}"
        
        # Get recommended orders
        recommended_orders = message['predicted_orders'][message['sender']][message['recipient']]
        
        if prompt_type == "full":
            # Create a single prompt for all orders
            formatted_recommended_orders = ",\n".join(recommended_orders)
            
            # Create prompt for full set of orders
            prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are an expert assistant specializing in the Diplomacy board game. Your role is to analyze:
1. The current board state.
2. The recommended orders for the player.
3. The potential orders for every power.

Your goal is to provide a comprehensive explanation of the recommended strategy:

<belief_state>
Outline the key strategic considerations and assumptions that underlie these recommended orders, including relevant board positions, territorial control, and the overall strategic context.
</belief_state>

<rationale>
Explain the strategic logic behind these recommended orders, how they work together to advance the player's position, and why they're optimal compared to alternatives given the current board state and potential moves by other powers.
</rationale>

<friction>
Summarize the most crucial insights about this strategy that the player should understand, highlighting the key moves and their importance to the overall plan.
</friction>

Your response MUST include all three required XML tags (<belief_state>, <rationale>, and <friction>) with complete content for each.
<|eot_id|><|start_header_id|>user<|end_header_id|>
**Board State:**  
{formatted_board}

**Recommended Orders for {message['recipient']}:**  
{formatted_recommended_orders}

**Potential Orders for other powers:**
{formatted_opponent_orders}

**Request:**  
You are advising the player controlling {message['recipient']}. Explain the strategic rationale behind the recommended orders.
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
            
            # Store the prompt with a unique key
            processed_prompts[f'Prompt_level2_game_{message_idx}'] = prompt
            
        elif prompt_type == "single":
            # Create individual prompts for each order
            for order_idx, order in enumerate(recommended_orders):
                prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are an expert assistant specializing in the Diplomacy board game. Your role is to analyze:
1. The current board state.
2. A specific recommended order for the player.
3. The potential orders for every power.

Your goal is to provide a detailed explanation for this specific recommended order:

<belief_state>
Outline the key strategic considerations and assumptions that underlie this specific recommended order, including relevant board positions and how it fits into the overall strategic context.
</belief_state>

<rationale>
Explain the strategic logic behind this specific recommended order, how it advances the player's position, and why it's optimal compared to alternatives given the current board state and potential moves by other powers.
</rationale>

<friction>
Summarize the most crucial insight about this specific order that the player should understand, emphasizing its relationship to the broader game state and strategy.
</friction>

Your response MUST include all three required XML tags (<belief_state>, <rationale>, and <friction>) with complete content for each.
<|eot_id|><|start_header_id|>user<|end_header_id|>
**Board State:**  
{formatted_board}

**Recommended Order for {message['recipient']}:**  
{order}

**Potential Orders for other powers:**
{formatted_opponent_orders}

**Request:**  
You are advising the player controlling {message['recipient']}. Explain the strategic rationale behind this specific recommended order.
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
                # Store the prompt with a unique key that won't repeat
                processed_prompts[f'Prompt_level2_game_{message_idx}_order_{order_idx}'] = prompt
                overall_index += 1
    
    return processed_prompts

In [80]:
full_explanation_prompts = process_diplomacy_for_explanations(test_data[0:1], prompt_type="full")
single_explanation_prompts = process_diplomacy_for_explanations(test_data[0:200], prompt_type="single")

In [78]:
full_explanation_prompts

{'Prompt_level2_game_0': "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are an expert assistant specializing in the Diplomacy board game. Your role is to analyze:\n1. The current board state.\n2. The recommended orders for the player.\n3. The potential orders for every power.\n\nYour goal is to provide a comprehensive explanation of the recommended strategy:\n\n<belief_state>\nOutline the key strategic considerations and assumptions that underlie these recommended orders, including relevant board positions, territorial control, and the overall strategic context.\n</belief_state>\n\n<rationale>\nExplain the strategic logic behind these recommended orders, how they work together to advance the player's position, and why they're optimal compared to alternatives given the current board state and potential moves by other powers.\n</rationale>\n\n<friction>\nSummarize the most crucial insights about this strategy that the player should understand, highlighting the key move

In [13]:
import pickle
# len(single_explanation_prompts)
# single_explanation_prompts.keys()

output_folder = "diplomacy_prompts/init_generations_single_explanation"

# Load existing results
results = pickle.load(open(f"{output_folder}/explanation_results_single_DELI_faaf_weights_checkpoint-2000.pkl", "rb"))
#load all these files pickle and combine them without worrying about duploicates 
#then, check the structure of the log below and parse the recommended orders field (the actual order) from the raw prompt for eacg sample
# and then add to that sample apart from the other fields like game ids etc. 
# finally write debugging and analysis code if this was done correctly. 
/s/chopin/d/proj/ramfis-aida/XE/diplomacy_prompts/init_generations_single_explanation/explanation_results_single_sft_deli_multiturn_rogue_cleaned_checkpoint-3000.pkl,6487045,1745408164000
/s/chopin/d/proj/ramfis-aida/XE/diplomacy_prompts/init_generations_single_explanation/explanation_results_single_DELI_dpo_weights_checkpoint-3500.pkl,7224099,1745422110000
/s/chopin/d/proj/ramfis-aida/XE/diplomacy_prompts/init_generations_single_explanation/explanation_results_single_DELI_sft_weights_checkpoint-small-1500.pkl,5980528,1745429657000
/s/chopin/d/proj/ramfis-aida/XE/diplomacy_prompts/init_generations_single_explanation/explanation_results_single_DELI_ppo_weights_ppo_checkpoint_epoch_1_batch_800.pkl,5519109,1745435084000
/s/chopin/d/proj/ramfis-aida/XE/diplomacy_prompts/init_generations_single_explanation/explanation_results_single_DELI_faaf_weights_checkpoint-2000.pkl,7068394,1745451419000

results.keys
dict_keys(['model_name', 'prompt_type', 'explanations'])
results['explanations'][34]
{'prompt_key': 'Prompt_level2_game_7_order_4',
 'game_idx': 7,
 'order_idx': 4,
 'type': 'single',
 'raw_prompt': '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are an expert assistant specializing in the Diplomacy board game. Your role is to analyze:\n1. The current board state.\n2. A specific recommended order for the player.\n3. The potential orders for every power.\n\n\nIMPORTANT: This is a snapshot of an ongoing Diplomacy game. The board state shows 

In [33]:
import pickle
import os
import re # Import regular expressions module

# --- 1. Define File Paths ---
# Using the paths provided in the prompt
output_folder = "/s/chopin/d/proj/ramfis-aida/XE/diplomacy_prompts/init_generations_single_explanation" # Base folder path

pickle_files = [
    os.path.join(output_folder, "explanation_results_single_sft_deli_multiturn_rogue_cleaned_checkpoint-3000.pkl"),
    os.path.join(output_folder, "explanation_results_single_DELI_dpo_weights_checkpoint-3500.pkl"),
    os.path.join(output_folder, "explanation_results_single_DELI_sft_weights_checkpoint-small-1500.pkl"),
    os.path.join(output_folder, "explanation_results_single_DELI_ppo_weights_ppo_checkpoint_epoch_1_batch_800.pkl"),
    os.path.join(output_folder, "explanation_results_single_DELI_faaf_weights_checkpoint-2000.pkl"),
    # Add the initial file mentioned too, in case it wasn't in the list
#     os.path.join(output_folder, "explanation_results_single_DELI_faaf_weights_checkpoint-2000.pkl") # Appears twice, harmless as per prompt
]

# Filter out duplicates if necessary (though prompt said not to worry)
# pickle_files = list(set(pickle_files)) # Optional: remove duplicate file paths


# --- 2. Load and Combine Data ---
all_explanations = []
loaded_files_count = 0
failed_files = []

print("Loading pickle files...")
for file_path in pickle_files:
    try:
        if not os.path.exists(file_path):
            print(f"Warning: File not found, skipping: {file_path}")
            failed_files.append(file_path + " (Not Found)")
            continue

        with open(file_path, "rb") as f:
            data = pickle.load(f)
            # Check if the loaded data has the expected structure
            if isinstance(data, dict) and 'explanations' in data and isinstance(data['explanations'], list):
                
                # --- MODIFICATION START ---
                model_name = data.get('model_name', os.path.basename(file_path)) # Get model name or use filename as fallback
                
                temp_explanations = []
                for explanation_dict in data['explanations']:
                    # Add the model name to each individual explanation dictionary
                    explanation_dict['source_model'] = model_name 
                    temp_explanations.append(explanation_dict)
                
                all_explanations.extend(temp_explanations) # Extend with modified dictionaries
                # --- MODIFICATION END ---

                print(f"Successfully loaded and combined {len(data['explanations'])} explanations from: {os.path.basename(file_path)} (Model: {model_name})")
                loaded_files_count += 1
            else:
                print(f"Warning: Unexpected data structure in file, skipping: {file_path}")
                failed_files.append(file_path + " (Bad Structure)")
                
    except FileNotFoundError:
        print(f"Error: File not found (double check): {file_path}")
        failed_files.append(file_path + " (Not Found - Error)")
    except pickle.UnpicklingError:
        print(f"Error: Failed to unpickle file (corrupted?): {file_path}")
        failed_files.append(file_path + " (Unpickling Error)")
    except Exception as e:
        print(f"An unexpected error occurred while loading {file_path}: {e}")
        failed_files.append(file_path + f" (Other Error: {e})")


print(f"\nFinished loading. Combined data from {loaded_files_count} files.")
if failed_files:
    print("Failed to load or process the following files:")
    for f in failed_files:
        print(f"- {f}")

print(f"Total explanations combined: {len(all_explanations)}")

# --- 3 & 4. Parse Raw Prompt and Add New Field ---

# Compile the regex for efficiency.
# Looks for "**Recommended Order for [COUNTRY NAME]:**" followed by a newline,
# then captures the entire next line until the following newline.
# Uses re.IGNORECASE for country name just in case, and re.MULTILINE
# order_pattern = re.compile(r"^\*\*Recommended Order for \w+:\*\*\s*?\n(.*?)$", re.MULTILINE | re.IGNORECASE)
# Simpler pattern assuming order is always on the line immediately following the header:
order_pattern = re.compile(r"\*\*Recommended Order for \w+:\*\*\s*\n([^\n]+)")


print("\nParsing raw prompts to extract recommended orders...")
parsing_errors = 0
for sample in all_explanations:
    extracted_order = None # Default value
    if 'raw_prompt' in sample and isinstance(sample['raw_prompt'], str):
        match = order_pattern.search(sample['raw_prompt'])
        if match:
            # Group 1 contains the captured order line
            extracted_order = match.group(1).strip()
        else:
            # If pattern not found in this specific sample's prompt
            parsing_errors += 1
            # You could log which prompt_key failed here if needed
            # print(f"Warning: Could not find order pattern for prompt_key: {sample.get('prompt_key', 'N/A')}")
    else:
        # If raw_prompt is missing or not a string
        parsing_errors += 1
        # print(f"Warning: Missing or invalid 'raw_prompt' for prompt_key: {sample.get('prompt_key', 'N/A')}")

    # Add the extracted order (or None if failed) to the dictionary
    sample['extracted_recommended_order'] = extracted_order

print(f"Finished parsing. Found order in {len(all_explanations) - parsing_errors} prompts.")
if parsing_errors > 0:
    print(f"Could not extract order from {parsing_errors} prompts (pattern not found or missing/invalid raw_prompt).")


# --- 5. Verification and Analysis ---
print("\n--- Verification ---")

# Check the first few samples
num_samples_to_check = 10
print(f"\nDisplaying the first {min(num_samples_to_check, len(all_explanations))} samples with extracted orders:")
for i, sample in enumerate(all_explanations[:num_samples_to_check]):
    print("-" * 20)
    print(f"Sample Index: {i}")
    print(f"Prompt Key: {sample.get('prompt_key', 'N/A')}")
    print(f"Extracted Recommended Order: '{sample.get('extracted_recommended_order', 'EXTRACTION FAILED')}'")
    # Optionally print a snippet of the raw prompt to manually verify
    # prompt_snippet = sample.get('raw_prompt', '')[:200] # First 200 chars
    # print(f"Raw Prompt Snippet:\n{prompt_snippet}...")

# Check for extraction failures
print("\nChecking for samples where extraction might have failed...")
failed_extractions = [s for s in all_explanations if s.get('extracted_recommended_order') is None]

if not failed_extractions:
    print("No extraction failures found (all samples have an extracted order string or were expected to fail).")
else:
    print(f"Found {len(failed_extractions)} samples where extraction failed (result is None):")
    # Print keys of the first few failures
    for i, sample in enumerate(failed_extractions[:5]):
         print(f" - Failed Sample Index (overall list): {all_explanations.index(sample)}, Prompt Key: {sample.get('prompt_key', 'N/A')}")
    if len(failed_extractions) > 5:
        print(f"   (... and {len(failed_extractions) - 5} more)")

print("\nVerification complete.")

Loading pickle files...
Successfully loaded and combined 998 explanations from: explanation_results_single_sft_deli_multiturn_rogue_cleaned_checkpoint-3000.pkl (Model: sft_deli_multiturn_rogue_cleaned_checkpoint-3000)
Successfully loaded and combined 998 explanations from: explanation_results_single_DELI_dpo_weights_checkpoint-3500.pkl (Model: DELI_dpo_weights_checkpoint-3500)
Successfully loaded and combined 998 explanations from: explanation_results_single_DELI_sft_weights_checkpoint-small-1500.pkl (Model: DELI_sft_weights_checkpoint-small-1500)
Successfully loaded and combined 998 explanations from: explanation_results_single_DELI_ppo_weights_ppo_checkpoint_epoch_1_batch_800.pkl (Model: DELI_ppo_weights_ppo_checkpoint_epoch_1_batch_800)
Successfully loaded and combined 998 explanations from: explanation_results_single_DELI_faaf_weights_checkpoint-2000.pkl (Model: DELI_faaf_weights_checkpoint-2000)

Finished loading. Combined data from 5 files.
Total explanations combined: 4990

Pars

In [34]:
len(all_explanations)
5988

5988

In [48]:
import collections

# Assume 'all_explanations' is your list of 5988 dictionaries.
# IMPORTANT: This assumes each dictionary in all_explanations has a key
# like 'source_model' indicating which model generated it.
# Replace 'source_model' with the actual key name if it's different.
model_key_name = 'source_model'

# Use defaultdict for easy appending to lists
explanations_by_model = collections.defaultdict(list)

missing_model_key_count = 0

# Iterate through the flat list
for explanation_dict in all_explanations:
    # Get the model identifier from the dictionary
    # Use .get() with a default value to handle cases where the key might be missing
    model_name = explanation_dict.get(model_key_name)
    
 # Extract the model type from the filename
    if "sft_deli_multiturn" in model_name:
        model_name = "SFT_multiturn"
    elif "faaf_weights" in model_name:
        model_name = "FAAF"
    elif "dpo_weights" in model_name:
        model_name = "DPO"
    elif "ipo_weights" in model_name:
        model_name = "IPO"
    elif "sft_weights" in model_name:
        model_name = "SFT"
    elif "ppo_weights" in model_name:
        model_name = "PPO"
    else:
        model_name = "Unknown"

    if model_name is not None:
        # Append the entire explanation dictionary to the list for this model
        explanations_by_model[model_name].append(explanation_dict)
    else:
        # Handle dictionaries missing the model key (optional)
        explanations_by_model['Unknown_Or_Missing_Model'].append(explanation_dict)
        missing_model_key_count += 1
    explanation_dict['source_model'] = model_name


# Optional: Convert defaultdict back to a regular dict if needed for compatibility
explanations_by_model = dict(explanations_by_model)

# --- Verification ---
print(f"Successfully restructured {len(all_explanations)} explanations into a dictionary.")
print(f"Number of unique model keys found: {len(explanations_by_model)}")

if missing_model_key_count > 0:
     print(f"Warning: {missing_model_key_count} explanations were missing the key '{model_key_name}' and grouped under 'Unknown_Or_Missing_Model'.")

print("\nCounts per model:")
total_counted = 0
for model, explanations_list in explanations_by_model.items():
    count = len(explanations_list)
    print(f"- {model}: {count} explanations")
    total_counted += count

print(f"\nTotal explanations accounted for in the new dictionary: {total_counted}")

with open(f"{output_folder}/combined_parsed_results.pkl", 'wb') as f:
    pickle.dump(explanations_by_model, f)

# You can now access explanations for a specific model like this:
# e.g., first_model_key = list(explanations_by_model.keys())[0]
# specific_model_explanations = explanations_by_model[first_model_key]
# print(f"\nExample: First explanation for model '{first_model_key}':")
# print(specific_model_explanations[0])

Successfully restructured 4990 explanations into a dictionary.
Number of unique model keys found: 5

Counts per model:
- SFT_multiturn: 998 explanations
- DPO: 998 explanations
- SFT: 998 explanations
- PPO: 998 explanations
- FAAF: 998 explanations

Total explanations accounted for in the new dictionary: 4990


In [50]:
explanations_by_model['PPO'][400]

{'prompt_key': 'Prompt_level2_game_81_order_0',
 'game_idx': 81,
 'order_idx': 0,
 'type': 'single',
 'raw_prompt': '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are an expert assistant specializing in the Diplomacy board game. Your role is to analyze:\n1. The current board state.\n2. A specific recommended order for the player.\n3. The potential orders for every power.\n\n\nIMPORTANT: This is a snapshot of an ongoing Diplomacy game. The board state shows each country\'s current units:\n- Units prefixed with "A" are Armies located on land territories\n- Units prefixed with "F" are Fleets located on sea spaces or coastal territories\n- Each unit belongs ONLY to the country listed before the colon (e.g., all units under "GERMANY:" belong to GERMANY)\n- Supply centers are critical territories that allow powers to build new units\n- Home supply centers are especially important to protect\n\nThe "Recommended Order" is a specific move being suggested for your power (GERMA

In [17]:
results['explanations'][34]

{'prompt_key': 'Prompt_level2_game_7_order_4',
 'game_idx': 7,
 'order_idx': 4,
 'type': 'single',
 'raw_prompt': '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\nYou are an expert assistant specializing in the Diplomacy board game. Your role is to analyze:\n1. The current board state.\n2. A specific recommended order for the player.\n3. The potential orders for every power.\n\n\nIMPORTANT: This is a snapshot of an ongoing Diplomacy game. The board state shows each country\'s current units:\n- Units prefixed with "A" are Armies located on land territories\n- Units prefixed with "F" are Fleets located on sea spaces or coastal territories\n- Each unit belongs ONLY to the country listed before the colon (e.g., all units under "AUSTRIA:" belong to AUSTRIA)\n- Supply centers are critical territories that allow powers to build new units\n- Home supply centers are especially important to protect\n\nThe "Recommended Order" is a specific move being suggested for your power (AUSTRIA

In [1]:
import pandas as pd

In [50]:
group1_noEmbodiment_dpo= pd.read_csv("hannah_friction_cvpr/group1_noEmbodiment_dpo.csv")
group1_withEmbodiment_dpo= pd.read_csv("hannah_friction_cvpr/group1_withEmbodiment_dpo.csv")
group2_noEmbodiment_dpo= pd.read_csv("hannah_friction_cvpr/group2_noEmbodiment_dpo.csv")
group2_withEmbodiment_dpo = pd.read_csv("hannah_friction_cvpr/group2_withEmbodiment_dpo.csv")

In [17]:
fricion1 

Unnamed: 0,Transcriptions,Friction
0,it's right there\um it's turned off\ok well it...,No Friction
1,um it's turned off\ok well it would be nine mi...,"Friction: Wait, aren't we assuming that the we..."
2,ok well it would be nine minus ten\that's ten ...,No Friction
3,that's ten so then\just like put it on\it seem...,"Friction: Hey, let's not jump to conclusions j..."
4,just like put it on\it seems pretty balanced\y...,No Friction
5,it seems pretty balanced\yeah ok so now we kno...,"Friction: Hold on, before we assume the next b..."
6,yeah ok so now we know that this is also ten\w...,"Friction: ""Hey, I think we're moving too fast..."
7,what would we maybe put that one on there too\...,No Friction
8,it's more\lets try this one oh gee louise this...,"Friction: ""Hey, wait a minute! We're making s..."
9,lets try this one oh gee louise this one\ok so...,"Friction: I'm not sure about that, guys. If we..."


In [None]:
# format with special OPT token, format with the RM prompt and then process and tokenizer, call forward get rewards 
# store raw logits, compute average and STD etc 

In [42]:
friction1_parsed.columns

Index(['Transcriptions', 'Friction', 'parsed_friction', 'parsed_rationale'], dtype='object')

In [59]:

def process_data_template_for_reward_model(row, no_rationale = False):
    system_prompt_rm = (
        "Please rate the following friction intervention in light of the **dialogue history** of a *game* provided below. "
        "A friction intervention is a statement that acts as indirect persuasion and prompts participants to "
        "reevaluate their beliefs and assumptions about the task, primarily—but not exclusively—in response to new evidence "
        "that challenges their preconceived notions about the state of the task or the block weights."
    )

    friction_definition_game_definition_prompt_rm = (
        "*Game details and ground-truth*: The game is called 'Game of Weights.' The participants (P1, P2, and P3) are "
        "trying to determine the weight of various blocks. The blocks are of different colors and have specific weights in grams: "
        "the red block is 10 grams, the blue block is 10 grams, the green block is 20 grams, the purple block is 30 grams, and "
        "the yellow block is 50 grams. At the start of the game, participants are only allowed to weigh two blocks at a time, "
        "and they are told the weight of the red block. The participants must figure out how to determine the weight of each block. "
        "At the beginning of the game, they are unaware of the actual weights. Additionally, we inform the participants that they "
        "don’t need to use the scale's slider. The actual reason is that the blocks are in increments of 10 grams. "
        "The **dialogue history** is given below: "
    )

  
    prompt = (system_prompt_rm + row['Transcriptions']).replace('\n', ' ')
    
    text = prompt + " " + f"</s> {row['parsed_friction']} </s>"
    if no_rationale:
        text = prompt + " " + f"</s> {row['parsed_friction']} </s>"
    
    return text

In [52]:
import pandas as pd
import re

# Function to check if a row contains friction
def has_friction(text):
    # Check if the text contains explicit "No Friction" statement
    if re.search(r'No\s+Friction', text, re.IGNORECASE):
        return False
    
    # Check if the text contains "Friction:" marker
    if re.search(r'Friction:', text):
        return True
    
    return False

# Function to extract friction and rationale
def extract_friction_rationale(text):
    # Look for the specific pattern with "r*Rationale" possibly attached to the preceding text
    parts = re.split(r'(r\*Rationale)', text, 1)
    
    # If we found the pattern
    if len(parts) >= 2:
        # First part is everything before r*Rationale (the friction part)
        friction_text = parts[0].replace("Friction:", "", 1).strip()
        
        # Last part is everything after r*Rationale (the rationale part)
        rationale_text = parts[-1].strip()
        
        return friction_text, rationale_text
    else:
        # If the pattern isn't found, check for just the friction part
        friction_match = re.search(r'Friction:(.*?)$', text, re.DOTALL)
        if friction_match:
            return friction_match.group(1).strip(), None
        else:
            return None, None

# Process dataframe - filter and extract columns
def process_dataframe(df):
    # Make a copy to avoid modifying the original
    df_processed = df.copy()
    
    # Filter rows to keep only those with friction
    df_processed = df_processed[df_processed['Friction'].apply(has_friction)]
    
    # Apply extraction to create new columns
    friction_rationale_tuples = df_processed['Friction'].apply(extract_friction_rationale)
    
    # Split the tuples into separate columns
    df_processed['parsed_friction'] = [t[0] for t in friction_rationale_tuples]
    df_processed['parsed_rationale'] = [t[1] for t in friction_rationale_tuples]
    
    return df_processed
 

group1_noEmbodiment_dpo = process_dataframe(group1_noEmbodiment_dpo)
group1_withEmbodiment_dpo = process_dataframe(group1_withEmbodiment_dpo)
group2_noEmbodiment_dpo = process_dataframe(group2_noEmbodiment_dpo)
group2_withEmbodiment_dpo = process_dataframe(group2_withEmbodiment_dpo)
 

In [53]:
group1_noEmbodiment_dpo.shape, group1_withEmbodiment_dpo.shape, group2_noEmbodiment_dpo.shape,group2_withEmbodiment_dpo.shape 

((30, 4), (30, 4), (28, 4), (30, 4))

In [39]:
import gc
from tqdm import tqdm
from collections import defaultdict
import json
import time
import os
from datetime import datetime
import os
import pandas as pd
from collections import defaultdict
from collections.abc import Mapping
import wandb
# wandb.init(project="friction_agent_inference", name="log_friction_interventions") 
from datasets import Dataset,load_dataset, DatasetDict
from datasets import load_from_disk
import re
import matplotlib.pyplot as plt
import torch
import random
import numpy as np
from tqdm import tqdm
# from datasets import load_metric
from rouge_score import rouge_scorer
from sentence_transformers import SentenceTransformer, util
import sys
import pickle
from dataclasses import dataclass, field
from typing import Optional, List
import pickle
import pandas as pd
from datasets import Dataset, DatasetDict
from itertools import combinations
import torch
from accelerate import Accelerator
from datasets import load_dataset, load_from_disk
from peft import AutoPeftModelForCausalLM, LoraConfig,PeftModel
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    set_seed,
)
import random
import os
import json
import torch
from datetime import datetime
from tqdm import tqdm
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM
import matplotlib.pyplot as plt
import numpy as np
import os
from datetime import datetime
import pandas as pd
import seaborn as sns
from nltk.tokenize import word_tokenize
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from sentence_transformers import SentenceTransformer, util
import torch
from transformers import pipeline


reward_model_name = "friction_rm_alldata_results"
device = 'cuda:1'
torch_dtype = torch.bfloat16

reward_model = AutoModelForSequenceClassification.from_pretrained(
    reward_model_name,
    num_labels=1,  # Set to 1 for reward model classification
    trust_remote_code=True,
     device_map="auto",
    torch_dtype=torch.float16,

    )

rm_tokenizer = AutoTokenizer.from_pretrained(
reward_model_name, 
trust_remote_code=True, 
use_fast=True
)
all_results = {}

[2025-04-16 23:33:08,553] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)


2025-04-16 23:33:13.189536: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-04-16 23:33:13.189570: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-04-16 23:33:13.190688: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [41]:
friction1_parsed.columns

Index(['Transcriptions', 'Friction', 'parsed_friction', 'parsed_rationale'], dtype='object')

In [61]:
import numpy as np
from tqdm import tqdm
import pandas as pd

# Function to manually calculate standard error of mean
def sem(data):
    """Calculate standard error of the mean"""
    return np.std(data, ddof=1) / np.sqrt(len(data))

# Function to compute rewards for a dataframe
def compute_rewards(df, df_name):
    # Lists to store individual sample rewards
    raw_rewards = []
    
    for index, row in tqdm(df.iterrows(), total=len(df), desc=f"Processing {df_name}"):
        # Get the friction and rationale if present
        rationale = row.get('parsed_rationale')
        
        if rationale:
            processed_inputs = process_data_template_for_reward_model(row)
        else:
            processed_inputs = process_data_template_for_reward_model(row, no_rationale=True)
        
        # Tokenize the inputs
        tokenized_inputs = rm_tokenizer(processed_inputs, return_tensors="pt", padding=True, truncation=True, max_length=1024)
        
        # Run the model on the inputs
        with torch.no_grad():
            reward_outputs = reward_model(**tokenized_inputs)
        
        # Extract score
        score = reward_outputs.logits.squeeze(-1).cpu().numpy().item()
        raw_rewards.append(score)
    
    # Convert to numpy array for calculations
    raw_rewards = np.array(raw_rewards)
    
 
    # 2. Min-max normalization to [0, 1]
    min_max_rewards = (raw_rewards - np.min(raw_rewards)) / (np.max(raw_rewards) - np.min(raw_rewards))
 
    # Calculate statistics for each method
    stats_dict = {
        'raw': {
            'values': raw_rewards,
            'mean': np.mean(raw_rewards),
            'sem': sem(raw_rewards)
        },
        
        'min_max': {
            'values': min_max_rewards,
            'mean': np.mean(min_max_rewards),
            'sem': sem(min_max_rewards)
        },
        
        
    }
    
    # Create a DataFrame with individual sample rewards
    sample_rewards_df = pd.DataFrame({
        'sample_id': df.index,
        'raw_reward': raw_rewards,
         
        'min_max_reward': min_max_rewards,
   
        
    })
    
    return stats_dict, sample_rewards_df

# Define the dataframes and their names
dataframes = {
    'group1_noEmbodiment': group1_noEmbodiment_dpo,
    'group1_withEmbodiment': group1_withEmbodiment_dpo,
    'group2_noEmbodiment': group2_noEmbodiment_dpo,
    'group2_withEmbodiment': group2_withEmbodiment_dpo
}

# Print dataframe shapes
print("Dataframe shapes:")
for name, df in dataframes.items():
    print(f"{name}: {df.shape}")

# Create dictionaries to store results
stats_results = {}
sample_dfs = {}

# Compute rewards for all dataframes
for name, df in dataframes.items():
    print(f"\nProcessing {name}...")
    stats_results[name], sample_dfs[name] = compute_rewards(df, name)
    
    # Print results for this dataframe
    print(f"\nResults for {name}:")
    for norm_method, stats in stats_results[name].items():
        print(f"Average {norm_method} reward: {stats['mean']:.4f} ± {stats['sem']:.4f}")
    
    # Save individual sample rewards
    sample_dfs[name].to_csv(f"{name}_sample_rewards.csv", index=False)

# Create summary dataframe
summary_data = []
for name in dataframes.keys():
    row_data = {'Dataset': name}
    for norm_method, stats in stats_results[name].items():
        row_data[f'Avg_{norm_method.capitalize()}_Reward'] = stats['mean']
        row_data[f'SEM_{norm_method.capitalize()}_Reward'] = stats['sem']
    summary_data.append(row_data)

summary_df = pd.DataFrame(summary_data)

print("\nSummary of results:")
print(summary_df)

# Save summary to CSV
summary_df.to_csv("cvpr_reward_summary.csv", index=False)

Dataframe shapes:
group1_noEmbodiment: (30, 4)
group1_withEmbodiment: (30, 4)
group2_noEmbodiment: (28, 4)
group2_withEmbodiment: (30, 4)

Processing group1_noEmbodiment...


Processing group1_noEmbodiment: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 27.89it/s]



Results for group1_noEmbodiment:
Average raw reward: 6.9441 ± 0.3226
Average min_max reward: 0.5574 ± 0.0383

Processing group1_withEmbodiment...


Processing group1_withEmbodiment: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 28.42it/s]



Results for group1_withEmbodiment:
Average raw reward: 7.7103 ± 0.2961
Average min_max reward: 0.5930 ± 0.0525

Processing group2_noEmbodiment...


Processing group2_noEmbodiment: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:00<00:00, 29.35it/s]



Results for group2_noEmbodiment:
Average raw reward: 6.5367 ± 0.3030
Average min_max reward: 0.4850 ± 0.0608

Processing group2_withEmbodiment...


Processing group2_withEmbodiment: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 29.31it/s]


Results for group2_withEmbodiment:
Average raw reward: 7.0983 ± 0.2574
Average min_max reward: 0.5142 ± 0.0465

Summary of results:
                 Dataset  Avg_Raw_Reward  SEM_Raw_Reward  Avg_Min_max_Reward   
0    group1_noEmbodiment        6.944141        0.322640            0.557375  \
1  group1_withEmbodiment        7.710286        0.296092            0.592964   
2    group2_noEmbodiment        6.536691        0.303049            0.485014   
3  group2_withEmbodiment        7.098307        0.257365            0.514222   

   SEM_Min_max_Reward  
0            0.038310  
1            0.052456  
2            0.060847  
3            0.046464  





In [34]:
# Test with the provided example
example = "Friction: Wait, aren't we assuming that the weight of the red block is 9? But we never actually weighed it against anything to confirm that. Could we be miscounting? Maybe we should recheck our calculations and make sure we're not jumping to conclusions?r*RationaleP1's assumption about the weight of the red block being 9 is incorrect, and it's essential to challenge this belief to ensure accurate weight determination. P3's assumption about the weight of the blue block being 20 is also uncertain and needs to be questioned. By introducing a friction statement, we can encourage participants to reevaluate their beliefs and assumptions about the task."

friction, rationale = extract_friction_rationale(example)
print("Parsed Friction:", friction)
print("Parsed Rationale:", rationale)

Parsed Friction: Wait, aren't we assuming that the weight of the red block is 9? But we never actually weighed it against anything to confirm that. Could we be miscounting? Maybe we should recheck our calculations and make sure we're not jumping to conclusions?
Parsed Rationale: P1's assumption about the weight of the red block being 9 is incorrect, and it's essential to challenge this belief to ensure accurate weight determination. P3's assumption about the weight of the blue block being 20 is also uncertain and needs to be questioned. By introducing a friction statement, we can encourage participants to reevaluate their beliefs and assumptions about the task.


In [27]:
friction1_parsed['Friction'][3]

"Friction: Hey, let's not jump to conclusions just yet. What if we're not looking at a simple addition pattern here? What if the weights are more complex? Maybe we should try weighing a different combination to test our assumption.r*RationaleA friction statement is needed because P1 and P2 are making incorrect assumptions about the weights being evenly distributed and increasing in a simple pattern. They are not considering other possible weight combinations or patterns, and their conversation is leading them to make rash decisions about the weights."