In [9]:
import pandas as pd
import matplotlib.pyplot as plt
import json
import tensorflow as tf
from tensorflow import keras
from transformers import BertTokenizer, TFBertModel

In [10]:
def parse_jsonl(path: str) -> pd.DataFrame:
    data = {}
    with open(path, 'r') as f:
        for (i, line) in enumerate(f):
            line = json.loads(line)
            line["conversation_id"] = [i] * len(line["messages"])
            for k in line.keys():
                if k == "players":
                    continue
                if k == "game_id":
                    line[k] = [line[k]] * len(line["messages"])
                data[k] = data.get(k, []) + line[k]
    return pd.DataFrame(data)

In [11]:
def clean_data(df: pd.DataFrame) -> pd.DataFrame:
    df.loc[df["receiver_labels"] == "NOANNOTATION", "receiver_labels"] = pd.NA
    df["receiver_labels"] = df["receiver_labels"].astype(bool)
    return df

In [12]:
def pipe(*funcs):
    out = funcs[0]
    for func in funcs[1:]:
        out = func(out)
    return out

In [13]:
def docs_to_vocab_ids(tokenized_texts_list):
    """
    converting a list of strings to a list of lists of word ids
    """
    texts_vocab_ids = []
    text_labels = []
    valid_example_list = []
    for i, token_list in enumerate(tokenized_texts_list):

        # Get the vocab id for each token in this doc ([UNK] if not in vocab)
        vocab_ids = []
        for token in list(token_list.numpy()):
            decoded = token.decode('utf-8', errors='ignore')
            if decoded in vocab_dict:
                vocab_ids.append(vocab_dict[decoded])
            else:
                vocab_ids.append(vocab_dict['[UNK]'])

        # Truncate text to max length, add padding up to max length
        vocab_ids = vocab_ids[:MAX_SEQUENCE_LENGTH]
        n_padding = (MAX_SEQUENCE_LENGTH - len(vocab_ids))
        # For simplicity in this model, we'll just pad with unknown tokens
        vocab_ids += [vocab_dict['[UNK]']] * n_padding
        # Add this example to the list of converted docs
        texts_vocab_ids.append(vocab_ids)

        if i % 5000 == 0:
            print('Examples processed: ', i)

    print('Total examples: ', i)
    return np.array(texts_vocab_ids)

In [14]:
MAX_SEQUENCE_LENGTH = 10000

In [15]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
bert_model = TFBertModel.from_pretrained('bert-base-cased')
# train_encodings = bert_tokenizer(train_texts, truncation=True, padding=True, max_length=max_length, return_tensors='tf')
# valid_encodings = bert_tokenizer(valid_texts, truncation=True, padding=True, max_length=max_length, return_tensors='tf')
# test_encodings = bert_tokenizer(test_texts, truncation=True, padding=True, max_length=max_length, return_tensors='tf')

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions w

# data explortion

In [16]:
train_df = pd.read_json("../data/train.jsonl", lines=True)
train_df

Unnamed: 0,messages,sender_labels,receiver_labels,speakers,receivers,absolute_message_index,relative_message_index,seasons,years,game_score,game_score_delta,players,game_id
0,[Germany!\n\nJust the person I want to speak w...,"[True, True, True, True, True, True, True, Tru...","[True, True, True, True, NOANNOTATION, NOANNOT...","[italy, germany, italy, germany, italy, italy,...","[germany, italy, germany, italy, germany, germ...","[74, 76, 86, 87, 89, 92, 97, 117, 119, 121, 12...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[Spring, Spring, Spring, Spring, Spring, Sprin...","[1901, 1901, 1901, 1901, 1901, 1901, 1901, 190...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[italy, germany]",1
1,[Hello there! What's your general plan for thi...,"[True, False, True, False, True, True, True, T...","[True, True, True, True, True, NOANNOTATION, T...","[austria, italy, austria, italy, italy, austri...","[italy, austria, italy, austria, austria, ital...","[1, 67, 71, 73, 98, 99, 101, 179, 181, 185, 18...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[Spring, Spring, Spring, Spring, Spring, Sprin...","[1901, 1901, 1901, 1901, 1901, 1901, 1901, 190...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 5, 4, 4, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 1, -1, -...","[italy, austria]",1
2,[Buongiorno! \nBe kinda nice to know if you're...,"[True, True, False, True, True, True, True, Tr...","[True, False, True, False, True, True, NOANNOT...","[russia, italy, russia, italy, russia, italy, ...","[italy, russia, italy, russia, italy, russia, ...","[11, 50, 52, 57, 61, 66, 77, 85, 96, 102, 116,...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[Spring, Spring, Spring, Spring, Spring, Sprin...","[1901, 1901, 1901, 1901, 1901, 1901, 1901, 190...","[4, 3, 4, 3, 4, 3, 4, 3, 3, 3, 4, 3, 3, 4, 4, ...","[1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1...","[italy, russia]",1
3,[Hey italy! good luck this game. I'm guessing ...,"[True, False, True, True, True, True, True, Tr...","[NOANNOTATION, True, True, False, True, True, ...","[england, italy, england, england, england, it...","[italy, england, italy, italy, italy, england,...","[32, 95, 106, 107, 108, 110, 113, 125, 126, 12...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[Spring, Spring, Spring, Spring, Spring, Sprin...","[1901, 1901, 1901, 1901, 1901, 1901, 1901, 190...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[italy, england]",1
4,[Hello Italy what’s up what are your thoughts ...,"[True, False, False, True, True, True, True, T...","[NOANNOTATION, True, True, True, True, True, N...","[turkey, italy, italy, italy, turkey, italy, t...","[italy, turkey, turkey, turkey, italy, turkey,...","[45, 94, 103, 150, 154, 178, 192, 194, 195, 19...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[Spring, Spring, Spring, Spring, Fall, Fall, F...","[1901, 1901, 1901, 1901, 1901, 1901, 1901, 190...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 5, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 1...","[italy, turkey]",1
...,...,...,...,...,...,...,...,...,...,...,...,...,...
184,[Greetings Sultan!\n\nAs your neighbor I would...,"[False, True, False, True, True, True, True, T...","[True, True, True, True, True, True, True, Tru...","[russia, turkey, russia, russia, russia, turke...","[turkey, russia, turkey, turkey, turkey, russi...","[78, 107, 145, 370, 371, 374, 415, 420, 495, 4...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]","[Spring, Spring, Spring, Spring, Spring, Sprin...","[1901, 1901, 1901, 1902, 1902, 1902, 1902, 190...","[4, 3, 4, 5, 5, 4, 5, 4, 5, 3, 7]","[1, -1, 1, 1, 1, -1, 1, -1, 2, -2, 7]","[russia, turkey]",10
185,[Greetings My Good Frenchman! \n\nHow are your...,"[True, True, True, True, True, True, True, Fal...","[True, True, False, True, True, True, True, Tr...","[russia, france, russia, russia, france, franc...","[france, russia, france, france, russia, russi...","[75, 115, 147, 176, 205, 206, 254, 285, 306, 5...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[Spring, Spring, Spring, Fall, Fall, Fall, Fal...","[1901, 1901, 1901, 1901, 1901, 1901, 1901, 190...","[4, 3, 4, 4, 3, 3, 4, 4, 5, 5, 3, 5, 5, 3, 7, 1]","[1, -1, 1, 1, -1, -1, 1, -1, 1, 2, -2, 2, 2, -...","[russia, france]",10
186,"[Hey, Hello]","[True, True]","[True, True]","[england, turkey]","[turkey, england]","[7, 61]","[0, 1]","[Spring, Spring]","[1901, 1901]","[3, 3]","[0, 0]","[england, turkey]",10
187,"[Hello France, world you like to discuss worki...","[False, True, True, True, True, True, True, Tr...","[True, True, True, True, True, True, False, Tr...","[england, france, england, england, england, f...","[france, england, france, france, france, engl...","[0, 5, 6, 14, 17, 19, 45, 49, 50, 51, 52, 380,...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...","[Spring, Spring, Spring, Spring, Spring, Sprin...","[1901, 1901, 1901, 1901, 1901, 1901, 1901, 190...","[3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 5, 4, 3]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 1, -...","[england, france]",10


In [17]:
train_df = (parse_jsonl("../data/train.jsonl"))
test_df = (parse_jsonl("../data/test.jsonl"))
validation_df = (parse_jsonl("../data/validation.jsonl"))

In [20]:
train_df

Unnamed: 0,messages,sender_labels,receiver_labels,speakers,receivers,absolute_message_index,relative_message_index,seasons,years,game_score,game_score_delta,game_id,conversation_id
0,Germany!\n\nJust the person I want to speak wi...,True,True,italy,germany,74,0,Spring,1901,3,0,1,0
1,"You've whet my appetite, Italy. What's the sug...",True,True,germany,italy,76,1,Spring,1901,3,0,1,0
2,👍,True,True,italy,germany,86,2,Spring,1901,3,0,1,0
3,It seems like there are a lot of ways that cou...,True,True,germany,italy,87,3,Spring,1901,3,0,1,0
4,"Yeah, I can’t say I’ve tried it and it works, ...",True,NOANNOTATION,italy,germany,89,4,Spring,1901,3,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
13127,Is there any way of me actually ending this co...,True,True,france,england,380,11,Fall,1902,4,-1,10,187
13128,Can we agree on peace? What are your demands?,True,True,france,england,433,12,Fall,1902,4,-1,10,187
13129,"Neutrality in exchange for current holdings, S...",True,False,england,france,434,13,Fall,1902,5,1,10,187
13130,"Thats a bit too much, can I keep Spain and i h...",True,True,france,england,437,14,Fall,1902,4,-1,10,187


We can see that there are some `receiver_labels` that have the value `NOANNOTATION`. this is not a bool value 

In [16]:
round(len(train_df[train_df["receiver_labels"] == "NOANNOTATION"].index) /  len(train_df.index), 4) * 100

8.43

About 8.43% of the rows have `NOANNOTATION`

It is a small amount that we would removed, but the process may not use the `receiver_labels` columns.

The columns that may be used is `messages` and `sender_labels`

# Test 1

In [None]:
def create_bert_multiclass_model(checkpoint = 'bert-base-cased',
                                 num_classes = 20,
                                 hidden_size = 201,
                                 dropout=0.3,
                                 learning_rate=0.00005):
    """
    Build a simple classification model with BERT. Use the Pooler Output for classification purposes.
    """
    input_ids = tf.keras.layers.Input(shape=(MAX_SEQUENCE_LENGTH,), dtype=tf.int64, name='input_ids_layer')
    token_type_ids = tf.keras.layers.Input(shape=(MAX_SEQUENCE_LENGTH,), dtype=tf.int64, name='token_type_ids_layer')
    attention_mask = tf.keras.layers.Input(shape=(MAX_SEQUENCE_LENGTH,), dtype=tf.int64, name='attention_mask_layer')

    classification_output = pipe(
        {
            'input_ids': input_ids,
            'token_type_ids': token_type_ids,
            'attention_mask': attention_mask
        },                                                                              # input
        TFBertModel.from_pretrained(checkpoint, trainable=True),                        # bert model
        lambda x: x[1],                                                                 # pooler output
        tf.keras.layers.Dense(hidden_size, activation='relu', name='hidden_layer'),     # dense layer
        tf.keras.layers.Dropout(dropout, name="dropout_layer"),                         # dropout layer
        tf.keras.layers.Dense(num_classes, activation='softmax', name='ouput_layer')    # ouput layer
    )

    classification_model = tf.keras.Model(inputs=[input_ids, token_type_ids, attention_mask], outputs=classification_output)

    # Compile the model with categorical cross-entropy loss and an optimizer
    classification_model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        metrics=['accuracy']
    )

    return classification_model

In [None]:
def create_bert_cls_model(bert_base_model,
                          max_sequence_length=MAX_SEQUENCE_LENGTH,
                          hidden_size = 100,
                          dropout=0.3,
                          learning_rate=0.00005):
    """
    Build a simple classification model with BERT. Use the CLS Token output for classification purposes.
    """
    bert_base_model.trainable = True

    input_ids = tf.keras.layers.Input(shape=(max_sequence_length,), dtype=tf.int64, name='input_ids_layer')
    token_type_ids = tf.keras.layers.Input(shape=(max_sequence_length,), dtype=tf.int64, name='token_type_ids_layer')
    attention_mask = tf.keras.layers.Input(shape=(max_sequence_length,), dtype=tf.int64, name='attention_mask_layer')

    bert_inputs = {'input_ids': input_ids,
                   'token_type_ids': token_type_ids,
                   'attention_mask': attention_mask}

    classification_ouput = pipe(
        bert_model(bert_inputs)[0][:, 0, :],
        tf.keras.layers.Dense(hidden_size, activation='relu', name='hidden_layer'),
        tf.keras.layers.Dropout(dropout),
        tf.keras.layers.Dense(1, activation='sigmoid', name='classification_layer')
    )

    classification_model = tf.keras.Model(inputs=[input_ids, token_type_ids, attention_mask], outputs=[classification_ouput])

    # Compile the model
    classification_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
                    loss='binary_crossentropy',
                    metrics=['accuracy'])

    return classification_model

# TEST 2

In [None]:
def create_lstm_model():
    model = keras.Sequential([
        keras.layers.Embedding()
    ])