In [1]:
from src.recom_search.model.beam_node_reverse import ReverseNode
from transformers import AutoTokenizer, AutoModel

import flatten_lattice as fl
import torch
from bert_models import LinearLatticeBert, LinearPOSBert
from encoding_utils import *
import pickle
import toy_helper as thelp

import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, AutoConfig
from latmask_bert_models import LatticeBertModel
import json


device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')

from mask_utils import *
from encoding_utils import *


bert_tok = AutoTokenizer.from_pretrained("bert-base-cased")
mbart_tok = AutoTokenizer.from_pretrained("facebook/mbart-large-50-many-to-one-mmt")

2022-08-31 07:33:07.956023: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-08-31 07:33:07.956046: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [2]:
# Model Wrapper
class LinearPOSBertV1(nn.Module):
    def __init__(self, num_labels):
        super().__init__()
        self.bert = LatticeBertModel(AutoConfig.from_pretrained('bert-base-cased'))
        self.probe = nn.Linear(self.bert.config.hidden_size, num_labels)
        self.to(device)

    def parameters(self):
        return self.probe.parameters()
  
    def forward(self, sentences, pos_ids=None, attmasks=None):
        with torch.no_grad(): # no training of BERT parameters
            if pos_ids==None:
                word_rep, sentence_rep = self.bert(sentences, return_dict=False)
            else:
                word_rep, sentence_rep = self.bert(sentences, position_ids=pos_ids, encoder_attention_mask=attmasks, attention_mask=attmasks, return_dict=False)
        return self.probe(word_rep)
    
def prepare_dataset(resset):
    x = []
    y = []
    for res in resset:
        curinps = []
        for r in res:
            try:
                toktmp = torch.tensor(bert_tok(clean_expanded(r)).input_ids)
                #print(toktmp.shape)
                if float(toktmp.shape[0])<MAX_LEN:
                    toktmp = torch.cat([toktmp, torch.zeros(MAX_LEN-toktmp.shape[0])])
                else:
                    toktmp = toktmp[:MAX_LEN]
                curinps.append(toktmp)
            except:
                print("weird error happened") 
        print(len(curinps))
        curouts = []
        tinp = torch.stack(curinps).long().to(device)
        print(tinp.shape)
        y.append(posbmodel(tinp))
        x.append(tinp)
        
        #print("error somewhere")
    return x, y

def check_accuracy(setpred, setlabels):
    cor = 0
    tot = 0
    for i in range(0, len(setpred)):
        ex = setpred[i]
        for j in range(0, len(ex)):
            if sum(setlabels[i][j])==0:
                continue
            elif torch.argmax(setlabels[i][j])==0:
                continue
            tot+=1
            if torch.argmax(ex[j])==torch.argmax(setlabels[i][j]):
                cor+=1
    return cor/tot

# correct posids
def mod_posids(pids):
    cop = pids
    for p in cop:
        for i in range(0, len(p)):
            if p[i]==0:
                p[i] = i
    return cop

# set posids to default
def def_posids(pids):
    cop = pids
    for p in cop:
        for i in range(0, len(p)):
            p[i] = i
    return cop

# Load POS model, label vocabulary 
with open('./a3distrib/lab_vocab.json') as json_file:
    labels = json.load(json_file)
posbmodel = LinearPOSBertV1(len(list(labels.keys())))
tmp = torch.load("./a3distrib/ckpt/posbert.pth")
posbmodel.load_state_dict(tmp)
posbmodel.eval()
torch.cuda.empty_cache()
print(torch.cuda.memory_allocated("cuda:2"))

AttributeError: 'collections.OrderedDict' object has no attribute 'to'

In [None]:
# get 2 input strings of the format where the start w/ the same pre-fix but have different endings
s1 = "The Fed raises interest rates"
s2 = "The Fed raises interest him"

# construct data structure for toy graph in format used on actual examples
toygraph = thelp.create_toy_graph(s2, s1, mbart_tok)

# get list of exploded candidates using same algorithm from numbers
exploded = fl.get_all_possible_candidates(toygraph)

# get a flattened version of toy lattice (same method as on actual examples)
flat_toy = fl.flatten_lattice(toygraph)

# generate mask (uses same method as actual examples), convert to -inf mask (seems to not do anything)
mask = connect_mat(flat_toy)
#mask[mask==0] = -float('inf')

In [None]:

# get gold labels for the exploded set
dsetx, dsety = prepare_dataset([exploded])

assert len(dsetx)==1

# from encoding utils, get posids and relevant tokens
sents, posids = create_inputs([flat_toy])

# get gold label dictionaries for tokens in example, based on averages of tokens on dsety
_ , tmaps = lattice_pos_goldlabels(dsetx, dsety, sents)

# generate gold y labels using tmaps and 
latposylabels = tmap_pos_goldlabels(tmaps, sents)

# get generated labels for flattened lattice, def_posids can be used for default posids
# params start as (sents.to(device), mod_posids(posids).to(device), torch.stack([mask]).to(device))
# posids, mask can be set to None to ablate to default
pred = posbmodel(sents.to(device), mod_posids(posids).to(device), torch.stack([mask]).to(device))


In [None]:
# accuracy (assumes that gold is good, which isn't confirmed here)
check_accuracy(pred, latposylabels)

In [None]:
lablist = [k for k in labels.keys()]
CUTOFF = 10

def show_labels (pred):
    res = []
    for p in pred:
        res.append(lablist[torch.argmax(p)])
    return res

# sanity check to look at flat lattice 
p = flat_toy
tlist = fl.get_toklist(p)
decstr = bert_tok.decode(tlist)

# number of tokens, the tokens that are passed into model for lattice
print(len(tlist))
print(decstr)

print("PREDICTED")
print(show_labels(pred[0])[:CUTOFF])
print("GOLD")
print(show_labels(latposylabels[0])[:CUTOFF])


In [None]:
indivlabs = posbmodel(dsetx[0])

# show labels for s1, s2 when run through individually
print(s1)
print(show_labels(indivlabs[0])[:8])
print(s2)
print(show_labels(indivlabs[1])[:8])