In [None]:
!pip install simpletransformers

## Process data, assign labels
### Run this before running any training scripts

In [None]:
labeled_df = pd.read_csv('/path/to/dataset/data_final.csv')
gap_df = pd.read_csv('/path/to/gap-dataset/gap-test.tsv', sep='\t')

In [None]:
import pandas as pd
import ast

def get_answer_index(ans_list):
    cand_list = ast.literal_eval(ans_list[0])
    answer = ans_list[1]
    if answer in cand_list:
        return cand_list.index(answer)
    else:
      print("Answer not found in sentence; will be dropped from dataset")
      return -1

start_size = len(labeled_df)
labeled_df['Num candidates'] = labeled_df['Possible answers'].apply(lambda x: len(ast.literal_eval(x)))
labeled_df = labeled_df[labeled_df['Num candidates'] > 1]
labeled_df['Answer index'] = labeled_df[['Possible answers', 'Answer']].apply(get_answer_index, axis=1)
labeled_df = labeled_df[labeled_df['Answer index'] != -1]
labeled_df['Answer index'] = labeled_df['Answer index'].apply(lambda x: int(x) if x == 0 or x == 1 else 2)
labeled_df.rename(columns={'Answer index': 'Label'}, inplace=True)
print("Using {} samples out of {}...".format(len(labeled_df), start_size))
#labeled_df.to_csv('/content/drive/My Drive/NLP/data_labeled.csv')
labeled_df

Answer not found in sentence; will be dropped from dataset
Using 6220 samples out of 6330...


Unnamed: 0.2,Unnamed: 0,Unnamed: 0.1,Text,Pronoun,Pronoun-offset,Answer,Answer-offset,Possible answers,Num candidates,Label
0,0,0,"After the war, Herbert attended the University...",he,67,Herbert,67,"['Herbert', 'Beverly']",2,0
1,1,1,Herbert was appalled to learn of senator McCar...,he,138,Herbert,138,"['Herbert', 'McCarthy']",2,0
2,2,2,Herbert has attracted a sometimes fanatical fa...,he,96,Herbert,96,"['Herbert', 'Herbert']",2,0
3,3,3,"In 1887, after the death of his brother and a ...",he,119,Julian,119,"['Norris', 'Julian', 'Zola']",3,1
4,4,4,"The publication of Buechner's third novel, The...",he,125,Buechner,125,"['Buechner', 'Buechner']",2,0
...,...,...,...,...,...,...,...,...,...,...
6325,75,75,Then slice the cabbage into small pieces and w...,it,51,cabbage,16,"['cabbage', 'pieces']",2,0
6326,76,76,Add a little bit of salt and give it a rough mix.,it,8,salt,21,"['bit', 'salt', 'mix']",3,1
6327,77,77,Now add required amount of mayonnaise to the b...,it,68,mayonnaise,28,"['amount', 'mayonnaise', 'bowl']",3,1
6328,78,78,After the dry is all mixed up add the eggs and...,it,64,dry,11,"['dry', 'eggs', 'oil']",3,0


##BERT implementation
###   Multiclass classification:
* Use all samples that have at least two answer candidates
*     Label is the index of the correct answer in the list of candidates
*     Label = 0 (A), 1 (B), or 2 (Neither)





#### Train and test on our dataset

In [None]:
from simpletransformers.classification import ClassificationModel
import pandas as pd
import logging

alldata_df = labeled_df

# Shuffle
alldata_df = alldata_df.sample(frac=1)

# Create 60%, 20%, 20% split of train, eval, test 
train_df = alldata_df[:int(len(alldata_df)*0.8)]
test_df = alldata_df[int(len(alldata_df)*0.8):]
train_df = train_df[:int(len(train_df)*0.75)]
eval_df = train_df[int(len(train_df)*0.75):]

logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

# Train and Evaluation data needs to be in a Pandas Dataframe of two columns. The first column is the text with type str, and the second column is the label with type int.
train_dict = {'text': train_df['Text'], 'label': train_df['Label']}
train_df_final = pd.DataFrame(train_dict)

eval_dict = {'text': eval_df['Text'], 'label': eval_df['Label']}
eval_df_final = pd.DataFrame(eval_dict)

# Create a ClassificationModel
model = ClassificationModel('bert', 'bert-base-cased', num_labels=3, args={'reprocess_input_data': True, 'overwrite_output_dir': True, 'num_train_epochs': 20})
print(train_df_final)
# Train the model
model.train_model(train_df_final, output_dir='/content/drive/My Drive/NLP/models/bert-base-cased')

# Evaluate the model
result, model_outputs, wrong_predictions = model.eval_model(eval_df_final)


Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

                                                   text  label
1812  While at Harvard, McPherson studied fiction wr...      0
4786  In a 2008 interview with Charlie, Matthiessen ...      1
2324  After the divorce, the actor was determined to...      0
2766  Authorization for such a promotion "with the a...      0
4358  As the leading publisher of children's books, ...      0
...                                                 ...    ...
4496  Gilbert, 4th Earl of Pembroke , married Marjor...      0
414   "Blue Jay Way" was one of several songs that H...      0
5404  Writing in The New York Times, Anatole claimed...      0
3841  The book includes an anonymous anecdote about ...      1
3151  In 1998, lawyers for Paula released court docu...      1

[3732 rows x 2 columns]


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3732.0), HTML(value='')))




HBox(children=(HTML(value='Epoch'), FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(HTML(value='Running Epoch 0 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))






HBox(children=(HTML(value='Running Epoch 1 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 2 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 3 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 4 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 5 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 6 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 7 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 8 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 9 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 10 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 11 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 12 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 13 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 14 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 15 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 16 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 17 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 18 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 19 of 20'), FloatProgress(value=0.0, max=467.0), HTML(value='')))





INFO:simpletransformers.classification.classification_model: Training of bert model complete. Saved to /content/drive/My Drive/NLP/models/bert-base-cased.
  "Dataframe headers not specified. Falling back to using column 0 as text and column 1 as labels."
INFO:simpletransformers.classification.classification_model: Converting to features started. Cache is not used.


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=933.0), HTML(value='')))




HBox(children=(HTML(value='Running Evaluation'), FloatProgress(value=0.0, max=117.0), HTML(value='')))

INFO:simpletransformers.classification.classification_model:{'mcc': 1.0, 'eval_loss': 1.3362288501590758e-05}





In [None]:
from sklearn.metrics import confusion_matrix, classification_report
# Predict and get metrics
predictions, raw_outputs = model.predict(test_df['Text'].tolist())
print("\n{}".format(confusion_matrix(test_df['Label'].tolist(), predictions)))
print("\n{}".format(classification_report(test_df['Label'].tolist(), predictions, digits=4)))

INFO:simpletransformers.classification.classification_model: Converting to features started. Cache is not used.


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1245.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=156.0), HTML(value='')))



[[761  81   9]
 [ 69 253   6]
 [ 10  27  29]]

              precision    recall  f1-score   support

           0     0.9060    0.8942    0.9001       851
           1     0.7008    0.7713    0.7344       328
           2     0.6591    0.4394    0.5273        66

    accuracy                         0.8378      1245
   macro avg     0.7553    0.7017    0.7206      1245
weighted avg     0.8388    0.8378    0.8367      1245



####Train on our dataset, test on GAP Coreference

In [None]:
from simpletransformers.classification import ClassificationModel
import pandas as pd
import logging

alldata_df = labeled_df

# Shuffle
alldata_df = alldata_df.sample(frac=1)

# Create 80% train, 20% eval 
train_df = alldata_df[:int(len(alldata_df)*0.8)]
eval_df = alldata_df[int(len(alldata_df)*0.8):]

logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)

# Train and Evaluation data needs to be in a Pandas Dataframe of two columns. The first column is the text with type str, and the second column is the label with type int.
train_dict = {'text': train_df['Text'], 'label': train_df['Label']}
train_df_final = pd.DataFrame(train_dict)

eval_dict = {'text': eval_df['Text'], 'label': eval_df['Label']}
eval_df_final = pd.DataFrame(eval_dict)

# Create a ClassificationModel
model = ClassificationModel('bert', 'bert-base-cased', num_labels=3, args={'reprocess_input_data': True, 'overwrite_output_dir': True, 'num_train_epochs': 10})
print(train_df_final)
# Train the model
model.train_model(train_df_final, output_dir='/content/drive/My Drive/NLP/models/bert-base-cased-gap')

# Evaluate the model
result, model_outputs, wrong_predictions = model.eval_model(eval_df_final)


Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

                                                   text  label
5396  The president of the Juilliard School consulte...      0
3904  "Piano": Lucy annoys Schroeder while he is pla...      1
2766  Authorization for such a promotion "with the a...      0
12    Before the March on Rome, De even went so far ...      2
5978  Ralph said he and Sting were not paid for thei...      0
...                                                 ...    ...
4385  Evelyn was educated at home until the age of 1...      0
1685  Rowling has lived a "rags to riches" life in w...      0
4453  When Baudelaire returned from Belgium after hi...      1
2536  In 1996, Peter announced that he and Sharp wou...      0
2291  Rahman is the main judge and he is accompanied...      0

[4976 rows x 2 columns]


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4976.0), HTML(value='')))




HBox(children=(HTML(value='Epoch'), FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(HTML(value='Running Epoch 0 of 10'), FloatProgress(value=0.0, max=622.0), HTML(value='')))








HBox(children=(HTML(value='Running Epoch 1 of 10'), FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 2 of 10'), FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 3 of 10'), FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 4 of 10'), FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 5 of 10'), FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 6 of 10'), FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 7 of 10'), FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 8 of 10'), FloatProgress(value=0.0, max=622.0), HTML(value='')))




HBox(children=(HTML(value='Running Epoch 9 of 10'), FloatProgress(value=0.0, max=622.0), HTML(value='')))





INFO:simpletransformers.classification.classification_model: Training of bert model complete. Saved to /content/drive/My Drive/NLP/models/bert-base-cased-gap.
  "Dataframe headers not specified. Falling back to using column 0 as text and column 1 as labels."
INFO:simpletransformers.classification.classification_model: Converting to features started. Cache is not used.


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1244.0), HTML(value='')))




HBox(children=(HTML(value='Running Evaluation'), FloatProgress(value=0.0, max=156.0), HTML(value='')))

INFO:simpletransformers.classification.classification_model:{'mcc': 0.6414865360038053, 'eval_loss': 1.5143181399368517}





In [None]:
def assign_label(coref_list):
    if coref_list[0] == True:
        return 0
    elif coref_list[1] == True:
        return 1
    else:
        return 2

gap_df['Label'] = gap_df[['A-coref', 'B-coref']].apply(assign_label, axis=1)
gap_df

Unnamed: 0,ID,Text,Pronoun,Pronoun-offset,A,A-offset,A-coref,B,B-offset,B-coref,URL,Label
0,test-1,Upon their acceptance into the Kontinental Hoc...,His,383,Bob Suter,352,False,Dehner,366,True,http://en.wikipedia.org/wiki/Jeremy_Dehner,1
1,test-2,"Between the years 1979-1981, River won four lo...",him,430,Alonso,353,True,Alfredo Di St*fano,390,False,http://en.wikipedia.org/wiki/Norberto_Alonso,0
2,test-3,Though his emigration from the country has aff...,He,312,Ali Aladhadh,256,True,Saddam,295,False,http://en.wikipedia.org/wiki/Aladhadh,0
3,test-4,"At the trial, Pisciotta said: ``Those who have...",his,526,Alliata,377,False,Pisciotta,536,True,http://en.wikipedia.org/wiki/Gaspare_Pisciotta,1
4,test-5,It is about a pair of United States Navy shore...,his,406,Eddie,421,True,Rock Reilly,559,False,http://en.wikipedia.org/wiki/Chasers,0
...,...,...,...,...,...,...,...,...,...,...,...,...
1995,test-1996,"The sole exception was Wimbledon, where she pl...",She,479,Goolagong Cawley,400,True,Peggy Michel,432,False,http://en.wikipedia.org/wiki/Evonne_Goolagong_...,0
1996,test-1997,"According to news reports, both Moore and Fily...",her,338,Esther Sheryl Wood,263,True,Barbara Morgan,404,False,http://en.wikipedia.org/wiki/Hastings_Arthur_Wise,0
1997,test-1998,"In June 2009, due to the popularity of the Sab...",She,328,Kayla,364,True,Natasha Henstridge,412,False,http://en.wikipedia.org/wiki/Raya_Meddine,0
1998,test-1999,She was delivered to the Norwegian passenger s...,she,305,Irma,255,True,Bergen,274,False,http://en.wikipedia.org/wiki/SS_Irma_(1905),0


In [None]:
from sklearn.metrics import confusion_matrix, classification_report
# Predict and get metrics
predictions, raw_outputs = model.predict(gap_df['Text'])
print("\n{}".format(confusion_matrix(gap_df['Label'], predictions)))
print("\n{}".format(classification_report(gap_df['Label'], predictions, digits=4)))

INFO:simpletransformers.classification.classification_model: Converting to features started. Cache is not used.


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2000.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=250.0), HTML(value='')))



[[587 213 118]
 [459 252 144]
 [126  40  61]]

              precision    recall  f1-score   support

           0     0.5009    0.6394    0.5617       918
           1     0.4990    0.2947    0.3706       855
           2     0.1889    0.2687    0.2218       227

    accuracy                         0.4500      2000
   macro avg     0.3962    0.4010    0.3847      2000
weighted avg     0.4647    0.4500    0.4414      2000

