In [None]:
import flatten_lattice as fl
import torch
from bert_models import LinearPOSBertV1
from encoding_utils import *
from transformers import AutoTokenizer
import pickle
from mask_utils import *
import json
import os
import numpy as np
from distill_comet import XLMCometEmbeds, XLMCometRegressor

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
xlm_tok = fl.bert_tok


In [None]:
#V12 first attempt with inputs prepended
#V14 final pipeline first attempt
VNUM = 14
MOD_NAME = 'bertonewayv1.pth'

# specifies files for pre-loading
LOADED = {
    'amasks': 'attmasksallv'+str(VNUM)+'.pt',
    'tmaps': 'tmapsmaskedv'+str(VNUM)+'/'
}

In [None]:
def prepend_input(pgraph, inp):
    
    inptoks = xlm_tok(inp).input_ids
    # add in the <s> token
    # we don't need the </s> token by how the mask works
    inptoks.append(0)
    posadd = len(inptoks)
    inpflat = []
    ind = 0
    for i in range(len(inptoks)):
        nl = []
        inp = inptoks[i]
        if i<(len(inptoks)-1):
            nl.append(str(inptoks[i+1])+" "+str(ind+1))
        inpflat.append({
            'token_idx':inp, 
            'pos':ind,
            'id': str(inp)+" "+str(ind),
            'nexts':nl,
            'score':0,
        })
        ind+=1
    inpflat[-1]['nexts'].append(pgraph[0]['id'].split()[0]+" "+str(posadd))
    
    inpflat.extend(pgraph)
    for i in range(posadd, len(inpflat)):
        extok = inpflat[i]
        extok['pos']+=posadd
        extok['id']= str(extok['token_idx'])+" "+str(extok['pos'])
        for j in range(len(extok['nexts'])):
            newpos = int(extok['nexts'][j].split()[1])+posadd
            extok['nexts'][j] = extok['nexts'][j].split()[0]+" "+str(newpos)
    return inpflat, posadd

In [None]:
STOPS = -1

# Get examples (just use the normal lattice examples ig?)
processedgraphs, inps, refs = fl.get_processed_graph_data(fl.frenbase, -1, STOPS)

# get exploded candidates to generate gold labels
resarrs = [fl.get_cover_paths(p)[0] for p in processedgraphs]

# extra step for greedy 
if STOPS==1:
    processedgraphs = filter_greedy(processedgraphs)
    

# ensure no empty examples
clean_empty(resarrs, processedgraphs)

ppinput = [prepend_input(processedgraphs[i], inps[i]) for i in range(len(processedgraphs))]
processedgraphs = [p[0] for p in ppinput]
posadds = [p[1] for p in ppinput]


In [None]:
def causal_mask (pgraph, padd):
    start = connect_mat(processedgraphs[0])
    start[:, :padd] = 1
    start[:padd, padd:] = 0 
    start[padd:, padd:] = torch.tril(start[padd:, padd:])
    return start

def get_allamasks():
    attmasks = []
    for i in range(len(posadds)):
        if i%10==0:
            print(i)
        attmasks.append(causal_mask(processedgraphs[i], posadds[i]-1))
    return attmasks

In [None]:
processedgraphs[0]

In [None]:
def get_validnext (pos, nlist):
    retval = ""
    for n in nlist:
        if pos< int(n.split()[1]):
            retval = n
            if "2 " not in n:
                return retval
    if len(retval)>0:
        return retval
    print("no valid")
    print(pos)
    print(nlist)
    return ""

def p_wnext(pgraph):
    nid = '0 0'
    for tokd in pgraph:
        if tokd['id']==nid:
            print(xlm_tok.decode(tokd['token_idx']))
            nid = get_validnext(tokd['pos'], tokd['nexts'])
            
            #print(tokd)
    print(nid)
            
p_wnext(processedgraphs[0])

In [None]:
# Attention mask code (TODO needs some updating)
if os.path.exists('./torchsaved/'+LOADED['amasks']):
        print("using loaded masks")
        attmasks = torch.load('./torchsaved/'+LOADED['amasks']).to(device)
else:
    print("creating new masks")
    masktmp = get_allamasks()
    attmasks = torch.stack(masktmp).to(device)
    torch.save(attmasks, './torchsaved/'+LOADED['amasks'])

In [None]:
# get tokenized inputs with posids (TODO needs an update for src/tgt format)
sents, posids = create_inputs(processedgraphs)

In [None]:
sents.to(device)
posids.to(device)
attmasks.to(device)

In [None]:
import torch.nn as nn
from transformers import AutoModel
# Returns Token Scores For Each (Sum becomes Regression)
class XLMCometEmbeds(nn.Module):
    
    def __init__(self, drop_rate=0.1):
        # TODO should we be freezing layers?
        super().__init__()
        
        self.xlmroberta = AutoModel.from_pretrained('xlm-roberta-base')
        # Num labels 1 should just indicate regression (?)
        self.regressor = nn.Sequential(
            nn.Dropout(drop_rate),
            nn.Linear(self.xlmroberta.config.hidden_size, 1), 
        )
        self.to(device)
        
    def forward(self, input_ids, positions, attention_masks):
        # don't finetune xlmroberta model
        #with torch.no_grad():
        word_rep, sentence_rep = self.xlmroberta(input_ids, position_ids = positions, attention_mask=attention_masks, encoder_attention_mask=attention_masks, return_dict=False)
        # use the first <s> token as a CLS token, TODO experiment with using the sum of 
        # ensure padding not factored in
        #word_rep = word_rep*(input_ids>0).unsqueeze(-1)
        res = []
        for w in word_rep:
            res.append(self.regressor(w))
            
        word_rep = word_rep*(input_ids>0).unsqueeze(-1)
        #outputs = self.regressor(torch.sum(word_rep, 1))
        #print("Shape: ", outputs.shape)
        return word_rep


In [None]:
del model

In [None]:
model = XLMCometEmbeds(drop_rate=0.1)
model.load_state_dict(torch.load("./torchsaved/maskedcont3.pt"))
model.eval()

In [None]:
model.regressor[1].weight

In [None]:
model.regressor[1].bias

In [None]:
with torch.no_grad():
    outs = model(sents, posids, attmasks)

In [None]:
scores = torch.inner(outs, model.regressor[1].weight).squeeze(-1)

torch.Size([500])

In [125]:
# set scores computed for each token by the model
def set_pgscores(pgraphs, scores):
    #idlist = get_idlist(pgraph)
    for p in range(len(pgraphs)):
        pgraph = pgraphs[p]
        for i in range(min(len(pgraph), 500)):
            pgraph[i]['score'] = scores[p][i]
            if pgraph[i]['token_idx']>0:
                if pgraph[i]['score']==0:
                    print(i)
                    print(p)
    return pgraphs

# topological sort the graphs, make sure that nodes that are next always come next in the list
def topo_sort_pgraph(pgraph):
    # reverse ordering
    pgraph.reverse()
    # for all tokens
    i = 0
    while i < min(len(pgraph), 500):
        ns = pgraph[i]
        # check if any tokens that come after actually should be before
        for j in range(i, min(len(pgraph), 500)):
            # if so, re-insert right before in list
            if pgraph[j]['id'] in ns:
                tmp = pgraph[j]
                del pgraph[j]
                pgraph.insert(i, tmp)
                i+=1
        i+=1
        
def prepare_pgraphs(pgraphs, scores):
    res = []
    # make a deep copy of processed graphs
    for p in pgraphs:
        res.append([x for x in p])
    # set scores for stuff
    set_pgscores(res, scores)
    # do topological sorting
    for r in res:
        topo_sort_pgraph(r)
    return res
    
    

prepared_pgraphs = prepare_pgraphs(processedgraphs, scores)    

498
55
499
55
498
57
499
57
498
82
499
82
498
90
499
90


In [170]:
# given a list of sub-scores (topological flattening of the graph), use dp to get the highest scoring path
# idlist has the corresponding graph ids for 
# would need to do a sort on pgrapaps that makes sure that no next node is before in the linear ordering
# reverse since we're using nexts
# TODO simplify code to not need so many data structures
def dp_best_path(pgraphs, graph):
    bplist = []
    bsco_list =[]
    idlist = get_idlist(pgraphs)
    for i in range(len(idlist)):
        bpath = []
        cur = pgraphs[i]
            
        # get the highest prev from ahead to use
        mval = -10
        maxnext = None
        for n in cur['nexts']:
            try:
                if graph[n]['bestsco']>mval:
                    mval = graph[n]['bestsco']
                    maxnext = graph[n]
            except:
                ""

        # add in scores / path from that prev
        if maxnext==None:
            bpath.append(i)
            bplist.append(bpath)
            bsco_list.append(cur['score'])
            # check if this is how things work in python
            graph[cur['id']]['bestsco'] = cur['score']
            graph[cur['id']]['plist'] = bpath
            continue
        bpath.extend(maxnext['plist']+[i])
        bplist.append(bpath)
        bsco_list.append(cur['score']+mval)
        graph[cur['id']]['bestsco'] = cur['score']+mval
        graph[cur['id']]['plist'] = bpath
        #print(bpath)
    return bplist[-1], bsco_list[-1]

def get_idlist(pgraph):
    return [p['id'] for p in pgraph]

def dp_pgraph(pgraph):
    graph = {}
    for p in pgraph:
        # TODO check if scores are negative number compatible
        p['bestsco'] = 0
        p['plist'] = []
        graph[p['id']] = p
    bestpath, bso = dp_best_path(pgraph, graph)
    print(bsco)
    bestpath.reverse()
    return [pgraph[x]['token_idx'] for x in bestpath]
        
    
bp = dp_pgraph(prepared_pgraphs[0])

tensor(0.1371, device='cuda:1', grad_fn=<AddBackward0>)


In [183]:
examplenum = 21
print(xlm_tok.decode(dp_pgraph(prepared_pgraphs[examplenum])))
print(refs[examplenum])

tensor(0.1371, device='cuda:1', grad_fn=<AddBackward0>)
<s> Mais le coût du nouveau vaccin devrait être bien inférieur car il transforme les cellules du foie en usines à anticorps.</s><s> But the new vaccine should cost much less because it turns liver cells into antibody factories.
But the cost of the new vaccine is likely to be far lower, because it turns liver cells into antibody factories.


In [160]:
cnt = 0
for t in processedgraphs[0]:
    if t['token_idx']==0:
        print(t)
cnt

{'token_idx': 0, 'pos': 0, 'id': '0 0', 'nexts': ['636 1'], 'score': tensor(0., device='cuda:1', grad_fn=<SelectBackward0>), 'bestsco': tensor(0., device='cuda:1', grad_fn=<SelectBackward0>), 'plist': [364]}
{'token_idx': 0, 'pos': 49, 'id': '0 49', 'nexts': ['581 50'], 'score': tensor(0., device='cuda:1', grad_fn=<SelectBackward0>), 'bestsco': tensor(0., device='cuda:1', grad_fn=<SelectBackward0>), 'plist': [315]}


0

In [None]:
bert_tok.decode(0)

In [None]:
processedgraphs[0][0]

In [None]:
bert_tok("<s>")

In [None]:
' </s> en_XX The US President was to receive Iraq i Prime Minister No uri Al Malik i Friday , November 1 , 2013 in an effort to se ek US assistance in fighting the worst wa ve of violence in five years'