# NLP2022 - Homework 2


This notebook contains code for a fast processing of data and experiments execution for the second homework of the course Natural Language Processing 2022. It has been completely wrote by Dennis Rotondi 1834864 using the methodologies learned throughout the course.

In [None]:
# imports and deterministic stuff
import os, sys
sys.path.append(os.path.join("..")) #to access hw2 functions
sys.path.append(os.path.join("../..")) #to access model folder
os.environ['WANDB_NOTEBOOK_NAME'] = './nlp_hw2.ipynb' # to avoid a wandb warning
os.environ['TOKENIZERS_PARALLELISM'] = "false" # to avoid deadlock at traning time for the tokenizer

import torch
import numpy as np
import random
import pytorch_lightning as pl
from collections import Counter
import matplotlib.pyplot as plt
from utils import read_dataset
import wandb
from pytorch_lightning.loggers.wandb import WandbLogger

np.random.seed(0)
random.seed(0)
torch.cuda.manual_seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True  # Note that this Deterministic mode can have a performance impact
torch.backends.cudnn.benchmark = False
_ = pl.seed_everything(0)

# to have a better workflow using notebook https://stackoverflow.com/questions/5364050/reloading-submodules-in-ipython
# these commands allow to update the .py codes imported instead of re-importing everything every time.
%load_ext autoreload
%autoreload 2

## Dataset Analysis

As for the bonus exercise and hw1, I want to start by looking at the data I have to better understand how to proceed in the pre-processing operations. I've read that there are problems with some (sentence-ground truth) pairs, since we are not allowed to do any change I'll directly discharge them for the training phase if needed. I'll do my analysis mostly for the english dataset since it is mandatory and larger 

(it's possible to reproduce the experiments in different languages just changing the language parameter that follow from "EN" to {"ES", "FR"}).

In [None]:
language = "EN" # "ES" or "FR" if you want
data_file = f"../../data/{language}/train.json"

sentences, labels = read_dataset(data_file)
print(f"Number of training sentences ({language}): "+ str(len(sentences.keys())))
# I'm just playing with the field of a sentence_id to understand our data samples.
sentence_id = list(sentences.keys())[0]
print("## SENTENCE {} ##".format(sentence_id))
for key in sentences[sentence_id]:
    print(key)
    print(sentences[sentence_id][key])
print("## LABEL ##")
for key in labels[sentence_id]:
    print(key)
    print(labels[sentence_id][key])

# let's check and count the different frames and roles
verbatlas_frames = Counter()
predicate_roles = Counter()
pos_tags = Counter()

for k in labels:
    verbatlas_frames.update(labels[k]['predicates'])
    pos_tags.update(sentences[k]['pos_tags'])
    for idx in labels[k]['roles']:
        predicate_roles.update(labels[k]['roles'][idx])

In [None]:
print("## VF ##")
print(verbatlas_frames)
# list of frames in the training dataset
l_vf = list(verbatlas_frames.keys())
print(l_vf)
print(len(l_vf))
print("## RL ##")
print(predicate_roles)
p_r = list(predicate_roles.keys())
print(p_r)
print(len(p_r))
print("## PT ##")
ppl = list(pos_tags.keys())
print(ppl)
print(len(ppl))

In [None]:
del(predicate_roles["_"])
# to place it in a different position and not have overlapping labels
predicate_roles["experiencer"] = predicate_roles.pop("experiencer")
plt.figure(figsize=(26,10))
_ = plt.bar(predicate_roles.keys(), predicate_roles.values()) 
plt.title("Bar Plot of role frequency (without '_' label, 95% of them)") 
plt.show() #it's possible to notice that most of them are between size 7 and 30

We are clearly not using all the 466 [verbatlas](https://verbatlas.org/) frames but less than 3/4 of them: 303. Working with fewer clusters surely increases the overall performances because the system can only focus on a subset of them. In the next code cell I want to check if in the dev-set I do not have other frames.

In [None]:
dev_sentences, dev_labels = read_dataset(f"../../data/{language}/dev.json")
print(f"Number of training sentences ({language}): "+ str(len(dev_sentences.keys())))
for k in dev_labels:
    verbatlas_frames.update(dev_labels[k]['predicates'])
    for idx in dev_labels[k]['roles']:
        predicate_roles.update(dev_labels[k]['roles'][idx])

l_vf_dev = list(verbatlas_frames.keys())
print(len(l_vf_dev))

So there are only 4 more frames in the dev_set wrt the train_set, this information is useful for further consideration when I'll deal with the optional part of this homework.

Now that I'm starting to understand the samples, it's clear that our dataset does not need much pre-processing, since we already have words tokens and associated lemmas for each sentence. Some more useful statistics are on how long are the sentences on average, how many predicates they have and how the distribution of pos-tagging tokens correlate with roles and predicates. I'll rapidly compute them in what follows. 

In [None]:
from mergedeep import merge
sentences_length = list()
predicates_counter = list()
sentences = merge(sentences, labels)
pos_pre_corr = Counter()
pos_role_corr = Counter()

In [None]:
p_tags = list(pos_tags.keys())
for pt in pos_tags:
    pos_pre_corr.update({pt:0})
    pos_role_corr.update({pt:0})

for s in sentences:
    s_l = len(sentences[s]["lemmas"])
    roles = sentences[s]["roles"].keys()
    p_c = len(roles)
    sentences_length.append(s_l)
    predicates_counter.append(p_c)
    for pos, predicate in zip(sentences[s]["pos_tags"], sentences[s]["predicates"]):
        if predicate != "_":
            pos_pre_corr.update({pos:1})
    for r in roles:
        for pos, role in zip(sentences[s]["pos_tags"],sentences[s]["roles"][r]):
            if role != "_":
                pos_role_corr.update({pos:1})
    
sl_np=np.asarray(sentences_length)
pc_np=np.asarray(predicates_counter)

print("Sentences Length")
print("mean", sl_np.mean())
print("std", sl_np.std())
print("min", sl_np.min())
print("max", sl_np.max())

print("Predicates Counter")
print("mean", pc_np.mean())
print("std", pc_np.std())
print("min", pc_np.min())
print("max", pc_np.max())

plt.figure(figsize=(8,8)) 
_ = plt.hist(sl_np, bins = 'auto') 
plt.title("Histogram of sentences length available") 
plt.show()

plt.figure(figsize=(8,8)) 
_ = plt.hist(pc_np, bins = 5) 
plt.title("Histogram of predicate counts for each sentence") 
plt.show()

It's interesting to notice that (for our EN dataset) there are some sentences with 0 and some with 10 predicates, even if the average is slightly more than 2.


In [None]:
plt.figure(figsize=(26,10))
_ = plt.bar(pos_pre_corr.keys(), pos_pre_corr.values()) 
plt.title("Number of predicates for each pos_token") 
plt.show() #it's possible to notice that most of them are between size 7 and 30

print(pos_pre_corr)

A photo is worth a thousand words, if we are able to identify the pos_tag it's also very easy to understand we have a verb or not. (holy grail for task1)

In [None]:
plt.figure(figsize=(26,10))
_ = plt.bar(pos_role_corr.keys(), pos_role_corr.values()) 
plt.title("Number of predicates for each pos_token") 
plt.show() #it's possible to notice that most of them are between size 7 and 30

print(pos_role_corr)

Also in this case it's clear that there are pos tags that gives more information about which are the arguments, so I'll introduce them in my model if time will allow it.

Now that we have understood the importance of having a postag information, taking into account that the input sentence on which my work will be evaluated does not have pos-tag information, I have to retrieve them with an external library. It is important to understand how good is this library.

In [None]:
import spacy
from seqeval.metrics import accuracy_score, f1_score
from tqdm import tqdm
taggers = {"EN":"en_core_web_sm", "ES":"es_core_news_sm", "FR":"fr_core_news_sm"}
nlp = spacy.load(taggers[language])

def compute_metrics_postag(field: str):
    p_labels = list()
    predictions = list()
    ppl = list()
    pv = list()
    for s in tqdm(sentences):
        fr = ' '.join(sentences[s][field])
        doc = nlp(fr)
        for token, pos in zip(doc,sentences[s]["pos_tags"]):
            predictions.append(token.pos_)
            p_labels.append(pos)
            if pos ==  "VERB":
                pv.append(token.pos_)
                ppl.append(pos)
    acc = accuracy_score([p_labels], [predictions])
    f = f1_score([p_labels], [predictions])
    print("Accuracy, f1 on all the tokens")
    print(acc, f)
    # but as seen to solve task 1 we are more interested in identify verb tokens!
    accv = accuracy_score([ppl], [pv])
    fv = f1_score([ppl], [pv])
    print("Accuracy on VERB tokens")
    print(accv)

compute_metrics_postag("words")

In [None]:
# now we repeat the same experiment computing the pos_tag from the lemmas and NOT the words
compute_metrics_postag("lemmas")

we can conclude that since acc for verbs with words is much greater than the one with lemmas is better to embed using words. To be precise we are not using the most accurate spacy model for pos-tagging, still this is the most efficient (30mb instead of 400mb) and for this homework I'm not aiming to "top score" but to complete different pipelines in a reasonable time due to the fact that working with transformers requires a lot of resources in term of memory and time.

## Training

Now it's time to train our model. Pytorch-lightning allows that in such a way that it's easy to modularize everything and train with few lines of code all the different models. Moreover using wandb as logger I auto-plot the training evolution in high quality plots and it's also possible to save the training history of the different trials. This will be very useful for comparing the experiments in the report.   

In [None]:
from datasets_srl import SRL_DataModule
from implementation import HParams, SRL_34, SRL_234, SRL_1234
from dataclasses import dataclass, asdict
from pprint import pprint
from utils import read_dataset, evaluate_argument_classification, evaluate_argument_identification
from mergedeep import merge
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint
# these are some parameters that allow as I said to modularize the training. We need to store the hypermarameters of the model (lr, wd, ...), the language
# and the task on which we want to perform the training.
hparams = HParams()
languages = ["EN", "ES", "FR"]
tasks = ["34", "234", "1234"]
models = {"34": SRL_34, "234": SRL_234, "1234": SRL_1234}

language = languages[0]
task = tasks[2]
epochs = 100
SRL_Model = models[task]
hparams.language = language
hparams.task = task
hparams = asdict(hparams)
pprint(hparams)
# after reading the dataset I merge the two dicts (sentences and labels) since there is a field in common (predicate)
# and it's only a waste of space keeping it in memory 2 copies of it.
sentences = merge(*read_dataset("../../data/"+language+"/train.json"))
sentences_test = merge(*read_dataset("../../data/"+language+"/dev.json"))

data = SRL_DataModule(hparams, task, language, sentences, sentences_test)


In [None]:
model = SRL_Model(hparams=hparams, sentences_for_evaluation=sentences_test)
# Define the logger
# https://www.wandb.com/articles/pytorch-lightning-with-weights-biases.
# NOTE: to use wandb properly you need to login in wandb (need an account) 
# or use a different logger eg. TensorBoard, I'm used to this one so I'll go for it.
# to login: https://docs.wandb.ai/ref/cli/wandb-login
wandb.require("service")
wandb_logger = WandbLogger(project="SRL_"+task, log_model = True) # note not language in the project name, so we can compare different languages
wandb_logger.experiment.watch(model, log = 'False', log_freq = 100000)
# Define the trainer
metric_to_monitor =  "f1" # f1 of argument classification, also possible eg 'avg_val_loss'
mode = "max" #you want to maximixe or minimize the metric?
# we employ the early stopping technique to avoid hours of usuless training, pl gives it for free
early_stop_callback = EarlyStopping(monitor = metric_to_monitor, min_delta = 0.00, patience = 5, verbose = True, mode = mode)
# it is also useful to keep track of the best model during the epochs (if you remember I did all this manually last hw)or use a different logger,
# we have a callback even for this.
checkpoint_callback = ModelCheckpoint(
                        save_top_k = 1,
                        monitor = metric_to_monitor,
                        mode = mode,
                        dirpath = "../../model",
                        filename = "SRL_"+task+"_"+language+"-{epoch:02d}-{f1:.4f}",
                        verbose = True
                    )
# the trainer collect all the useful informations so far for the training 
trainer = pl.Trainer(logger = wandb_logger,
                    max_epochs = epochs, 
                    gpus = 1,
                    callbacks = [early_stop_callback, checkpoint_callback])    

save_ckpt_file = "../../model/SRL_{}_{}_last.ckpt".format(task, language)

Start the training without initialized weights, if you want to inizialize them skip this cell.

In [None]:
trainer.fit(model, data)
trainer.save_checkpoint(save_ckpt_file)

To continue the training it's possible to just increase the number of epochs and create a new trainer, also possible to fine tune the model for another language.

In [None]:
resume_ckpt = save_ckpt_file # or use another language / pre-trained model.
epochs += 10 # increase the maximum number of epochs
trainer = pl.Trainer(logger = wandb_logger,
                    max_epochs = epochs, 
                    gpus = 1,
                    callbacks = [early_stop_callback, checkpoint_callback],
                    resume_from_checkpoint = resume_ckpt)    
trainer.fit(model, data)
trainer.save_checkpoint(save_ckpt_file)

In [None]:
# if you have trained for enough epochs you can now finish logging with wandb to have your plot.
wandb.finish()

In [None]:
#to upload online the run after finishing (you will have a string like the one below, just execute in on terminal or here)
!wandb sync /home/dennis/Desktop/nlp2022-hw2/hw2/stud/wandb/offline-run-20220712_105118-2uum41pe

### Confusion Matrix Analysis for argument classification

Now that we have our trained model it's interesting to understand where it performs better and where not. Recalling that from the english dataset analysis these were the numbers of different roles:

{'_': 437392, 'agent': 7581, 'theme': 6593, 'patient': 2907, 'goal': 1463, 'topic': 1403, 'recipient': 837, 'beneficiary': 590, 'result': 577, 'stimulus': 367, 'experiencer': 319, 'attribute': 294, 'destination': 276, 'co-theme': 253, 'source': 229, 'location': 198, 'co-agent': 145, 'product': 95, 'instrument': 70, 'co-patient': 60, 'extent': 54, 'cause': 51, 'value': 45, 'time': 35, 'asset': 28, 'purpose': 25, 'material': 11}

by inspection could be possible to estimate a correlation between the number of samples and the results.

In [None]:
# import and utils for this evaluation
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from typing import Dict, List, Any
from utils import evaluate_argument_classification, evaluate_argument_identification

def flat_dict_roles(sentences: Dict[str, Dict[str, List[str]]]) -> List[Any]:
    list_tokens = list()
    for s in sentences:
        for l in sentences[s]["roles"]:
            list_tokens+=[token for token in sentences[s]["roles"][l]]
    return list_tokens

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
best_ckpt = "../../model/SRL_34_EN-epoch=55-f1=0.8581.ckpt"
# we load weights in a not strict fashion since I made some adjustment to the models during the way 
model = SRL_34.load_from_checkpoint(best_ckpt, strict=False)

In [None]:
predict = model.to(device).predict(sentences_test, require_ids=True)

In [None]:
print("AI")
print(evaluate_argument_identification(sentences_test, predict))
print("AC")
print(evaluate_argument_classification(sentences_test, predict))

flat_labels_s = flat_dict_roles(sentences_test)
flat_predictions_s = flat_dict_roles(predict)

all_labels = ['_', 'agent', 'theme', 'beneficiary', 'patient', 'topic', 'goal', 'recipient', 'co-theme', 'result', \
    'stimulus', 'experiencer', 'destination', 'value', 'attribute', 'location', 'source', 'cause', 'co-agent', \
    'time', 'co-patient', 'product', 'purpose', 'instrument', 'extent', 'asset', 'material']

cm = confusion_matrix(flat_labels_s, flat_predictions_s, labels=all_labels, normalize="true")
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=all_labels)
fig, ax = plt.subplots(figsize=(36,36))
disp.plot(ax=ax) #7162

As expected the fewer the samples the worst the performances.

### Task 234 - AMuSE-WSD extra

Since for task 1(234) and task 34 I've basically implemented a very robust architecture that for some time had the SOTA results, it's clear that repeating something similar for 234 will lead to best results. Instead of doing a boring almost copy-paste of the other models where this time we could replace the post-tag with the pre-computed WSD predicts or simply the predicate identification, I've decided to first try a different approach using the suggested architecture AMuSE-WSD that currently has SOTA for WSD in multiple languages.

What I decided to do is indeed a fun but sub-optimal procedure: I start by predicting the meaning with AMuSE-WSD that uses more frames that the ones we have in the english dataset (the one I'll use as reference), so at this point I fine tune the predictions, also to remove the OOV prediction, to the most common real values for the training dataset, in a certain sense we are fine-tuning a model by applying a "filter layer" to increase the predictions! 

 For doing so I'll use all the different dataset informations since this model is intended to be used on different languages.

In [None]:
from amuse import AMuSE_WSD_online
from utils import evaluate_predicate_disambiguation

# extracted above from english dataset analysis
frames_in_dataset = ['_', 'ASK_REQUEST', 'BENEFIT_EXPLOIT', 'PLAN_SCHEDULE', 'CARRY-OUT-ACTION', 'ESTABLISH', 'SIMPLIFY', 'PROPOSE', 'TAKE-INTO-ACCOUNT_CONSIDER', 'BEGIN', 'CIRCULATE_SPREAD_DISTRIBUTE', 'REFER', 'SHOW', 'PRECLUDE_FORBID_EXPEL', 'VIOLATE', 'VERIFY', 'CAUSE-SMT', 'ABSTAIN_AVOID_REFRAIN', 'TRANSMIT', 'SEE', 'SUMMON', 'GUARANTEE_ENSURE_PROMISE', 'RECEIVE', 'INCREASE_ENLARGE_MULTIPLY', 'DECREE_DECLARE', 'PAY', 'CAUSE-MENTAL-STATE', 'CAGE_IMPRISON', 'HURT_HARM_ACHE', 'MOVE-BACK', 'EXIST_LIVE', 'CALCULATE_ESTIMATE', 'ATTRACT_SUCK', 'EXIST-WITH-FEATURE', 'INFORM', 'EXPLAIN', 'SPEAK', 'SEEM', 'MISS_OMIT_LACK', 'DECIDE_DETERMINE', 'ASSIGN-SMT-TO-SMN', 'FOLLOW_SUPPORT_SPONSOR_FUND', 'MOVE-ONESELF', 'WORSEN', 'AMELIORATE', 'AGREE_ACCEPT', 'MOVE-SOMETHING', 'PUT_APPLY_PLACE_PAVE', 'ADJUST_CORRECT', 'INCLUDE-AS', 'CONTINUE', 'SPEED-UP', 'LOAD_PROVIDE_CHARGE_FURNISH', 'REMEMBER', 'FINISH_CONCLUDE_END', 'REPEAT', 'HELP_HEAL_CARE_CURE', 'IMPLY', 'OPPOSE_REBEL_DISSENT', 'STRENGTHEN_MAKE-RESISTANT', 'AROUSE_WAKE_ENLIVEN', 'RECORD', 'INCITE_INDUCE', 'GIVE_GIFT', 'DESTROY', 'REQUIRE_NEED_WANT_HOPE', 'ANALYZE', 'COME-AFTER_FOLLOW-IN-TIME', 'BELIEVE', 'GO-FORWARD', 'CANCEL_ELIMINATE', 'RECOGNIZE_ADMIT_IDENTIFY', 'CHOOSE', 'REPRESENT', 'TREAT', 'OBLIGE_FORCE', 'STOP', 'REACT', 'HAPPEN_OCCUR', 'OVERCOME_SURPASS', 'AFFECT', 'CREATE_MATERIALIZE', 'ALLY_ASSOCIATE_MARRY', 'MANAGE', 'OPEN', 'ORIENT', 'ANSWER', 'INFLUENCE', 'COMBINE_MIX_UNITE', 'LEAD_GOVERN', 'STAY_DWELL', 'WELCOME', 'AMASS', 'PREPARE', 'ORGANIZE', 'HAVE-A-FUNCTION_SERVE', 'GIVE-UP_ABOLISH_ABANDON', 'SORT_CLASSIFY_ARRANGE', 'GIVE-BIRTH', 'PUBLISH', 'USE', 'POSSESS', 'BEHAVE', 'WORK', 'SUBJECTIVE-JUDGING', 'APPROVE_PRAISE', 'ATTEND', 'LEAVE_DEPART_RUN-AWAY', 'CATCH', 'OBEY', 'SATISFY_FULFILL', 'UNDERSTAND', 'ACHIEVE', 'TRY', 'ATTACH', 'INTERPRET', 'DELAY', 'REDUCE_DIMINISH', 'UNDERGO-EXPERIENCE', 'RETAIN_KEEP_SAVE-MONEY', 'ARRIVE', 'REFUSE', 'IMAGINE', 'HARMONIZE', 'PARTICIPATE', 'HIRE', 'RESULT_CONSEQUENCE', 'FOCUS', 'CONTAIN', 'MOUNT_ASSEMBLE_PRODUCE', 'PROVE', 'WRITE', 'RESTRAIN', 'TOLERATE', 'ACCOMPANY', 'DISCUSS', 'RESTORE-TO-PREVIOUS/INITIAL-STATE_UNDO_UNWIND', 'TEACH', 'CHANGE-APPEARANCE/STATE', 'INVERT_REVERSE', 'RELY', 'SIGNAL_INDICATE', 'LEARN', 'ACCUSE', 'PERFORM', 'AFFIRM', 'REMOVE_TAKE-AWAY_KIDNAP', 'WATCH_LOOK-OUT', 'GROUND_BASE_FOUND', 'LEAVE-BEHIND', 'FACE_CHALLENGE', 'CHANGE_SWITCH', 'SHARE', 'APPLY', 'ARGUE-IN-DEFENSE', 'DIRECT_AIM_MANEUVER', 'WAIT', 'HEAR_LISTEN', 'CONSIDER', 'LIKE', 'FIGHT', 'PROTECT', 'AUTHORIZE_ADMIT', 'DIVERSIFY', 'PRESERVE', 'LOCATE-IN-TIME_DATE', 'SEND', 'ORDER', 'SEARCH', 'REGRET_SORRY', 'EMPHASIZE', 'CELEBRATE_PARTY', 'TAKE-SHELTER', 'HOST_MEAL_INVITE', 'REPLACE', 'THINK', 'MEET', 'PERCEIVE', 'BREAK_DETERIORATE', 'JOIN_CONNECT', 'BORDER', 'FIND', 'KNOW', 'KILL', 'CHARGE', 'FAIL_LOSE', 'CRITICIZE', 'CITE', 'HIT', 'LIBERATE_ALLOW_AFFORD', 'BRING', 'DERIVE', 'JUSTIFY_EXCUSE', 'PERSUADE', 'REVEAL', 'DRIVE-BACK', 'TAKE', 'OBTAIN', 'LOSE', 'ADD', 'MATCH', 'CONSUME_SPEND', 'COMPARE', 'BEFRIEND', 'NAME', 'BE-LOCATED_BASE', 'OFFER', 'OVERLAP', 'CARRY_TRANSPORT', 'REACH', 'FILL', 'ENCLOSE_WRAP', 'DISBAND_BREAK-UP', 'COUNT', 'DEFEAT', 'CO-OPT', 'ENDANGER', 'PUNISH', 'TRANSLATE', 'SECURE_FASTEN_TIE', 'INSERT', 'REMAIN', 'BUY', 'STEAL_DEPRIVE', 'SETTLE_CONCILIATE', 'EXTEND', 'SUMMARIZE', 'PUBLICIZE', 'CORRELATE', 'SEPARATE_FILTER_DETACH', 'GROUP', 'COST', 'ATTACK_BOMB', 'WARN', 'NEGOTIATE', 'ENTER', 'LIE', 'SPEND-TIME_PASS-TIME', 'EMPTY_UNLOAD', 'INVERT_REVERSE-', 'EMIT', 'TURN_CHANGE-DIRECTION', 'SELL', 'GUESS', 'DISCARD', 'CONTRACT-AN-ILLNESS_INFECT', 'WASH_CLEAN', 'DROP', 'OPERATE', 'SHARPEN', 'REFLECT', 'COMPENSATE', 'ASCRIBE', 'LOWER', 'COPY', 'DEBASE_ADULTERATE', 'DISMISS_FIRE-SMN', 'COVER_SPREAD_SURMOUNT', 'MEASURE_EVALUATE', 'RESIGN_RETIRE', 'READ', 'DISTINGUISH_DIFFER', 'TRAVEL', 'RESIST', 'SHOOT_LAUNCH_PROPEL', 'BURDEN_BEAR', 'SOLVE', 'WIN', 'APPEAR', 'FOLLOW-IN-SPACE', 'PULL', 'PAINT', 'COME-FROM', 'VISIT', 'COOL', 'DOWNPLAY_HUMILIATE', 'CHASE', 'EMBELLISH', 'EARN', 'RAISE', 'PROMOTE', 'MEAN', 'EXHAUST', 'ABSORB', 'PRESS_PUSH_FOLD', 'LEND', 'SHAPE', 'PRINT', 'REPAIR_REMEDY', 'GROW_PLOW', 'QUARREL_POLEMICIZE', 'TAKE-A-SERVICE_RENT', 'COMPETE', 'DIVIDE', 'COMMUNICATE_CONTACT', 'FIT', 'EXEMPT', 'SLOW-DOWN', 'FLOW', 'RISK', 'METEOROLOGICAL', 'NOURISH_FEED', 'STABILIZE_SUPPORT-PHYSICALLY']
mod = AMuSE_WSD_online(language, filter_layer=False)

In [None]:

filter = dict()
# uncomment if you have slow internet like me
# it is the result of mod.predict(sentences, require_ids=True) using words
# the online pipeline is better working with words than lemmas because it is tought to be a full end-to-end state-of-the-art 
# multilingual pretrained model.
# pred = torch.load("../../model/amuse/predictions_words") 

# comment if you have slow internet
pred = mod.predict(sentences, require_ids=True)

print(evaluate_predicate_disambiguation(pred, sentences))
# oov_labels = 0
possibilities = dict()
for p in pred:
    for predict, real in zip(pred[p]["predicates"], sentences[p]["predicates"]):
        possibilities[predict] = possibilities.get(predict,{})
        possibilities[predict][real] = possibilities[predict].get(real,0) + 1
# if oov_labels == 0:
#     # we end the training, we have all inside our labels_vocabulary
#     break
for k in possibilities:
    max_count = 0
    max_label = ""
    for j in possibilities[k]:
        if possibilities[k][j] > max_count:
            max_count = possibilities[k][j]
            max_label = j
    filter[k] = max_label
# code to avoid spending time in recumputing everything with mod.predict
new_pred = dict()
for k in pred:
    new_pred[k] = {"predicates" : [filter[s] for s in pred[k]["predicates"]]}
print(evaluate_predicate_disambiguation(new_pred, sentences))
"""
RESULTS:
{'true_positives': 8838, 'false_positives': 3724, 'false_negatives': 2878, 
'precision': 0.7035503900652762, 'recall': 0.7543530215090475, 'f1': 0.7280665623197956}
{'true_positives': 9229, 'false_positives': 3333, 'false_negatives': 2487, 
'precision': 0.7346760070052539, 'recall': 0.7877261864117446, 'f1': 0.7602767938050911}
"""

Now we conclude with this section repeating the experiment with and without the filter for the dev_set, then we save all these predictions and filter to work with a "fine tuned" amuse-wsd. 

In [None]:
pred_test = mod.predict(sentences_test, require_ids=True)
print(evaluate_predicate_disambiguation(pred_test, sentences_test))
mod.filter = filter
new_pred_test = mod.predict(sentences_test, require_ids=True)
print(evaluate_predicate_disambiguation(new_pred_test, sentences_test))
"""
RESULTS:
{'true_positives': 1782, 'false_positives': 771, 'false_negatives': 546, 
'precision': 0.6980023501762632, 'recall': 0.7654639175257731, 'f1': 0.7301782421634909}
{'true_positives': 1841, 'false_positives': 712, 'false_negatives': 487, 
'precision': 0.7211124167645907, 'recall': 0.790807560137457, 'f1': 0.7543536160622824}
"""

In [None]:
# notice index 2 wrt the one I upload... I was scared to override them
torch.save(pred, "../../model/amuse/prediction_words2")
torch.save(pred_test, "../../model/amuse/prediction_words_dev2")
torch.save(new_pred, "../../model/amuse/prediction_words_new2")
torch.save(new_pred_test, "../../model/amuse/prediction_words_dev_new2")
torch.save(filter, "../../model/amuse/filter_layer2")

## TOREMOVE

In [None]:
filt = torch.load("../../model/SRL_234_layer.ckpt")
print(len(frames_in_dataset))
print(filt)
print(len(filt.keys()))
mod.filter = filt

In [None]:
input = {"words": sentences["1996/a/50/18_supp__323:5"]["words"], "lemmas": sentences["1996/a/50/18_supp__323:5"]["lemmas"]}

In [None]:
model.predict(input)

In [None]:
from datasets_srl import SRL_DataModule
from implementation import HParams, SRL_1234, SRL_34
from dataclasses import dataclass, asdict
from pprint import pprint
from utils import read_dataset, evaluate_argument_classification, evaluate_argument_identification
from mergedeep import merge
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint

while True:
    # these are some parameters that allow as I said to modularize the training. We need to store the hypermarameters of the model (lr, wd, ...), the language
    # and the task on which we want to perform the training.
    hparams = HParams()
    languages = ["EN", "ES", "FR"]
    tasks = ["34", "234", "1234"]
    models = {"34": SRL_34, "1234": SRL_1234}

    language = languages[0]
    task = tasks[2]
    epochs = 100
    SRL_Model = models[task]
    hparams.language = language
    hparams.task = task
    hparams = asdict(hparams)
    pprint(hparams)
    # after reading the dataset I merge the two dicts (sentences and labels) since there is a field in common (predicate)
    # and it's only a waste of space keeping it in memory 2 copies of it.
    sentences = merge(*read_dataset("../../data/"+language+"/train.json"))
    sentences_test = merge(*read_dataset("../../data/"+language+"/dev.json"))

    data = SRL_DataModule(hparams, task, language, sentences, sentences_test)
    model = SRL_Model(hparams=hparams, sentences_for_evaluation=sentences_test)
    # Define the logger
    # https://www.wandb.com/articles/pytorch-lightning-with-weights-biases.
    # NOTE: to use wandb properly you need to login in wandb (need an account) 
    # or use a different logger eg. TensorBoard, I'm used to this one so I'll go for it.
    # to login: https://docs.wandb.ai/ref/cli/wandb-login
    wandb.require("service")
    wandb_logger = WandbLogger(project="SRL_"+task, log_model = True) # note not language in the project name, so we can compare different languages
    wandb_logger.experiment.watch(model, log = 'False', log_freq = 100000)
    # Define the trainer
    metric_to_monitor =  "f1" # f1 of argument classification, also possible eg 'avg_val_loss'
    mode = "max" #you want to maximixe or minimize the metric?
    # we employ the early stopping technique to avoid hours of usuless training, pl gives it for free
    early_stop_callback = EarlyStopping(monitor = metric_to_monitor, min_delta = 0.00, patience = 15, verbose = True, mode = mode)
    # it is also useful to keep track of the best model during the epochs (if you remember I did all this manually last hw)or use a different logger,
    # we have a callback even for this.
    checkpoint_callback = ModelCheckpoint(
                            save_top_k = 1,
                            monitor = metric_to_monitor,
                            mode = mode,
                            dirpath = "../../model",
                            filename = "SRL_"+task+"_"+language+"-{epoch:02d}-{f1:.4f}",
                            verbose = True
                        )
    # the trainer collect all the useful informations so far for the training 
    trainer = pl.Trainer(logger = wandb_logger,
                        max_epochs = epochs, 
                        gpus = 1,
                        callbacks = [early_stop_callback, checkpoint_callback])    

    save_ckpt_file = "../../model/SRL_{}_{}_last.ckpt".format(task, language)
    trainer.fit(model, data)
    wandb.finish()


In [None]:
a = torch.load("../../model/srl_34_EN.ckpt")

In [None]:
cm = confusion_matrix(flat_labels_s, flat_predictions_s, labels=all_labels, normalize="true")
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=all_labels)
fig, ax = plt.subplots(figsize=(400,400))
disp.plot(ax=ax) #7162