# `02_transformer_model.ipynb`

*Attribution note: portions of code in this notebook are borrowed from [another notebook](https://github.com/disinfo-detectors/tweet-turing-test/blob/main/src/05_BERT_fine_tuner.ipynb), which was a notebook written by one of our team members (Justin Minnion) for another class (DSCI 591/592).*

# 0.1 - Setup

### 0.1.1 - Package Imports

In [133]:
# imports from python standard library
import re
from pathlib import Path

# data science packages
import numpy as np
import pandas as pd

# huggingface packages
import evaluate
from datasets import Dataset, DatasetDict, ClassLabel
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
from transformers import TrainingArguments, Trainer

### 0.1.2 - Constants

In [123]:
# file locations
DATA_DIR = Path("./data")
DATA_DIR_PROCESSED = DATA_DIR / "processed"
PROCESSED_DATA = DATA_DIR_PROCESSED / "script_data_processed.csv"

MODEL_DIR = DATA_DIR / "models"

### 0.1.3 - Options

In [2]:
pd.set_option('display.max_colwidth', None)

## 0.2 - Load Data

In [8]:
script_df = pd.read_csv(
    filepath_or_buffer=PROCESSED_DATA,
    header=0,
    index_col=0,
    encoding='utf-8'
)

In [9]:
script_df.head(3)

Unnamed: 0,season,episode,title,scene,speaker,line,directed_by,written_by,writer1,writer2,writer3
0,1,1,Pilot,1,michael,All right Jim. Your quarterlies look very good. How are things at the library?,Ken Kwapis,Ricky Gervais & Stephen Merchant and Greg Daniels,Ricky Gervais,Stephen Merchant,Greg Daniels
1,1,1,Pilot,1,jim,"Oh, I told you. I couldn't close it. So...",Ken Kwapis,Ricky Gervais & Stephen Merchant and Greg Daniels,Ricky Gervais,Stephen Merchant,Greg Daniels
2,1,1,Pilot,1,michael,"So you've come to the master for guidance? Is this what you're saying, grasshopper?",Ken Kwapis,Ricky Gervais & Stephen Merchant and Greg Daniels,Ricky Gervais,Stephen Merchant,Greg Daniels


In [13]:
# examine numeric fields
script_df.describe()

Unnamed: 0,season,episode,scene
count,54267.0,54267.0,54267.0
mean,5.538099,12.490003,4190.521606
std,2.349106,7.286262,2294.821819
min,1.0,1.0,1.0
25%,3.0,6.0,2325.0
50%,6.0,12.0,4215.0
75%,8.0,18.0,6153.0
max,9.0,28.0,8157.0


In [12]:
script_df.info(memory_usage='deep')

<class 'pandas.core.frame.DataFrame'>
Int64Index: 54267 entries, 0 to 54266
Data columns (total 11 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   season       54267 non-null  int64 
 1   episode      54267 non-null  int64 
 2   title        54267 non-null  object
 3   scene        54267 non-null  int64 
 4   speaker      54267 non-null  object
 5   line         54267 non-null  object
 6   directed_by  54267 non-null  object
 7   written_by   54267 non-null  object
 8   writer1      54267 non-null  object
 9   writer2      9816 non-null   object
 10  writer3      699 non-null    object
dtypes: int64(3), object(8)
memory usage: 29.1 MB


While the dataset isn't particularly large, we can improve performance / memory footprint if we are more prescriptive with `dtype` settings. At a minimum we should aim for no "`object`" type columns.

In [14]:
dtype_mapping = {
    'season': 'int8',
    'episode': 'int8',
    'title': 'string',
    'scene': 'int16',
    'speaker': 'string',    # could be category if we limit to top 10 speakers
    'line': 'string',
    'directed_by': 'category',
    'written_by': 'string',
    'writer1': 'category',
    'writer2': 'category',
    'writer3': 'category',
}

script_df = script_df.astype(dtype_mapping)

script_df.info(memory_usage='deep')

<class 'pandas.core.frame.DataFrame'>
Int64Index: 54267 entries, 0 to 54266
Data columns (total 11 columns):
 #   Column       Non-Null Count  Dtype   
---  ------       --------------  -----   
 0   season       54267 non-null  int8    
 1   episode      54267 non-null  int8    
 2   title        54267 non-null  string  
 3   scene        54267 non-null  int16   
 4   speaker      54267 non-null  string  
 5   line         54267 non-null  string  
 6   directed_by  54267 non-null  category
 7   written_by   54267 non-null  string  
 8   writer1      54267 non-null  category
 9   writer2      9816 non-null   category
 10  writer3      699 non-null    category
dtypes: category(4), int16(1), int8(2), string(4)
memory usage: 17.4 MB


# 1 - Basic Transformer

Attempting a basic transformer model without too much customization to establish a baseline (within transformer-type models) for performance.

**Task**: Sequence Classification (Binary)

**Classes**: 
 - Positive (1): "Dwight" - a line is spoken by the character Dwight K. Schrute (played by Rainn Wilson).
 - Negative (0): "Not Dwight" - a line is spoken by any other character than Dwight.

**Data**:
 - `speaker` as pre-cursor to class label. Limited to top-10 most frequent speakers based on number of lines in dataset
 - `line` as sequence text.

**Encoding**:
 - Tokenizer: DistilBertTokenizerFast
 - Max Sequence Length: 128
 - Padding: True
 - Truncate: True

**Training**:
 - Train/Test/Validation Split: 50/25/25

**Notes**:
 - Class imbalance is present (positive: 6,752; negative: 32,668; about `1:4.8` imbalance ratio).
 - Vocabulary: no modifications made to pretrained transformer's vocabulary.
 - Secondary data: no inclusion of secondary data (director/writer credits).

## 1.1 - Dataset - Convert `pandas` -> 🤗 `dataset`

In [36]:
# limit to top 10 most frequent speakers
top_10_speaker_list = script_df['speaker'].value_counts(normalize=True).nlargest(10).index.tolist()
columns_to_keep = ['speaker', 'line']

script_df_subset = script_df.loc[script_df['speaker'].isin(top_10_speaker_list), columns_to_keep]

script_df_subset

Unnamed: 0,speaker,line
0,michael,All right Jim. Your quarterlies look very good. How are things at the library?
1,jim,"Oh, I told you. I couldn't close it. So..."
2,michael,"So you've come to the master for guidance? Is this what you're saying, grasshopper?"
3,jim,"Actually, you called me in here, but yeah."
4,michael,"All right. Well, let me show you how it's done."
...,...,...
54257,kevin,"No, but maybe the reason..."
54258,oscar,You're not gay.
54260,erin,"How did you do it? How did you capture what it was really like? How we felt and how made each other laugh and how we got through the day? How did you do it? Also, how do cameras work?"
54265,jim,"I sold paper at this company for 12 years. My job was to speak to clients on the phone about quantities and types of copier paper. Even if I didn't love every minute of it, everything I have, I owe to this job. This stupid...wonderful...boring...amazing job."


In [47]:
# rename the 'line' column to be 'text'
script_df_subset = script_df_subset.rename(columns={'line': 'text'})

In [56]:
# create class label column
dwight_mask = (script_df_subset['speaker'] == 'dwight')

# new column of zeros
script_df_subset['label'] = 0

# apply the Dwight mask (as seen in the CPR scene of S05E14 "Stress Relief")
script_df_subset.loc[dwight_mask, 'label'] = 1

# adjust dtype
script_df_subset['label'] = script_df_subset['label'].astype('int8')    
    # would love to use 'category', but not implemented in 🤗 datasets

# check results
script_df_subset['label'].value_counts()

0    32668
1     6752
Name: label, dtype: int64

In [57]:
script_df_subset.info(memory_usage='deep')

<class 'pandas.core.frame.DataFrame'>
Int64Index: 39420 entries, 0 to 54266
Data columns (total 3 columns):
 #   Column   Non-Null Count  Dtype 
---  ------   --------------  ----- 
 0   speaker  39420 non-null  string
 1   text     39420 non-null  string
 2   label    39420 non-null  int8  
dtypes: int8(1), string(2)
memory usage: 7.0 MB


In [75]:
# finally, convert to 🤗 dataset object
#   drop 'speaker' by way of not including it
dataset_full: Dataset = Dataset.from_pandas(script_df_subset[['text', 'label']].reset_index(drop=False)) \
                    .cast_column('label', ClassLabel(names=['not_dwight', 'dwight']))

# make sure we got the class labels mapped correctly
assert (dataset_full.features['label'].str2int('dwight') == 1)

Casting the dataset:   0%|          | 0/39420 [00:00<?, ? examples/s]

In [76]:
dataset_full

Dataset({
    features: ['index', 'text', 'label'],
    num_rows: 39420
})

## 1.2 - Train/Test/Val Split

As of v2.12.0, the 🤗 Datasets implementation of `train_test_split` is limited to outputting **two** splits only (train/test), so we'll perform the split twice to obtain train, test, and validation splits.

In [79]:
# set parameters
train_size = 0.50
test_size = 0.25
valid_size = 0.25

assert sum([train_size, test_size, valid_size]) == 1.0

split_random_seed = 27  # for Weird Al fans

first_split = dataset_full.train_test_split(
    test_size=(1.0 - train_size),
    shuffle=True,
    seed=split_random_seed,
    stratify_by_column='label'
)

second_split = first_split['test'].train_test_split(
    test_size=((valid_size) / (test_size + valid_size)),
    shuffle=True,
    seed=split_random_seed,
    stratify_by_column='label'
)

ds_dict = DatasetDict({
    'train': first_split['train'],
    'test': second_split['train'],
    'valid': second_split['test']
})

ds_dict

DatasetDict({
    train: Dataset({
        features: ['index', 'text', 'label'],
        num_rows: 19710
    })
    test: Dataset({
        features: ['index', 'text', 'label'],
        num_rows: 9855
    })
    valid: Dataset({
        features: ['index', 'text', 'label'],
        num_rows: 9855
    })
})

In [86]:
# confirm stratified sample
num_negative = ds_dict['train'].to_pandas()['label'].value_counts()[0]
num_positive = ds_dict['train'].to_pandas()['label'].value_counts()[1]

print(f"ratio positive/negative is:\t1 to {num_negative/num_positive:0.1f}")

ratio positive/negative is:	1 to 4.8


## 1.3 Tokenize and Encode

In [88]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

In [102]:
# tokenizer function
def tokenize_function(examples):
    return tokenizer(examples['text'], 
                     padding='longest', 
                     truncation=True, 
                     return_tensors='pt',
                     max_length=128)

ds_tokenized = ds_dict.map(
    tokenize_function, 
    batched=True, 
    batch_size=None)

Map:   0%|          | 0/19710 [00:00<?, ? examples/s]

Map:   0%|          | 0/9855 [00:00<?, ? examples/s]

Map:   0%|          | 0/9855 [00:00<?, ? examples/s]

In [120]:
def inspect_tokens(tokenizer, encoded_text: dict):
    '''Prints the provided encoded text as its original text and as its tokenized form.
        - tokenizer is an instantiated huggingface tokenizer (sub-subclass of PreTrainedTokenizerBase)
        - encoded_text is the dict created from one element of a huggingface dataset
        '''
    vocab = tokenizer.get_vocab()
    inverse_vocab = {v: k for (k, v) in vocab.items()}

    tokens_list = [inverse_vocab[i] for i in encoded_text['input_ids']]
    tokens_list_attention = [tokens_list[i] for i in range(len(tokens_list)) if (encoded_text['attention_mask'][i] == 1)]

    print("-"*50)
    print(f"Original text:\n\t{encoded_text['text']}", end="\n\n")
    print(f"Label:\t{encoded_text['label']}", end="\n\n")
    print(f"Tokenized form:\n\t{' '.join(tokens_list)}", end="\n\n")
    print(f"Tokens as a list:\n\t{tokens_list}", end="\n\n")
    print(f"Tokens as a list, attention mask applied:\n\t{tokens_list_attention}", end="\n\n")

In [121]:
inspect_tokens(tokenizer, ds_tokenized['train'][27])
inspect_tokens(tokenizer, ds_tokenized['test'][42])

--------------------------------------------------
Original text:
	 Birthday time is over! Now go make up for all the work you missed when you were taking your nap.  Many happy returns. 

Label:	1

Tokenized form:
	[CLS] birthday time is over ! now go make up for all the work you missed when you were taking your nap . many happy returns . [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]

Tokens as a list:
	['[CLS]', 'birthday', 'time', 'is

## 1.4 - Model

Create model from pre-trained 🤗 transformer.

In [122]:
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)

Downloading pytorch_model.bin:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_transform.weight', 'vocab_projector.weight', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias', 'pre_classifier

Setup training arguments:

In [131]:
start_time = pd.Timestamp.now().strftime(r'%Y%m%d_%H%M%S')  # yyyymmdd_hhmmss
run_name = f"basic_distilbert_{start_time}"

training_args = TrainingArguments(
    # model output
    run_name=run_name,
    output_dir=MODEL_DIR / run_name,
    save_strategy='epoch',
    save_total_limit=3,
    # training hyperparams
    num_train_epochs=5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    #gradient_accumulation_steps=4,
    #gradient_checkpointing=True,
    weight_decay=0.01,
    # evaluation during training
    evaluation_strategy='epoch',
    logging_strategy='epoch',
    log_level='warning',
)

Establish evaluation metrics:

In [134]:
# setup training / evaluation metric
#   Docs: https://huggingface.co/docs/evaluate/package_reference/main_classes#evaluate.combine
#   Each of these metrics corresponds to a script from huggingface, below are the links for each script.
#       accuracy:       https://huggingface.co/spaces/evaluate-metric/accuracy
#       f1:             https://huggingface.co/spaces/evaluate-metric/f1
#       precision:      https://huggingface.co/spaces/evaluate-metric/precision
#       recall:         https://huggingface.co/spaces/evaluate-metric/recall
metric_list = ['accuracy', 'f1', 'precision', 'recall']

metric = evaluate.combine(evaluations=metric_list)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/6.77k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.55k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/7.36k [00:00<?, ?B/s]

Finally, setup the 🤗 Trainer:

In [137]:
time_training_start = pd.Timestamp.now()

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=ds_tokenized['train'],
    eval_dataset=ds_tokenized['test'],
    compute_metrics=compute_metrics
)

result = trainer.train()

time_training_stop = pd.Timestamp.now()
time_training = time_training_stop - time_training_start

print("\nTraining duration:", str(time_training))



  0%|          | 0/3080 [00:00<?, ?it/s]

{'loss': 0.4413, 'learning_rate': 4e-05, 'epoch': 1.0}


  0%|          | 0/308 [00:00<?, ?it/s]

{'eval_loss': 0.42805027961730957, 'eval_accuracy': 0.8305428716387621, 'eval_f1': 0.0324449594438007, 'eval_precision': 0.7368421052631579, 'eval_recall': 0.016587677725118485, 'eval_runtime': 17.4104, 'eval_samples_per_second': 566.04, 'eval_steps_per_second': 17.691, 'epoch': 1.0}
{'loss': 0.373, 'learning_rate': 3e-05, 'epoch': 2.0}


  0%|          | 0/308 [00:00<?, ?it/s]

{'eval_loss': 0.446800172328949, 'eval_accuracy': 0.8272957889396245, 'eval_f1': 0.3304484657749803, 'eval_precision': 0.4918032786885246, 'eval_recall': 0.24881516587677724, 'eval_runtime': 17.4796, 'eval_samples_per_second': 563.802, 'eval_steps_per_second': 17.621, 'epoch': 2.0}
{'loss': 0.277, 'learning_rate': 2e-05, 'epoch': 3.0}


  0%|          | 0/308 [00:00<?, ?it/s]

{'eval_loss': 0.5202162861824036, 'eval_accuracy': 0.8321664129883308, 'eval_f1': 0.26292335115864524, 'eval_precision': 0.5305755395683454, 'eval_recall': 0.17476303317535544, 'eval_runtime': 17.5137, 'eval_samples_per_second': 562.703, 'eval_steps_per_second': 17.586, 'epoch': 3.0}
{'loss': 0.2014, 'learning_rate': 1e-05, 'epoch': 4.0}


  0%|          | 0/308 [00:00<?, ?it/s]

{'eval_loss': 0.6362824440002441, 'eval_accuracy': 0.8148148148148148, 'eval_f1': 0.33756805807622503, 'eval_precision': 0.43580131208997186, 'eval_recall': 0.2754739336492891, 'eval_runtime': 17.5254, 'eval_samples_per_second': 562.326, 'eval_steps_per_second': 17.574, 'epoch': 4.0}
{'loss': 0.1495, 'learning_rate': 0.0, 'epoch': 5.0}


  0%|          | 0/308 [00:00<?, ?it/s]

{'eval_loss': 0.7720897197723389, 'eval_accuracy': 0.8077118214104515, 'eval_f1': 0.34041072050121823, 'eval_precision': 0.41265822784810124, 'eval_recall': 0.2896919431279621, 'eval_runtime': 17.5274, 'eval_samples_per_second': 562.261, 'eval_steps_per_second': 17.572, 'epoch': 5.0}
{'train_runtime': 588.4075, 'train_samples_per_second': 167.486, 'train_steps_per_second': 5.234, 'train_loss': 0.2884544446870878, 'epoch': 5.0}

Training duration: 0 days 00:09:48.664755


Save the trained model:

In [139]:
trainer.save_model()    # saves to self.args.output_dir

## 1.5 - Evaluate

In [144]:
final_metrics = {}
final_metrics['train'] = trainer.evaluate(eval_dataset=ds_tokenized['train'], metric_key_prefix='final_train')
final_metrics['test']= trainer.evaluate(eval_dataset=ds_tokenized['test'], metric_key_prefix='final_test')
final_metrics['valid'] = trainer.evaluate(eval_dataset=ds_tokenized['valid'], metric_key_prefix='validation')

  0%|          | 0/616 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

  0%|          | 0/308 [00:00<?, ?it/s]

In [162]:
for split in final_metrics:
    print(f"\n{split.upper():->10}{'-'*15}")
    for k, v in final_metrics[split].items():
        print(f"{v:>10.3f} - {k}")
    print("-"*25)


-----TRAIN---------------
     0.112 - final_train_loss
     0.962 - final_train_accuracy
     0.877 - final_train_f1
     0.972 - final_train_precision
     0.799 - final_train_recall
    34.720 - final_train_runtime
   567.687 - final_train_samples_per_second
    17.742 - final_train_steps_per_second
     5.000 - epoch
-------------------------

------TEST---------------
     0.772 - final_test_loss
     0.808 - final_test_accuracy
     0.340 - final_test_f1
     0.413 - final_test_precision
     0.290 - final_test_recall
    17.456 - final_test_runtime
   564.549 - final_test_samples_per_second
    17.644 - final_test_steps_per_second
     5.000 - epoch
-------------------------

-----VALID---------------
     0.793 - validation_loss
     0.803 - validation_accuracy
     0.331 - validation_f1
     0.396 - validation_precision
     0.284 - validation_recall
    17.408 - validation_runtime
   566.124 - validation_samples_per_second
    17.693 - validation_steps_per_second
     5.000 