In [1]:
from pathlib import Path

while Path.cwd().name != 'proxy-tuning':
    %cd ..

/mmfs1/gscratch/xlab/alisaliu/proxy-tuning


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [2]:
import torch
import pandas as pd
import numpy as np
from collections import defaultdict
from scipy.stats import ttest_ind
from analysis.gsm_analysis import get_equation_lhs_rhs_indices

  from .autonotebook import tqdm as notebook_tqdm


# TruthfulQA analysis

In [3]:
all_results = torch.load('analysis/pkl/truthfulqa_analysis.pkl')

In [4]:
# calculate mean prob diff for every token
mean_prob_diff = defaultdict(list)
for results in all_results:
    for i, token in enumerate(results['tokens']):
        p_diff = (results['p_dexperts'][i] - results['p_base'][i]).item()
        mean_prob_diff[token].append(p_diff)

mean_prob_diff = {k: np.mean(v) for k, v in mean_prob_diff.items() if len(v) >= 100}

In [5]:
sorted_items = sorted(mean_prob_diff.items(), key=lambda x: x[1])

In [6]:
from collections import Counter
from nltk.tokenize import word_tokenize
from string import punctuation

def get_ngram(words, at_index, n=4):
    index = at_index
    gram = []
    num_words_in_gram = 0
    while num_words_in_gram < n and index < len(words):
        word = words[index]
        if word not in punctuation:
            num_words_in_gram += 1
        index += 1
        gram.append(word)
    return gram

def find_most_common_ngram(words, target_word, n=4):
    # Find n-grams containing the target word
    target_ngrams = [' '.join(get_ngram(words, i, n=n)) for i in range(len(words)) if target_word in get_ngram(words, i, n=n)]

    # Count the occurrences of each n-gram
    counter = Counter(target_ngrams)

    # Find the most common n-gram
    most_common_ngram = counter.most_common(1)[0]

    return most_common_ngram

predictions_df = pd.read_json('results/truthfulqa/dexperts-13B-helpful-prompt/open_results.jsonl', lines=True)
text = '\n'.join(predictions_df.output.tolist())

print("{:<20} {:<40} {:<10}".format('Word', '4-gram', 'Fraction of occurrences'))
print("-" * 85)

for item in sorted_items[-20:][::-1]:
    target_word = item[0]
    words = word_tokenize(text)
    word_freq = words.count(target_word)
    gram, occurrences = find_most_common_ngram(words, target_word, n=4)
    print("{:<20} {:<40} {:<10}".format(target_word, gram, f'{occurrences}/{word_freq}'))

Word                 4-gram                                   Fraction of occurrences
-------------------------------------------------------------------------------------
Here                 Here are some of                         7/35      
Additionally         . Additionally , it is important         33/179    
There                There is no scientific                   5/59      
While                . While some people may                  12/206    
several              depending on several factors             4/60      
It                   It 's important to                       265/786   
provide              I can not provide                        165/413   
respect              is important to respect                  48/216    
common               is a common myth                         4/51      
personal             do n't have personal                     50/168    
However              However , it 's important                119/528   
In                   In t

# GSM analysis

In [7]:
all_results = torch.load('analysis/pkl/gsm_analysis.pkl')

In [8]:
# calculate prob diffs corresponding to the LHS and RHS of math equations
lhs_diffs = []
rhs_diffs = []

for ex in all_results:
    lhs_idx, rhs_idx = get_equation_lhs_rhs_indices(ex['tokens'])
    
    for i in lhs_idx:
        p_diff = (ex['p_dexperts'][i] - ex['p_base'][i]).item()
        lhs_diffs.append(p_diff)

    for i in rhs_idx:
        p_diff = (ex['p_dexperts'][i] - ex['p_base'][i]).item()
        rhs_diffs.append(p_diff)

In [9]:
print("{:<5} {:<10} {:<10}".format('', 'Count', 'Mean diff'))
print("-" * 27)
print("{:<5} {:<10} {:<10}".format('LHS', str(len(lhs_diffs)), str(np.round(np.mean(lhs_diffs), 3))))
print("{:<5} {:<10} {:<10}".format('RHS', str(len(rhs_diffs)), str(np.round(np.mean(rhs_diffs), 3))))

      Count      Mean diff 
---------------------------
LHS   14104      0.131     
RHS   16452      0.056     


In [10]:
ttest_ind(lhs_diffs, rhs_diffs, equal_var=False)

TtestResult(statistic=33.08578492661922, pvalue=1.049491530636505e-234, df=23665.983157064817)