# Depolarizing Chatroom pre-analysis processing


## Setup


In [None]:
import pandas as pd
import numpy as np
from pathlib import Path


In [None]:
data_dir = Path("data") / "complete-launch"
output_dir = Path("data") / "complete-launch-processed"


## Load CSV data


### Load Qualtrics post-survey data


In [None]:
# Users who made it into a chat
# TODO: Clarify the exact conditions of this
post_survey_chat_df = pd.read_csv(data_dir / "post_survey_chat.csv")[2:].dropna(
    subset=["ResponseId"]
)
# Users who did not make it into a chat
# TODO: Also clarify here
post_survey_no_chat_df = pd.read_csv(data_dir / "post_survey_no_chat.csv")[2:].dropna(
    subset=["ResponseId"]
)


### Load exported chatroom data


In [None]:
users_df = pd.read_csv(data_dir / "users.csv")
chatrooms_df = pd.read_csv(data_dir / "chatrooms.csv")
messages_df = pd.read_csv(data_dir / "messages.csv")
rephrasings_df = pd.read_csv(data_dir / "rephrasings.csv")


## Sort messages by send time, ascending


In [None]:
messages_df = messages_df.sort_values(by="send_time", ascending=True)


## Add convenience columns


In [None]:
# Add limit_reached column to users_df from chatrooms_df.
# This is just pulled from the chatroom the user was in, so it will be the same for each
# pair of users in a given chatroom.
users_df = users_df.merge(
    chatrooms_df[["id", "limit_reached"]].rename(columns={"id": "chatroom_id"}),
    on="chatroom_id",
    how="left",
)

# Set of user IDs in no-chat post-survey
no_chat_ids = set(post_survey_no_chat_df["RESPONDENT_ID"].tolist())

# Set of user IDs in chat post-survey
chat_ids = set(post_survey_chat_df["RESPONDENT_ID"].tolist())

# Add column to users_df to indicate which post-survey they took—
# a string: "chat" or "no_chat"
users_df["post_survey"] = users_df["response_id"].apply(
    lambda user_id: "chat"
    if user_id in chat_ids
    else ("no_chat" if user_id in no_chat_ids else float("nan"))
)

# Concatenate both post survey dataframes
post_survey_df = pd.concat([post_survey_chat_df, post_survey_no_chat_df])

# Add which technical difficulties, if any, the user reported, to users_df.
# NOTE that this is distinct from leave_reason, which the user provides if they attempt
# to leave the chat early. Both of these fields are useful in determining what technical
# difficulties users might have experienced.
users_df = users_df.merge(
    post_survey_df[["RESPONDENT_ID", "difficulties_what"]].rename(
        columns={"RESPONDENT_ID": "response_id"}
    ),
    on="response_id",
    how="left",
)

# Rename "limit_reached" to "old_limit_reached"
users_df.rename(columns={"limit_reached": "old_limit_reached"}, inplace=True)

# For each message, add the position of the sender
messages_df = messages_df.merge(
    users_df[["id", "position"]].rename(columns={"id": "sender_id"}),
    on="sender_id",
    how="left",
)

users_grouped = dict(list(users_df.groupby("chatroom_id")))

# For each chatroom, determine if either user was treated. The chatroom doesn't store
# references to users, so we'll have to figure out which users are in each chatroom
# using the chatroom table.
chatrooms_df["treated"] = chatrooms_df["id"].apply(
    lambda chatroom_id: any(users_grouped[chatroom_id]["treatment"] == "TREATED")
)


## Add rephrasing counts to chatrooms where applicable

For each chatroom:

| `rephrasing_original_count`                                                   | `rephrasing_validate_count`                | `rephrasing_restate_count`                | `rephrasing_polite_count`                | `rephrasing_accepted_count`     | `rephrasing_total_count`                          |
| ----------------------------------------------------------------------------- | ------------------------------------------ | ----------------------------------------- | ---------------------------------------- | ------------------------------- | ------------------------------------------------- |
| rephrasing opportunities the user ignored,<br/>sending their original message | rephrasings sent using 'validate' strategy | rephrasings sent using 'restate' strategy | rephrasings sent using 'polite' strategy | sum of all accepted rephrasings | sum of all offered rephrasings, including ignored |


In [None]:
chatroom_rephrasings = {}

# Iterate over each chatroom row in chatrooms_df
for _, row in chatrooms_df.iterrows():
    chatroom_rephrasings[row["id"]] = {
        "original": 0,
        "validate": 0,
        "restate": 0,
        "polite": 0,
    }
    chatroom_id = row["id"]
    # Get all messages for this chatroom
    chatroom_messages = messages_df[
        (messages_df["chatroom_id"] == chatroom_id) & (messages_df["body"].notna())
    ]
    for _, message_row in chatroom_messages.iterrows():
        sender_id = message_row["sender_id"]
        # Get all rephrasings for this message
        message_rephrasings = rephrasings_df[
            rephrasings_df["message_id"] == message_row["id"]
        ]
        # If messages_rephrasings isn't empty, then the sender was offered rephrasings
        offered_rephrasing = not message_rephrasings.empty
        # If accepted_rephrasing_id is not null, then the sender accepted a rephrasing
        # I'm only creating a variable here for sake of explicitness
        accepted_rephrasing_id = message_row["accepted_rephrasing_id"]
        if not pd.isna(accepted_rephrasing_id):
            rephrasing = rephrasings_df[rephrasings_df["id"] == accepted_rephrasing_id]
            rephrasing_strategy = rephrasing["strategy"].iloc[0]
            chatroom_rephrasings[chatroom_id][rephrasing_strategy] += 1
        elif offered_rephrasing:
            chatroom_rephrasings[chatroom_id]["original"] += 1
    # Calculate total based on "validate", "restate", and "polite"

# Add columns to chatrooms_df for each rephrasing strategy
for strategy in ["original", "validate", "restate", "polite"]:
    chatrooms_df[f"rephrasing_{strategy}_count"] = chatrooms_df["id"].apply(
        lambda chatroom_id: chatroom_rephrasings[chatroom_id][strategy]
    )

# Calculate rephrasings sums
chatrooms_df["rephrasing_accepted_count"] = (
    chatrooms_df["rephrasing_validate_count"]
    + chatrooms_df["rephrasing_restate_count"]
    + chatrooms_df["rephrasing_polite_count"]
)
chatrooms_df["rephrasing_total_count"] = (
    chatrooms_df["rephrasing_original_count"]
    + chatrooms_df["rephrasing_accepted_count"]
)


## Add conversation counts to chatrooms

For each chatroom:

| `message_count`                                   | `turn_count`                                                                                                    |
| ------------------------------------------------- | --------------------------------------------------------------------------------------------------------------- |
| total number of messages sent during conversation | calculated using same method as production chatroom;<br/>collapses multiple messages and ignores short messages |


In [None]:
from typing import List, Dict, Any, Tuple, Optional, TypedDict

# NOTE: This is copied from the chatroom code.
# YOU SHOULD CHECK IF THIS VALUE IS STILL CORRECT IF YOU ARE UNFAMILIAR WITH THIS CODE
# Reference:
# https://github.com/BYU-PCCL/depolarizing-chatroom/blob/5c841c31443e7c7f35952772f4757b02d4321749/depolarizing_chatroom/constants.py#L3
MIN_COUNTED_MESSAGE_WORD_COUNT = 3

# NOTE: This is copied from the chatroom code.
# YOU SHOULD CHECK IF THIS VALUE IS STILL CORRECT IF YOU ARE UNFAMILIAR WITH THIS CODE
# Reference:
# https://github.com/BYU-PCCL/depolarizing-chatroom/blob/ee5f08c40ad5b4994bc507b45f421338ef20ac2a/depolarizing_chatroom/constants.py#LL4-L4C25
MIN_REPHRASING_TURNS = 2

# ^^^
REPHRASE_EVERY_N_TURNS = 2


class Message(TypedDict):
    body: str
    position: str
    rephrased: bool


# This is also copied from the chatroom code, so I try to leave it as untouched as possible. This is why we go through
# the extra effort of converting the message rows to the list of dictionaries this function accepts.
def calculate_turns(
    messages: List[Message], current_position: str
) -> Tuple[int, int, int, bool, List[List[Message]]]:
    turns = []
    # Note that this number is not the length of the returned list because
    counted_turn_count = 0
    user_turn_count = 0
    partner_turn_count = 0
    last_message_position = None
    turn_has_counted_message = False
    # TODO: Justify this to a reader of this code
    turn_switch = False
    for message in messages:
        # Only count a turn if the message is at least 4 words long AND it was sent by a
        # different user than the message before it.
        message_position = message["position"]
        if message_position != last_message_position:
            turns.append([])
            turn_has_counted_message = False
            last_message_position = message_position
        if not turn_has_counted_message and (
            len(message["body"].split()) >= MIN_COUNTED_MESSAGE_WORD_COUNT
            # Also count a message as part of a turn if it was the result of a
            # rephrasing, regardless of whether it was the original message or a
            # rephrasing, because it means that the user recieved a rephrasing
            or message.get("rephrased", False)
        ):
            counted_turn_count += 1
            turn_has_counted_message = True
            if message_position == current_position:
                user_turn_count += 1
            else:
                partner_turn_count += 1
        turns[-1].append(message)
    return (
        counted_turn_count,
        user_turn_count,
        partner_turn_count,
        turn_has_counted_message,
        turns,
    )


def rows_to_messages(rows: List[Dict[str, Any]]) -> List[Message]:
    # NOTE that we have to do extra work to resolve the displayed body of the message
    # here because row["body"] is not necessarily the displayed body of the message.
    def resolve_body(row: Dict[str, Any]) -> str:
        rephrasing = rephrasings_df[
            rephrasings_df["id"] == row["accepted_rephrasing_id"]
        ]
        if not rephrasing.empty:
            assert len(rephrasing) == 1
            rephrasing = rephrasing.iloc[0]
            return (
                rephrasing["edited_body"]
                if pd.notna(rephrasing["edited_body"])
                else rephrasing["body"]
            )

        return row["edited_body"] if pd.notna(row["edited_body"]) else row["body"]

    def resolve_body_assert_notna(row: Dict[str, Any]) -> str:
        body = resolve_body(row)
        try:
            assert pd.notna(body)
        except AssertionError:
            print(row)
            raise
        return body

    return [
        {
            "body": resolve_body_assert_notna(row),
            "position": row["position"],
            # This was written by copilot and is accurate but not very well written:
            # Note that this is not the same as "accepted_rephrasing_id" because
            # "accepted_rephrasing_id" is null if the user did not accept a
            # rephrasing, but "rephrased" is true if the user was offered a
            # rephrasing, even if they did not accept it.
            "rephrased": not rephrasings_df[
                rephrasings_df["message_id"] == row["id"]
            ].empty,
        }
        for _, row in rows.iterrows()
    ]


In [None]:
def calculate_rephrasing_count_from_turn_count(
    turn_count, min_turns, rephrase_every_n_turns
):
    # Contributed by Chris
    return 1 + ((turn_count - min_turns) // rephrase_every_n_turns)


In [None]:
def calculate_turns_improved(
    messages: List[Message], current_position: str, treatment: str
) -> Tuple[int, int, int, bool]:
    counted_turn_count = 0
    user_turn_count = 0
    partner_turn_count = 0
    last_message_position = None
    limit_reached = False
    for message in messages:
        # Only count a turn if the message is at least 4 words long AND it was sent by a
        # different user than the message before it.
        message_position = message["position"]
        if (message_position != last_message_position) and (
            len(message["body"].split()) >= MIN_COUNTED_MESSAGE_WORD_COUNT
            # Also count a message as part of a turn if it was the result of a
            # rephrasing, regardless of whether it was the original message or a
            # rephrasing, because it means that the user recieved a rephrasing
            or message.get("rephrased", False)
        ):
            last_message_position = message_position
            counted_turn_count += 1
            if message_position == current_position:
                user_turn_count += 1
            else:
                partner_turn_count += 1
        if treatment == "TREATED":
            if user_turn_count / REPHRASE_EVERY_N_TURNS >= 4:
                if message_position != current_position:
                    if len(message["body"].split()) >= MIN_COUNTED_MESSAGE_WORD_COUNT:
                        limit_reached = True
        if treatment == "UNTREATED":
            if partner_turn_count / REPHRASE_EVERY_N_TURNS >= 4:
                if message_position == current_position:
                    if len(message["body"].split()) >= MIN_COUNTED_MESSAGE_WORD_COUNT:
                        limit_reached = True
        if treatment == "CONTROL":
            if max(user_turn_count, partner_turn_count) / REPHRASE_EVERY_N_TURNS >= 4:
                limit_reached = True
    return (counted_turn_count, user_turn_count, partner_turn_count, limit_reached)


In [None]:
import sys

chatroom_turn_counts = {}
corrected_chatroom_turn_counts = {}
chatroom_message_counts = {}
user_turn_counts = {}
corrected_user_turn_counts = {}
user_message_counts = {}
user_conversation_rephrasing_lengths = {}
corrected_user_conversation_rephrasing_lengths = {}
corrected_limit_reached = {}

# Iterate over chatroom rows
for _, row in chatrooms_df.iterrows():
    chatroom_id = row["id"]
    # Get all messages for this chatroom
    chatroom_messages = messages_df[messages_df["chatroom_id"] == chatroom_id]
    # Ignore all chatrooms where any message is empty, unedited, and not rephrased
    if not chatroom_messages[(chatroom_messages["body"].isna())].empty:
        # Get number of empty messages
        empty_message_count = len(chatroom_messages[(chatroom_messages["body"].isna())])
        print(
            f"Chatroom {chatroom_id} contains {empty_message_count} empty message(s). "
            f"It had {len(chatroom_messages)} messages total. "
            "Setting 'error' field to True.",
            file=sys.stderr,
        )
        chatrooms_df.loc[chatrooms_df["id"] == chatroom_id, "error"] = True
    # Get all users for this chatroom
    chatroom_users = users_df[users_df["chatroom_id"] == chatroom_id]
    # Convert chatroom messages to format used by turn counting code
    chatroom_messages = rows_to_messages(
        chatroom_messages[chatroom_messages["body"].notna()]
    )
    # Flip first two messages if message count is greater than 2 and chatroom
    # swap_view_messages field is true
    if row["swap_view_messages"] and len(chatroom_messages) >= 2:
        chatroom_messages[0], chatroom_messages[1] = (
            chatroom_messages[1],
            chatroom_messages[0],
        )
    supporter = chatroom_users[chatroom_users["position"] == "SUPPORT"].iloc[0]
    opponent = chatroom_users[chatroom_users["position"] == "OPPOSE"].iloc[0]
    # Get supporter ID
    supporter_id = supporter["id"]
    # Get opponent ID
    opponent_id = opponent["id"]
    # Calculate supporter's turns
    (
        total_turn_count,
        supporter_turn_count,
        opponent_turn_count,
        # The chatroom turn counting code returns fields that are useful in a live
        # conversation, but we don't need them here. So we just ignore them.
        *_,
    ) = calculate_turns(chatroom_messages, "SUPPORT")
    (
        corrected_total_turn_count,
        corrected_supporter_turn_count,
        corrected_opponent_turn_count,
        # The chatroom turn counting code returns fields that are useful in a live
        # conversation, but we don't need them here. So we just ignore them.
        limit_reached,
    ) = calculate_turns_improved(chatroom_messages, "SUPPORT", supporter["treatment"])
    # Save turn count for chatroom
    chatroom_turn_counts[chatroom_id] = total_turn_count
    corrected_chatroom_turn_counts[chatroom_id] = corrected_total_turn_count
    # Save message count for chatroom
    chatroom_message_counts[chatroom_id] = len(chatroom_messages)
    # Save turn count for supporter
    user_turn_counts[supporter_id] = supporter_turn_count
    # Save turn count for opponent
    user_turn_counts[opponent_id] = opponent_turn_count
    # Save corrected turn count for supporter
    corrected_user_turn_counts[supporter_id] = corrected_supporter_turn_count
    # Save corrected turn count for opponent
    corrected_user_turn_counts[opponent_id] = corrected_opponent_turn_count
    # Save message count for supporter
    user_message_counts[supporter_id] = len(
        messages_df[
            (messages_df["sender_id"] == supporter_id) & messages_df["body"].notna()
        ]
    )
    # Save message count for opponent
    user_message_counts[opponent_id] = len(
        messages_df[
            (messages_df["sender_id"] == opponent_id) & messages_df["body"].notna()
        ]
    )
    # Save conversation length in rephrasings for each user. That is, regardless of
    # whether they were treated, how many rephrasings would they have received if they
    # were?
    user_conversation_rephrasing_lengths[
        supporter_id
    ] = calculate_rephrasing_count_from_turn_count(
        supporter_turn_count, MIN_REPHRASING_TURNS, REPHRASE_EVERY_N_TURNS
    )
    user_conversation_rephrasing_lengths[
        opponent_id
    ] = calculate_rephrasing_count_from_turn_count(
        opponent_turn_count, MIN_REPHRASING_TURNS, REPHRASE_EVERY_N_TURNS
    )
    corrected_user_conversation_rephrasing_lengths[
        supporter_id
    ] = calculate_rephrasing_count_from_turn_count(
        corrected_supporter_turn_count, MIN_REPHRASING_TURNS, REPHRASE_EVERY_N_TURNS
    )
    corrected_user_conversation_rephrasing_lengths[
        opponent_id
    ] = calculate_rephrasing_count_from_turn_count(
        corrected_opponent_turn_count, MIN_REPHRASING_TURNS, REPHRASE_EVERY_N_TURNS
    )
    corrected_limit_reached[chatroom_id] = limit_reached

chatrooms_df["message_count"] = chatrooms_df["id"].map(chatroom_message_counts)
chatrooms_df["old_turn_count"] = chatrooms_df["id"].map(chatroom_turn_counts)
chatrooms_df["turn_count"] = chatrooms_df["id"].map(corrected_chatroom_turn_counts)
chatrooms_df["limit_reached"] = chatrooms_df["id"].map(corrected_limit_reached)


## Add turn counts to users

For each user:

| `message_count`                   | `turn_count`                                     | `conversation_length_rephrasings`                                                                        | `chatroom_message_count`        | `chatroom_turn_count` |
| --------------------------------- | ------------------------------------------------ | -------------------------------------------------------------------------------------------------------- | ------------------------------- | --------------------- |
| number of messages this user sent | number of counted turns calculated for this user | number of rephrasings this user would have received if they were treated, computed from turn count only. | copied from associated chatroom | ...                   |


In [None]:
users_df["message_count"] = users_df["id"].map(user_message_counts)
users_df["old_turn_count"] = users_df["id"].map(user_turn_counts)
users_df["turn_count"] = users_df["id"].map(corrected_user_turn_counts)
users_df["old_conversation_rephrasing_length"] = users_df["id"].map(
    user_conversation_rephrasing_lengths
)
users_df["conversation_rephrasing_length"] = users_df["id"].map(
    corrected_user_conversation_rephrasing_lengths
)
users_df["chatroom_message_count"] = users_df["chatroom_id"].map(
    chatroom_message_counts
)
users_df["old_chatroom_turn_count"] = users_df["chatroom_id"].map(chatroom_turn_counts)
users_df["chatroom_turn_count"] = users_df["chatroom_id"].map(
    corrected_chatroom_turn_counts
)
users_df["limit_reached"] = (
    users_df["chatroom_id"].map(corrected_limit_reached).fillna(0).astype(int)
)


## Add rephrasing counts to users

These are derived from chatroom numbers. Note that these columns are NA where not applicable; this data doesn't exist for users in untreated control conversations.

1. **Rephrasing counts if the user was offered rephrasings:**

| `self_rephrasing_original_count`                                              | `self_rephrasing_validate_count`           | `self_rephrasing_restate_count`           | `self_rephrasing_polite_count`           | `self_rephrasing_accepted_count` | `self_rephrasing_total_count`                     |
| ----------------------------------------------------------------------------- | ------------------------------------------ | ----------------------------------------- | ---------------------------------------- | -------------------------------- | ------------------------------------------------- |
| rephrasing opportunities the user ignored,<br/>sending their original message | rephrasings sent using 'validate' strategy | rephrasings sent using 'restate` strategy | rephrasings sent using 'polite' strategy | sum of all accepted rephrasings  | sum of all offered rephrasings, including ignored |

2. **Then rephrasing counts if the user's partner was offered rephrasing:**

| `partner_rephrasing_original_count`                  | `partner_rephrasing_validate_count` | `partner_rephrasing_restate_count` | `partner_rephrasing_polite_count` | `partner_rephrasing_accepted_count` | `partner_rephrasing_total_count` |
| ---------------------------------------------------- | ----------------------------------- | ---------------------------------- | --------------------------------- | ----------------------------------- | -------------------------------- |
| similar to `self_rephrasing_*_count` but for partner | ...                                 | ...                                | ...                               | ...                                 | ...                              |


In [None]:
def create_rephrasing_count_column(chatroom_column_name, treatment):
    return users_df[["chatroom_id", "treatment"]].apply(
        lambda row: chatrooms_df[chatrooms_df["id"] == row["chatroom_id"]][
            chatroom_column_name
        ].iloc[0]
        if row["treatment"] == treatment
        else np.nan,
        axis=1,
    )


# Create self_ columns for TREATED users

users_df["self_rephrasing_original_count"] = create_rephrasing_count_column(
    "rephrasing_original_count", "TREATED"
)

users_df["self_rephrasing_validate_count"] = create_rephrasing_count_column(
    "rephrasing_validate_count", "TREATED"
)

users_df["self_rephrasing_restate_count"] = create_rephrasing_count_column(
    "rephrasing_restate_count", "TREATED"
)

users_df["self_rephrasing_polite_count"] = create_rephrasing_count_column(
    "rephrasing_polite_count", "TREATED"
)

users_df["self_rephrasing_accepted_count"] = create_rephrasing_count_column(
    "rephrasing_accepted_count", "TREATED"
)

users_df["self_rephrasing_total_count"] = create_rephrasing_count_column(
    "rephrasing_total_count", "TREATED"
)

# Create partner_ columns for UNTREATED users

users_df["partner_rephrasing_original_count"] = create_rephrasing_count_column(
    "rephrasing_original_count", "UNTREATED"
)

users_df["partner_rephrasing_validate_count"] = create_rephrasing_count_column(
    "rephrasing_validate_count", "UNTREATED"
)

users_df["partner_rephrasing_restate_count"] = create_rephrasing_count_column(
    "rephrasing_restate_count", "UNTREATED"
)

users_df["partner_rephrasing_polite_count"] = create_rephrasing_count_column(
    "rephrasing_polite_count", "UNTREATED"
)

users_df["partner_rephrasing_accepted_count"] = create_rephrasing_count_column(
    "rephrasing_accepted_count", "UNTREATED"
)

users_df["partner_rephrasing_total_count"] = create_rephrasing_count_column(
    "rephrasing_total_count", "UNTREATED"
)


## Tests


## Turn counting


In [None]:
LIMIT_REACHED_TURN_TEST_CASES = [
    ("CONTROL", "SUPPORT", "l l l l l l l l l l l l l", (7, 6, False)),
    ("CONTROL", ("SUPPORT", "OPPOSE"), "l l l l l l l l l l l l l", (6, 7, False)),
    ("CONTROL", "SUPPORT", "l l l l l l l l l l l l l l l", (8, 7, True)),
    ("CONTROL", ("SUPPORT", "OPPOSE"), "l l l l l l l l l l l l l l l", (7, 8, True)),
    # These are just swapped versions of the above—it really shouldn't matter
    ("CONTROL", "OPPOSE", "l l l l l l l l l l l l l", (7, 6, False)),
    ("CONTROL", ("OPPOSE", "SUPPORT"), "l l l l l l l l l l l l l", (6, 7, False)),
    ("CONTROL", "OPPOSE", "l l l l l l l l l l l l l l l", (8, 7, True)),
    ("CONTROL", ("OPPOSE", "SUPPORT"), "l l l l l l l l l l l l l l l", (7, 8, True)),
]


In [None]:
MESSAGE_LENGTH_TURN_TEST_CASES = [
    ("TREATED", "SUPPORT", "s s s s s s s s s", (0, 0)),
    ("TREATED", "SUPPORT", "l l l l l l l l l", (5, 4)),
    ("TREATED", "SUPPORT", "l s l s l s l s l", (1, 0)),
    ("TREATED", "SUPPORT", "s l s l s l s l s", (0, 1)),
    ("TREATED", "SUPPORT", "l s l s s l l l s", (2, 2)),
    ("TREATED", "SUPPORT", "l l s s l s s l l", (3, 2)),
    ("TREATED", "SUPPORT", "l l l l s l l l l", (4, 3)),
    ("TREATED", "SUPPORT", "l l s l l l s l l", (3, 2)),
    ("TREATED", "SUPPORT", "l l s l s l s l l", (2, 1)),
    ("TREATED", "SUPPORT", "l sss l s l l l s", (2, 1)),
    ("TREATED", "SUPPORT", "l lll lll l l s", (3, 2)),
    ("TREATED", "SUPPORT", "l lsl ssl l s l", (2, 2)),
    # The rest of these are just swapped from the above—it shouldn't matter but that
    # is what we're testing
    ("UNTREATED", "SUPPORT", "s s s s s s s s s", (0, 0)),
    ("UNTREATED", "SUPPORT", "l l l l l l l l l", (5, 4)),
    ("UNTREATED", "SUPPORT", "l s l s l s l s l", (1, 0)),
    ("UNTREATED", "SUPPORT", "s l s l s l s l s", (0, 1)),
    ("UNTREATED", "SUPPORT", "l s l s s l l l s", (2, 2)),
    ("UNTREATED", "SUPPORT", "l l s s l s s l l", (3, 2)),
    ("UNTREATED", "SUPPORT", "l l l l s l l l l", (4, 3)),
    ("UNTREATED", "SUPPORT", "l l s l l l s l l", (3, 2)),
    ("UNTREATED", "SUPPORT", "l l s l s l s l l", (2, 1)),
    ("UNTREATED", "SUPPORT", "l sss l s l l l s", (2, 1)),
    ("UNTREATED", "SUPPORT", "l lll lll l l s", (3, 2)),
    ("UNTREATED", "SUPPORT", "l lsl ssl l s l", (2, 2)),
    ("TREATED", "OPPOSE", "s s s s s s s s s", (0, 0)),
    ("TREATED", "OPPOSE", "l l l l l l l l l", (5, 4)),
    ("TREATED", "OPPOSE", "l s l s l s l s l", (1, 0)),
    ("TREATED", "OPPOSE", "s l s l s l s l s", (0, 1)),
    ("TREATED", "OPPOSE", "l s l s s l l l s", (2, 2)),
    ("TREATED", "OPPOSE", "l l s s l s s l l", (3, 2)),
    ("TREATED", "OPPOSE", "l l l l s l l l l", (4, 3)),
    ("TREATED", "OPPOSE", "l l s l l l s l l", (3, 2)),
    ("TREATED", "OPPOSE", "l l s l s l s l l", (2, 1)),
    ("TREATED", "OPPOSE", "l sss l s l l l s", (2, 1)),
    ("TREATED", "OPPOSE", "l lll lll l l s", (3, 2)),
    ("TREATED", "OPPOSE", "l lsl ssl l s l", (2, 2)),
    ("UNTREATED", "OPPOSE", "s s s s s s s s s", (0, 0)),
    ("UNTREATED", "OPPOSE", "l l l l l l l l l", (5, 4)),
    ("UNTREATED", "OPPOSE", "l s l s l s l s l", (1, 0)),
    ("UNTREATED", "OPPOSE", "s l s l s l s l s", (0, 1)),
    ("UNTREATED", "OPPOSE", "l s l s s l l l s", (2, 2)),
    ("UNTREATED", "OPPOSE", "l l s s l s s l l", (3, 2)),
    ("UNTREATED", "OPPOSE", "l l l l s l l l l", (4, 3)),
    ("UNTREATED", "OPPOSE", "l l s l l l s l l", (3, 2)),
    ("UNTREATED", "OPPOSE", "l l s l s l s l l", (2, 1)),
    ("UNTREATED", "OPPOSE", "l sss l s l l l s", (2, 1)),
    ("UNTREATED", "OPPOSE", "l lll lll l l s", (3, 2)),
    ("UNTREATED", "OPPOSE", "l lsl ssl l s l", (2, 2)),
]


In [None]:
import random


def construct_convo(spec):
    treatment, position, pattern, turns = spec

    longs = [
        "This is a long message",
        "This is a longer message",
        "This is the longest message",
    ]
    shorts = ["short", "shorter", "shortest"]

    convo = []
    current_position = position
    current_treatment = treatment
    for spurt in pattern.split():
        for message in spurt:
            if message == "s":
                text = random.choice(shorts)
            else:
                text = random.choice(longs)
            convo.append(
                {
                    "body": text,
                    "position": current_position,
                    "treatment": current_treatment,
                }
            )
        if current_position == "SUPPORT":
            current_position = "OPPOSE"
        else:
            current_position = "SUPPORT"
        if current_treatment == "TREATED":
            current_treatment = "UNTREATED"
        elif current_treatment == "UNTREATED":
            current_treatment = "TREATED"
    return convo


In [None]:
def test_turn_case(case):
    treatment, position, pattern, expected = case
    if len(expected) == 3:
        expected_first_turns, expected_second_turns, expected_limit_reached = expected
    else:
        expected_limit_reached = None
        expected_first_turns, expected_second_turns = expected
    if isinstance(position, tuple):
        first_position, considered_position = position
    else:
        first_position = position
        considered_position = position
    convo = construct_convo((treatment, first_position, pattern, None))
    _, first_turns, second_turns, limit_reached = calculate_turns_improved(
        convo, considered_position, treatment
    )
    assert first_turns == expected_first_turns
    assert second_turns == expected_second_turns
    if expected_limit_reached is not None:
        assert limit_reached == expected_limit_reached


In [None]:
# Run on message length test cases
for case in MESSAGE_LENGTH_TURN_TEST_CASES:
    test_turn_case(case)


In [None]:
# Run on limit reached test cases
for case in LIMIT_REACHED_TURN_TEST_CASES:
    test_turn_case(case)


## `corrected_limit_reached` has correct number of turns


In [None]:
def test_limit_reached_turn_count(row):
    turn_count = row["turn_count"]
    treatment = row["treatment"]
    limit_reached = bool(row["limit_reached"])
    if treatment == "TREATED" or treatment == "UNTREATED":
        assert not limit_reached or (turn_count >= 8)
    elif treatment == "CONTROL":
        assert not limit_reached or (turn_count >= 7)


users_df.apply(test_limit_reached_turn_count, axis=1)


In [None]:
messages_df["accepted_rephrasing_id"].notna().sum()


## View distributions of generated columns


In [None]:
users_df["limit_reached"].apply(bool).describe()


In [None]:
users_df["turn_count"].describe()


In [None]:
users_df[users_df["treatment"] == "TREATED"][
    [
        "self_rephrasing_original_count",
        "self_rephrasing_validate_count",
        "self_rephrasing_restate_count",
        "self_rephrasing_polite_count",
        "self_rephrasing_accepted_count",
        "self_rephrasing_total_count",
    ]
].describe()


In [None]:
users_df[users_df["treatment"] == "UNTREATED"][
    [
        "partner_rephrasing_original_count",
        "partner_rephrasing_validate_count",
        "partner_rephrasing_restate_count",
        "partner_rephrasing_polite_count",
        "partner_rephrasing_accepted_count",
        "partner_rephrasing_total_count",
    ]
].describe()


## Export CSVs


In [None]:
# First create output_dir if it doesn't exist
output_dir.mkdir(parents=True, exist_ok=True)

chatrooms_df.to_csv(output_dir / "chatrooms.csv", index=False)
users_df.to_csv(output_dir / "users.csv", index=False)
messages_df.to_csv(output_dir / "messages.csv", index=False)
rephrasings_df.to_csv(output_dir / "rephrasings.csv", index=False)
