In [None]:
# notebook dedicated to computing various intrinsic metrics to see whether 
# they corresond to downstream performance and could be useful for RLHF

In [None]:
from transformers import AutoTokenizer
import pandas as pd
from rlhfutils.eval_utils import getapfsft, tok_dist, proctmp
import matplotlib.pyplot as plt
from rlhfutils.debug_utils import load_rm, progress_rm, load_all_rmdfs, load_all_hackdfs, highlight_differences
from statistics import mean, stdev, median
from scipy.stats import pearsonr, kendalltau, spearmanr
import math
from rouge_score import rouge_scorer
from rlhfutils.eval_utils import oai_kwargs, load_alldfs, annotate_apfarm, apf_format, load_wgpt, filter_and_sort_df
import pandas as pd
from statistics import mean
import matplotlib.pyplot as plt
import re
from transformers import AutoTokenizer
from datasets import load_dataset
import openai
from rlhfutils.data import qaform
import os
import numpy as np
import random

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# replace all wgptouts with corresponding stack QA format (RM input format)
def setall(l):
    newl = []
    try:
        for ind in l:
            newl.append(getapfsft(ind, True))
            #print(0)
    except:
        return None
    return newl

def splitall(l):
    try: 
        return [s.split("Answer:")[1] for s in l]
    except:
        return None

def getfulldist(lcol):
    hist = []
    for l in lcol:
        hist.extend(l)
    return hist

def compdist(lcol, slen):
    res = []
    tmp = []
    for i in range(len(lcol)):
        tmp.append(lcol[i])
        if len(tmp)%slen==0:
            res.append(tmp)
            tmp = []
    return res
    
def procall(indf, toker, needset=True):
    if needset:
        indf['response'] = [setall(s) for s in indf['response']]
    indf = indf.dropna()
    indf['answers'] = [splitall(s) for s in indf['response']]
    indf = indf.dropna()
    indf['atoks'] = [tok_dist(s, wgtok) for s in list(indf['answers'])]
    indf['ttoks'] = [tok_dist(s, wgtok) for s in list(indf['response'])]
    return indf

# take rouge between all pairs. High rouge should bigger gaps
def rpaircorr(row, scat=False):
    rouges = []
    diffs = []
    for i in range(len(row)):
        for j in range(i, len(row)):
            if i==j:
                continue
            trmp = scorer.score(row['answers'][i], row['answers'][j])
            rouges.append(trmp['rouge1'].fmeasure)
            diffs.append(math.pow(row['reward'][i]-row['reward'][j], 2))
    if scat:
        return rouges, diffs
    return pearsonr(rouges, diffs).statistic

def getcorr(row, pearson=False):
    if pearson:
        return pearsonr(row['atoks'], row['reward']).statistic
    else:
        return kendalltau(row['atoks'], row['reward']).statistic

In [None]:
scorer = rouge_scorer.RougeScorer(['rouge1'], use_stemmer=True)
stacktok = AutoTokenizer.from_pretrained("../stack-llama/models/sft/")
wgtok = AutoTokenizer.from_pretrained("../webgpt-llama/models/sft10k/")

In [None]:
ndfs = load_all_rmdfs("../trl-general/rmouts//")

In [None]:
# shuffdfs = load_all_rmdfs("../trl-general/rmshuffs/")
# rmdfs = load_all_rmdfs("../trl-general/rmouts/")
# moredfs = load_all_rmdfs("../trl-general/morermouts/")
# attdfs = load_all_hackdfs("../trl-general/fullattacks/")

In [None]:
def process_qa(statdf, label, lim=1000):
    questions = []
    minouts = []
    maxouts = []
    sampouts = []
    for ind, row in statdf.iterrows():
        topind = np.argmax(row['atoks'])
        botind = np.argmin(row['atoks'])
        # don't re-use outputs
        result = [x for x in range(len(row['atoks'])) if x not in [topind, botind]]
        rind = random.choice(result)
        questions.append(row['question'])
        minouts.append(row['answers'][botind])
        maxouts.append(row['answers'][topind])
        sampouts.append(row['answers'][rind])
    mins = pd.DataFrame({
        'question':questions,
        'response':minouts, 
        'rtoks': [len(wgtok(r).input_ids) for r in minouts]
    })
    mins = proctmp(mins)
    maxes = pd.DataFrame({
        'question':questions,
        'response':maxouts,
        'rtoks': [len(wgtok(r).input_ids) for r in maxouts]
    })
    maxes = proctmp(maxes)
    samps = pd.DataFrame({
        'question':questions,
        'response':sampouts,
        'rtoks': [len(wgtok(r).input_ids) for r in sampouts]
    })
    samps = proctmp(samps)
    return {
        label+'_maxouts':maxes.loc[:lim],
        label+'_minouts':mins.loc[:lim],
        label+'_sampouts':samps.loc[:lim]
    }

In [None]:
ndfs

In [None]:
pq = process_qa(ndfs['stack_rewardda'], 'rlcd', 200)

In [None]:
pq.keys()

In [None]:
pq['rlcd_sampouts'].loc[:200].rtoks.mean()

In [None]:
mean([len(w) for w in pq['rlcd_sampouts']['response']])

In [None]:
mean([len(w) for w in pq['rlcd_maxouts']['response']])

In [None]:
# APEVAL call (be cautious)
lenannot = annotate_apfarm(pq, "rlcd_sampouts", "rlcd_maxouts", 0, len(pq['rlcd_maxouts']), oai_kwargs())

In [None]:
tmp = pd.DataFrame(lenannot)
print(tmp.preference.mean())

In [None]:
tmp

In [None]:
pq['minouts'][:200]

In [None]:
ndfs['stack_rewardda']

In [None]:
ind = 10
r = attdfs['wgptda'].iloc[ind]
highlight_differences(r['origseq'], r['bestseqs'][-1])

In [None]:
kmap = {
    'stackrandaug':'stack_rewardrandaug',
    'stackda':'stack_rewardda',
    'stackmix':'stack_rewardmixed',
    'wgptda':'wgpt_rewardrandda',
    'wgptorig':'wgpt_rewardmodel',
    'stacksanity':'stack_rewardsanity'
}

In [None]:
rmdfs.keys()

In [None]:
keyval = "wgpt_rewardmodel"

In [None]:
for k in attdfs.keys():
    norm = stdev(getfulldist(rmdfs[kmap[k]].reward))
    print(k)
    # print(attdfs[k]['diff'].mean())
    print(norm)

In [None]:
for k in shuffdfs.keys():
    tmpa = pd.DataFrame({
        'or':getfulldist(rmdfs[k.replace("shuff", "")].reward),
        'shuff':getfulldist(shuffdfs[k].reward),
    })
    norm = stdev(list(tmpa['or']))
    print(k)
    # print((tmpa['or']-tmpa['shuff']).abs().mean()/norm)
    print(norm)
    #print(spearmanr(getfulldist(rmdfs[keyval].reward), getfulldist(rmdfs[keyval].atoks)))
    #print(kendalltau(getfulldist(rmdfs[keyval].reward), getfulldist(rmdfs[keyval].atoks)))v

In [None]:
rmdfs = ndfs

In [None]:
# get spearman and pearson corrs per batch
for k in rmdfs.keys():
    print(k)
    keyval = k
    sps = [getcorr(r, True) for _, r in rmdfs[k].iterrows()]
    kts = [getcorr(r, False) for _, r in rmdfs[k].iterrows()]
    print(mean([s for s in sps if not math.isnan(s)]))
    print(mean([s for s in kts if not math.isnan(s)]))
    #print(spearmanr(getfulldist(rmdfs[keyval].reward), getfulldist(rmdfs[keyval].atoks)))
    #print(kendalltau(getfulldist(rmdfs[keyval].reward), getfulldist(rmdfs[keyval].atoks)))v

In [None]:
for k in rmdfs.keys():
    print(k)
    keyval = k
    norm = stdev(getfulldist(rmdfs[k].reward))
    # results = [(max(s)-median(s))/norm for s in rmdfs[k].reward]
    results = [(stdev(s))/norm for s in rmdfs[k].reward]
    #rouges = [rpaircorr(r) for _, r in rmdfs[k].iterrows()]
    print(mean([r for r in results if not math.isnan(r)]))
    #print(mean([r for r in rouges if not math.isnan(r)]))


In [None]:
plt.hist(getfulldist(rmdfs[keyval].atoks))

In [None]:
for k in rmdfs.keys():
    print(k)
    keyval = k
    print(pearsonr(getfulldist(rmdfs[keyval].reward), getfulldist(rmdfs[keyval].atoks)))
    #print(spearmanr(getfulldist(rmdfs[keyval].reward), getfulldist(rmdfs[keyval].atoks)))
    # print(kendalltau(getfulldist(rmdfs[keyval].reward), getfulldist(rmdfs[keyval].atoks)))

In [None]:
rmdfs = load_all_rmdfs("../trl-general/fullattacks/")

In [None]:
print(mean([stdev(s) for s in rmdfs[keyval].reward])/stdev(getfulldist(rmdfs[keyval].reward)))
plt.hist([stdev(s) for s in rmdfs[keyval].reward])

In [None]:
wgptorigrm = load_rm("../tr")

In [None]:
allresps = getfulldist(stackouts.response)

In [None]:
allscos = progress_rm(allresps[:100], stackorigrm, kwargs)

In [None]:
compdist([a[0]['score'] for a in allscos], 8)

In [None]:
plt.hist(getfulldist(stackouts.atoks))