In [18]:
'''
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
'''

import os
import re
import torch
import subprocess
from pytorch_transformers import *
import random
from bs4 import BeautifulSoup
from nltk.corpus import wordnet as wn
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
from sklearn.feature_extraction.text import TfidfVectorizer

pos_converter = {'NOUN':'n', 'PROPN':'n', 'VERB':'v', 'AUX':'v', 'ADJ':'a', 'ADV':'r'}

def generate_key(lemma, pos):
    if pos in pos_converter.keys():
        pos = pos_converter[pos]
    key = '{}+{}'.format(lemma, pos)
    return key

def load_pretrained_model(name):
    if name == 'roberta-base':
        # model = RobertaModel.from_pretrained('roberta-base')
        model = RobertaModel.from_pretrained('roberta-base', output_hidden_states=True)
        hdim = 768
    elif name == 'roberta-large':
        # model = RobertaModel.from_pretrained('roberta-large')
        model = RobertaModel.from_pretrained('roberta-large', output_hidden_states=True)
        hdim = 1024
    elif name == 'xlmroberta-base':
        model = AutoModel.from_pretrained("xlm-roberta-base", output_hidden_states=True)
        hdim = 768
    elif name == 'xlmroberta-large':
        model = AutoModel.from_pretrained("xlm-roberta-large", output_hidden_states=True)
        hdim = 1024
    elif name == 'bert-large':
        model = BertModel.from_pretrained('bert-large-cased', output_hidden_states=True)
        hdim = 1024
    else: #bert base
        model = BertModel.from_pretrained('bert-base-cased', output_hidden_states=True)
        hdim = 768
    return model, hdim

def load_tokenizer(name):
    if name == 'roberta-base':
        tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    elif name == 'roberta-large':
        tokenizer = RobertaTokenizer.from_pretrained('roberta-large')
    elif name == 'xlmroberta-base':
        tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
    elif name == 'xlmroberta-large':
        tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-large")
    elif name == 'bert-large':
        tokenizer = BertTokenizer.from_pretrained('bert-large-cased')
    else: #bert base
        tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
    return tokenizer

def load_wn_senses(path):
    wn_senses = {}
    with open(path, 'r', encoding="utf8") as f:
        for line in f:
            line = line.strip().split('\t')
            lemma = line[0]
            pos = line[1]
            senses = line[2:]

            key = generate_key(lemma, pos)
            wn_senses[key] = senses
    return wn_senses

def get_label_space(data):
    #get set of labels from dataset
    labels = set()

    for sent in data:
        for _, _, _, _, label in sent:
            if label != -1:
                labels.add(label)

    labels = list(labels)
    labels.sort()
    labels.append('n/a')

    label_map = {}
    for sent in data:
        for _, lemma, pos, _, label in sent:
            if label != -1:
                key = generate_key(lemma, pos)
                label_idx = labels.index(label)
                if key not in label_map: label_map[key] = set()
                label_map[key].add(label_idx)

    return labels, label_map

def process_encoder_outputs(output, mask, as_tensor=False):
    combined_outputs = []
    position = -1
    avg_arr = []
    for idx, rep in zip(mask, torch.split(output, 1, dim=0)):
        #ignore unlabeled words
        if idx == -1: continue
        #average representations for units in same example
        elif position < idx:
            position=idx
            if len(avg_arr) > 0: combined_outputs.append(torch.mean(torch.stack(avg_arr, dim=-1), dim=-1))
            avg_arr = [rep]
        else:
            assert position == idx
            avg_arr.append(rep)
    #get last example from avg_arr
    if len(avg_arr) > 0: combined_outputs.append(torch.mean(torch.stack(avg_arr, dim=-1), dim=-1))
    if as_tensor: return torch.cat(combined_outputs, dim=0)
    else: return combined_outputs

#run WSD Evaluation Framework scorer within python
def evaluate_output(scorer_path, gold_filepath, out_filepath):
    eval_cmd = ['java','-cp', scorer_path, 'Scorer', gold_filepath, out_filepath]
    output = subprocess.Popen(eval_cmd, stdout=subprocess.PIPE ).communicate()[0]
    output = [x.decode("utf-8") for x in output.splitlines()]
    p, r, f1 = [float(output[i].split('=')[-1].strip()[:-1]) for i in range(3)]
    return p, r, f1

def get_adj_keys():
    key_list = []
    for synset in wn.all_synsets('a'):
        for lemma in synset.lemmas():
            key_list.extend([lemma.key()])
    return key_list

def load_data(datapath, name, train_sent=None):
    if 'wngt' in name:
        name, new_name = name.split('-')
    else:
        name, new_name = name, ''
    text_path = os.path.join(datapath, '{}.data.test.xml'.format(name))
    gold_path = os.path.join(datapath, '{}.gold.key.txt'.format(name))

    #load gold labels
    gold_labels = {}
    with open(gold_path, 'r', encoding="utf8") as f:
        for line in f:
            line = line.strip().split(' ')
            instance = line[0]
            #this means we are ignoring other senses if labeled with more than one
            #(happens at least in SemCor data)
            key = line[1]
            gold_labels[instance] = key

    #load train examples + annotate sense instances with gold labels
    sentences = []
    s = []
    with open(text_path, 'r', encoding="utf8") as f:
        for line in f:
            line = line.strip()
            if line == '</sentence>':
                sentences.append(s)
                s=[]
                if 'semcor' in name and len(sentences) >= train_sent:
                    break

            elif line.startswith('<instance') or line.startswith('<wf'):
                word = re.search('>(.+?)<', line).group(1)
                # print(line)
                try:
                    lemma = re.search('lemma="(.+?)"', line).group(1)
                except AttributeError:
                    lemma = word.lower()
                pos = re.search('pos="(.+?)"', line).group(1)

                #clean up data
                word = re.sub('&apos;', '\'', word)
                lemma = re.sub('&apos;', '\'', lemma).lower()

                sense_inst = -1
                sense_label = -1
                if line.startswith('<instance'):
                    sense_inst = re.search('instance id="(.+?)"', line).group(1)
                    #annotate sense instance with gold label
                    sense_label = gold_labels.get(sense_inst)
                    sense_label = sense_label if sense_label else -1
                s.append((word, lemma, pos, sense_inst, sense_label))
    if new_name and 'semcor' in name:
        # sent_num = 0
        extra_path = os.path.join(datapath, '{}.xml'.format(new_name))
        wngt_corpus = open(extra_path, 'r').read()
        wsd_bs = BeautifulSoup(wngt_corpus, 'xml')
        text_all = wsd_bs.find_all('sentence')
        type2pos = {'j': 'ADJ', 'n': 'NOUN', 'r': 'ADV', 'v': 'VERB'}

        adj_keys = get_adj_keys()
        num = 0
        for sent in tqdm(text_all[:]):
            s = []
            for word in sent.find_all('word'):
                w = word['surface_form'].replace('_', ' ')
                lemma = word['lemma'] if 'lemma' in word.attrs else word['surface_form'].replace('_', ' ')
                pos = type2pos[word['pos'][0].lower()] if word['pos'][0].lower() in type2pos else word['pos']
                key = word['wn30_key'].split(';')[0] if 'wn30_key' in word.attrs else -1
                if key != -1 and key not in adj_keys and '%3:' in key:
                    pos_string = key.split('%')[1][0]
                    replace_string = '35'.replace(key.split('%')[1][0], '')
                    key = key.replace('%' + pos_string + ':', '%' + replace_string + ':')
                sense_inst = 'd0.s%d.t0' % num if key != -1 else -1
                s.append((w, lemma, pos, sense_inst, key))
            num += 1
            sentences.append(s)

    return sentences

#normalize ids list, masks to whatever the passed in length is
def normalize_length(ids, attn_mask, o_mask, max_len, pad_id):
    if max_len == -1:
        return ids, attn_mask, o_mask
    else:
        if len(ids) < max_len:
            while len(ids) < max_len:
                ids.append(torch.tensor([[pad_id]]))
                attn_mask.append(0)
                o_mask.append(-1)
        else:
            ids = ids[:max_len-1]+[ids[-1]]
            attn_mask = attn_mask[:max_len]
            o_mask = o_mask[:max_len]

        assert len(ids) == max_len
        assert len(attn_mask) == max_len
        assert len(o_mask) == max_len

        return ids, attn_mask, o_mask

#filters down training dataset to (up to) k examples per sense 
#for few-shot learning of the model
def filter_k_examples(data, k):
    #shuffle data so we don't only get examples for (common) senses from beginning
    random.shuffle(data)
    #track number of times sense from data is used
    sense_dict = {}
    #store filtered data
    filtered_data = []

    example_count = 0
    for sent in data:
        filtered_sent = []
        for form, lemma, pos, inst, sense in sent:
            #treat unlabeled words normally
            if sense == -1:
                x  = (form, lemma, pos, inst, sense)
            elif sense in sense_dict:
                if sense_dict[sense] < k:
                    #increment sense count and add example to filtered data
                    sense_dict[sense] += 1
                    x = (form, lemma, pos, inst, sense)
                    example_count += 1
                else: #if the data already has k examples of this sense
                    #add example with no instance or sense label to data
                    x = (form, lemma, pos, -1, -1)
            else:
                #add labeled example to filtered data and sense dict
                sense_dict[sense] = 1
                x = (form, lemma, pos, inst, sense)
                example_count += 1
            filtered_sent.append(x)
        filtered_data.append(filtered_sent)

    print("k={}, training on {} sense examples...".format(k, example_count))

    return filtered_data

def tokenize_glosses(encoder_name, gloss_arr, tokenizer, max_len):
    glosses = []
    masks = []
    for gloss_text in gloss_arr:
        if 'xlm' in encoder_name:
            g_ids = [torch.tensor([[x]]) for x in tokenizer.encode(gloss_text)]
        else:
            g_ids = [torch.tensor([[x]]) for x in
                 tokenizer.encode(tokenizer.cls_token) + tokenizer.encode(gloss_text) + tokenizer.encode(
                     tokenizer.sep_token)]
        g_attn_mask = [1]*len(g_ids)
        g_fake_mask = [-1]*len(g_ids)
        if 'xlm' in encoder_name:
            g_ids, g_attn_mask, _ = normalize_length(g_ids, g_attn_mask, g_fake_mask, max_len,
                                                     pad_id=tokenizer.encode(tokenizer.pad_token)[1])
        else:
            g_ids, g_attn_mask, _ = normalize_length(g_ids, g_attn_mask, g_fake_mask, max_len,
                                                 pad_id=tokenizer.encode(tokenizer.pad_token)[0])
        g_ids = torch.cat(g_ids, dim=-1)
        g_attn_mask = torch.tensor(g_attn_mask)
        glosses.append(g_ids)
        masks.append(g_attn_mask)

    return glosses, masks



def load_and_preprocess_glosses(data, tokenizer, wn_senses, max_len=-1):
    sense_glosses = {}

    for sent in data:
        for _, lemma, pos, _, label in sent:
            if label == -1:
                continue  # ignore unlabeled words
            else:
                key = generate_key(lemma, pos)
                if key not in sense_glosses:
                    # get all sensekeys for the lemma/pos pair
                    # get wordnet key
                    sensekey_arr = wn_senses[key]
                    if max_len <= 32:
                        gloss_arr = [wn.lemma_from_key(s).synset().definition() for s in sensekey_arr]
                        # print('gloss_arr:',gloss_arr)
                    else:
                        gloss_arr = [wn.lemma_from_key(s).synset().definition() + ' ' + '. '.join(
                         wn.lemma_from_key(s).synset().examples()) for s in sensekey_arr]

                    # preprocess glosses into tensors
                    gloss_ids, gloss_masks = tokenize_glosses('bert-base', gloss_arr, tokenizer, max_len)
                    gloss_ids = torch.cat(gloss_ids, dim=0)
                    gloss_masks = torch.stack(gloss_masks, dim=0)
                    sense_glosses[key] = (gloss_ids, gloss_masks, sensekey_arr)

                # make sure that gold label is retrieved synset
                assert label in sense_glosses[key][2]

    return sense_glosses

In [2]:
def preprocess_context(encoder_name, context_len, gloss_bsz, context_mode, tokenizer, text_data, gloss_dict=None, bsz=4, max_len=128):
    if max_len == -1: assert bsz==1 #otherwise need max_length for padding

    context_ids = []
    context_attn_masks = []

    example_keys = []

    context_output_masks = []
    instances = []
    labels = []

    #tensorize data
    # print(tokenizer.encode(tokenizer.cls_token), tokenizer.encode(tokenizer.sep_token))
    for sent in (text_data):
        #cls token aka sos token, returns a list with index
        if 'xlm' in encoder_name:
            c_ids = [torch.tensor([tokenizer.encode(tokenizer.cls_token)[1:-1]])]
        else:
            c_ids = [torch.tensor([tokenizer.encode(tokenizer.cls_token)])]
        o_masks = [-1]
        sent_insts = []
        sent_keys = []
        sent_labels = []

        #For each word in sentence...
        key_len = []
        for idx, (word, lemma, pos, inst, label) in enumerate(sent):
            #tensorize word for context ids
            if 'xlm' in encoder_name:
                word_ids = [torch.tensor([[x]]) for x in tokenizer.encode(word.lower())[1:-1]]
            else:
                word_ids = [torch.tensor([[x]]) for x in tokenizer.encode(word.lower())]
            c_ids.extend(word_ids)

            #if word is labeled with WSD sense...
            if label != -1:
                #add word to bert output mask to be labeled
                o_masks.extend([idx]*len(word_ids))
                #track example instance id
                sent_insts.append(inst)
                #track example instance keys to get glosses
                ex_key = generate_key(lemma, pos)
                sent_keys.append(ex_key)
                key_len.append(len(gloss_dict[ex_key][2]))
                sent_labels.append(label)
            else:
                #mask out output of context encoder for WSD task (not labeled)
                o_masks.extend([-1]*len(word_ids))

            #break if we reach max len
            if max_len != -1 and len(c_ids) >= (max_len-1):
                break

        if 'xlm' in encoder_name:
            c_ids.append(torch.tensor([tokenizer.encode(tokenizer.sep_token)[1:-1]])) #aka eos token
        else:
            c_ids.append(torch.tensor([tokenizer.encode(tokenizer.sep_token)]))  # aka eos token
        c_attn_mask = [1]*len(c_ids)
        o_masks.append(-1)
        assert len(c_ids) == len(o_masks)

        #not including examples sentences with no annotated sense data
        if len(sent_insts) > 0:
            context_ids.append(c_ids)
            context_attn_masks.append(c_attn_mask)
            context_output_masks.append(o_masks)
            example_keys.append(sent_keys)
            instances.append(sent_insts)
            labels.append(sent_labels)

    #package data
    context_dict = dict()

    doc_id, doc_seg = [], []
    for index, x in enumerate(instances):
        inst = '.'.join(x[0].split('.')[:-2])
        if inst not in doc_id:
            doc_id.append(inst)
            doc_seg.append(index)
    doc_seg.append(len(instances))
    new_context, new_attn_mask, new_out_mask = [], [], []

    # 针对每个文档
    for seg_index, seg_id in enumerate((doc_seg[:-1])):
        ids_c = context_ids[seg_id: doc_seg[seg_index + 1]]
        attn_masks_c = context_attn_masks[seg_id: doc_seg[seg_index + 1]]
        output_masks_c = context_output_masks[seg_id: doc_seg[seg_index + 1]]
        example_keys_c = example_keys[seg_id: doc_seg[seg_index + 1]]
        instances_c = instances[seg_id: doc_seg[seg_index + 1]]
        valid_instance = [i for i in instances_c[0] if i != -1][0]
        sent_ids = ['.'.join(i[0].split('.')[:-1]) for i in instances_c]
        if len(valid_instance.split('.')[0]) > 2:
            # doc = [' '.join(examp) for examp in example_keys_c]
            doc = [' '.join([i.split('+')[0] for i in examp if i.split('+')[1] in 'nvar']) for examp in example_keys_c]
            vectorizer = TfidfVectorizer()
            doc_mat = vectorizer.fit_transform(doc).toarray()
            for sent_id, vec in enumerate(doc_mat):
                scores = doc_mat[:, doc_mat[sent_id].nonzero()[0]].sum(1)
                
                #context_len控制了tf-idf筛选的数量
                id_score = [j for j in
                            sorted(zip([i for i in range(len(doc_mat))], scores), key=lambda x: x[1], reverse=True) if
                            j[0] != sent_id][:context_len]
                selected = [i[0] for i in id_score]
                
                #context_len控制了上下文窗口大小
                window_id = [i for i in range(len(doc_mat))][
                            max(sent_id - context_len, 0):sent_id + context_len + 1]
                
                pure_neighbor = [i for i in window_id if i != sent_id]
                #
                if context_mode == 'all':
                    ids = sorted(set(selected + [sent_id] + pure_neighbor))
                    # ids = selected + [sent_id] + pure_neighbor
                elif context_mode == 'nonselect':
                    ids = sorted(set([sent_id] + pure_neighbor))
                    # ids = [sent_id] + pure_neighbor
                elif context_mode == 'nonwindow':
                    ids = sorted(set(selected + [sent_id]))
                    # ids = selected + [sent_id]
                else:
                    ids = [sent_id]

                # total_len = len(sum([ids_c[i]for i in ids], []))
                # while total_len > 512:
                #     distance_index = sorted([(abs(s_id-sent_id), s_id) for s_id in ids], reverse=True)
                #     ids.remove(distance_index[0][1])
                #     total_len = len(sum([ids_c[i] for i in ids], []))
                    
                if context_len > 0:
                    new_context.append(sum([ids_c[i]for i in ids], []))
                    new_attn_mask.append(sum([attn_masks_c[i] for i in ids], []))
                    new_out_mask.append(
                        sum([[-1] * len(output_masks_c[i]) if i != sent_id else output_masks_c[i] for i in ids], []))
                    assert len(new_context[-1]) == len(new_attn_mask[-1]) == len(new_out_mask[-1])
                else:
                    new_context.append(ids_c[sent_id])
                    new_attn_mask.append(attn_masks_c[sent_id])
                    new_out_mask.append(output_masks_c[sent_id])
                context_dict[sent_ids[sent_id]] = [sent_ids[i] for i in ids]
        else:
            new_context.extend(ids_c)
            new_attn_mask.extend(attn_masks_c)
            new_out_mask.extend(output_masks_c)

            for sent_id in sent_ids:
                context_dict[sent_id] = [sent_id]

    assert len(context_ids) == len(new_context)

    data = [list(i) for i in
            list(zip(new_context, new_attn_mask, new_out_mask, example_keys, instances, labels))]

    # print('Batching data with gloss length = {}...'.format(args.gloss_bsz))
    batched_data = []
    sent_index, current_list = [0], []
    sent_senses = [sum([len(gloss_dict[ex_key][2]) for ex_key in sent[3]]) for sent in data]
    for index, i in enumerate(sent_senses):
        current_list.append(i)
        if sum(current_list) > gloss_bsz:
            sent_index.append(index)
            current_list = current_list[-1:]
    sent_index.append(len(sent_senses))

    for index, data_index in enumerate(sent_index[:-1]):
        b = data[data_index: sent_index[index + 1]]
        max_len_b = max([len(x[1]) for x in b])
        if context_len > 0:
            max_len = max(max_len_b, max_len)
        for b_index, sent in enumerate(b):
            if 'xlm' in encoder_name:
                b[b_index][0], b[b_index][1], b[b_index][2] = normalize_length(sent[0], sent[1], sent[2], max_len,
                                                                           tokenizer.encode(tokenizer.pad_token)[1])
            else:
                b[b_index][0], b[b_index][1], b[b_index][2] = normalize_length(sent[0], sent[1], sent[2], max_len,
                                                                           tokenizer.encode(tokenizer.pad_token)[0])

        context_ids = torch.cat([torch.cat(x, dim=-1) for x, _, _, _, _, _ in b], dim=0)[:, :max_len_b]
        context_attn_mask = torch.cat([torch.tensor(x).unsqueeze(dim=0) for _, x, _, _, _, _ in b], dim=0)[:,
                            :max_len_b]
        context_output_mask = torch.cat([torch.tensor(x).unsqueeze(dim=0) for _, _, x, _, _, _ in b], dim=0)[:,
                              :max_len_b]
        example_keys = []
        for _, _, _, x, _, _ in b: example_keys.extend(x)
        instances = []
        for _, _, _, _, x, _ in b: instances.extend(x)
        labels = []
        for _, _, _, _, _, x in b: labels.extend(x)
        batched_data.append(
            (context_ids, context_attn_mask, context_output_mask, example_keys, instances, labels))
    # context_dict包含每个句子的相关句集合，窗口大小为2，如句子7的相关句为[5,6,7,8,9]，同时加上tf-idf最大的top2
    # batch是根据gloss_bsz划分的，每个batch中所有单词的义项总数不超过gloss_bsz
    return batched_data, context_dict

In [3]:
'''
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
'''

import torch
import torch.nn as nn
from torch.nn import functional as F
import os

def mask_logits(target, mask, logit=-1e30):
    return target * mask + (1 - mask) * (logit)

def load_projection(path):
    proj_path = os.path.join(path, 'best_probe.ckpt')
    with open(proj_path, 'rb') as f: proj_layer = torch.load(f)
    return proj_layer

class PretrainedClassifier(torch.nn.Module):
    def __init__(self, num_labels, encoder_name, proj_ckpt_path):
        super(PretrainedClassifier, self).__init__()

        self.encoder, self.encoder_hdim = load_pretrained_model(encoder_name)

        if proj_ckpt_path and len(proj_ckpt_path) > 0:
            self.proj_layer = load_projection(proj_ckpt_path)
            #assert to make sure correct dims
            assert self.proj_layer.in_features == self.encoder_hdim
            assert self.proj_layer.out_features == num_labels
        else:
            self.proj_layer = torch.nn.Linear(self.encoder_hdim, num_labels)

    def forward(self, input_ids, input_mask, example_mask):
        output = self.encoder(input_ids, attention_mask=input_mask)[0]

        example_arr = []
        for i in range(output.size(0)):
            example_arr.append(process_encoder_outputs(output[i], example_mask[i], as_tensor=True))
        output = torch.cat(example_arr, dim=0)
        output = self.proj_layer(output)
        return output

class GlossEncoder(torch.nn.Module):
    def __init__(self, encoder_name, freeze_gloss, tied_encoder=None):
        super(GlossEncoder, self).__init__()

        #load pretrained model as base for context encoder and gloss encoder
        if tied_encoder:
            self.gloss_encoder = tied_encoder
            _, self.gloss_hdim = load_pretrained_model(encoder_name)
        else:
            self.gloss_encoder, self.gloss_hdim = load_pretrained_model(encoder_name)
        self.is_frozen = freeze_gloss

    def forward(self, input_ids, attn_mask):
        #encode gloss text
        if self.is_frozen:
            with torch.no_grad():
                gloss_output = self.gloss_encoder(input_ids, attention_mask=attn_mask)[0]
        else:
            gloss_output = self.gloss_encoder(input_ids, attention_mask=attn_mask)[-1][-4:]

        gloss_output = torch.cat([i.unsqueeze(0) for i in gloss_output], dim=0).mean(0)

        #training model to put all sense information on CLS token
        gloss_output = gloss_output[:,:,:].squeeze(dim=1)
        return gloss_output

class ContextEncoder(torch.nn.Module):
    def __init__(self, encoder_name, freeze_context):
        super(ContextEncoder, self).__init__()

        #load pretrained model as base for context encoder and gloss encoder
        self.context_encoder, self.context_hdim = load_pretrained_model(encoder_name)
        self.is_frozen = freeze_context

    def forward(self, input_ids, attn_mask, output_mask):
        #encode context
        if self.is_frozen:
            with torch.no_grad():
                context_output = self.context_encoder(input_ids, attention_mask=attn_mask)[0]
        else:
            context_output = self.context_encoder(input_ids, attention_mask=attn_mask)[-1][-4:]
        context_output = torch.cat([i.unsqueeze(0) for i in context_output], dim=0).mean(0)
        print('context_output1:', context_output.shape)
        #average representations over target word(s)
        example_arr = []
        for i in range(context_output.size(0)):
            example_arr.append(process_encoder_outputs(context_output[i], output_mask[i], as_tensor=True))
        context_output = torch.cat(example_arr, dim=0)
        print('context_output2:', context_output.shape)
        return context_output

class LinearAttention(nn.Module):
    def __init__(self, in_dim=300, mem_dim=300):
        # in dim, the dimension of query vector
        super().__init__()
        self.linear = nn.Linear(in_dim, mem_dim)
        self.fc = nn.Linear(in_dim, in_dim)
        self.leakyrelu = nn.LeakyReLU(1e-2)
        self.linear1 = nn.Linear(in_dim, mem_dim)
        self.linear2 = nn.Linear(in_dim, mem_dim)
        torch.nn.init.xavier_normal_(self.linear.weight.data)
        torch.nn.init.xavier_normal_(self.linear1.weight.data)
        torch.nn.init.xavier_normal_(self.linear2.weight.data)

    def forward(self, feature, aspect_v, dmask, word='word'):
        Q = self.linear(aspect_v.float())
        Q = nn.functional.normalize(Q, dim=1)

        attention_s = torch.mm(Q, Q.T)
        attention_sk = mask_logits(attention_s, dmask, 0)
        # print('attention_sk:', attention_sk.shape)

        if 'word' in word:
            new_feature = self.linear(feature.float())
            new_feature = nn.functional.normalize(new_feature, dim=2)

            feature_reshape = new_feature.reshape(new_feature.shape[0] * new_feature.shape[1], -1)
            # print('feature_reshape:', feature_reshape.shape)
            attention_ww = torch.mm(feature_reshape, feature_reshape.T)
            attention_w = torch.stack(
                torch.stack(attention_ww.split(new_feature.shape[1]), dim=0).mean(1).squeeze(1).split(new_feature.shape[1],
                                                                                                      dim=1), dim=1).mean(2)
            attention_wk = mask_logits(attention_w, dmask, 0)
            # print('attention_wk:', attention_wk.shape)
            
            att_weight = attention_sk + attention_wk
            # print('att_weight:', att_weight.shape)
        else:
            att_weight = attention_sk

        att_weight[att_weight == 0] = -1e30
        attention = F.softmax(att_weight, dim=1)

        new_out = torch.mm(attention.float(), aspect_v.float())
        # print('new_out:', new_out.shape)
        return new_out

class BiEncoderModel(torch.nn.Module):
    def __init__(self, encoder_name, freeze_gloss=False, freeze_context=False, tie_encoders=False, num_heads=6):
        super(BiEncoderModel, self).__init__()

        #tying encoders for ablation
        self.tie_encoders = tie_encoders

        #load pretrained model as base for context encoder and gloss encoder
        self.context_encoder = ContextEncoder(encoder_name, freeze_context)
        if self.tie_encoders:
            self.gloss_encoder = GlossEncoder(encoder_name, freeze_gloss, tied_encoder=self.context_encoder.context_encoder)
        else:
            self.gloss_encoder = GlossEncoder(encoder_name, freeze_gloss)
        assert self.context_encoder.context_hdim == self.gloss_encoder.gloss_hdim
        self.gat = [LinearAttention(self.gloss_encoder.gloss_hdim, self.gloss_encoder.gloss_hdim).cuda() for _ in
                    range(num_heads)]

    def context_forward(self, context_input, context_input_mask, context_example_mask):
        return self.context_encoder.forward(context_input, context_input_mask, context_example_mask)

    def gloss_forward(self, gloss_input, gloss_mask):
        return self.gloss_encoder.forward(gloss_input, gloss_mask)

    def gat_forward(self, gloss_input, gloss_mask, key_len_list, instances, pre_index, context_dict, senses=''):
        gloss_out_all = self.gloss_encoder.forward(gloss_input, gloss_mask)
        print('gloss_out_all:', gloss_out_all.shape)
        if 'sense' in 'sense-pred':
            key_len = sum(key_len_list, [])
            # print('key_len:', len(key_len), key_len)
            adjacency_mat = torch.zeros(sum(key_len), sum(key_len))
            # print('adjacency_mat:', adjacency_mat.shape, adjacency_mat)
            sense_index = [sum(key_len[:i]) for i in range(len(key_len))]
            # print('sense_index:', sense_index)
            if 'pred' in 'sense-pred':
                p_index = [pre_index.get(inst, 0) for inst in instances]
                sense_index = [sense_index[i] + p_index[i] for i in range(len(p_index))]
                # print('p_index:', p_index)
                # print('sense_index_new:', sense_index)
            doc_sent = [('.'.join(i.split('.')[:-2]), int(i.split('.')[-2][1:]), '.'.join(i.split('.')[:-1])) for i in
                        instances]
            # print('doc_sent:', doc_sent)
            adjacency_mat[:, sense_index] = 1
            # print('adjacency_mat:', adjacency_mat.shape, adjacency_mat)
            for i in range(len(instances)):
                index = []
                for s_index, sense in enumerate(sense_index):
                    if True:
                        if doc_sent[s_index][-1] not in context_dict[doc_sent[i][-1]]:
                            index.extend([i for i in range(sum(key_len[:s_index]), sum(key_len[:s_index + 1]))])
                    else:
                        if len(doc_sent[s_index][0]) > 2:
                            if doc_sent[s_index][0] != doc_sent[i][0] or abs(doc_sent[s_index][1] - doc_sent[i][1]) > 0:
                                index.extend([i for i in range(sum(key_len[:s_index]), sum(key_len[:s_index + 1]))])
                        elif abs(doc_sent[s_index][1] - doc_sent[i][1]) > 0:
                            index.extend([i for i in range(sum(key_len[:s_index]), sum(key_len[:s_index + 1]))])
                print(index)
                adjacency_mat[sum(key_len[:i]): sum(key_len[:i + 1]), index] = 0

            for k_index, j in enumerate(key_len):
                start, end = sum(key_len[:k_index]), sum(key_len[:k_index + 1])
                adjacency_mat[start: end, start: end] = 0
            # print('!!!:',torch.nonzero(adjacency_mat[0]))

            adjacency_mat_f = adjacency_mat + torch.eye(sum(key_len))
            
            # print('!!!:',torch.nonzero(adjacency_mat_f[0]))

            att_out = [att.forward(gloss_out_all[:, 1:-1, :], gloss_out_all[:, 0, :], adjacency_mat_f.cuda(),
                                   'word').unsqueeze(1) for att in self.gat]
            att_out = torch.cat(att_out, dim=1)
            # print('att_out1:', att_out.shape)
            att_out = att_out.mean(dim=1)  # (N, D)min(31, gloss_out_all.shape[1]-1)
            print('att_out:', att_out.shape)
            assert len(gloss_out_all) == len(att_out)
            return att_out
        else:
            return gloss_out_all[:, 0, :]

In [4]:
wn_path = os.path.join('./WSD_Evaluation_Framework/', 'Data_Validation/candidatesWN30.txt')
# 获取一个字典，key是lemma+pos，value是该词的所有senses，senses表示方法为wordnet_key
wn_senses =  load_wn_senses(wn_path)
wn_senses["pen+n"]

['pen%1:06:00::',
 'pen%1:06:01::',
 'pen%1:06:03::',
 'pen%1:06:02::',
 'pen%1:05:00::']

In [5]:
len(wn_senses)

155287

In [19]:
data = load_data('./WSD_Evaluation_Framework/Training_Corpora/SemCor', 'semcor', 1000000)

In [7]:
encoder_name = 'bert-base'
tokenizer = load_tokenizer(encoder_name)

In [8]:
model = BiEncoderModel('bert-base').cuda()

In [9]:
gloss_dict = load_and_preprocess_glosses(data, tokenizer, wn_senses, 32)

In [21]:
data[0]

[('How', 'how', 'ADV', -1, -1),
 ('long', 'long', 'ADJ', 'd000.s000.t000', 'long%3:00:02::'),
 ('has', 'have', 'VERB', -1, -1),
 ('it', 'it', 'PRON', -1, -1),
 ('been', 'be', 'VERB', 'd000.s000.t001', 'be%2:42:03::'),
 ('since', 'since', 'ADP', -1, -1),
 ('you', 'you', 'PRON', -1, -1),
 ('reviewed', 'review', 'VERB', 'd000.s000.t002', 'review%2:31:00::'),
 ('the', 'the', 'DET', -1, -1),
 ('objectives', 'objective', 'NOUN', 'd000.s000.t003', 'objective%1:09:00::'),
 ('of', 'of', 'ADP', -1, -1),
 ('your', 'you', 'PRON', -1, -1),
 ('benefit', 'benefit', 'NOUN', 'd000.s000.t004', 'benefit%1:21:00::'),
 ('and', 'and', 'CONJ', -1, -1),
 ('service', 'service', 'NOUN', 'd000.s000.t005', 'service%1:04:07::'),
 ('program', 'program', 'NOUN', 'd000.s000.t006', 'program%1:09:01::'),
 ('?', '?', '.', -1, -1)]

In [148]:
a, b = preprocess_context(encoder_name, 4, 400, 'all', tokenizer, data, gloss_dict, bsz=4, max_len=128)

In [146]:
len(b['d000.s000']),len(b['d000.s001']),len(b['d000.s002']),len(b['d000.s003'])

(7, 8, 9, 8)

In [149]:
b['d000.s000']

['d000.s000',
 'd000.s001',
 'd000.s002',
 'd000.s003',
 'd000.s004',
 'd000.s034',
 'd000.s064',
 'd000.s065',
 'd000.s072']

In [122]:
a[1][0].shape

torch.Size([8, 512])

In [26]:
context_ids, context_attn_mask, context_output_mask, example_keys, instances, labels = a[0]

In [114]:
tokenized_tensor = context_ids[0].tolist()
tokenized_tensor1 = context_ids[1].tolist()
# 使用BertTokenizer将tokenized的tensor还原为token
restored_tokens = tokenizer.convert_ids_to_tokens(tokenized_tensor)
restored_tokens1 = tokenizer.convert_ids_to_tokens(tokenized_tensor1)
# 将token还原为原始文本
restored_text = tokenizer.convert_tokens_to_string(restored_tokens)
restored_text1 = tokenizer.convert_tokens_to_string(restored_tokens1)

# 打印恢复的原文
print("恢复的原文:", restored_text)
print("恢复的原文:", restored_text1)

恢复的原文: [CLS] how long has it been since you reviewed the objectives of your benefit and service program ? [SEP] [CLS] have you permitted it to become a giveaway program rather than one that has the goal of improved employee morale and , consequently , increased productivity ? [SEP] [CLS] what effort do you make to assess results of your program ? [SEP] [CLS] do you measure its relation to reduced absenteeism , turnover , accidents , and grievances , and to improved quality and output ? [SEP] [CLS] have you set specific objectives for your employee publication ? [SEP] [CLS] are your expenses in this area commensurate with the number of employees who benefit from your program ? [SEP] [CLS] if you have an annual or regular physical examination program , is it worth what it is costing you ? [SEP] [CLS] consider what you can afford to spend and what your goals are before setting up or revamping your employee benefit program . [SEP] [CLS] these factors can make the difference between waste a

In [118]:
context_ids[1],context_attn_mask[1], context_output_mask[1]

(tensor([  101,  1293,  1263,  1144,  1122,  1151,  1290,  1128,  7815,  1103,
         11350,  1104,  1240,  5257,  1105,  1555,  1788,   136,   102,   101,
          1138,  1128,  7485,  1122,  1106,  1561,   170,  1660,  7138,  1788,
          1897,  1190,  1141,  1115,  1144,  1103,  2273,  1104,  4725,  7775,
         22407,  1105,   117, 14007,   117,  2569, 18222,   136,   102,   101,
          1184,  3098,  1202,  1128,  1294,  1106, 15187,  2686,  1104,  1240,
          1788,   136,   102,   101,  1202,  1128,  4929,  1157,  6796,  1106,
          3549, 10040,  3051,  1863,   117, 23804,   117, 14705,   117,  1105,
           176,  5997, 24043,  1116,   117,  1105,  1106,  4725,  3068,  1105,
          5964,   136,   102,   101,  1138,  1128,  1383,  2747, 11350,  1111,
          1240,  7775,  4128,   136,   102,   101,  1110,  1122,  3634,  1292,
          2513,   136,   102,   101,  1132,  1240, 11928,  1107,  1142,  1298,
          3254,  2354,  6385,  5498,  1114,  1103,  

In [115]:
b

{'d000.s000': ['d000.s000',
  'd000.s001',
  'd000.s002',
  'd000.s003',
  'd000.s004',
  'd000.s034',
  'd000.s064',
  'd000.s065',
  'd000.s072'],
 'd000.s001': ['d000.s000',
  'd000.s001',
  'd000.s002',
  'd000.s003',
  'd000.s004',
  'd000.s005',
  'd000.s034',
  'd000.s064',
  'd000.s065'],
 'd000.s002': ['d000.s000',
  'd000.s001',
  'd000.s002',
  'd000.s003',
  'd000.s004',
  'd000.s005',
  'd000.s006',
  'd000.s034',
  'd000.s042',
  'd000.s056',
  'd000.s072'],
 'd000.s003': ['d000.s000',
  'd000.s001',
  'd000.s002',
  'd000.s003',
  'd000.s004',
  'd000.s005',
  'd000.s006',
  'd000.s007',
  'd000.s126'],
 'd000.s004': ['d000.s000',
  'd000.s001',
  'd000.s002',
  'd000.s003',
  'd000.s004',
  'd000.s005',
  'd000.s006',
  'd000.s007',
  'd000.s008',
  'd000.s011',
  'd000.s014',
  'd000.s080'],
 'd000.s005': ['d000.s000',
  'd000.s001',
  'd000.s002',
  'd000.s003',
  'd000.s004',
  'd000.s005',
  'd000.s006',
  'd000.s007',
  'd000.s008',
  'd000.s009',
  'd000.s065'],
 

In [None]:
len(context_ids[2]), context_ids[2]

In [None]:
# tokenize时，一些词可能会产生多个token，如tokenizer.encode('absenteeism'.lower()) == [10040, 3051, 1863]
context_output_mask[3],len(context_output_mask[3])

In [38]:
labels[3], example_keys[3], gloss_dict[example_keys[3]][2]

('objective%1:09:00::',
 'objective+n',
 ['objective%1:09:00::', 'objective%1:06:00::'])

In [39]:
sent_id, sent_seg = [], []
key_len_list = []
for in_index, inst in enumerate(instances):
    s_id = '.'.join(inst.split('.')[:-1])
    if s_id not in sent_id:
        sent_id.append(s_id)
        sent_seg.append(in_index)
sent_seg.append(len(instances))
print(sent_seg)
# gloss_dict: [0] definition_tokenize, [1] definition_mask, [2] sense_keys
for seg_index, seg in enumerate(sent_seg[:-1]):
    key_len_list.append([len(gloss_dict[key][2]) for key in example_keys[seg:sent_seg[seg_index + 1]]])
print(key_len_list)
total_sense = sum(sum(key_len_list, []))
print(total_sense)

[0, 7, 20, 25, 34, 39, 41, 46, 52, 58, 67]
[[9, 13, 5, 2, 3, 15, 8], [3, 4, 3, 8, 4, 19, 4, 3, 1, 2, 2, 1, 2], [4, 49, 4, 4, 8], [4, 2, 1, 4, 2, 3, 3, 5, 5], [25, 4, 2, 1, 4], [9, 4], [13, 1, 1, 4, 3], [6, 3, 3, 4, 2, 7], [13, 4, 4, 3, 2, 3], [4, 1, 3, 3, 4, 9, 4, 7, 4]]
373


In [48]:
train_index = {}
key_mat = dict()
loss = 0.
gloss_sz = 0
context_sz = 0
model.eval()
with torch.no_grad():
    context_ids = context_ids.cuda()
    context_attn_mask = context_attn_mask.cuda()
    # 输出一个batch中，所有多义词对应的embeddings
    context_output = model.context_forward(context_ids, context_attn_mask, context_output_mask)
    print('context_output:',context_output.shape)
    max_len_gloss = max(
            sum([[torch.sum(mask_list).item() for mask_list in gloss_dict[key][1]] for key in example_keys],
                []))
    gloss_ids_all = torch.cat([gloss_dict[key][0][:, :max_len_gloss] for key in example_keys])
    gloss_attn_mask_all = torch.cat([gloss_dict[key][1][:, :max_len_gloss] for key in example_keys])
    print(gloss_ids_all.shape)
    print(gloss_attn_mask_all.shape)
    gloss_ids = gloss_ids_all.cuda()
    gloss_attn_mask = gloss_attn_mask_all.cuda()
    gat_out_all = model.gat_forward(gloss_ids, gloss_attn_mask, key_len_list, instances, train_index, b)
    
    for seg_index, seg in enumerate(sent_seg[:-1]):
        current_example_keys = example_keys[seg: sent_seg[seg_index + 1]]
        print('current_example_keys:',current_example_keys)
        current_key_len = key_len_list[seg_index]
        print('current_key_len:',current_key_len)
        current_context_output = context_output[seg: sent_seg[seg_index + 1], :]
        print('current_context_output:',current_context_output.shape)
        current_insts = instances[seg: sent_seg[seg_index + 1]]
        print('current_insts:',current_insts)
        current_labels = labels[seg: sent_seg[seg_index + 1]]
        print('current_labels:',current_labels)
        gat_out = gat_out_all[
                        sum(sum(key_len_list[:seg_index], [])): sum(sum(key_len_list[:seg_index + 1], [])),
                        :]
        print('gat_out:',gat_out.shape) # 55x768
        c_senses = sum([gloss_dict[key][2] for key in current_example_keys], [])
        print('c_senses:',c_senses)
        gat_cpu = gat_out.cpu()
        for k_index, key in enumerate(c_senses):
            key_mat[key] = gat_cpu[k_index:k_index + 1]
        
        gloss_output_pad = torch.cat([F.pad(
                gat_out[sum(current_key_len[:i]): sum(current_key_len[:i+1]), :],
                pad=[0, 0, 0, max(current_key_len) - j]).unsqueeze(0) for i, j in enumerate(current_key_len)], dim=0)
        print('gloss_output_pad:',gloss_output_pad.shape) # gloss_output_pad: torch.Size([7, 15, 768])
        out = torch.bmm(gloss_output_pad, current_context_output.unsqueeze(2)).squeeze(2)
        print('out:',out.shape)
        gloss_sz += gat_out.size(0)
        context_sz += 1
        
        # for j, (key, label) in enumerate(zip(current_example_keys, current_labels)):
        #     idx = gloss_dict[key][2].index(label)
        #     label_tensor = torch.tensor([idx]).cuda()
        #     train_index[current_insts[j]] = out[j:j + 1, :current_key_len[j]].argmax(dim=1).item()
        #     loss += F.cross_entropy(out[j:j + 1, :current_key_len[j]], label_tensor)
        #     all_instance += 1
        #     if out[j:j + 1, :current_key_len[j]].argmax(dim=1).item() == idx: 
        #         pre_instance += 1
        #     if idx == 0:
        #         mfs_instance += 1
        
        break

context_output1: torch.Size([10, 260, 768])
context_output2: torch.Size([67, 768])
context_output: torch.Size([67, 768])
torch.Size([373, 32])
torch.Size([373, 32])
gloss_out_all: torch.Size([373, 32, 768])
[245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372]
[245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 2

In [None]:
b

In [51]:
context_ids

tensor([[ 101, 1293, 1263,  ...,    0,    0,    0],
        [ 101, 1293, 1263,  ...,    0,    0,    0],
        [ 101, 1293, 1263,  ...,    0,    0,    0],
        ...,
        [ 101, 1202, 1128,  ...,    0,    0,    0],
        [ 101, 1138, 1128,  ..., 5913,  119,  102],
        [ 101, 1110, 1122,  ...,    0,    0,    0]], device='cuda:0')

In [50]:
def create_hypergraph_matrix(context_ids, context_attn_mask):
    # 找出所有不重复的单词
    unique_tokens = torch.unique(context_ids)
    
    # 移除101, 102和0
    unique_tokens = unique_tokens[~torch.isin(unique_tokens, torch.tensor([0, 101, 102]))]
    
    # 创建一个映射，将unique token映射到它的位置
    token_to_idx = {token.item(): idx for idx, token in enumerate(unique_tokens)}
    
    num_sentences, _ = context_ids.shape
    num_unique_tokens = len(unique_tokens)
    
    # 初始化超图关联矩阵为0
    H = torch.zeros(num_sentences, num_unique_tokens)
    
    for i in range(num_sentences):
        # 获取句子的有效tokens（由attention mask确定）
        valid_tokens = context_ids[i, context_attn_mask[i].bool()]
        
        # 对于每个有效token，标记其在关联矩阵中的位置为1（只有当token在unique_tokens中时）
        for token in valid_tokens:
            if token.item() in token_to_idx:  # 这里增加了一个条件判断
                H[i, token_to_idx[token.item()]] = 1
    
    return H,token_to_idx

# # 示例
# context_ids = torch.randint(0, 30000, (10, 260))  # 随机的token ids
# context_attn_mask = torch.ones((10, 260))  # 假设所有位置都是有效的

H,token_to_idx = create_hypergraph_matrix(context_ids.cpu(), context_attn_mask.cpu())
print(H.shape)

torch.Size([10, 303])


In [55]:
token_to_idx

{112: 0,
 117: 1,
 118: 2,
 119: 3,
 136: 4,
 169: 5,
 170: 6,
 175: 7,
 176: 8,
 188: 9,
 1103: 10,
 1104: 11,
 1105: 12,
 1106: 13,
 1107: 14,
 1108: 15,
 1110: 16,
 1111: 17,
 1112: 18,
 1113: 19,
 1114: 20,
 1115: 21,
 1116: 22,
 1118: 23,
 1121: 24,
 1122: 25,
 1126: 26,
 1128: 27,
 1129: 28,
 1132: 29,
 1133: 30,
 1134: 31,
 1137: 32,
 1138: 33,
 1141: 34,
 1142: 35,
 1144: 36,
 1146: 37,
 1147: 38,
 1149: 39,
 1150: 40,
 1151: 41,
 1152: 42,
 1155: 43,
 1156: 44,
 1157: 45,
 1158: 46,
 1159: 47,
 1165: 48,
 1168: 49,
 1169: 50,
 1170: 51,
 1172: 52,
 1175: 53,
 1180: 54,
 1183: 55,
 1184: 56,
 1185: 57,
 1190: 58,
 1191: 59,
 1194: 60,
 1196: 61,
 1202: 62,
 1206: 63,
 1207: 64,
 1209: 65,
 1211: 66,
 1216: 67,
 1218: 68,
 1231: 69,
 1240: 70,
 1250: 71,
 1251: 72,
 1256: 73,
 1260: 74,
 1263: 75,
 1269: 76,
 1277: 77,
 1290: 78,
 1292: 79,
 1293: 80,
 1294: 81,
 1295: 82,
 1298: 83,
 1315: 84,
 1321: 85,
 1329: 86,
 1343: 87,
 1383: 88,
 1385: 89,
 1388: 90,
 1389: 91,
 1396: 9