# Redirection Demo in US Supreme Court oral arguments

This notebook demonstrates our redirection framework introduced this paper: **Taking a turn for the better: Conversation redirection throughout the course of mental-health therapy.** In the paper, we define redirection as the extent to which speakers shift the immediate focus of the conversation and applied our measure in the context of long-term messaging therapy. In this demo, we provide an initial exploration into how our redirection framework can be applied in other domains in particular, to a publicly available dataset of U.S. Supreme Court oral arguments (Danescu-Niculescu-Mizil et al., 2012; Chang et al., 2020). Although court proceedings differ from therapy in terms of topics, goals, and interaction styles, their relatively unstructured and dynamic nature enables an initial exploration of how such discussions are redirected.

In this setting, we focus on the interactions between justices and lawyers. The power dynamics between these distinct roles reflect the asymmetric relationship between therapists and patients in mental-health domains, where one party generally holds more influence over the direction of the conversation.

We first install and import all the necessary packages from Convokit including our wrapper models and config files.

In [None]:
!pip install git+https://github.com/vianxnguyen/ConvoKit.git
# !pip install -q convokit

In [None]:
from convokit import Corpus, download
from convokit.redirection.likelihoodModel import LikelihoodModel
from convokit.redirection.gemmaLikelihoodModel import GemmaLikelihoodModel
from convokit.redirection.redirection import Redirection
from convokit.redirection.config import DEFAULT_BNB_CONFIG, DEFAULT_LORA_CONFIG, DEFAULT_TRAIN_CONFIG
import random
from sklearn.model_selection import train_test_split
import numpy as np
from scipy.stats import wilcoxon

We then download the `supreme-court` corpus we will be using for training and analysis. If you already have the corpus saved locally, you can specify the path to load the corpus from.

In [None]:
# If you already have the corpus saved locally, load the corpus from the saved path.
# DATA_DIR = '/Users/vian/.convokit/downloads/supreme-corpus'
# corpus = Corpus(DATA_DIR)

# Otherwise download the corpus
corpus = Corpus(filename=download('supreme-corpus'))
corpus.print_summary_stats()

For the purposes of the demo, we will randomly sample a subset of 50 conversations (~20k utterances) for our analysis. Since in this demonstration, we focus on interactions between two distinct roles of justices and lawyers, we label the speaker role for each utterance (either justice or lawyer). 

In [None]:
convos = [convo for convo in corpus.iter_conversations()]
sample_convos = random.sample(convos, 50)
print(len(sample_convos))


In [5]:
for convo in sample_convos:
  for utt in convo.iter_utterances():
    if utt.speaker.id.startswith("j_"):
      utt.meta["role"] = "justice"
    else:
      utt.meta["role"] = "lawyer"

We will use a 90/10/10 train/val/test split. We then label the conversations with their corresponding split.

In [None]:
train_convos, temp_convos = train_test_split(sample_convos, test_size=0.2, random_state=10)
val_convos, test_convos = train_test_split(temp_convos, test_size=0.5, random_state=10)
print(len(train_convos), len(val_convos), len(test_convos))

for convo in train_convos:
  convo.meta["train"] = True
for convo in val_convos: 
  convo.meta["val"] = True 
for convo in test_convos:
  convo.meta["test"] = True 

Now, we define our likelihood model responsible for computing utterance likelihoods based on provided context.The likelihood probabilities are later used to compute redirection scores for each utterance. Here, we define a likelihood model using the Gemma-2B model called `GemmaLikelihodModel` which inherits from a default `LikelihoodModel` interface. Different models (Gemma, Llama, Mistral, etc.) can be supported by inheriting from this base interface. 

Since in this demo, we are using Gemma-2B through HuggingFace, we need to provide an authentication token for access to the model.

In [None]:
gemma_likelihood_model = \
    GemmaLikelihoodModel(
        hf_token = "TODO: ADD HUGGINGFACE AUTH TOKEN",
        model_id = "google/gemma-2b", 
        train_config = DEFAULT_TRAIN_CONFIG,
        bnb_config = DEFAULT_BNB_CONFIG,
        lora_config = DEFAULT_LORA_CONFIG,
    )

We use the following default configs and parameters for fine-tuning. However, you may override these by defining your own configs and passing them to the `GemmaLikelihoodModel`.

In [None]:
"""
DEFAULT_BNB_CONFIG = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

DEFAULT_LORA_CONFIG = LoraConfig(
    r=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

DEFAULT_TRAIN_CONFIG = {
    "output_dir": "checkpoints",
    "logging_dir": "logging",
    "logging_steps": 25,
    "eval_steps": 50, 
    "num_train_epochs": 2, 
    "per_device_train_batch_size": 1,  
    "per_device_eval_batch_size": 1,   
    "evaluation_strategy": "steps",
    "save_strategy": "steps",
    "save_steps": 50,
    "optim": "paged_adamw_8bit",
    "learning_rate": 2e-4,
    "max_seq_length": 512,
    "load_best_model_at_end": True,
}
"""

Now we can define our redirection model, providing the initialized `gemma_likelihood_model` as our `LikelihoodModel`. The `redirection_attribute_name` represents the name of the meta-data field to save our redirection scores to in the corpus.

We also note that it is possible to define your own `previous_context_selector` and `future_context_selector` to determine which contexts you would use to compute the likelihoods. The functions take as input an utterance and returns the previous (actual and reference) or future contexts for that particular utterance. By default, we use the immediate contexts described in our paper. Note that the default implementation for these contexts assumes we are working with two distinct speaker roles. You may write your own context selectors to customize them for more than two speaker types.

In [None]:
redirection = \
    Redirection(
        likelihood_model = gemma_likelihood_model,
        redirection_attribute_name = "redirection"
#         previous_context_selector = <YOUR OWN PREVIOUS CONTEXT SELECTOR>, 
#         future_context_selector = <YOUR OWN FUTURE CONTEXT SELECTOR>,
    )

Now we can call the fit method to fine-tune our model on a subset of the conversations in the corpus. We use a selector function to only fine-tune on the `train` subset of our data. Alternatively, if you already have saved an existing model, you can load it into memory using `load_from_disk`.

In [None]:
redirection.fit(corpus, 
                      train_selector=lambda convo: "train" in convo.meta, 
                      val_selector=lambda convo: "val" in convo.meta
                      )

After we have our fine-tuned model, we can then run inference on the test conversations in order to compute the redirection scores. 

In [None]:
redirection.transform(corpus, selector=lambda convo: "test" in convo.meta)

We can then call summarize to view examples of high and low redirecting utterances from each speaker.

In [None]:
redirection.summarize(corpus)

We can also perform a FightingWords analysis to see distinguishing bigrams indicating high vs. low redirection from both speakers.

In [10]:
from convokit import FightingWords

We first label top 20% and bottom 20% of utterances from both speakers based on their redirection scores.

In [None]:
justice_utts = []
lawyer_utts = []

for convo in test_convos: 
  for utt in convo.iter_utterances():
    if "redirection" in utt.meta:
      if utt.meta["role"] == "justice":
        justice_utts.append(utt)
      else:
        lawyer_utts.append(utt)

justice_utts = sorted(justice_utts, key=lambda utt: utt.meta["redirection"])
lawyer_utts = sorted(lawyer_utts, key=lambda utt: utt.meta["redirection"])

justice_threshold = int(len(justice_utts) * 0.20)
lawyer_threshold = int(len(lawyer_utts) * 0.20)

for utt in justice_utts[:justice_threshold]:
  utt.meta['type'] = "justice_low"
for utt in justice_utts[-justice_threshold:]:
  utt.meta['type'] = "justice_high"

for utt in lawyer_utts[:lawyer_threshold]:
  utt.meta['type'] = "lawyer_low"
for utt in lawyer_utts[-lawyer_threshold:]:
  utt.meta['type'] = "lawyer_high"

Here we first show phrasings indicative of low redirection from justices.

In [None]:
fw_justice = FightingWords(ngram_range=(2,2))
class1 = 'justice_high'
class2 = 'justice_low'
fw_justice.fit(corpus, class1_func=lambda utt: 'type' in utt.meta and utt.meta['type'] == class1, 
               class2_func=lambda utt: 'type' in utt.meta and utt.meta['type'] == class2)
justice = fw_justice.summarize(corpus, plot=False, class1_name=class1, class2_name=class2)
justice.head(20)

Here we show phrasings indicative of high redirection from justices.

In [None]:
justice.tail(20)[::-1]

We can perform the corresponding analysis for lawyers as well.

In [None]:
fw_lawyer = FightingWords(ngram_range=(2,2))
class1 = 'lawyer_high'
class2 = 'lawyer_low'
fw_lawyer.fit(corpus, class1_func=lambda utt: 'type' in utt.meta and utt.meta['type'] == class1, 
               class2_func=lambda utt: 'type' in utt.meta and utt.meta['type'] == class2)
lawyer = fw_lawyer.summarize(corpus, plot=False, class1_name=class1, class2_name=class2)
lawyer.head(20)

In [None]:
lawyer.tail(20)[::-1]

We can also compare the average redirection between justices and lawyers in the cases.

In [None]:
convo_justices = []
convo_lawyers = []
for convo in test_convos: 
    justice = []
    lawyer = []
    for utt in convo.iter_utterances():
        if "redirection" in utt.meta:
            if utt.meta["role"] == "justice":
                justice.append(utt)
            else:
                lawyer.append(utt)
    convo_justices.append(np.mean(justice))
    convo_lawyers.append(np.mean(lawyer))
    
print("Average justice:", np.mean(convo_justices))
print("Average lawyer:", np.mean(convo_lawyers))
stat, p_value = wilcoxon(convo_justices, convo_lawyers)
print(f"Statistic: {stat}, P-value: {p_value}")