In [1]:
# ---------------- Imports ----------------
import os
import json
import sys

import yaml

from transformers import AutoTokenizer



In [2]:
# ---------------- Args ----------------
dataset = "yield-v1-small10pct-factualnovelty-with-rtg"
tokenizer_model = "meta-llama/Llama-3.1-8B-Instruct"
TOKEN_LIMIT = 512  # max tokens allowed per dialogue
WINDOW_SIZE = 6 # Sliding window


In [3]:
# ---------------- Config ----------------
with open("../../config/config.yaml", "r") as f:
    config = yaml.safe_load(f)

proj_store = config["paths"]["proj_store"]
data_path = os.path.join(proj_store, "data")
models_folderpath = config["paths"]["models"]

base_folder = os.path.join(data_path, dataset)
output_folder = f"{base_folder.replace('-with-rtg', '')}-rl"
os.makedirs(output_folder, exist_ok=True)

# Initialize tokenizer globally
MODEL_NAME = os.path.join(models_folderpath, tokenizer_model) 
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, local_files_only=True)



In [4]:
def process_dialogue(dialogue, file_path):
    turns = dialogue.get("turns", [])
    if not turns:
        raise ValueError(f"Dialogue {dialogue.get('dialogue_id')} in {file_path} has no turns.")

    try:
        domain = dialogue["domain"]
    except KeyError:
        raise ValueError(f"Missing required field 'domain' in dialogue {dialogue.get('dialogue_id')} in {file_path}")

    system_prompt = f"Act as an information elicitation agent for {domain.replace('_', ' ')}."

    # skip leading elicitor-only turns
    index = 0
    while index < len(turns) and turns[index]["role"] == "elicitor":
        index += 1

    if index >= len(turns):
        raise ValueError(f"Dialogue {dialogue.get('dialogue_id')} in {file_path} has no respondent turn to start with.")

    results = []

    # Build merged blocks + track which turn indices they came from
    merged_blocks = []
    #block_turn_indices = []
    block_start_indices = []   
    block_end_indices = []
    current_role = None
    utterances = []

    for i in range(index, len(turns)):
        turn = turns[i]
        role = turn["role"]
        utterance = turn["utterance"].strip()

        if current_role is None:
            current_role = role
            utterances = [utterance]
            start_index = i
        elif role == current_role:
            utterances.append(utterance)
        else:
            merged_content = "\n\n".join(utterances)
            merged_blocks.append((current_role, merged_content))
            block_start_indices.append(start_index)
            block_end_indices.append(i - 1)

            current_role = role
            utterances = [utterance]
            start_index = i

    if utterances:
        merged_content = "\n\n".join(utterances)
        merged_blocks.append((current_role, merged_content))
        block_start_indices.append(start_index)
        block_end_indices.append(len(turns) - 1)

    # SLIDING WINDOW 
    
    for window_start in range(len(merged_blocks) - WINDOW_SIZE + 1):
        window = merged_blocks[window_start : window_start + WINDOW_SIZE]

        if window[-1][0] == "elicitor":
            ensuing_index_in_blocks = window_start + WINDOW_SIZE
            ensuing_score = None
            ensuing_return = None

            if ensuing_index_in_blocks < len(merged_blocks):
                ensuing_block_role = merged_blocks[ensuing_index_in_blocks][0]
                ensuing_block_last_turn = block_end_indices[ensuing_index_in_blocks]
                if ensuing_block_role == "respondent":
                    ensuing_turn = turns[ensuing_block_last_turn]
                    ensuing_score = ensuing_turn.get("factual_novelty_score")
                    ensuing_return = ensuing_turn.get("return_to_go")
                else:
                    continue
            else:
                continue

            # Build message history
            messages = [{"role": "system", "content": system_prompt}]
            for role, content in window:
                if role == "elicitor":
                    messages.append({"role": "assistant", "content": content})
                elif role == "respondent":
                    messages.append({"role": "user", "content": content})
                else:
                    raise ValueError(f"Unexpected role: {role}")

            # Format conversation and count tokens
            chat_text = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=False
            )
            tokens = tokenizer(chat_text, return_tensors=None, add_special_tokens=False)["input_ids"]
            token_count = len(tokens)

            if token_count <= TOKEN_LIMIT:
                first_block_turn_index = block_start_indices[window_start]
                
                results.append({
                    "block_id": f"{dialogue.get('dialogue_id')}:{first_block_turn_index}",
                    "domain": domain,
                    "factual_novelty_score": ensuing_score,
                    "return_to_go": ensuing_return,
                    "messages": messages,
                })

    return results




def process_file(input_file):
    with open(input_file, 'r', encoding='utf-8') as f:
        dialogues = json.load(f)

    if not isinstance(dialogues, list):
        raise ValueError(f"Expected list of dialogues in file {input_file}, found {type(dialogues)}")

    processed_dialogues = []

    for _, dialogue in enumerate(dialogues):
        
        
        try:
            processed_windows = process_dialogue(dialogue, input_file)
            for window in processed_windows:
                processed_dialogues.append(json.dumps(window, ensure_ascii=False))
        except Exception as e:
            raise RuntimeError(f"Error in {input_file}: {e}")

    return processed_dialogues



def process_split(input_folder, output_folder):
    json_files = [f for f in os.listdir(input_folder) if f.endswith('.json')]

    for json_file in json_files:
        file_path = os.path.join(input_folder, json_file)
        processed = process_file(file_path)

        output_file_path = os.path.join(output_folder, json_file.replace('.json', '.jsonl'))

        with open(output_file_path, 'w', encoding='utf-8') as f_out:
            for line in processed:
                f_out.write(line + '\n')

        print(f"Processed {len(processed)} dialogues from {json_file} into {output_file_path}")




In [5]:
# Procedure
splits = ["train", "dev", "test"]
os.makedirs(output_folder, exist_ok=True)

for split in splits:
    input_split_folder = os.path.join(base_folder, split)
    output_split_folder = os.path.join(output_folder, split)
    os.makedirs(output_split_folder, exist_ok=True)

    process_split(input_split_folder, output_split_folder)


Processed 2528 dialogues from train-001.json into /data/sequential-ieas/data/yield-v1-small10pct-factualnovelty-rl/train/train-001.jsonl
Processed 8702 dialogues from train-000.json into /data/sequential-ieas/data/yield-v1-small10pct-factualnovelty-rl/train/train-000.jsonl
Processed 2040 dialogues from dev-000.json into /data/sequential-ieas/data/yield-v1-small10pct-factualnovelty-rl/dev/dev-000.jsonl
Processed 1044 dialogues from test-000.json into /data/sequential-ieas/data/yield-v1-small10pct-factualnovelty-rl/test/test-000.jsonl
