In [58]:
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from transformers import AutoTokenizer
from transformers import AutoModel
import torch

import numpy as np

import networkx as nx
import spacy

import pandas as pd

import ast

import pprint

In [15]:
#layers = [-4, -3, -2, -1]
layers = [-1]
model = AutoModel.from_pretrained('bert-base-cased', output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
nlp = spacy.load("en_core_web_sm")

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [18]:
device = torch.device('cpu')


In [16]:
def get_hidden_states(encoded, model, layers):
    with torch.no_grad():
         output = model(**encoded)
    # Get all hidden states
    states = output.hidden_states
    # Stack and sum all requested layers
    output = torch.stack([states[i] for i in layers]).sum(0).squeeze()

    return output

def get_words_vector(sent, tokenizer, model, layers):
    encoded = tokenizer.encode_plus(sent, return_tensors="pt")
    # get all token idxs that belong to the word of interest
    #token_ids_word = np.where(np.array(encoded.word_ids()) == idx)

    return get_hidden_states(encoded, model, layers)

In [17]:
corpus = open("dev_clean_format.txt").readlines()
corpus = [t.replace('\n', '') for t in corpus]

relations = pd.read_csv("dev_relations.tsv", delimiter='\t', header=None)
relations.fillna('<NONE>')

entities = pd.read_csv('dev_entities.tsv', delimiter='\t', header=0)
column_sentence = entities.columns[0]
column_surface_form = entities.columns[1]
column_pos = entities.columns[2]


In [49]:
list_of_networks = []

for enum, sentence in enumerate(corpus[:5]):
    print(enum, "/", len(corpus), end='\r')
    try:
        network = nx.Graph()
        edge_list = []
        id_sentence = enum

        specific_rel = relations.iloc[enum]    
        specific_ent = entities.loc[entities[column_sentence] == enum]
        
        doc = nlp(sentence)
        tokens = [t for t in doc]

        sent_embeddings = get_words_vector(sentence, tokenizer, model, layers)

        id_token = 0
        for enum, t in enumerate(tokens):
            tokens_bert = tokenizer.tokenize(t.text, add_special_tokens=False)
            token_idx = tokenizer.encode(t.text, add_special_tokens=False)
            token_embeddings = []
            for token_id in token_idx:
                token_embeddings.append(sent_embeddings[id_token])
                id_token += 1

            if len(token_embeddings) > 1:
                token_embeddings = torch.stack(token_embeddings).to(device)
                token_embeddings = torch.mean(token_embeddings, -2).tolist()
            else:
                token_embeddings = token_embeddings[0].tolist()

            edge = (t.i, t.head.i, t.dep_)
            edge_list.append(edge)
            network.add_node(t.i, embedding=token_embeddings, label=t.text, type='token')

        for edge in edge_list:
            network.add_edge(edge[0], edge[1], label=edge[2])

        rel_label = specific_rel[0]
        
        subj_label = specific_ent.loc[specific_ent[column_pos] == specific_rel[1]][column_surface_form]
        obj_label = specific_ent.loc[specific_ent[column_pos] == specific_rel[2]][column_surface_form]

        
        rel_subj = [i-1 for i in ast.literal_eval(specific_rel[1])]
        rel_obj = [i-1 for i in ast.literal_eval(specific_rel[2])]

        nodesubj = enum+1
        nodeobj = enum+2

        embeddings_subj = []
        embeddings_obj = []

        for n in range(rel_subj[0], rel_subj[1]+1):
            #test this vs random initialization?
            embeddings_subj.append(network.nodes[n]['embedding'])
            network.add_edge(n, nodesubj, label="in_entity")

        for n in range(rel_obj[0], rel_obj[1]+1):
            embeddings_obj.append(network.nodes[n]['embedding'])
            network.add_edge(n, nodeobj, label="in_entity")

    
        embeddings_subj = torch.stack(embeddings_subj).to(device)
        embeddings_obj = torch.stack(embeddings_obj).to(device)

        network.nodes[nodesubj]['embedding'] = torch.mean(embeddings_subj, -2).tolist()
        network.nodes[nodeobj]['embedding'] = torch.mean(embeddings_obj, -2).tolist()

        network.nodes[nodesubj]['label'] = subj_label
        network.nodes[nodeobj]['label'] = obj_label

        network.nodes[nodesubj]['type'] = 'subjEntity'
        network.nodes[nodeobj]['type'] = 'objEntity'
        
        #network.add_edge(nodesubj, nodeobj, label=rel_label)


        list_of_networks.append(network)
        
    except:
        pass

4 / 1714

In [51]:
for enum, network in enumerate(list_of_networks):
    nx.write_gpickle(network, f'dev_test_embeddings_sentence_{enum}.pickle')

In [59]:
for node in network.nodes():
    pprint.pprint(network.nodes()[node])

{'embedding': [0.39194196462631226,
               -0.021351080387830734,
               0.05150236561894417,
               0.16482554376125336,
               -0.3031255900859833,
               0.017282968387007713,
               -0.05545925721526146,
               -0.17744414508342743,
               -0.0034507019445300102,
               -1.081575632095337,
               -0.2698577642440796,
               0.06727661192417145,
               -0.058423977345228195,
               -0.20837141573429108,
               -0.22354884445667267,
               -0.16161976754665375,
               0.09785101562738419,
               -0.17580853402614594,
               -0.011503915302455425,
               0.02218370884656906,
               0.14136207103729248,
               0.019204245880246162,
               0.45536646246910095,
               -0.2767193913459778,
               0.4267374277114868,
               -0.011106511577963829,
               0.06215640902519226,
           

               0.019990185275673866,
               -0.23627230525016785,
               -0.3771041929721832,
               -2.79923939704895,
               0.30351659655570984,
               0.12701039016246796,
               -0.07197537273168564,
               -0.27643513679504395,
               -0.17402470111846924,
               0.20669052004814148,
               -0.03951617330312729,
               0.5492315888404846,
               -0.1319860816001892,
               0.5835739374160767,
               -0.1807684451341629,
               0.3305374085903168,
               -0.001956629566848278,
               -0.09703775495290756,
               0.25593793392181396,
               0.07767974585294724,
               -0.0333997868001461,
               0.24745799601078033,
               -0.3420375883579254,
               -0.018941253423690796,
               0.24343736469745636,
               0.13227468729019165,
               -0.27639517188072205,
               -0.522

               0.25316137075424194,
               -0.4488154649734497,
               0.1324242651462555,
               -0.19223588705062866,
               0.13781267404556274,
               0.4628254473209381,
               0.19766610860824585,
               -0.4282374680042267,
               0.25761884450912476,
               -0.7289301753044128,
               0.11172566562891006,
               0.11395470798015594,
               -0.2668493092060089,
               0.190031960606575,
               0.26353898644447327,
               -0.05588875338435173,
               0.3935449421405792,
               -0.36373990774154663,
               -0.02485961653292179,
               0.005766584537923336,
               -0.5118984580039978,
               -0.5576788187026978,
               -0.3784717917442322,
               0.09086829423904419,
               0.3817165195941925,
               0.2722310721874237,
               -0.2808632254600525,
               -0.425278484821

               0.05105408653616905,
               0.5561125874519348,
               0.47784924507141113,
               0.32802534103393555,
               -0.33507290482521057,
               0.022581318393349648,
               -0.3977358639240265,
               0.27011677622795105,
               -0.20085369050502777,
               1.0836822986602783,
               -0.5225558876991272,
               0.5701486468315125,
               0.004181129392236471,
               -0.5824798941612244,
               -0.5710626244544983,
               -0.22668007016181946,
               0.07271118462085724,
               0.1794930100440979,
               0.5146499872207642,
               0.14011700451374054,
               -0.3088413178920746,
               -0.10884607583284378,
               -0.12248067557811737,
               0.010098470374941826,
               0.06629649549722672,
               0.9670218825340271,
               -0.16787929832935333,
               0.27306431

               0.34183913469314575,
               -0.31366264820098877,
               -0.4288724958896637,
               0.08195947855710983,
               0.06116098538041115,
               0.06222376972436905,
               0.3720608651638031,
               -0.09372273832559586,
               0.07412656396627426,
               -0.03717101737856865,
               -0.16165074706077576,
               0.31262147426605225,
               -0.351041704416275,
               0.07141965627670288,
               -0.6515282392501831,
               -0.31232884526252747,
               0.18483223021030426,
               -0.07469368726015091,
               -0.16338568925857544,
               -0.36688727140426636,
               -0.18893185257911682,
               0.4760788679122925,
               -0.01647009886801243,
               -0.48637768626213074,
               -0.5691742300987244,
               0.09621168673038483,
               -0.18750305473804474,
               0.45

               -0.31201863288879395,
               0.35123756527900696,
               0.5568476319313049,
               -0.3550277352333069,
               0.4030081331729889,
               -0.3093535006046295,
               -0.37753990292549133,
               0.12405320256948471,
               -0.40624675154685974,
               -0.3027859330177307,
               -0.18635007739067078,
               -0.280856728553772,
               -0.32469531893730164,
               0.512814998626709,
               0.19534002244472504,
               0.10239658504724503,
               -0.6763244867324829,
               0.012097864411771297,
               -0.1812322437763214,
               0.7277418971061707,
               -0.593099057674408,
               0.022160224616527557,
               -0.21485181152820587,
               -0.239424929022789,
               0.6354714632034302,
               -0.17251640558242798,
               -0.43247073888778687,
               0.4205845594

               -0.14219389855861664,
               -0.36443331837654114,
               -0.6418518424034119,
               0.0795857235789299,
               -0.25172364711761475,
               0.6301581859588623,
               0.30047401785850525,
               0.29681092500686646,
               0.4207507371902466,
               -0.07258524745702744,
               -0.0742538720369339,
               -0.14165933430194855,
               -0.01786797121167183,
               -0.07184360176324844,
               0.28616538643836975,
               -0.03704552352428436,
               -0.1565600335597992,
               -0.03786328434944153,
               0.8235569000244141,
               -0.08738787472248077,
               -0.10354935377836227,
               0.008542094379663467,
               -0.33575475215911865,
               0.01698531024158001,
               -0.33389124274253845,
               -0.2470361739397049,
               0.6343410015106201,
               -0.3

               0.2416950762271881,
               -0.02252296172082424,
               -0.026999738067388535,
               -0.08328106254339218,
               0.3101417124271393,
               -0.1380205750465393,
               -0.26401758193969727,
               0.05159098282456398,
               -0.4460228979587555,
               -0.2754266858100891,
               0.21334367990493774,
               -0.14081265032291412,
               -0.16472753882408142,
               0.7160155773162842,
               0.2731054127216339,
               -0.24238619208335876,
               0.0050783236511051655],
 'label': 'and',
 'type': 'token'}
{'embedding': [0.055573198944330215,
               -0.6533055305480957,
               -0.8201769590377808,
               -0.17105641961097717,
               0.5109385251998901,
               -0.34653255343437195,
               -0.20696914196014404,
               -0.5187040567398071,
               -0.15875792503356934,
               0.1

               0.12124883383512497,
               1.1344186067581177,
               0.260549396276474,
               0.6517025828361511,
               0.10891655087471008,
               -0.05006555840373039,
               -0.12310942262411118,
               -0.4307292103767395,
               0.8610190749168396,
               0.6206980347633362,
               -0.3585297167301178,
               -0.19493111968040466,
               -0.16396646201610565,
               -0.23660950362682343,
               -0.026576237753033638,
               0.22153502702713013,
               0.18120624125003815,
               -0.31774020195007324,
               -0.2636891007423401,
               -0.780871570110321,
               0.5711717009544373,
               -0.26314446330070496,
               -0.3545282185077667,
               -0.48114949464797974,
               0.27510982751846313,
               0.654908299446106,
               0.01298349816352129,
               -0.1513244360

               -0.2816949188709259,
               -0.22712655365467072,
               -0.40377160906791687,
               0.1775580197572708,
               -0.025448553264141083,
               0.3622649610042572,
               -0.2771345376968384,
               0.18611720204353333,
               0.48166322708129883,
               -0.09626629948616028,
               -0.21262980997562408,
               0.10322930663824081,
               0.04878479614853859,
               -0.35582235455513,
               0.23957277834415436,
               -0.48693785071372986,
               0.03403779864311218,
               -0.38770991563796997,
               -0.3754638433456421,
               0.027731018140912056,
               0.26262691617012024,
               -0.09739069640636444,
               0.17915433645248413,
               0.02519953064620495,
               -0.4135172963142395,
               0.25714707374572754,
               -0.03507233038544655,
               0.2944

               -0.012677283957600594,
               0.1005711555480957,
               -0.2074679136276245,
               0.09625961631536484,
               -0.19693723320960999,
               0.39805418252944946,
               0.11991723626852036,
               0.8283587694168091,
               0.08176112920045853,
               0.1439792960882187,
               0.5800865888595581,
               -0.08828970044851303,
               -0.11159110069274902,
               -0.2918369770050049,
               -0.31372174620628357,
               0.31723591685295105,
               -0.2529270648956299,
               0.34918200969696045,
               -0.9519758224487305,
               -0.17105771601200104,
               -0.11600420624017715,
               -0.813572108745575,
               0.38865992426872253,
               0.5013378262519836,
               -0.07518357038497925,
               0.16820751130580902,
               0.3364822268486023,
               -0.15486882

               0.23592449724674225,
               0.44766077399253845,
               -0.04376574233174324,
               -0.47420534491539,
               -0.06017596274614334,
               -0.4466721713542938,
               -0.274444580078125,
               0.1405183970928192,
               0.42703983187675476,
               0.5170071721076965,
               -0.06901340186595917,
               0.1124114990234375,
               0.6911598443984985,
               -0.014495356939733028,
               -0.5857446193695068,
               -0.23277437686920166,
               1.090188980102539,
               -0.12700895965099335,
               0.4384259283542633,
               -0.6131542325019836,
               -0.7356797456741333,
               -0.15092381834983826,
               -0.45723649859428406,
               0.8790286183357239,
               0.21644996106624603,
               0.059473004192113876,
               0.7028496861457825,
               0.2271448522806

               -0.5132522583007812,
               -0.18448090553283691,
               0.13218851387500763,
               0.01178787276148796,
               -0.369861364364624,
               -0.46379679441452026,
               -0.34748321771621704,
               0.2298252433538437,
               0.24375873804092407,
               0.19459766149520874,
               -0.18215292692184448,
               0.1335693597793579,
               0.4052704870700836,
               0.8016082644462585,
               -0.0316232331097126,
               -0.3564591705799103,
               0.7028185725212097,
               -0.0969512015581131,
               0.5279228687286377,
               -0.020857203751802444,
               0.2920624911785126,
               -0.4058748781681061,
               -0.046835511922836304,
               -0.21229100227355957,
               0.4308573007583618,
               0.04920881614089012,
               0.20123718678951263,
               0.64923644065

               -0.40026363730430603,
               -0.5910652279853821,
               -0.6003663539886475,
               -0.6325125098228455,
               0.5048944354057312,
               0.21630467474460602,
               0.5210700631141663,
               0.13250082731246948,
               -0.2693667709827423,
               0.24338611960411072,
               0.16670891642570496,
               0.2921102046966553,
               -0.3650311231613159,
               -0.06511622667312622,
               -0.22429326176643372,
               -0.5491403937339783,
               0.44204244017601013,
               0.8128686547279358,
               0.01577984169125557,
               -0.3460463285446167,
               -0.16858310997486115,
               -0.47999656200408936,
               -0.0609663762152195,
               0.8539027571678162,
               0.10461737960577011,
               -0.35104554891586304,
               -0.37094324827194214,
               -0.64647859

               -0.07576975226402283,
               -0.675157368183136,
               -0.43899911642074585,
               -0.1641295850276947,
               0.2744769752025604,
               1.1151690483093262,
               0.8350476026535034,
               -0.19856004416942596,
               -0.15395447611808777,
               -0.23330725729465485,
               0.36957502365112305,
               -0.6085110902786255,
               0.3656136095523834,
               0.7947342991828918,
               0.15954631567001343,
               -0.2238384634256363,
               0.2887592613697052,
               -0.11968378722667694,
               0.21624445915222168,
               -0.8580787181854248,
               -0.07770521938800812,
               0.2616744339466095,
               0.5031290054321289,
               -0.6357696652412415,
               0.20578519999980927,
               -0.3146752715110779,
               0.09093809872865677,
               -0.292545676231

               -0.45042097568511963,
               0.8172866702079773,
               -0.5476997494697571,
               0.0918930172920227,
               -1.0507622957229614,
               0.46668097376823425,
               0.6982491612434387,
               -0.33201971650123596,
               1.004198670387268,
               0.20285072922706604,
               0.935715913772583,
               -1.0858532190322876,
               -0.01772080920636654,
               0.6402831077575684,
               -0.40726742148399353,
               0.124561607837677,
               0.3154285252094269,
               -0.4451071619987488,
               0.07954257726669312,
               -0.1875189244747162,
               0.6189901232719421,
               0.3060384690761566,
               0.7441754341125488,
               0.7267709374427795,
               -0.41535791754722595,
               0.9426965117454529,
               -0.39742034673690796,
               0.5622731447219849,
   

               -0.021884232759475708,
               0.23919174075126648,
               -0.1397913098335266,
               -0.12607190012931824,
               0.4392608106136322,
               -0.3227083683013916,
               0.2589443027973175,
               -0.5361841320991516,
               -0.6478542685508728,
               0.14433686435222626,
               0.07310016453266144,
               0.3868565559387207,
               -0.27359431982040405,
               -0.12812553346157074,
               -0.42087092995643616,
               0.4068351984024048,
               -0.23917178809642792,
               -0.4940826892852783,
               -0.43261319398880005,
               -0.34765830636024475,
               0.39217159152030945,
               -0.06793602555990219,
               -0.1680760532617569,
               0.29298627376556396,
               0.008042626082897186,
               -0.3756818175315857,
               0.009642870165407658,
               0.031

In [63]:
network.edges

EdgeView([(0, 1), (1, 3), (1, 27), (2, 3), (3, 3), (3, 4), (3, 26), (4, 6), (5, 6), (6, 7), (7, 10), (8, 10), (9, 10), (10, 11), (10, 15), (12, 15), (13, 15), (14, 15), (15, 16), (15, 20), (16, 19), (17, 19), (18, 19), (20, 22), (21, 22), (21, 28), (22, 23), (23, 25), (24, 25)])