# Pivotal Moments Demo
This notebook demonstrates our pivotal-moments framework introduced in this paper: Hanging in the Balance: Pivotal Moments in Crisis Counseling Conversations. In the paper, we consider *pivotal* moments as moments where what is said next matters with respect to the eventual outcome of the conversation. 

Here, we demo our framework on online conversations in the CGA-CMV (Conversations Gone Awry-Change My View) setting, consisting of conversations that may derail into personal attacks (Chang and Danescu-Niculescu-Mizil, 2019). We provide an initial exploration into identifying pivotal moments with respect to the outcome of conversation derailment. Furthermore, we release the demo to encourage future work and applications in other domains.

We first import all the necessary packages and modules that will be used in this demo.

In [None]:
# !pip install convokit
# %load_ext autoreload
# %autoreload 2

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from convokit import Corpus, download
from convokit.pivotal_framework.pivotal import PIV
from convokit.pivotal_framework.simulator.unslothUtteranceSimulatorModel import UnslothUtteranceSimulatorModel 

from convokit.forecaster.BERTCGAModel import BERTCGAModel

import random
from functools import partial
import json, os

We then download the `conversations-gone-awry-cmv-corpus` corpus that we will be using throughout the demo. If you already have the corpus saved locally, you could specify the path to load the corpus from.

In [None]:
corpus = Corpus(filename=download("conversations-gone-awry-cmv-corpus"))
# If you have the corpus saved locally, load it as follows: 
# corpus = Corpus("<PATH_TO_CORPUS>")
corpus.print_summary_stats()

The `conversations-gone-awry-cmv-corpus` corpus consists of Reddit conversations that may derail into personal attacks and conversations that remain calm.

The conversations in the corpus are paired based on the length (number of utterances) of the conversation, where each pair consists of one *derailed* conversation and one *calm* conversation (indicated by the `has_removed_comment` metadata field). In our demo, we will select conversations to train, validate, and test our framework. To maintain this pairing in the data selection, we first create a set of conversation ids, consisting of *one* conversation id from each pair, which we will use to sample from.

In [None]:
pair_ids = set()
for i, convo in enumerate(corpus.iter_conversations()):
    pair_id = convo.meta['pair_id']
    if convo.id in pair_ids: 
        continue
    pair_ids.add(pair_id)

pair_ids = list(pair_ids)
print(len(pair_ids))

Our pivotal-moments framework consists of two main components: (1) *simulator model* to simulate potential next responses and (2) *forecaster model* to predict the likelihood of the outcome based on these potential responses. 

To train, validate, and test these components, we sample pairs of conversations from the corpus, where each pair consists of one *derailed* conversation and one *calm* conversation. Therefore, by selecting `x` pairs, we are selecting `x * 2` conversations to be included in a given set.

Here, we use a 80/10/10 train/val/test split for our forecaster model and a 90/10 train/val split for fine-tuning our simulator model. 

We also sample conversations to be included in our analysis set. 

Alternatively, our framework also supports using pre-existing trained models for simulation/forecasting, so you can skip this data setup if you go with this route.

In [None]:
forecast_pair_ids = random.sample(pair_ids, 500)
random.shuffle(forecast_pair_ids)

size = len(forecast_pair_ids)
forecast_pair_train_ids = forecast_pair_ids[:int(0.8*size)]
forecast_pair_val_ids = forecast_pair_ids[int(0.8*size): int(0.9*size)]
forecast_pair_test_ids = forecast_pair_ids[int(0.9*size):]

pair_ids_filt = [pair_id for pair_id in pair_ids if pair_id not in forecast_pair_ids] 
sim_pair_ids = random.sample(pair_ids_filt, 500)
random.shuffle(sim_pair_ids)

size = len(sim_pair_ids)
sim_pair_train_ids = sim_pair_ids[:int(0.9*size)]
sim_pair_val_ids = sim_pair_ids[int(0.9*size):]

pair_ids_filt = [pair_id for pair_id in pair_ids if pair_id not in forecast_pair_ids+sim_pair_ids] 
analysis_pair_ids = random.sample(pair_ids_filt, 10)


For each of the conversations in a set, we want to select its corresponding paired conversation and include it in the set as well. Here we can see all the conversations that are included in each selected set.

In [None]:
def get_paired_sample(sample):
  result = []
  for convo_id in sample:
    convo = corpus.get_conversation(convo_id)
    result.append(convo)
    paired_convo_id = convo.meta['pair_id']
    result.append(corpus.get_conversation(paired_convo_id))
  return result

forecast_train = get_paired_sample(forecast_pair_train_ids)
forecast_val = get_paired_sample(forecast_pair_val_ids)
forecast_test = get_paired_sample(forecast_pair_test_ids)

sim_train = get_paired_sample(sim_pair_train_ids)
sim_val = get_paired_sample(sim_pair_val_ids)

analysis = get_paired_sample(analysis_pair_ids)

print("Forecaster (train, val, test)")
print(len(forecast_train), len(forecast_val), len(forecast_test))

print("Simulator (train, val)")
print(len(sim_train), len(sim_val))

print("Analysis")
print(len(analysis))

Then, we label each of the conversations with their corresponding split by annotating the conversation metadata field `data_split`.

In [None]:
def label_split(convos, split):
  for convo in convos:
    convo.meta["data_split"] = split

label_split(corpus.iter_conversations(), None)

label_split(forecast_train, "forecast_train")
label_split(forecast_val, "forecast_val")
label_split(forecast_test, "forecast_test")

label_split(sim_train, "sim_train")
label_split(sim_val, "sim_val")

label_split(analysis, "analysis")

 This function is responsible for creating generic selector functions based on the `data_split` field, where the function returns `True` if the context is included in the corresponding data split. These selector functions will be used to select contexts used for each part of the framework.

In [None]:
def make_data_selector(split):
  return lambda context_tuple: context_tuple.current_utterance.get_conversation().meta.get("data_split") == split

Here, we define selector functions used specifically to fit (train) the forecaster and the simulator models. These are described below. 

In [None]:
def forecaster_fit_selector(context_tuple, split):
  """
  We use this generic function for both training and validation data.
  In both cases, its job is to select only those contexts for which the
  FUTURE context is empty. This is in accordance with how CRAFT Model was
  originally trained on CGA-CMV, taking the last context from each
  conversation ("last" defined as being up to and including the chronologically
  last utterance as recorded in the corpus)
  """
  matches_split = (context_tuple.current_utterance.get_conversation().meta.get("data_split") == split)
  is_end = (len(context_tuple.future_context) == 0)
  return (matches_split and is_end)

def simulator_fit_selector(context_tuple, split):
  """
  We use this generic function for both training and validation data.
  In both cases, its job is to select only those contexts for which the
  FUTURE context is not empty, so we have a next utterance to predict.
  """
  matches_split = (context_tuple.current_utterance.get_conversation().meta.get("data_split") == split)
  is_end = (len(context_tuple.future_context) == 0)
  return (matches_split and not is_end)


Our pivotal-moments framework consists of two main components: (1) *simulator model* to simulate potential next responses and (2) *forecaster model* to predict the likelihood of the outcome based on these potential responses. 

Here, we initialize these two models of types `UtteranceSimulatorModel` and `ForecasterModel`.

(1) For the `UtteranceSimulatorModel`, we use `UnslothUtteranceSimulatorModel` which is a general wrapper adapted to the Unsloth framework. Here, we use the 4-bit quantized Llama3-8B model base model, but any model supported by Unsloth can be used. It also possible to load an existing local model by specifying it's saved path. Other models of type `UtteranceSimulatorModel` can be customized and used accordingly.

In addition, we could also optionally specify a custom `prompt_fn` that converts contexts to prompts used for the model.

We use the following default configs which can be modifed. 

In [None]:
"""
DEFAULT_NUM_SIMULATIONS = 10

DEFAULT_LLAMA_CHAT_TEMPLATE = "llama3"
DEFAULT_LLAMA_CHAT_TEMPLATE_MAPPING = {
    "role": "from",
    "content": "value",
    "user": "human",
    "assistant": "gpt",
}

DEFAULT_MODEL_CONFIG = {
    "load_in_4bit": True,
    "max_seq_length": 2048,
    "dtype": None,
    "target_modules": [
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "embed_tokens",
        "lm_head",
    ],
    "r": 16,
    "lora_alpha": 16,
    "lora_dropout": 0,
    "bias": "none",
    "use_gradient_checkpointing": "unsloth",
    "use_rslora": False,
    "loftq_config": None,
}

DEFAULT_TRAIN_CONFIG = {
    "per_device_train_batch_size": 16,
    "per_device_eval_batch_size": 16,
    "eval_strategy": "steps",
    "save_strategy": "steps",
    "save_steps": 30,
    "gradient_accumulation_steps": 4,
    "warmup_steps": 5,
    "num_train_epochs": 1,
    "eval_steps": 30,
    "learning_rate": 2e-4,
    "logging_steps": 5,
    "optim": "adamw_8bit",
    "weight_decay": 0.01,
    "lr_scheduler_type": "linear",
    "output_dir": "outputs",
    "logging_dir": "logs",
    "load_best_model_at_end": True,
}
"""

In [None]:
DEVICE = "cuda"
simulator_model = UnslothUtteranceSimulatorModel(
  model_name="unsloth/Meta-Llama-3.1-8B-bnb-4bit",
  device=DEVICE,
  num_simulations=10,
  # model_config=DEFAULT_MODEL_CONFIG,
  # train_config=DEFAULT_TRAIN_CONFIG,
  # chat_template=DEFAULT_LLAMA_CHAT_TEMPLATE,
  # chat_template_mapping=DEFAULT_LLAMA_CHAT_TEMPLATE_MAPPING
  # prompt_fn=default_prompt_fn,
)

(2) For the `ForecasterModel` in this demo, we use `BERTCGAModel` which is a general wrapper adapted to BERT-based forecasting models. Here, we use `roberta-large` for our demo. It also possible to load an existing trained model by specifying it's local path instead. Other forecasting models of type `ForecasterModel` can be used as well.

We use the following default config which can be modifed. You can specify your saving directory in the config.

In [None]:
model_name_or_path = 'roberta-large'
config_dict = {
    "output_dir": "YOUR_SAVING_DIRECTORY", 
    "per_device_batch_size": 4, 
    "num_train_epochs": 4, 
    "learning_rate": 6.7e-6,
    "random_seed": 1,
    "device": DEVICE
}
forecaster_model = BERTCGAModel(model_name_or_path, config=config_dict)

Here, we now initialize the `PIV` transformer used to compute pivotal scores. The transformer uses the two models of types `UtteranceSimulatorModel` and `ForecasterModel` that we have defined earlier. We also specify metadata fields to save the scores to. Lastly, we add the `labeler` field to indicate the metadata field corresponding to the outcome of the conversation, in this case `has_removed_comment` corresponds to whether the conversation had derailed.

In [None]:
piv_transformer = PIV(
  simulator_model=simulator_model,
  forecaster_model=forecaster_model,
  piv_attribute_name="PIV",
  simulated_reply_attribute_name="sim_replies",
  simulated_reply_forecast_attribute_name="sim_replies_forecasts",
  simulated_reply_forecast_prob_attribute_name="sim_replies_forecast_probs",
  forecast_attribute_name="forecast",
  forecast_prob_attribute_name="forecast_prob",
  labeler="has_removed_comment",
)

Next, we can fit our transformer. We can individually fit the forecaster and simulator models by calling `fit_forecaster` and `fit_simulator` separately which trains the forecaster and fine-tunes the simulator based on the selected contexts. Alternatively, we can call `fit` to run the whole pipeline. 

If we want to use an existing trained model, we can skip this step. 

In [None]:
piv_transformer.fit_forecaster(
  corpus=corpus,
  train_context_selector=partial(forecaster_fit_selector, split="forecast_train"),
  val_context_selector=partial(forecaster_fit_selector, split="forecast_val"),
  test_context_selector=make_data_selector("forecast_test"),
)

In [None]:
piv_transformer.fit_simulator(
  corpus=corpus,
  train_context_selector=partial(simulator_fit_selector, split="sim_train"),
  val_context_selector=partial(simulator_fit_selector, split="sim_val"),
)

Now, we have our PIV transformer, we can simply call `transform` to compute pivotal scores on our analysis set.

In [None]:
piv_transformer.transform(
  corpus=corpus,
  context_selector=make_data_selector("analysis"),
)

We can take a look at conversations with their pivotal scores.

In [None]:
def print_random_convo(test_convos):
  i = random.randint(0, len(test_convos)-1)
  convo = test_convos[i]
  print("has_removed_comment:", convo.meta["has_removed_comment"])
  print()
  for ut in convo.iter_utterances():
      print("[", round(ut.meta["PIV"], 5), "]", ut.text, "\n")

In [None]:
print_random_convo(analysis)