In [1]:
import pandas as pd
import random
from statistics import mean
from transformers import T5Tokenizer
from src.tfr_decoding.pairwise_modeling import T5BinaryClassifier, validate
import numpy as np
import torch
from itertools import combinations
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, kendalltau
from collections import defaultdict

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")

In [None]:
pfmod_path = "./lightning_logs/multimodel/checkpoints/epoch=3-step=86946.ckpt"
pfname = 'stanfordnlp/SteamSHP-flan-t5-large'
# get prefix model
qpref = T5BinaryClassifier.load_from_checkpoint(pfmod_path).to("cuda:1")
preftok = T5Tokenizer.from_pretrained(pfname)

In [4]:
def make_toklen_lists(indf, tok):
    res = []
    for h in indf.hyps:
        lens = [len(tok(hyp).input_ids) for hyp in h]
        res.append(lens)
    return res

def make_prefs(indf, tok, pflen):
    res = []
    for h in indf.hyps:
        prefs = [tok.decode(tok(hyp).input_ids[:pflen], skip_special_tokens=True) for hyp in h]
        res.append(prefs)
    return res

def adaptbase(inpdf, thresh, mrange):
    inds = list(range(8))
    fscos = []
    budgets = []
    for ind in range(len(inpdf)):
        slist = inpdf.scos[ind]
        blist = inpdf.toklens[ind]
        random.shuffle(inds)
        tmp = []
        budget = 0
        # get adaptive score
        for i in inds[:mrange]:
            budget = budget+blist[i]
            if slist[i]>thresh:
                fscos.append(slist[i])
                break
            tmp.append(slist[i])
            if i==inds[mrange-1]:
                fscos.append(max(tmp))
        budgets.append(budget)
    return mean(fscos), mean(budgets)

def samprer(inpdf, mrange):
    inds = list(range(8))
    fscos = []
    budgets = []
    for ind in range(len(inpdf)):
        slist = inpdf.scos[ind]
        blist = inpdf.toklens[ind]
        random.shuffle(inds)
        tmp = []
        budget = 0
        fscos.append(max([slist[i] for i in inds[:mrange]]))
        budgets.append(sum([blist[i] for i in inds[:mrange]]))
        
    return mean(fscos), mean(budgets)

def adaptive_prefsort(inpdf, thresh, mrange, pf, hstop):
    inds = list(range(8))
    fscos = []
    budgets = []
    for ind in range(len(inpdf)):
        # score list and list of budgets
        slist = inpdf.scos[ind]
        blist = inpdf.toklens[ind]
        # list with prefix metric scores (specifically final class)
        plist = [a[-1] for a in inpdf["probs"+str(pf)][ind]]
        # mix up what order we get stuff in
        random.shuffle(inds)
        tmp = []
        # get prob values
        nplist = [plist[p] for p in inds[:mrange]]
        
        # get indices to use
        sortps = [inds[pl] for pl in np.argsort(nplist)]
        sortps.reverse()
        # if ind==0:
        #     print(plist)
        #     print(slist)
        #     print(sortps)
        budget = pf*mrange
        # get adaptive score
        for i in sortps[:hstop]:
            budget = budget+blist[i]
            budget = budget-pf # already part of budget, remove
            if slist[i]>thresh:
                fscos.append(slist[i])
                break
            tmp.append(slist[i])
            if i==sortps[hstop-1]:
                fscos.append(max(tmp))
        budgets.append(budget)
    return mean(fscos), mean(budgets)

def adaptive_ranksort(inpdf, thresh, mrange, pf, hstop):
    inds = list(range(8))
    fscos = []
    budgets = []
    for ind in range(len(inpdf)):
        # score list and list of budgets
        slist = inpdf.scos[ind]
        blist = inpdf.toklens[ind]
        # list with prefix metric scores (specifically final class)
        pdict = inpdf['pairdict'+str(pf)][ind]
        # mix up what order we get stuff in
        random.shuffle(inds)
        tmp = []
        # get prob values
        sortps, sortcnts = rank_pairwise(pdict, inds[:mrange])
        # if ind==0:
        #     print(plist)
        #     print(slist)
        #     print(sortps)
        budget = pf*mrange
        # get adaptive score
        for i in sortps[:hstop]:
            budget = budget+blist[i]
            budget = budget-pf # already part of budget, remove
            if slist[i]>thresh:
                fscos.append(slist[i])
                break
            tmp.append(slist[i])
            if i==sortps[hstop-1]:
                fscos.append(max(tmp))
        budgets.append(budget)
    return mean(fscos), mean(budgets)


In [None]:
np.argsort([3, 4, 2, 1])

In [5]:
if True:
    fulld = pd.read_json("output/adapt_explore.jsonl", orient="records", lines=True)
else:
    # load in data
    d1 = pd.read_json("output/testset1.jsonl", orient="records", lines=True).drop(columns=["stats", 'ver', 'pref'])
    d2 = pd.read_json("output/testset2.jsonl", orient="records", lines=True).drop(columns=["stats", 'ver', 'pref'])
    fulld = pd.DataFrame({'inp':d1.inp, 'hyps':[d1['hyps'][i]+d2['hyps'][i] for i in range(len(d1))], 'scos':[d1['scos'][i]+d2['scos'][i] for i in range(len(d1))]})
    fulld['toklens'] = make_toklen_lists(fulld, tokenizer)
    fulld['pf5'] = make_prefs(fulld, tokenizer, 5)
    fulld['pf10'] = make_prefs(fulld, tokenizer, 10)
    fulld['pf15'] = make_prefs(fulld, tokenizer, 15)
    fulld['pf20'] = make_prefs(fulld, tokenizer, 20)
    fulld.to_json("output/adapt_explore.jsonl", orient='records', lines=True)

In [6]:
# Function to construct a DataFrame with pairs
def construct_pair_dset(row, hcol):
    # go through all possible pairs
    all_pairs = list(combinations(range(len(row[hcol])), 2))
    pairs = [(row[hcol][i], row[hcol][j]) for i, j in all_pairs]
    scores = [(row['scos'][i], row['scos'][j]) for i, j in all_pairs]

    label = [score_a > score_b for score_a, score_b in scores]
    df_temp = pd.DataFrame({'inp': row['inp'], 'hyp_pairs': pairs,
                            'score_a': [s[0] for s in scores],
                            'score_b': [s[1] for s in scores],
                            'label': label, 'pair_id':[str(s[0])+"_"+str(s[1]) for s in all_pairs]})
    return df_temp

In [7]:
pairdf = pd.concat([construct_pair_dset(row, 'pf20') for _,row in fulld.iterrows()]).reset_index(drop=True)
pairdf[['hyp_a', 'hyp_b']] = pd.DataFrame(pairdf['hyp_pairs'].tolist(), index=pairdf.index)
pairdf = pairdf.drop('hyp_pairs', axis=1)

In [9]:
CLASSES = [71, 272]
def pairlabel(val):
    if val==1:
        return CLASSES[0]
    return CLASSES[1]

pairdf['numlab'] = pairdf.label
pairdf['label'] = [pairlabel(p) for p in pairdf.numlab]

In [10]:
# get labels, probs for all pairs
preds, labels, probs = validate(pairdf, "lightning_logs/balanced_compmodel/checkpoints/epoch=4-step=36732.ckpt")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Restoring states from the checkpoint path at lightning_logs/balanced_compmodel/checkpoints/epoch=4-step=36732.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loaded model weights from the checkpoint at lightning_logs/balanced_compmodel/checkpoints/epoch=4-step=36732.ckpt
  rank_zero_warn(


Validation: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy          0.7033439881976887
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


In [11]:
pairdf['pred20'] = preds
pairdf['prob20'] = probs

In [12]:
pairdf.to_json("output/adapt2_pairwise20.jsonl", orient='records', lines=True)

In [13]:
dists = (pairdf.score_a - pairdf.score_b).abs()
selprobs = [max(m) for m in pairdf.prob20]

In [14]:
pairdf['dists'] = dists

In [20]:
bdist = pairdf[pairdf.dists>0.1]
(bdist['pred20']==bdist['label']).mean()

0.8151694594935114

In [21]:
CLASSES = [71, 272]
g_dicts = []
for i in range(len(fulld)):
    g = pairdf[pairdf.inp==fulld.inp.iloc[i]]
    tmpdict = {}
    for i in range(len(g)):
        tmpdict[g['pair_id'].iloc[i]] = 1-CLASSES.index(g['pred20'].iloc[i])
    g_dicts.append(tmpdict)
fulld['pairdict20'] = g_dicts

In [22]:
def getprobs(indf, ex):
    allprobs = []
    for num in range(8):
        out = qpref.predsingle(indf['inp'][ex], indf['pf20'][ex][num], True)
        cprobs = []
        for c in CLASSES:
            out.sequences[0][1] = c
            transition_scores = qpref.model.compute_transition_scores(
                out.sequences, out.scores, normalize_logits=True
            )
            cprobs.append(float(np.exp(transition_scores[0][0].cpu())))
        allprobs.append(cprobs)
    return allprobs

In [33]:
def rank_pairwise(compdict, include=[0, 1, 2, 3, 4, 5, 6, 7]):
    include = [str(s) for s in include]
    # Create a dictionary to count votes for each index
    vote_count = defaultdict(int)
    for i in include:
        vote_count[str(i)]=0

    # Process the comparisons
    for pair, label in compdict.items():
        #print(pair)
        #print(label)
        # Split the pair into individual indices
        indices = pair.split("_")
        if indices[0] in include and indices[1] in include:
            # Assign the vote to the correct index
            if label == 1:
                vote_count[indices[0]] += 1
            else:
                vote_count[indices[1]] += 1
    #print(vote_count)
    # Convert to a list of tuples and sort in descending order by vote count
    sorted_indices = sorted(vote_count.items(), key=lambda x: x[1], reverse=True)
    indices = [int(ind) for ind, _ in sorted_indices] 
    counts = [count for _, count in sorted_indices] 
    # Print the sorted indices
    #for index, count in sorted_indices:
    #    print(f"Index {index} with vote count {count}")
    return indices, counts

def pred_best_n(n):
    allpreds = []
    allgolds = []
    allrands = []
    for ind in range(len(fulld)):
        inds, counts = rank_pairwise(g_dicts[ind])
        scovals = fulld.scos.iloc[ind]
        goldranks = list(np.argsort(scovals))
        allgolds.append(max([scovals[s] for s in goldranks[-1*n:]]))
        allpreds.append(max([scovals[s] for s in inds[:n]]))
        random.shuffle(inds)
        allrands.append(max([scovals[s] for s in inds[:n]]))
    return mean(allgolds), mean(allpreds), mean(allrands)

In [None]:
a = [1, 2, 3]
a[-1:]

In [32]:
pred_best_n(2)

(0.9002139297487092, 0.8834123045165233, 0.8287699305666093)

In [None]:
fulld

In [None]:
pairdf[pairdf.inp==fulld.inp.iloc[0]]

In [None]:
rank_pairwise(g_dicts[0])

In [None]:
with torch.no_grad():
    allprobs = []
    for val in range(len(fulld)):
        if val%10==0:
            print(val)
        allprobs.append(getprobs(fulld, val))

In [None]:
fulld['probs20'] = allprobs
fulld.to_json("output/adapt_explore.jsonl", orient='records', lines=True)

In [None]:
fulld.scos[0]

In [None]:
getprobs(fulld, 0)

In [None]:
transition_scores = qpref.model.compute_transition_scores(
    out.sequences, out.scores, normalize_logits=True
)
print(np.exp(transition_scores[0][0].cpu()))
print(np.argmax(out.scores[0].cpu()))

In [25]:
CLASSES = [71, 272, 205, 309]

In [35]:
## run adaptive baseline
for j in range(2, 9):
    scos, buds = [], []
    for i in range(1000):
        s, b = adaptive_ranksort(fulld, .85, j, 20, 1)
        # s, b = samprer(fulld,  j)
        scos.append(s)
        buds.append(b)
    print(mean(scos), " ", mean(buds))

0.8041931372659045   94.44742685025818
0.8229735549667913   116.10885197934596
0.8347407183849952   137.5811308089501
0.8429356585925766   158.49361617900172
0.8494724683492553   179.53176936316694
0.854479042185624   200.31748192771084
0.8587141485635704   220.94034767641998


In [None]:
mean([max(m) for m in fulld.scos])

In [None]:
0.8267750212205238   119.82409810671257
0.8447314879783238   139.1265800344234
0.8548433872520032   158.27455421686747
0.8619431309783976   177.52356798623063
0.8674830498737955   196.6789586919105
0.8720366012693377   215.81833734939758
0.8760026756320878   234.8388760757315