## Änderungen
- Um Eventtypes zu erkennen gibt es jetzt nicht mehr 17 entities +1 trigger sondern 17 + 49 = 66 entity_types bei der mention detection -> prototype_embeddings als Variablen im Modell angepasst. Liste wird von trigger_entity_types.json geladen
- Ausgabe von events und triples
- Bei Events werden unten feasible roles gecheckt und nur argument roles ausgegeben, die für das event Sinn ergeben.

## Probleme
- batch_text ist Rückumwandlung von tokens in text -> spans passen zu text durch subword und [CLS] offsets. Problem: Die start und end spans meiner ausgegebenen Events passen nicht zu den im WikiEvents Datensatz deswegen. Nur das 1. Subword wird jeweils als text ausgegeben. Siehe Ende des Notebooks    
  
  
  
- Bei Inferenz habe ich keine entity_types, entity_spans, relation_labels. Wie rufe ich dann trotzdem meine model.forward() Methode auf? Passt es an der markierten Stelle unten nach der mention detection zu versuchen die selbe Ausgangsstruktur wie durch Vorliegen der entity_types, entity_spans zu schaffen um wie gehabt weiter zu machen?

In [1]:
from transformers import AutoTokenizer
from transformers import AutoConfig, AutoModel, BertConfig
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from src.data import parse_file, collate_fn
import tqdm
import json
from transformers.optimization import AdamW
import numpy as np
import torch.nn.functional as F
from src.losses import ATLoss
from src.util import process_long_input
from transformers import BertConfig, RobertaConfig, DistilBertConfig, XLMRobertaConfig
from itertools import groupby
from collections import defaultdict

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
#########################
## Model Configuration ##
#########################
language_model = 'bert-base-uncased'
lm_config = AutoConfig.from_pretrained(
    language_model,
    num_labels=10,
)
lm_model = AutoModel.from_pretrained(
    language_model,
    from_tf=False,
    config=lm_config,
)
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")


with open("data/roles.json") as f:
    relation_types = json.load(f)
with open("data/trigger_entity_types.json") as f:
    trigger_entity_types = json.load(f)

max_n = 9
dev_loader = DataLoader(
    parse_file("data/WikiEvents/train_docred_format_small.json",
    tokenizer=tokenizer,
    relation_types=relation_types,
    max_candidate_length=max_n),
    batch_size=2,
    shuffle=False,
    collate_fn=collate_fn)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Parsing & generating candidates (n=9): 100%|███| 35/35 [00:00<00:00, 121.31it/s]


In [4]:
class Encoder(nn.Module):
    def __init__(self, config, model, cls_token_id, sep_token_id):
        super().__init__()
        
        self.config = config
        self.model = model

        self.entity_anchor = nn.Parameter(torch.zeros((66, 768)))
        torch.nn.init.uniform_(self.entity_anchor, a=-1.0, b=1.0)
        
        self.relation_embeddings = nn.Parameter(torch.zeros((57,3*768)))
        torch.nn.init.uniform_(self.relation_embeddings, a=-1.0, b=1.0)            
        self.nota_embeddings = nn.Parameter(torch.zeros((20,3*768)))
        torch.nn.init.uniform_(self.nota_embeddings, a=-1.0, b=1.0)


        self.triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
        self.at_loss = ATLoss()

        self.k_mentions = 50
                
        self.cls_token_id = cls_token_id
        self.sep_token_id = sep_token_id
        
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

        
        with open("data/roles.json") as f:
            self.relation_types = json.load(f)
        with open("data/trigger_entity_types.json") as f:
            self.types = json.load(f)
        with open("data/feasible_roles.json") as f:
            self.feasible_roles = json.load(f)
        
    def encode(self, input_ids, attention_mask):
        config = self.config
        if type(config) == BertConfig or type(config) == DistilBertConfig:
            start_tokens = [self.cls_token_id]
            end_tokens = [self.sep_token_id]
        elif type(config) == RobertaConfig or type(config) == XLMRobertaConfig:
            start_tokens = [self.cls_token_id]
            end_tokens = [self.sep_token_id, self.sep_token_id]
        sequence_output, attention = process_long_input(self.model, input_ids, attention_mask, start_tokens, end_tokens)
        return sequence_output, attention


   
        
    def forward(self, input_ids, attention_mask, candidate_spans, relation_labels, entity_spans, entity_types, entity_ids):
        sequence_output, attention = self.encode(input_ids, attention_mask)
        loss = torch.zeros((1)).to(sequence_output)
        mention_loss = torch.zeros((1)).to(sequence_output)
        counter = 0
        batch_triples = []
        batch_events = []
        batch_text = []
        
        for batch_i in range(sequence_output.size(0)):
            text_i = self.tokenizer.convert_ids_to_tokens(input_ids[batch_i])
            batch_text.append(text_i)


            # ARGUMENT ROLE LABELING


            # ---------- Pooling Entity Embeddings and Attentions ------------
            entity_embeddings = []
            entity_attentions = []
            for ent in entity_spans[batch_i]:
                ent_embedding = torch.mean(sequence_output[batch_i, ent[0][0]:ent[0][1],:],0)
                entity_embeddings.append(ent_embedding)
                ent_attention = torch.mean(attention[batch_i,:,ent[0][0]:ent[0][1],:],1)
                entity_attentions.append(ent_attention)
            if(len(entity_embeddings) == 0):
                continue
            entity_embeddings = torch.stack(entity_embeddings)
            entity_attentions = torch.stack(entity_attentions)

                
            # ---------- Localized Context Pooling ------------
            relation_candidates = []
            localized_context = []
            concat_embs = []
            triggers = []
            for s in range(entity_embeddings.shape[0]):
                if entity_types[batch_i][s].split(".")[-1] != "TRIGGER":
                    continue
                triggers.append(s)
                for o in range(entity_embeddings.shape[0]):
                    if s != o:

                        relation_candidates.append((s,o))

                        A_s = entity_attentions[s,:,:]
                        A_o = entity_attentions[o,:,:]
                        A = torch.mul(A_o,A_s)
                        q = torch.sum(A,0)
                        a = q / q.sum()
                        H_T = sequence_output[batch_i].T
                        c = torch.matmul(H_T,a)
                        localized_context.append(c)

                        concat_emb = torch.cat((entity_embeddings[s],entity_embeddings[o],c),0)
                        concat_embs.append(concat_emb)
            if(len(localized_context) == 0):
                continue
            localized_context = torch.stack(localized_context)
            embs = torch.stack(concat_embs)
            
            triggers = list(set(triggers))
            # ---------- Pairwise Comparisons and Predictions ------------

            scores = torch.matmul(embs,self.relation_embeddings.T)
            nota_scores = torch.matmul(embs,self.nota_embeddings.T)
            nota_scores = nota_scores.max(dim=-1,keepdim=True)[0]
            scores = torch.cat((nota_scores, scores), dim=-1)
            predictions = torch.argmax(scores, dim=-1, keepdim=False)
            #Achtung: NOTA wird an 0. Stelle gesetzt
            
            if self.training:
            # ---------- ATLoss with one-hot encoding for true labels ------------
                targets = []
                for r in relation_candidates:
                    onehot = torch.zeros(len(relation_types))
                    if r in relation_labels[batch_i]:
                        onehot[relation_labels[batch_i][r]] = 1.0
                    targets.append(onehot)
                targets = torch.stack(targets).to(self.model.device)
                loss += self.at_loss(scores,targets)
                counter += 1
            
            # ---------- Inference ------------
            triples = []
            for idx,pair in enumerate(relation_candidates):
                triple = {
                    pair:relation_types[predictions[idx]]
                }
                triples.append(triple)
            batch_triples.append(triples)
                
            events = []
            for t,v in groupby(triples,key=lambda x:next(iter(x.keys()))[0]):
                t_word = entity_types[batch_i][t]
                t_start = entity_spans[batch_i][t][0][0]
                t_end = entity_spans[batch_i][t][0][1]
                event_type = t_word.split(".TRIGGER")[0]
                event_id = entity_ids[batch_i][t]

                arguments = []
                for d in v:
                    dic = next(iter(d.items()))
                    o = dic[0][1]
                    r = dic[1]

                    if r in mymodel.feasible_roles[event_type]:
                        a_start = entity_spans[batch_i][o][0][0]
                        a_end = entity_spans[batch_i][o][0][1]
                        argument = {
                            'entity_id':entity_ids[batch_i][o],
                            'role':r,
                            'text':batch_text[batch_i][a_start:a_end][0],
                            'start':a_start,
                            'end':a_end,
                        }
                        arguments.append(argument)

                event = {
                    'id': entity_ids[batch_i][t],
                    'event_type':event_type,
                    'trigger': {'start':t_start ,'end':t_end, 'text':batch_text[batch_i][t_start:t_end][0]},
                    'arguments':arguments
                }
                events.append(event)
            batch_events.append(events)
        if(counter == 0):
                return torch.autograd.Variable(loss,requires_grad=True), batch_triples, batch_events
        else:
            return (mention_loss+loss)/counter, batch_triples, batch_events

In [5]:
mymodel = Encoder(lm_config,lm_model, cls_token_id=tokenizer.convert_tokens_to_ids(tokenizer.cls_token), 
                                            sep_token_id=tokenizer.convert_tokens_to_ids(tokenizer.sep_token))
optimizer = AdamW(mymodel.parameters(), lr=1e-5, eps=1e-6)




In [7]:
doc_ids_list = []
event_list = []
losses = []
with tqdm.tqdm(dev_loader) as progress_bar:
    for sample in progress_bar:

        token_ids, input_mask, entity_spans, entity_types, entity_ids, relation_labels, text, token_map, candidate_spans, doc_ids = sample
        
        
        mymodel.train()
        loss ,triples, events = mymodel(token_ids, input_mask, candidate_spans, relation_labels, entity_spans, entity_types, entity_ids)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        losses.append(loss.item())

        progress_bar.set_postfix({"L":f"{loss.item():.2f}"})
        print(loss.item())
        
        for e in events:
            event_list.append(e)
        for d in doc_ids:
            doc_ids_list.append(d)


  6%|██                                  | 1/18 [00:00<00:09,  1.80it/s, L=8.29]

8.288229942321777


 11%|████                                | 2/18 [00:01<00:09,  1.64it/s, L=8.96]

8.961225509643555


 17%|██████                              | 3/18 [00:01<00:09,  1.55it/s, L=4.44]

4.4380598068237305


 22%|████████                            | 4/18 [00:02<00:10,  1.38it/s, L=8.60]

8.596845626831055


 28%|██████████                          | 5/18 [00:03<00:10,  1.18it/s, L=3.42]

3.42008113861084


 33%|████████████                        | 6/18 [00:04<00:11,  1.05it/s, L=3.87]

3.8684604167938232


 39%|██████████████                      | 7/18 [00:06<00:12,  1.16s/it, L=2.69]

2.689852476119995


 44%|████████████████                    | 8/18 [00:07<00:11,  1.12s/it, L=4.76]

4.759007930755615


 50%|██████████████████                  | 9/18 [00:08<00:10,  1.14s/it, L=3.26]

3.2639245986938477


 56%|███████████████████▍               | 10/18 [00:09<00:09,  1.15s/it, L=2.02]

2.015265941619873


 61%|█████████████████████▍             | 11/18 [00:11<00:08,  1.17s/it, L=3.67]

3.67482328414917


 67%|███████████████████████▎           | 12/18 [00:12<00:07,  1.28s/it, L=0.82]

0.8188276290893555


 72%|█████████████████████████▎         | 13/18 [00:15<00:08,  1.74s/it, L=1.95]

1.9539991617202759


 78%|███████████████████████████▏       | 14/18 [00:17<00:07,  1.90s/it, L=1.99]

1.9919514656066895


 83%|█████████████████████████████▏     | 15/18 [00:19<00:05,  1.92s/it, L=3.59]

3.5947961807250977


 89%|███████████████████████████████    | 16/18 [00:21<00:03,  1.90s/it, L=2.70]

2.699817657470703


 94%|█████████████████████████████████  | 17/18 [00:23<00:01,  1.84s/it, L=3.37]

3.368267297744751


100%|███████████████████████████████████| 18/18 [00:24<00:00,  1.36s/it, L=5.26]

5.2552714347839355





In [16]:
event_list

[[{'id': 'K0C03N4OM-E1',
   'event_type': 'Justice.ArrestJailDetain.Unspecified',
   'trigger': {'start': 4, 'end': 5, 'text': 'custody'},
   'arguments': []},
  {'id': 'K0C03N4OM-E2',
   'event_type': 'Movement.Transportation.Evacuation',
   'trigger': {'start': 10, 'end': 11, 'text': 'evacuation'},
   'arguments': []}],
 [{'id': 'suicide_ied_103-E1',
   'event_type': 'Conflict.Attack.DetonateExplode',
   'trigger': {'start': 9, 'end': 12, 'text': 'de'},
   'arguments': []},
  {'id': 'suicide_ied_103-E2',
   'event_type': 'Life.Die.Unspecified',
   'trigger': {'start': 42, 'end': 43, 'text': 'killed'},
   'arguments': []},
  {'id': 'suicide_ied_103-E3',
   'event_type': 'Life.Injure.Unspecified',
   'trigger': {'start': 47, 'end': 48, 'text': 'injured'},
   'arguments': []}],
 [{'id': 'suicide_ied_111-E1',
   'event_type': 'Conflict.Attack.DetonateExplode',
   'trigger': {'start': 5, 'end': 6, 'text': 'occurred'},
   'arguments': []},
  {'id': 'suicide_ied_111-E2',
   'event_type': 'L

## Evaluation

In [None]:
event_list

In [12]:
from collections import defaultdict
import json
df_we = pd.read_json('data/WikiEvents/train.json').set_index('doc_id')
df = pd.read_json("data/WikiEvents/train_coref.json").set_index('doc_key')

coref_mapping = defaultdict(dict)
for doc_id, row in df.iterrows():
    for cluster in row.clusters:
        for item in cluster:
            coref_mapping[doc_id][item] = cluster
for doc_id, row in df_we.iterrows():
    for event in row.event_mentions:
        for arg in event['arguments']:
            if arg['entity_id'] not in coref_mapping[doc_id].keys():
                coref_mapping[doc_id][arg['entity_id']] = arg['entity_id']

with open("data/WikiEvents/coref_mapping.json", "w") as f:
    json.dump(coref_mapping,f)

In [152]:
idf_pred, idf_gold, idf_h_matched, idf_c_matched = 0,0,0,0
clf_pred, clf_gold, clf_h_matched, clf_c_matched = 0,0,0,0
for doc_id, events in zip(doc_ids_list,event_list):
    print("\033[1m"+f"doc_id:{doc_id}")
    print("\033[0m")
    gold_events = df_we.loc[doc_id].event_mentions

    
    for ge,e in zip(gold_events,events):
        for g_arg in ge['arguments']:
            for arg in e['arguments']:
                #----- Head Matches -----
                if g_arg['entity_id'] == arg['entity_id']:
                    idf_h_matched += 1
                    print(f"Headword Identified: doc_id: {doc_id}, entity: {arg['entity_id']}, gold_entity: {g_arg['entity_id']}")
                    if g_arg['role'] == arg['role']:
                        clf_h_matched += 1
                    
                #----- Coref Matches -----
                #if g_arg['entity_id'] in coref_mapping[doc_id].keys():  #Falls es überhaupt ein cluster gibt
                if arg['entity_id'] in coref_mapping[doc_id][g_arg['entity_id']]:
                    idf_c_matched += 1
                    print(f"Corefence Identified: doc_id: {doc_id}, entity: {arg['entity_id']}, gold_entity: {g_arg['entity_id']}")
                    #wenn roles übereinstimmen
                    if g_arg['role'] == arg['role']:
                        clf_c_matched += 1
                            
            idf_gold += 1
        #zaehle alle erkannten
        for arg in e['arguments']:
            idf_pred += 1
    
            
                

[1mdoc_id:K0C03N4OM
[0m
[1mdoc_id:suicide_ied_103
[0m
[1mdoc_id:suicide_ied_111
[0m
[1mdoc_id:suicide_ied_110
[0m
[1mdoc_id:road_ied_1
[0m
Corefence Identified: doc_id: road_ied_1, entity: road_ied_1-T11, gold_entity: road_ied_1-T5
[1mdoc_id:suicide_ied_119
[0m
[1mdoc_id:scenario_en_kairos_0
[0m
[1mdoc_id:scenario_en_kairos_1
[0m
[1mdoc_id:suicide_ied_109
[0m
[1mdoc_id:wiki_ied_bombings_3_news_5
[0m
[1mdoc_id:road_ied_2
[0m
Headword Identified: doc_id: road_ied_2, entity: road_ied_2-T41, gold_entity: road_ied_2-T41
Corefence Identified: doc_id: road_ied_2, entity: road_ied_2-T41, gold_entity: road_ied_2-T41
[1mdoc_id:suicide_ied_104
[0m
[1mdoc_id:suicide_ied_118
[0m
[1mdoc_id:road_ied_11
[0m
[1mdoc_id:suicide_ied_102
[0m
[1mdoc_id:K0C041NHW
[0m
[1mdoc_id:suicide_ied_106
[0m
[1mdoc_id:scenario_en_kairos_11
[0m
[1mdoc_id:scenario_en_kairos_13
[0m
[1mdoc_id:road_ied_3
[0m
Corefence Identified: doc_id: road_ied_3, entity: road_ied_3-T4, gold_entity: 

In [147]:
print(f"Head idf|clf Matches: {idf_h_matched} | {clf_h_matched}")
print(f"Coref idf|clf  Matches: {idf_c_matched} | {clf_c_matched}")

Head idf|clf Matches: 2 | 0
Coref idf|clf  Matches: 6 | 2


In [148]:
def safe_div(num, denom):
    if denom > 0:
        return num / denom
    else:
        return 0

def compute_f1(predicted, gold, matched):
    precision = safe_div(matched, predicted)
    recall = safe_div(matched, gold)
    f1 = safe_div(2 * precision * recall, precision + recall)
    return precision, recall, f1

In [149]:
idf_h_p, idf_h_r, idf_h_f1 = compute_f1(idf_pred, idf_gold, idf_h_matched)
idf_c_p, idf_c_r, idf_c_f1 = compute_f1(idf_pred, idf_gold, idf_c_matched)

In [153]:
print('Head Role identification: P: {:.2f}, R: {:.2f}, F: {:.2f}'.format(
            idf_h_p , idf_h_r , idf_h_f1 ))
print('Coref Role identification: P: {:.2f}, R: {:.2f}, F: {:.2f}'.format(
            idf_c_p , idf_c_r , idf_c_f1 ))

Head Role identification: P: 0.03, R: 0.01, F: 0.01
Coref Role identification: P: 0.10, R: 0.02, F: 0.03


In [5]:
import pandas as pd
df = pd.read_json("data/WikiEvents/train_docred_format.json")
df_we = pd.read_json("data/WikiEvents/train.json")

In [4]:
df.vertexSet[0]

[[{'pos': [166, 167],
   'type': 'Cognitive.IdentifyCategorize.Unspecified.TRIGGER',
   'sent_id': 3,
   'name': 'discovered',
   'id': 'scenario_en_kairos_14-E1'}],
 [{'pos': [88, 89],
   'type': 'Cognitive.Inspection.SensoryObserve.TRIGGER',
   'sent_id': 1,
   'name': 'reviewing',
   'id': 'scenario_en_kairos_14-E2'}],
 [{'pos': [30, 31],
   'type': 'Cognitive.IdentifyCategorize.Unspecified.TRIGGER',
   'sent_id': 0,
   'name': 'searching',
   'id': 'scenario_en_kairos_14-E3'}]]

In [8]:
df_we.event_mentions[1]

[{'id': 'scenario_en_kairos_65-E1',
  'event_type': 'Conflict.Attack.Unspecified',
  'trigger': {'start': 50, 'end': 51, 'text': 'attack', 'sent_idx': 1},
  'arguments': []},
 {'id': 'scenario_en_kairos_65-E2',
  'event_type': 'Life.Injure.Unspecified',
  'trigger': {'start': 62, 'end': 63, 'text': 'injured', 'sent_idx': 2},
  'arguments': [{'entity_id': 'scenario_en_kairos_65-T13',
    'role': 'Victim',
    'text': 'Terry Duffield'}]},
 {'id': 'scenario_en_kairos_65-E3',
  'event_type': 'Conflict.Attack.DetonateExplode',
  'trigger': {'start': 65, 'end': 66, 'text': 'detonated', 'sent_idx': 2},
  'arguments': []},
 {'id': 'scenario_en_kairos_65-E4',
  'event_type': 'Conflict.Attack.DetonateExplode',
  'trigger': {'start': 417, 'end': 419, 'text': 'went off', 'sent_idx': 12},
  'arguments': []},
 {'id': 'scenario_en_kairos_65-E5',
  'event_type': 'Conflict.Attack.DetonateExplode',
  'trigger': {'start': 433, 'end': 434, 'text': 'detonated', 'sent_idx': 14},
  'arguments': []}]

In [13]:
lm_model

NameError: name 'lm_model' is not defined

# TODO
## 1 EAE Modell fertig machen und evaluieren
### in Skript packen Ordnerstruktur dabei anpassen
### auf GPU mit gesamten Train split rechnen lassen
### Hyperparameter exposen 
### sh Skripte schreiben und systematische Evaluation mit Wandb vorbereiten 

## 2 E2E Modell erstellen

### Wie kann hier Inferenz aussehen? -> Start und End müssen ausgegeben werden und dann wird evaluiert.

### 

## 3 Modelle mit 2 forward Methoden in 1 packen und damit evaluieren 