# ConvoKit Forecaster framework: CRAFT demo

The `Forecaster` class provides a generic interface to *conversational forecasting models*, a class of models designed to computationally capture the trajectory of conversations in order to predict future events. Though individual conversational forecasting models can get quite complex, the `Forecaster` API abstracts away the implementation details into a standard fit-transform interface. To demonstrate the power of this framework, this notebook walks through an example of fine-tuning the CRAFT conversational forecasting model (Chang and Danescu-Niculescu-Mizil, 2019) on the CGA-CMV corpus. You will see how the `Forecaster` API allows us to load the data, select training, validation, and testing samples, train the CRAFT model, and perform evaluation - replicating the original paper's full pipeline (minus pre-training, which is considered outside the scope of ConvoKit) all in only a few lines of code!

Let's start by importing the necessary ConvoKit classes and functions, and loading the CGA-CMV corpus.

In [1]:
from convokit import download, Corpus, Forecaster, CRAFTModel
from functools import partial

In [2]:
corpus = Corpus(filename=download("conversations-gone-awry-cmv-corpus"))

Dataset already exists at /home/jonathan/.convokit/downloads/conversations-gone-awry-cmv-corpus


## Define selectors for the Forecaster

Core to the flexibility of the `Forecaster` framework is the concept of *selectors*. 

To capture the temporal dimension of the conversational forecasting task, `Forecaster` iterates through conversations in chronological utterance order, at each step presenting to the backend forecasting model a "context tuple" containing both the comment itself and the full "context" preceding that comment. As a general framework, `Forecaster` on its own does not try to make any further assumptions about what "context" should contain or look like; it simply presents context as a chronologically ordered list of all utterances up to and including the current one. 

But in practice, we often want to be pickier about what we mean by "context". At a basic level, we might want to select only specific contexts during training versus during evaluation. The simplest version of this is the desire to split the conversations by training and testing splits, but more specifically, we might also want to select only certain contexts within a conversation. This is necessary for CRAFT training, which works by taking only the chronologically last context (i.e., all utterances up to and not including the toxic comment, or up to the end of the conversation) as a labeled training instance. This is where selectors come in! A selector is a user-provided function that takes in a context and returns a boolean representing whether or not that context should be used. You can provide separate selectors for `fit` and `transform`, and `fit` also takes in a second selector that you can use to define validation data.

Here we show how to implement the necessary selectors for CRAFT.

In [3]:
def generic_fit_selector(context_tuple, split):
    """
    We use this generic function for both training and validation data.
    In both cases, its job is to select only those contexts for which the
    FUTURE context is empty. This is in accordance with how CRAFT was
    originally trained on CGA-CMV, taking the last context from each
    conversation ("last" defined as being up to and including the chronologically
    last utterance as recorded in the corpus)
    """
    matches_split = (context_tuple.current_utterance.get_conversation().meta["split"] == split)
    is_end = (len(context_tuple.future_context) == 0)
    return (matches_split and is_end)

def transform_selector(context_tuple):
    """
    For transform we only need to check that the conversation is in the test split
    """
    return (context_tuple.current_utterance.get_conversation().meta["split"] == "test")

## Initialize the Forecaster and CRAFTModel backend

Now the rest of the process is pretty straightforward! We simply need to:
1. Initialize a backend `ForecasterModel` for the `Forecaster` to use, in this case we choose ConvoKit's implementation of CRAFT.
2. Initialize a `Forecaster` instance to wrap that `ForecasterModel` in a generic fit-transform API

In [4]:
craft = CRAFTModel("craft-cmv-pretrained", torch_device="cuda")

Downloading craft-cmv-pretrained to /home/jonathan/.convokit/saved-models/craft-cmv-pretrained
Downloading craft-cmv-pretrained/craft_pretrained.tar from https://zissou.infosci.cornell.edu/convokit/models/craft_cmv/craft_pretrained.tar (974.6MB)... Done
Downloading craft-cmv-pretrained/index2word.json from https://zissou.infosci.cornell.edu/convokit/models/craft_cmv/index2word.json (1.0MB)... Done
Downloading craft-cmv-pretrained/word2index.json from https://zissou.infosci.cornell.edu/convokit/models/craft_cmv/word2index.json (928.0KB)... Done


In [5]:
craft_forecaster = Forecaster(craft, "has_removed_comment")

## Fine-tune the model using Forecaster.fit

And now, just like any other ConvoKit Transformer, model training is done simply by calling `fit` (note how we pass in the selectors we previously defined!)...

In [6]:
craft_forecaster.fit(corpus, 
                     partial(generic_fit_selector, split="train"), 
                     val_context_selector=partial(generic_fit_selector, split="val"))

Processed 4106 context tuples for model training
Processed 1368 context tuples for model validation
Loading saved parameters...
Building encoders, decoder, and classifier...
Models built and ready to go!
Building optimizers...
Starting Training!
Will train for 1920 iterations
Initializing ...
Training...
Iteration: 10; Percent complete: 0.5%; Average loss: 0.6933
Iteration: 20; Percent complete: 1.0%; Average loss: 0.6924
Iteration: 30; Percent complete: 1.6%; Average loss: 0.6929
Iteration: 40; Percent complete: 2.1%; Average loss: 0.6927
Iteration: 50; Percent complete: 2.6%; Average loss: 0.6929
Iteration: 60; Percent complete: 3.1%; Average loss: 0.6923
Validating!
Iteration: 1; Percent complete: 4.5%
Iteration: 2; Percent complete: 9.1%
Iteration: 3; Percent complete: 13.6%
Iteration: 4; Percent complete: 18.2%
Iteration: 5; Percent complete: 22.7%
Iteration: 6; Percent complete: 27.3%
Iteration: 7; Percent complete: 31.8%
Iteration: 8; Percent complete: 36.4%
Iteration: 9; Percen

<convokit.forecaster.forecaster.Forecaster at 0x7fdd1bfe90c0>

## Run the fitted model on the test set and perform evaluation

...and inference is done simply by calling `transform`! (again, note the selector)

In [7]:
corpus = craft_forecaster.transform(corpus, transform_selector)

Processed 7098 context tuples for model evaluation
Loading saved parameters...
Building encoders, decoder, and classifier...
Models built and ready to go!
Iteration: 1; Percent complete: 0.9%
Iteration: 2; Percent complete: 1.8%
Iteration: 3; Percent complete: 2.7%
Iteration: 4; Percent complete: 3.6%
Iteration: 5; Percent complete: 4.5%
Iteration: 6; Percent complete: 5.4%
Iteration: 7; Percent complete: 6.3%
Iteration: 8; Percent complete: 7.2%
Iteration: 9; Percent complete: 8.1%
Iteration: 10; Percent complete: 9.0%
Iteration: 11; Percent complete: 9.9%
Iteration: 12; Percent complete: 10.8%
Iteration: 13; Percent complete: 11.7%
Iteration: 14; Percent complete: 12.6%
Iteration: 15; Percent complete: 13.5%
Iteration: 16; Percent complete: 14.4%
Iteration: 17; Percent complete: 15.3%
Iteration: 18; Percent complete: 16.2%
Iteration: 19; Percent complete: 17.1%
Iteration: 20; Percent complete: 18.0%
Iteration: 21; Percent complete: 18.9%
Iteration: 22; Percent complete: 19.8%
Iterati

Finally, to get a human-readable interpretation of model performance, we can use `summarize` to generate a table of standard performance metrics. It also returns a table of conversation-level predictions in case you want to do more complex analysis!

In [8]:
craft_forecaster.summarize(corpus, lambda c: c.meta['split'] == "test")

Accuracy     0.632310
Precision    0.605110
Recall       0.761696
FPR          0.497076
F1           0.674434
dtype: float64


(                 label     score  forecast
 conversation_id                           
 cus26gy              1  0.714089       1.0
 cus37h0              1  0.785188       1.0
 cus142u              0  0.664932       1.0
 cus19ml              0  0.440758       0.0
 cusxft0              1  0.388766       0.0
 ...                ...       ...       ...
 e8qli0i              0  0.357186       0.0
 e8qm4aj              0  0.383017       0.0
 e8ql8ii              0  0.886789       1.0
 e8qzjei              1  0.971277       1.0
 e8r00ko              0  0.902596       1.0
 
 [1368 rows x 3 columns],
 {'Accuracy': 0.6323099415204678,
  'Precision': 0.6051103368176539,
  'Recall': 0.7616959064327485,
  'FPR': 0.49707602339181284,
  'F1': 0.6744336569579288})