In [2]:
import torch
import numpy as np


In [13]:
def rearrange(energy_scores, candidate_position_idx, true_position_idx):
    tmp = np.array([[x==y for x in candidate_position_idx] for y in true_position_idx]).any(0)
    correct = np.where(tmp)[0]
    incorrect = np.where(~tmp)[0]
    labels = torch.cat((torch.ones(len(correct)), torch.zeros(len(incorrect)))).int()
    energy_scores = torch.cat((energy_scores[correct], energy_scores[incorrect]))
    return energy_scores, labels

In [14]:
import re 

def calculate_ranks_from_distance(all_distances, positive_relations):
    """
    all_distances: a np array
    positive_relations: a list of array indices

    return a list
    """
    # positive_relation_distance = all_distances[positive_relations]
    # negative_relation_distance = np.ma.array(all_distances, mask=False)
    # negative_relation_distance.mask[positive_relations] = True
    # ranks = list((negative_relation_distance < positive_relation_distance[:, np.newaxis]).sum(axis=1) + 1)
    # ranks = list((all_distances < positive_relation_distance[:, np.newaxis]).sum(axis=1) + 1)
    ranks = list(np.argsort(np.argsort(all_distances))[positive_relations]+1)
    return ranks

def obtain_ranks(outputs, targets):
    """ 
    outputs : tensor of size (batch_size, 1), required_grad = False, model predictions
    targets : tensor of size (batch_size, ), required_grad = False, labels
        Assume to be of format [1, 0, ..., 0, 1, 0, ..., 0, ..., 0]
    mode == 0: rank from distance (smaller is preferred)
    mode == 1: rank from similarity (larger is preferred)
    """
    calculate_ranks = calculate_ranks_from_distance
    all_ranks = []
    prediction = outputs.cpu().numpy().squeeze()
    label = targets.cpu().numpy()
    sep = np.array([0, 1], dtype=label.dtype)
    
    # fast way to find subarray indices in a large array, c.f. https://stackoverflow.com/questions/14890216/return-the-indexes-of-a-sub-array-in-an-array
    end_indices = [(m.start() // label.itemsize)+1 for m in re.finditer(sep.tostring(), label.tostring())]
    end_indices.append(len(label)+1)
    start_indices = [0] + end_indices[:-1]
    for start_idx, end_idx in zip(start_indices, end_indices):
        distances = prediction[start_idx: end_idx]
        labels = label[start_idx:end_idx]
        positive_relations = list(np.where(labels == 1)[0])
        ranks = calculate_ranks(distances, positive_relations)
        all_ranks.append(ranks)
    return all_ranks

In [15]:
import itertools

def macro_mr(all_ranks):
    macro_mr = np.array([np.array(all_rank).mean() for all_rank in all_ranks]).mean()
    return macro_mr

def micro_mr(all_ranks):
    micro_mr = np.array(list(itertools.chain(*all_ranks))).mean()
    return micro_mr

def hit_at_1(all_ranks):
    rank_positions = np.array(list(itertools.chain(*all_ranks)))
    hits = np.sum(rank_positions <= 1)
    return 1.0 * hits / len(rank_positions)

def hit_at_3(all_ranks):
    rank_positions = np.array(list(itertools.chain(*all_ranks)))
    hits = np.sum(rank_positions <= 3)
    return 1.0 * hits / len(rank_positions)

def hit_at_5(all_ranks):
    rank_positions = np.array(list(itertools.chain(*all_ranks)))
    hits = np.sum(rank_positions <= 5)
    return 1.0 * hits / len(rank_positions)

def hit_at_10(all_ranks):
    rank_positions = np.array(list(itertools.chain(*all_ranks)))
    hits = np.sum(rank_positions <= 10)
    return 1.0 * hits / len(rank_positions)

def precision_at_1(all_ranks):
    rank_positions = np.array(list(itertools.chain(*all_ranks)))
    hits = np.sum(rank_positions <= 1)
    return 1.0 * hits / len(all_ranks)

def precision_at_3(all_ranks):
    rank_positions = np.array(list(itertools.chain(*all_ranks)))
    hits = np.sum(rank_positions <= 3)
    return 1.0 * hits / (len(all_ranks)*3)

def precision_at_5(all_ranks):
    rank_positions = np.array(list(itertools.chain(*all_ranks)))
    hits = np.sum(rank_positions <= 5)
    return 1.0 * hits / (len(all_ranks)*5)

def precision_at_10(all_ranks):
    rank_positions = np.array(list(itertools.chain(*all_ranks)))
    hits = np.sum(rank_positions <= 10)
    return 1.0 * hits / (len(all_ranks)*10)

def mrr_scaled_10(all_ranks):
    """ Scaled MRR score, check eq. (2) in the PinSAGE paper: https://arxiv.org/pdf/1806.01973.pdf
    """
    rank_positions = np.array(list(itertools.chain(*all_ranks)))
    
    scaled_rank_positions = np.ceil(rank_positions / 10)
 #   print(scaled_rank_positions, (1.0 / scaled_rank_positions).mean())
    return (1.0 / scaled_rank_positions).mean()

In [4]:
import pickle

#test_path = '../data/psychology/test_nodes.pickle'
test_path = '../data/noun/test_hypernyms_def.pickle'
with open(test_path, 'rb') as f:
    test = pickle.load(f)

test_wn = []

for elem in test:
    case = elem['case']
    if case == 'predict_hypernym':
        cur_pair = (elem['children'], [elem['parents']])
    else:
        cur_pair = (elem['children'], elem['parents'])

    test_wn.append(cur_pair)
test = test_wn

In [6]:
test[0]

('unconventionality.n.02', ['unorthodoxy.n.03'])

In [7]:
def clean(s):
    return s.split('.')[0]

new_test = []
for child, parents in test:
    temp = []
    for parent in parents:
        temp.append((clean(parent), clean(child)))
    new_test.append(temp)

In [8]:
pred_path = '../../../data/taxonomy/model_outputs/_meta-llama-Llama-2-7b-hfTaxoEnrich_noun_Unified_3beams_40topk_0.8temp_3norepeat_stohastic_'
with open(pred_path, 'rb') as f:
    pred = pickle.load(f)

def get_hypernyms(line):
    clean_line = line.strip().replace("\n", ",").split(",")

    res = []
    for hyp in clean_line:
        if not hyp in ("", " ", ", ", ","):
            res.append(hyp.lower().strip())

    return res

from collections import Counter

def unique_words_by_frequency(words):
    # Count the frequency of each word in the list
    frequency = Counter(words)
    # Sort the words first by frequency, then by the order they appear in the original list
    sorted_words = sorted(set(words), key=lambda x: (-frequency[x], words.index(x)))
    return sorted_words

concat = True

new_pred = []
for elem in pred:
    if concat:
        cur_portion = []
        for line in elem:
            cur_portion.extend(get_hypernyms(line))
        new_pred.append(unique_words_by_frequency(cur_portion))
    else:
        new_pred.append(get_hypernyms(elem[1]))

In [9]:
new_test[:3], new_pred[:3]

([[('unorthodoxy', 'unconventionality')],
  [('brome', 'awnless_bromegrass')],
  [('insufficiency', 'slenderness')]],
 [['unorthodoxiness',
   'unorthodxy',
   'unconformity',
   'heterodoxy',
   'dissidence',
   'disobedience',
   'disloyalty',
   'dis',
   'disaffection',
   'dissent',
   'disagreement'],
  ['bromus',
   'brome',
   'bromine grass',
   'foxtail',
   "fox's tail",
   'caryopsis',
   'cereal grass',
   'c',
   'briza',
   'breeze grass',
   'b'],
  ['weakness',
   'insufficiency',
   'deficiency',
   'shortcoming',
   'incompleteness',
   'imperfection',
   'indefiniteness',
   'insubstantial',
   'inadequacies']])

In [16]:
metric_names = {
    'mrr': mrr_scaled_10,
    'p1': precision_at_1,
    'p5': precision_at_5,
    'r1': hit_at_1,
    'r5': hit_at_5
}

metrics = {}
for name in metric_names.keys():
    metrics[name] = []
for idx in range(len(new_test)):
    #hyps = get_hypernyms(new_pred[idx])
    hyps = new_pred[idx] + [', ']
   # print(hyps)
    gold = new_test[idx]

    child = gold[0][1]
    new_hyps = [(hyp, child) for hyp in hyps]
    scores = torch.arange(len(new_hyps))

    batched_energy_scores, labels = rearrange(scores, new_hyps, gold)

    all_ranks = obtain_ranks(batched_energy_scores, labels)
    for name, func in metric_names.items():
        cur_metric = np.nan_to_num(func(all_ranks))
        metrics[name].append(cur_metric)


  end_indices = [(m.start() // label.itemsize)+1 for m in re.finditer(sep.tostring(), label.tostring())]
  return (1.0 / scaled_rank_positions).mean()
  ret = ret.dtype.type(ret / rcount)
  return 1.0 * hits / len(rank_positions)
  return 1.0 * hits / len(rank_positions)


In [17]:
for name, v in metrics.items():
    print(name, np.mean(v))

mrr 0.48020304568527916
p1 0.38274111675126904
p5 0.09401015228426396
r1 0.3822335025380711
r5 0.46954314720812185


In [20]:
len(metrics['mrr'])

985

In [21]:
(np.array(metrics['mrr']) > 0).sum()

473