## TODO

- offsets anpassen, sodass end-2-end model ohne die entity_ids ausgewertet werden kann und texte stimmen

In [1]:
from transformers import AutoTokenizer
from transformers import AutoConfig, AutoModel, BertConfig
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from src.data import parse_file, collate_fn
import tqdm
import json
from transformers.optimization import AdamW
import numpy as np
import torch.nn.functional as F
from src.losses import ATLoss
from src.util import process_long_input
from transformers import BertConfig, RobertaConfig, DistilBertConfig, XLMRobertaConfig

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Encoder(nn.Module):
    def __init__(self, config, model, cls_token_id, sep_token_id, soft_mention = True):
        super().__init__()
        self.soft_mention = soft_mention

        self.config = config
        self.model = model

        self.entity_anchor = nn.Parameter(torch.zeros((66, 768)))
        torch.nn.init.uniform_(self.entity_anchor, a=-1.0, b=1.0)
        
        self.relation_embeddings = nn.Parameter(torch.zeros((57,3*768)))
        torch.nn.init.uniform_(self.relation_embeddings, a=-1.0, b=1.0)            
        self.nota_embeddings = nn.Parameter(torch.zeros((20,3*768)))
        torch.nn.init.uniform_(self.nota_embeddings, a=-1.0, b=1.0)


        self.triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
        self.ce_loss = nn.CrossEntropyLoss()
        self.at_loss = ATLoss()
        

        self.k_mentions = 50
                
        self.cls_token_id = cls_token_id
        self.sep_token_id = sep_token_id
        
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

        
        with open("data/roles.json") as f:
            self.relation_types = json.load(f)
        with open("data/trigger_entity_types.json") as f:
            self.te_types = json.load(f)
        with open("data/feasible_roles.json") as f:
            self.feasible_roles = json.load(f)
        
    def encode(self, input_ids, attention_mask):
        config = self.config
        if type(config) == BertConfig or type(config) == DistilBertConfig:
            start_tokens = [self.cls_token_id]
            end_tokens = [self.sep_token_id]
        elif type(config) == RobertaConfig or type(config) == XLMRobertaConfig:
            start_tokens = [self.cls_token_id]
            end_tokens = [self.sep_token_id, self.sep_token_id]
        sequence_output, attention = process_long_input(self.model, input_ids, attention_mask, start_tokens, end_tokens)
        return sequence_output, attention

    def forward(self, input_ids, attention_mask, candidate_spans, relation_labels = [[]], entity_spans= [[]], entity_types= [[]], entity_ids= [[]]):
        sequence_output, attention = self.encode(input_ids, attention_mask)

        loss = torch.zeros((1)).to(sequence_output)
        mention_loss = torch.zeros((1)).to(sequence_output)
        counter = 0
        batch_triples = []
        batch_events = []
        batch_text = []

        if not self.training:
            entity_spans = [[]*sequence_output.shape[0]]
            entity_types = [[]*sequence_output.shape[0]]
            entity_ids = [[]*sequence_output.shape[0]]
            relation_labels = [[]*sequence_output.shape[0]]
        
        for batch_i in range(sequence_output.size(0)):
            #entity_spans = entity_spans[batch_i]
            text_i = self.tokenizer.convert_ids_to_tokens(input_ids[batch_i])
            batch_text.append(text_i)

            # MENTION DETECTION

            # ---------- Candidate span embeddings ------------
            mention_candidates = []
            candidates_attentions = []
            for span in candidate_spans[batch_i]:
                mention_embedding = torch.mean(sequence_output[batch_i, span[0]:span[1]+1,:], 0)
                mention_attention = torch.mean(attention[batch_i, span[0]:span[1]+1,:], 0)
                mention_candidates.append(mention_embedding)
                candidates_attentions.append(mention_attention)
            embs = torch.stack(mention_candidates)
            atts = torch.stack(candidates_attentions)

            # ---------- mention detection (scores) ------------
            span_scores = embs.unsqueeze(1) * self.entity_anchor.unsqueeze(0)
            span_scores = torch.sum(span_scores, dim=-1)
            span_scores_max, class_for_span = torch.max(span_scores, dim=-1)
            scores_for_max, max_spans = torch.topk(span_scores_max.view(-1), min(self.k_mentions, embs.size(0)), dim=0)
            class_for_max_span = class_for_span[max_spans]

            if self.training:
                # ---------- Mention Loss and adding true spans during training ------------

                if self.soft_mention:
                    spans_for_type = {}

                    for span, rtype in zip(entity_spans[batch_i], entity_types[batch_i]):
                        if rtype not in spans_for_type.keys():
                            spans_for_type[rtype] = []
                        spans_for_type[rtype].append(span[0])

                    anchors, positives, negatives = [], [], []

                    for rtype, positive_examples in spans_for_type.items():

                        # add negative examples from entity spans
                        for pos in positive_examples:
                            for rtype2, negative_examples in spans_for_type.items():
                                if rtype2 == rtype:
                                    continue
                                for neg in negative_examples:
                                    anchors.append(self.entity_anchor[self.te_types.index(rtype),:])
                                    positives.append(torch.mean(sequence_output[batch_i, pos[0]:pos[1]+1,:], 0))
                                    negatives.append(torch.mean(sequence_output[batch_i, neg[0]:neg[1]+1,:], 0))

                        # add negative examples from candidate spans
                        for pos in positive_examples:
                            for neg in [x for x in candidate_spans[batch_i] if x not in entity_spans[batch_i]]:
                                anchors.append(self.entity_anchor[self.te_types.index(rtype),:])
                                positives.append(torch.mean(sequence_output[batch_i, pos[0]:pos[1]+1,:], 0))
                                negatives.append(torch.mean(sequence_output[batch_i, neg[0]:neg[1]+1,:], 0))


                    mention_loss += self.triplet_loss(torch.stack(anchors), torch.stack(positives), torch.stack(negatives))

                else:
                    #TODO: Cross Entropy Loss
                    #one-hot encoding for annotated mention labels 
                    mention_labels = torch.zeros(len(candidate_spans[batch_i]),len(self.te_types))
                    for idx,c in enumerate(candidate_spans[batch_i]):
                        for ent,t in zip(entity_spans[batch_i],entity_types[batch_i]):
                            if [c] == ent:
                                mention_labels[idx][self.te_types.index(t)] = 1
                    mention_loss += self.ce_loss(mention_labels, span_scores)


            # ARGUMENT ROLE LABELING

            if self.training:
                # ---------- Pooling Entity Embeddings and Attentions ------------
                entity_embeddings = []
                entity_attentions = []
                for ent in entity_spans[batch_i]:
                    ent_embedding = torch.mean(sequence_output[batch_i, ent[0][0]:ent[0][1],:],0)
                    entity_embeddings.append(ent_embedding)
                    ent_attention = torch.mean(attention[batch_i,:,ent[0][0]:ent[0][1],:],1)
                    entity_attentions.append(ent_attention)
                if(len(entity_embeddings) == 0):
                    continue
                entity_embeddings = torch.stack(entity_embeddings)
                entity_attentions = torch.stack(entity_attentions)
            else:
                entity_embeddings = embs[max_spans]
                entity_attentions = atts[max_spans]

                for c in class_for_max_span:
                    entity_types[batch_i].append(self.te_types[c])
                for s in max_spans:
                    entity_spans[batch_i].append([candidate_spans[batch_i][s]])


                
                
            # ---------- Localized Context Pooling ------------
            relation_candidates = []
            localized_context = []
            concat_embs = []
            triggers = []
            for s in range(entity_embeddings.shape[0]):
                if entity_types[batch_i][s].split(".")[-1] != "TRIGGER":
                    continue
                triggers.append(s)
                for o in range(entity_embeddings.shape[0]):
                    if s != o:

                        relation_candidates.append((s,o))

                        A_s = entity_attentions[s,:,:]
                        A_o = entity_attentions[o,:,:]
                        A = torch.mul(A_o,A_s)
                        q = torch.sum(A,0)
                        a = q / q.sum()
                        H_T = sequence_output[batch_i].T
                        c = torch.matmul(H_T,a)
                        localized_context.append(c)

                        concat_emb = torch.cat((entity_embeddings[s],entity_embeddings[o],c),0)
                        concat_embs.append(concat_emb)
            if(len(localized_context) == 0):
                continue
            localized_context = torch.stack(localized_context)
            embs = torch.stack(concat_embs)
            
            triggers = list(set(triggers))
            # ---------- Pairwise Comparisons and Predictions ------------

            scores = torch.matmul(embs,self.relation_embeddings.T)
            nota_scores = torch.matmul(embs,self.nota_embeddings.T)
            nota_scores = nota_scores.max(dim=-1,keepdim=True)[0]
            scores = torch.cat((nota_scores, scores), dim=-1)
            predictions = torch.argmax(scores, dim=-1, keepdim=False)
            #Achtung: NOTA wird an 0. Stelle gesetzt
            
            if self.training:
            # ---------- ATLoss with one-hot encoding for true labels ------------
                targets = []
                for r in relation_candidates:
                    onehot = torch.zeros(len(self.relation_types))
                    if r in relation_labels[batch_i]:
                        onehot[relation_labels[batch_i][r]] = 1.0
                    targets.append(onehot)
                targets = torch.stack(targets).to(self.model.device)
                loss += self.at_loss(scores,targets)
                counter += 1
            
            # ---------- Inference ------------
            triples = []
            for idx,pair in enumerate(relation_candidates):
                triple = {
                    pair:self.relation_types[predictions[idx]]
                }
                triples.append(triple)
            batch_triples.append(triples)
                
            events = []
            for trigger in triggers:
                trigger_word = entity_types[batch_i][trigger]
                event_type = trigger_word.split(".TRIGGER")[0]
                trigger_span = entity_spans[batch_i][trigger][0]
                t_start = trigger_span[0]
                t_end = trigger_span[1]
                arguments = []
                for t in triples:
                    #Triples ist Liste von dicts, gibt es eine geschicktere Möglichkeit an s,o,r zu kommen?
                    for i in t.items():
                        s = i[0][0]
                        o = i[0][1]
                        r = i[1] 
                        
                        if s == trigger:
                            if r in self.feasible_roles[event_type]:
                                a_span = entity_spans[batch_i][o][0]
                                a_start = a_span[0]
                                a_end = a_span[1]

                                #TODO:
                                # Mapping von spans zu altem Text.
                                argument = {
                                    
                                    'role':r,
                                    'start':a_start,
                                    'end':a_end,
                                    'text':batch_text[batch_i][a_start:a_end][0]
                                }
                                arguments.append(argument)
                
                event = {
                    'event_type':event_type,
                    'trigger': {'start':t_start ,'end':t_end, 'text':batch_text[batch_i][t_start:t_end][0]},
                    'arguments':arguments
                }
                events.append(event)
            batch_events.append(events)
        if(counter == 0):
                return torch.autograd.Variable(loss,requires_grad=True)
        else:
            return (mention_loss+loss)/counter, batch_triples, batch_events
        #TODO: Irgendetwas sollte implizit berücksichtigt werden?



In [3]:
#########################
## Model Configuration ##
#########################
language_model = 'bert-base-uncased'
lm_config = AutoConfig.from_pretrained(
    language_model,
    num_labels=10,
)
lm_model = AutoModel.from_pretrained(
    language_model,
    from_tf=False,
    config=lm_config,
)
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")


with open("data/roles.json") as f:
    relation_types = json.load(f)

max_n = 9
test_loader = DataLoader(

    parse_file("data/WikiEvents/train_docred_format_all.json",
    tokenizer=tokenizer,
    relation_types=relation_types,
    max_candidate_length=max_n),
    batch_size=4,
    shuffle=False,
    collate_fn=collate_fn)

mymodel = Encoder(lm_config,
                lm_model,
                cls_token_id=tokenizer.convert_tokens_to_ids(tokenizer.cls_token), 
                sep_token_id=tokenizer.convert_tokens_to_ids(tokenizer.sep_token),
                soft_mention = True
                )
mymodel.eval()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Parsing & generating candidates (n=9): 100%|██| 206/206 [00:12<00:00, 16.79it/s]


Encoder(
  (model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
   

In [4]:
with tqdm.tqdm(test_loader) as progress_bar:
    for sample in progress_bar:

        input_ids, attention_mask, entity_spans, entity_types, entity_ids, relation_labels, text, token_map, candidate_spans, doc_ids = sample
        
        

100%|█████████████████████████████████████████| 52/52 [00:00<00:00, 2746.69it/s]


### Input_ids

In [5]:
import pandas as pd
df_we = pd.read_json("data/WikiEvents/train.json")
df_we.iloc[205].text

'German police are searching for a suspect after an abandoned bag loaded with potential explosives was discovered Monday at the central train station in Bonn, Spiegel Online reported Wednesday.The German daily Frankfurter Allegemeine Zeitung, citing unnamed high level sources, reported that authorities determined the bomb was extremely dangerous, comparable to the 2004 attack in Madrid, in which 10 bombs placed on regional trains claimed 191 lives.The station was temporarily closed.Two Somali-born men, detained Tuesday in connection with the find, were later released for lack of evidence, Cologne police said, according to Spiegel.The first suspect, identified as Omar D., was arrested on Tuesday around 1:30 p.m. in a Bonn Internet café along with his companion Abdifatah W., Spiegel Online reported Tuesday.According to Spiegel, Omar D. was detained in 2008 at the Cologne/Bonn airport after boarding a plane to Amsterdam, along with another Somali native, both suspected of planning to join

In [6]:
# docId = 205 in WikiEvents train
batch_i = 1
re_text = tokenizer.convert_ids_to_tokens(input_ids[batch_i])
re_text

['[CLS]',
 'German',
 'police',
 'are',
 'searching',
 'for',
 'a',
 'suspect',
 'after',
 'an',
 'abandoned',
 'bag',
 'loaded',
 'with',
 'potential',
 'explosives',
 'was',
 'discovered',
 'Monday',
 'at',
 'the',
 'central',
 'train',
 'station',
 'in',
 'Bonn',
 ',',
 'S',
 '##pie',
 '##gel',
 'Online',
 'reported',
 'Wednesday',
 '.',
 'The',
 'German',
 'daily',
 'Frankfurt',
 '##er',
 'All',
 '##ege',
 '##mei',
 '##ne',
 'Z',
 '##eit',
 '##ung',
 ',',
 'citing',
 'unnamed',
 'high',
 'level',
 'sources',
 ',',
 'reported',
 'that',
 'authorities',
 'determined',
 'the',
 'bomb',
 'was',
 'extremely',
 'dangerous',
 ',',
 'comparable',
 'to',
 'the',
 '2004',
 'attack',
 'in',
 'Madrid',
 ',',
 'in',
 'which',
 '10',
 'bombs',
 'placed',
 'on',
 'regional',
 'trains',
 'claimed',
 '191',
 'lives',
 '.',
 'The',
 'station',
 'was',
 'temporarily',
 'closed',
 '.',
 'Two',
 'Somali',
 '-',
 'born',
 'men',
 ',',
 'detained',
 'Tuesday',
 'in',
 'connection',
 'with',
 'the',
 'fin

In [7]:
for c in candidate_spans[batch_i]:
    s = c[0]
    e = c[1]
    ##Hier auf Methode generate_naive achten: len+1
    print(re_text[s:e+1])

['German']
['German', 'police']
['German', 'police', 'are']
['German', 'police', 'are', 'searching']
['German', 'police', 'are', 'searching', 'for']
['German', 'police', 'are', 'searching', 'for', 'a']
['German', 'police', 'are', 'searching', 'for', 'a', 'suspect']
['German', 'police', 'are', 'searching', 'for', 'a', 'suspect', 'after']
['German', 'police', 'are', 'searching', 'for', 'a', 'suspect', 'after', 'an']
['police']
['police', 'are']
['police', 'are', 'searching']
['police', 'are', 'searching', 'for']
['police', 'are', 'searching', 'for', 'a']
['police', 'are', 'searching', 'for', 'a', 'suspect']
['police', 'are', 'searching', 'for', 'a', 'suspect', 'after']
['police', 'are', 'searching', 'for', 'a', 'suspect', 'after', 'an']
['police', 'are', 'searching', 'for', 'a', 'suspect', 'after', 'an', 'abandoned']
['are']
['are', 'searching']
['are', 'searching', 'for']
['are', 'searching', 'for', 'a']
['are', 'searching', 'for', 'a', 'suspect']
['are', 'searching', 'for', 'a', 'suspe

### Entity Spans

In [10]:
for e in entity_spans[batch_i]:
    s = token_map[batch_i][e[0][0]]
    e = token_map[batch_i][e[0][1]]
    #print(re_text[s:e+1])

IndexError: list index out of range

In [11]:
len(df_we.iloc[205].entity_mentions) + len(df_we.iloc[205].event_mentions)

64

In [12]:
len(entity_spans[batch_i])

64

In [14]:
# Problem: Bei candidate spans stimmen die offsets bereits nicht und man bekommt out-of-range except.*
token_maps = []
re_texts = []
candidates_list = []
with tqdm.tqdm(test_loader) as progress_bar:
    for sample in progress_bar:

        input_ids, attention_mask, entity_spans, entity_types, entity_ids, relation_labels, text, token_map, candidate_spans, doc_ids = sample
        
        for batch_i in range(input_ids.size(0)):
            token_maps.append(token_map[batch_i])
            re_texts.append(tokenizer.convert_ids_to_tokens(input_ids[batch_i]))
            candidates_list.append(candidate_spans[batch_i])
    
    
    
        '''for batch_i in range(input_ids.size(0)):
            for e in entity_spans[batch_i]:
                print(e)
                s = token_map[batch_i][e[0][0]]
                e = token_map[batch_i][e[0][1]]
                print(text[batch_i][s:e+1])
                print(s,e)
                print(text[batch_i])'''
            

100%|███████████████████████████████████████████| 52/52 [00:00<00:00, 81.83it/s]


In [19]:
for cands, text, t_map in zip(candidates_list,re_texts,token_maps):
    for c in cands:
        s = t_map[c[0]]
        print(e)
        e = t_map[c[1]]
        text[s:e]


231
1
2
3
4
5
6
7
8
9
2
3
4
5
6
7
8
9
10
3
4
5
6
7
8
9
10
11
4
5
6
7
8
9
10
11
12
5
6
7
8
9
10
11
12
17
6
7
8
9
10
11
12
17
19
7
8
9
10
11
12
17
19
20
8
9
10
11
12
17
19
20
22
9
10
11
12
17
19
20
22
23
10
11
12
17
19
20
22
23
24
11
12
17
19
20
22
23
24
25
12
17
19
20
22
23
24
25
26
17
19
20
22
23
24
25
26
27
19
20
22
23
24
25
26
27
28
20
22
23
24
25
26
27
28
29
22
23
24
25
26
27
28
29
30
23
24
25
26
27
28
29
30
31
24
25
26
27
28
29
30
31
32
25
26
27
28
29
30
31
32
33
26
27
28
29
30
31
32
33
34
27
28
29
30
31
32
33
34
35
28
29
30
31
32
33
34
35
36
29
30
31
32
33
34
35
36
37
30
31
32
33
34
35
36
37
38
31
32
33
34
35
36
37
38
39
32
33
34
35
36
37
38
39
40
33
34
35
36
37
38
39
40
41
34
35
36
37
38
39
40
41
42
35
36
37
38
39
40
41
42
43
36
37
38
39
40
41
42
43
44
37
38
39
40
41
42
43
44
45
38
39
40
41
42
43
44
45
46
39
40
41
42
43
44
45
46
47
40
41
42
43
44
45
46
47
48
41
42
43
44
45
46
47
48
49
42
43
44
45
46
47
48
49
50
43
44
45
46
47
48
49
50
51
44
45
46
47
48
49
50
51
52
45
46
47
48
49


IndexError: list index out of range

In [None]:
for c in candidate_spans[batch_i]:
    s = token_map[batch_i][c[0]]
    e = token_map[batch_i][c[1]]
    print(s,e)

In [None]:
#Sanity Check for Candidate Span Offsets
batch_i = 1
for c in candidate_spans[batch_i]:
    s = token_map[batch_i][c[0]]
    e = token_map[batch_i][c[1]]

    print(text[batch_i][s:e])

In [None]:
df_we.iloc[205].event_mentions

In [57]:
re_text[token_map[batch_i][158]+1:token_map[batch_i][159]+1]

['detained']

In [50]:
token_map[batch_i][196]
token_map[batch_i][197]

215

In [59]:
for em in df_we.iloc[205].event_mentions:
    s = token_map[batch_i][em['trigger']['start']]+1
    e = token_map[batch_i][em['trigger']['end']]+1
    print(re_text[s:e])

['attack']
['claimed']
['detained']
['arrested']
['released']
['detained']
['released']
['arrest']


In [70]:
trigger_texts = []
for idx,(text,t_map) in enumerate(zip(re_texts,token_maps)):
    for em in df_we.iloc[idx].event_mentions:
        s = token_map[batch_i][em['trigger']['start']]
        e = token_map[batch_i][em['trigger']['end']]
        
        if(re_text[s:e] != em['trigger']['text']):
            print(re_text[s:e])
            print(em['trigger']['text'])
            print(idx,em)

['Bonn']
discovered
0 {'id': 'scenario_en_kairos_14-E1', 'event_type': 'Cognitive.IdentifyCategorize.Unspecified', 'trigger': {'start': 166, 'end': 167, 'text': 'discovered', 'sent_idx': 3}, 'arguments': []}
['Tuesday']
reviewing
0 {'id': 'scenario_en_kairos_14-E2', 'event_type': 'Cognitive.Inspection.SensoryObserve', 'trigger': {'start': 88, 'end': 89, 'text': 'reviewing', 'sent_idx': 1}, 'arguments': []}
['Wednesday']
searching
0 {'id': 'scenario_en_kairos_14-E3', 'event_type': 'Cognitive.IdentifyCategorize.Unspecified', 'trigger': {'start': 30, 'end': 31, 'text': 'searching', 'sent_idx': 0}, 'arguments': []}
['bomb']
attack
1 {'id': 'scenario_en_kairos_65-E1', 'event_type': 'Conflict.Attack.Unspecified', 'trigger': {'start': 50, 'end': 51, 'text': 'attack', 'sent_idx': 1}, 'arguments': []}
[',']
injured
1 {'id': 'scenario_en_kairos_65-E2', 'event_type': 'Life.Injure.Unspecified', 'trigger': {'start': 62, 'end': 63, 'text': 'injured', 'sent_idx': 2}, 'arguments': [{'entity_id': 'scen

IndexError: list index out of range

In [78]:
for em in df_we.iloc[1].event_mentions:
    print(len(token_map[1]))
    s = token_maps[1][em['trigger']['start']]+1
    e = token_maps[1][em['trigger']['end']]+1
    print(re_texts[1][s:e])

302
['##IS', 'attack']
302
['injured']
302
['de']
302
['went', 'off']
302
['de']


In [74]:
len(token_maps[1])

660

In [72]:
df_we.iloc[1].event_mentions

[{'id': 'scenario_en_kairos_65-E1',
  'event_type': 'Conflict.Attack.Unspecified',
  'trigger': {'start': 50, 'end': 51, 'text': 'attack', 'sent_idx': 1},
  'arguments': []},
 {'id': 'scenario_en_kairos_65-E2',
  'event_type': 'Life.Injure.Unspecified',
  'trigger': {'start': 62, 'end': 63, 'text': 'injured', 'sent_idx': 2},
  'arguments': [{'entity_id': 'scenario_en_kairos_65-T13',
    'role': 'Victim',
    'text': 'Terry Duffield'}]},
 {'id': 'scenario_en_kairos_65-E3',
  'event_type': 'Conflict.Attack.DetonateExplode',
  'trigger': {'start': 65, 'end': 66, 'text': 'detonated', 'sent_idx': 2},
  'arguments': []},
 {'id': 'scenario_en_kairos_65-E4',
  'event_type': 'Conflict.Attack.DetonateExplode',
  'trigger': {'start': 417, 'end': 419, 'text': 'went off', 'sent_idx': 12},
  'arguments': []},
 {'id': 'scenario_en_kairos_65-E5',
  'event_type': 'Conflict.Attack.DetonateExplode',
  'trigger': {'start': 433, 'end': 434, 'text': 'detonated', 'sent_idx': 14},
  'arguments': []}]

In [89]:
out = tokenizer("Smithers was shot when Cowley was in Farah",return_offsets_mapping=True)

In [90]:
tokenizer.convert_ids_to_tokens(out.input_ids)

['[CLS]',
 'Smith',
 '##ers',
 'was',
 'shot',
 'when',
 'Co',
 '##wley',
 'was',
 'in',
 'Far',
 '##ah',
 '[SEP]']

In [92]:
out.offset_mapping

[(0, 0),
 (0, 5),
 (5, 8),
 (9, 12),
 (13, 17),
 (18, 22),
 (23, 25),
 (25, 29),
 (30, 33),
 (34, 36),
 (37, 40),
 (40, 42),
 (0, 0)]

In [91]:
tokenizer("The quick brown fox",return_offsets_mapping=True)

{'input_ids': [101, 1109, 3613, 3058, 17594, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1], 'offset_mapping': [(0, 0), (0, 3), (4, 9), (10, 15), (16, 19), (0, 0)]}

In [None]:
(input_ids,attention_mask,attention_mask)