# perform clustering to obtain the minilongbench

In this notebook, we show how to cluster the representations to obtain the minilongbench

## Prepare data

In [1]:
import numpy as np
import pickle
from tqdm import tqdm
from sklearn.cluster import KMeans
from sklearn.metrics.pairwise import pairwise_distances
from irt import *
from utils import *




In [2]:
to_handle_scenario = 'longbench'
scenarios

{'longbench': ['LongBench_2wikimqa',
  'LongBench_dureader',
  'LongBench_gov_report',
  'LongBench_hotpotqa',
  'LongBench_lcc',
  'LongBench_lsht',
  'LongBench_multifieldqa_en',
  'LongBench_multifieldqa_zh',
  'LongBench_multi_news',
  'LongBench_musique',
  'LongBench_narrativeqa',
  'LongBench_passage_count',
  'LongBench_passage_retrieval_en',
  'LongBench_passage_retrieval_zh',
  'LongBench_qasper',
  'LongBench_qmsum',
  'LongBench_repobench-p',
  'LongBench_samsum',
  'LongBench_trec',
  'LongBench_triviaqa',
  'LongBench_vcsum']}

Loading longbench data:

In [3]:
with open('data/longbench.pickle', 'rb') as handle:
    data = pickle.load(handle)

In [4]:
scenarios_position, subscenarios_position = prepare_data(scenarios, data)
Y = create_responses(scenarios, data)

Y.shape

(40, 4750)

In [5]:
balance_weights = np.ones(Y.shape[1])
# per_scen indicates which scenario this document belongs to
per_scen = [1, 1, 2, 1, 5, 3, 0, 0, 2, 1, 0, 4, 4, 4, 0, 2, 5, 3, 3, 3, 2]
N = len(scenarios_position[to_handle_scenario])
n_sub = len(scenarios[to_handle_scenario])
for i, sub in enumerate(scenarios[to_handle_scenario]):
    if per_scen[i] == 4:
        num = 3
    elif per_scen[i] == 5:
        num = 2
    else:
        num = 4
    n_i = len(subscenarios_position[to_handle_scenario][sub])
    balance_weights[subscenarios_position[to_handle_scenario][sub]] = N/(num*6*n_i)  

## Clustering

In [6]:
# number_item = Y_train.shape[1]
from scipy.stats import spearmanr, pearsonr, kendalltau

scenario_dict = {"Single-Document QA":["LongBench_narrativeqa", "LongBench_qasper", "LongBench_multifieldqa_en", "LongBench_multifieldqa_zh"],
                "Multi-Document QA":["LongBench_hotpotqa", "LongBench_2wikimqa", "LongBench_musique", "LongBench_dureader"],
                "Summarization":["LongBench_gov_report", "LongBench_qmsum", "LongBench_vcsum", "LongBench_samsum"],
                "Few-shot Learning":["LongBench_trec", "LongBench_lsht", "LongBench_triviaqa", "LongBench_multi_news"],
                "Code Completion":["LongBench_lcc", "LongBench_repobench-p"],
                "Synthetic Task":["LongBench_passage_count", "LongBench_passage_retrieval_en", "LongBench_passage_retrieval_zh"]}


A, B, _ = load_irt_parameters('data/irt_model/')
X = np.vstack((A.squeeze(), B.squeeze().reshape((1,-1)))).T
# X = np.vstack((A.squeeze())).T
X = X[scenarios_position['longbench']]
norm_balance_weights = balance_weights[scenarios_position['longbench']]
norm_balance_weights /= norm_balance_weights.sum()
scenario = 'longbench'
with open('data/sub_scenarios_pospos.pkl', 'rb') as f:
    sub_scenarios_pospos = pickle.load(f)


ratio = 0.95
number_item = int((1-ratio) * 4750)

clustering = 'irt' # 'correct.' or 'irt'

anchor_points = {}
anchor_weights = {}


summ = 0
for i in range(6):
    idx = np.array(sub_scenarios_pospos[i])
    if i == 5:
        num = number_item - summ
    else:
        num = int(len(sub_scenarios_pospos[i]) / 4750 * number_item)
    summ += num
    # Fitting the KMeans model
    kmeans = KMeans(n_clusters=num, n_init="auto")
    kmeans.fit(X[idx, :], sample_weight=norm_balance_weights[idx])
    if i == 0:
        # Calculating anchor points
        tmp_points = pairwise_distances(kmeans.cluster_centers_, X[idx, :], metric='euclidean').argmin(axis=1)
        anchor_points[scenario] = idx[tmp_points]
        # Calculating anchor weights
        anchor_weights[scenario] = np.array([np.sum(norm_balance_weights[idx][kmeans.labels_==c]) for c in range(num)])# * len(idx) / 4750 
    else:
        # Calculating anchor points
        tmp_points = pairwise_distances(kmeans.cluster_centers_, X[idx, :], metric='euclidean').argmin(axis=1)
        anchor_points[scenario] = np.concatenate((anchor_points[scenario], idx[tmp_points]))
        # Calculating anchor weights
        anchor_weights[scenario] = np.concatenate((anchor_weights[scenario], np.array([np.sum(norm_balance_weights[idx][kmeans.labels_==c]) for c in range(num)]) ))#* len(idx) / 4750)) 

with open("data/new_anchor.pkl", "wb") as f:
    pickle.dump(anchor_points, f)
       

