In [None]:
import numpy as np
from datasets import load_dataset
import pandas as pd
from collections.abc import Callable, Iterable
import datetime
import uuid

from tqdm import tqdm


In [None]:
wildchat_dataset = load_dataset("allenai/WildChat")["train"]
print("Wildchat schema:")
for k, v in wildchat_dataset[0].items():
    print(f"{k}: {type(v)}")

# filter only english rows
wildchat_dataset = wildchat_dataset.filter(lambda row: row["language"] == "English")

In [None]:
np.random.seed(42)

# Split into 1000 mock logs, and use the rest for private logs
MAX_PRIVATE_LOGS = None
MAX_MOCK_LOGS = 1000

all_indices = np.arange(len(wildchat_dataset))
np.random.shuffle(all_indices)

assert MAX_MOCK_LOGS <= len(
    all_indices
), f"Dataset is too small for {MAX_MOCK_LOGS} mock logs"

mock_logs = wildchat_dataset.select(all_indices[:MAX_MOCK_LOGS])
private_logs = wildchat_dataset.select(all_indices[MAX_MOCK_LOGS:])

if MAX_PRIVATE_LOGS is not None:
    private_logs = private_logs.select(range(MAX_PRIVATE_LOGS))

print(f"Mock logs: {len(mock_logs)}")
print(f"Private logs: {len(private_logs):,}")

In [None]:
def clean_json(obj):
    if isinstance(obj, str):
        return obj.replace("\x00", "")
    if isinstance(obj, uuid.UUID):
        return str(obj)
    if isinstance(obj, (datetime.date, datetime.datetime)):
        return obj.isoformat()
    if isinstance(obj, list):
        return [clean_json(v) for v in obj]
    if isinstance(obj, dict):
        return {k: clean_json(v) for k, v in obj.items()}
    return obj


def process_wildchat_log(log_data: dict, max_content_length: int = 4096) -> list[dict]:
    """
    Convert a row of wildchat data into a list of messages with metadata

    Expected output format:
    [
        {
            "id": "1234" # A message ID, should be unique across all messages
            "text": "Hello, how are you?", # The log content
            **metadata, # Any metadata you want to store with the message
        },
        ...
    ]

    Args:
        log_data (dict): A single row of wildchat data
        max_content_length (int, optional): Maximum length of the "text" field, depends on the embedding model.
            Defaults to 4096 (~= 1000 tokens).

    Returns:
        list[dict]: A list of messages with metadata
    """

    # Common metadata shared by all messages in the log
    log_id = log_data["conversation_id"]
    base_metadata = {
        "log_id": log_id,
        "model": log_data["model"],
        "timestamp": log_data["timestamp"],
        "total_turns": log_data["turn"],
        "language": log_data["language"],
    }

    result = []
    for message_idx, message in enumerate(log_data["conversation"]):
        content = message["content"]
        if len(content) > max_content_length:
            content = content[:max_content_length]

        message_metadata = {
            **base_metadata,
            "role": message["role"],
            "language": message.get("language", base_metadata["language"]),
            "message_idx": message_idx,
        }

        # deterministic + unique ID for each message
        doc_id = uuid.uuid5(uuid.NAMESPACE_URL, f"{log_id}_{message_idx}")

        result.append(
            {
                "id": str(doc_id),
                "text": content,
                **message_metadata,
            }
        )

    # Clean result so it can be inserted into Postgres
    result = clean_json(result)
    return result


def preprocess_dataset(
    dataset: Iterable[dict],
    size: int | None = None,
) -> list[dict]:
    results = []

    for log_data in tqdm(dataset, desc="Processing logs", total=size):
        processed = process_wildchat_log(log_data)
        results.extend(processed)

    return results

In [None]:
private_logs_processed = preprocess_dataset(
    private_logs, size=len(private_logs)
)
mock_logs_processed = preprocess_dataset(
    mock_logs, size=len(mock_logs)
)

print(f"Processed private logs: {len(private_logs_processed):,} messages")
print(f"Processed mock logs: {len(mock_logs_processed):,} messages")

In [None]:
private_df = pd.json_normalize(private_logs_processed)
mock_df = pd.json_normalize(mock_logs_processed)

In [None]:
private_data_path = "private_logs.parquet"
mock_data_path = "mock_logs.parquet"

private_df.to_parquet(private_data_path, index=False)
mock_df.to_parquet(mock_data_path, index=False)

In [None]:
# check filesizes

!ls -lh {private_data_path} {mock_data_path}