In [3]:

import os
import sys
import pandas as pd
import numpy as np
import pickle


sys.path.append("pipeline_src/")

from metrics.metrics import get_all_metrics, get_hypernyms
from dataset.dataset import HypernymDataset
from transformers import AutoTokenizer

from dataset.prompt_schemas import (
    hypo_term_hyper,
    predict_child_from_2_parents,
    predict_child_from_parent,
    predict_child_with_parent_and_grandparent,
    predict_children_with_parent_and_brothers,
    predict_parent_from_child_granparent,
)

In [4]:
test_path = "babel_datasets/wnet_test_en_babel.pickle"
saving_path = "/raid/rabikov/model_outputs/" + "eachadea-vicuna-13b-1.1remove_all_from_labels_"

In [5]:
df = pd.read_pickle(test_path)

In [6]:
transforms={
            "only_child_leaf": predict_parent_from_child_granparent,
            "only_leafs_all": predict_child_from_parent,
            "only_leafs_divided": predict_children_with_parent_and_brothers,
            "leafs_and_no_leafs": predict_child_from_parent,
            "simple_triplet_grandparent": predict_parent_from_child_granparent,
            "simple_triplet_2parent": predict_child_from_2_parents,
        }

In [7]:
all_labels = []
all_terms = []

cased = {}

with open(saving_path, "rb") as fp:
    all_preds = pickle.load(fp)

for i, elem in enumerate(df):
    try:
        all_preds[i]
    except IndexError:
        continue
    
    case = elem['case']
    processed_term, target = transforms[case](elem)
    all_labels.append(target)
    all_terms.append(processed_term)

    if not case in cased.keys():
        cased[case] = {'pred': [], 'label': [], 'term': []}

    cased[case]['pred'].append(all_preds[i])
    cased[case]['label'].append(target)
    cased[case]['term'].append(processed_term)

In [12]:
from metrics.metrics import *

def get_one_sample_metric(goldline, predline, limit=15):
    all_scores = []
    scores_names = ["MRR", "MAP", "P@1", "P@3", "P@5", "P@15"]


    avg_pat1 = []
    avg_pat2 = []
    avg_pat3 = []
    avg_pat4 = []

    gold_hyps = get_hypernyms(goldline, is_gold=True, limit=limit)
    pred_hyps = get_hypernyms(predline, is_gold=False, limit=limit)
    gold_hyps_n = len(gold_hyps)
    r = [0 for i in range(limit)]

    intersection = get_intersect(gold_hyps, pred_hyps)

    for j in range(len(pred_hyps)):
        pred_hyp = pred_hyps[j]
        if pred_hyp in gold_hyps:
            r[j] = 1

    avg_pat1.append(precision_at_k(r, 1, gold_hyps_n))
    avg_pat2.append(precision_at_k(r, 3, gold_hyps_n))
    avg_pat3.append(precision_at_k(r, 5, gold_hyps_n))
    avg_pat4.append(precision_at_k(r, 15, gold_hyps_n))

    mrr_score_numb = mean_reciprocal_rank(r)
    map_score_numb = mean_average_precision(r, gold_hyps_n)
    avg_pat1_numb = sum(avg_pat1) / len(avg_pat1)
    avg_pat2_numb = sum(avg_pat2) / len(avg_pat2)
    avg_pat3_numb = sum(avg_pat3) / len(avg_pat3)
    avg_pat4_numb = sum(avg_pat4) / len(avg_pat4)

    scores_results = [
        mrr_score_numb,
        map_score_numb,
        avg_pat1_numb,
        avg_pat2_numb,
        avg_pat3_numb,
        avg_pat4_numb,
    ]

    res = {}
    for k in range(len(scores_names)):
        res[scores_names[k]] = scores_results[k]

    return res, intersection

def get_intersect(gold, pred):

    return list(set(pred).intersection(set(gold)))

In [26]:
cased.keys()

dict_keys(['only_leafs_divided', 'leafs_and_no_leafs', 'simple_triplet_grandparent', 'only_child_leaf', 'simple_triplet_2parent', 'only_leafs_all'])

In [27]:
key

'P@15'

In [29]:
for key in cased.keys():
    n = len(cased[key]['pred'])

    total_str = ""
    for i in range(n):
        pred = cased[key]['pred'][i]
        gold = cased[key]['label'][i]
        term = cased[key]['term'][i]

        res, intersect = get_one_sample_metric(gold, pred, limit=50)

        res_str = ""
        for metric_key in res.keys():
            res_str += " " + str(metric_key) + " " + str(res[metric_key]) 


        total_str += (term + "\n\n" + 
            "predicted: " + pred + "\n \n" +
            "true: " + gold + "\n \n" +
            "intersection: " + ",".join(intersect) + "\n\n"+ 
            "metrics: " + res_str + " \n\n" + 
            "="*10 + "\n") 

    file_name = "babel_datasets/example_" + str(key) + "_vicuna_13b" + ".txt"  
    with open(file_name, "w") as f:
        f.write(total_str)

In [17]:
idx = 27
all_preds[idx], all_labels[idx], all_terms[idx]

('bungalow, chalet, cottage, dacha, farmhouse, mansion, manse, palace, ranch house,',
 'sod house, beach house, bungalow, chalet, chapterhouse, detached house, dollhouse, duplex house, farmhouse, guesthouse, hacienda, lodge, maisonette, ranch house, safe house, saltbox, solar house, tract house, villa',
 'hypernym: house.n.1 | hyponyms:')

In [24]:
len(all_preds), len(all_labels)

(796, 796)

In [13]:
all_labels = all_labels[:len(all_preds)]
all_preds = all_preds[:len(all_preds)]

In [18]:
metrics = get_all_metrics(all_labels, all_preds, limit=50)

In [19]:
metrics

{'MRR': 0.06706948188080267,
 'MAP': 0.06506133260517465,
 'P@1': 0.042767295597484274,
 'P@3': 0.05199161425576522,
 'P@5': 0.06062893081761008,
 'P@15': 0.0667036492508191}

In [46]:
preds = all_preds
golds = all_labels


def get_mean_first_pred_rank(golds, preds):
    limit = 30
    rank_first = []

    for i in range(len(golds)):
        goldline = golds[i]
        predline = preds[i]

        avg_pat1 = []
        avg_pat2 = []
        avg_pat3 = []
        avg_pat4 = []

        gold_hyps = get_hypernyms(goldline, is_gold=True, limit=limit)
        pred_hyps = get_hypernyms(predline, is_gold=False, limit=limit)
        gold_hyps_n = len(gold_hyps)
        r = [0 for i in range(limit)]

        for j in range(len(pred_hyps)):
            pred_hyp = pred_hyps[j]
            if pred_hyp in gold_hyps:
                r[j] = 1

        first_rank = np.argmax(r) + 1
        if np.sum(r) == 0:
            rank_first.append(0)
        else:
            rank_first.append(1/first_rank)

    return np.mean(rank_first)

In [52]:
r = [0,0,1,0,1]

In [22]:
metrics_cased ={}
mrr_cased = {}

for case in cased.keys():
    metric = get_all_metrics(cased[case]['label'], cased[case]['pred'], limit=50)
    metrics_cased[case] = metric


In [23]:
metrics_cased

{'only_leafs_divided': {'MRR': 0.09227582846003898,
  'MAP': 0.09198648963122646,
  'P@1': 0.07894736842105263,
  'P@3': 0.05701754385964912,
  'P@5': 0.06900584795321638,
  'P@15': 0.09594820384294067},
 'leafs_and_no_leafs': {'MRR': 0.13693552356343056,
  'MAP': 0.10310735442405565,
  'P@1': 0.05426356589147287,
  'P@3': 0.08914728682170546,
  'P@5': 0.09599483204134364,
  'P@15': 0.1061772207121044},
 'simple_triplet_grandparent': {'MRR': 0.015625,
  'MAP': 0.015625,
  'P@1': 0.015625,
  'P@3': 0.015625,
  'P@5': 0.015625,
  'P@15': 0.015625},
 'only_child_leaf': {'MRR': 0.006807511737089202,
  'MAP': 0.013427230046948357,
  'P@1': 0.004694835680751174,
  'P@3': 0.004694835680751174,
  'P@5': 0.014084507042253521,
  'P@15': 0.014084507042253521},
 'simple_triplet_2parent': {'MRR': 0.12384259259259262,
  'MAP': 0.20333333333333334,
  'P@1': 0.06944444444444445,
  'P@3': 0.16666666666666666,
  'P@5': 0.20833333333333334,
  'P@15': 0.20833333333333334},
 'only_leafs_all': {'MRR': 0.156

In [51]:
mrr_cased

{'only_leafs_divided': 0.06721921224630202,
 'leafs_and_no_leafs': 0.09295701385304668,
 'simple_triplet_grandparent': 0.030415034391155137,
 'only_child_leaf': 0.03619989590595837,
 'simple_triplet_2parent': 0.06252850725889941,
 'only_leafs_all': 0.0678095238095238}

In [35]:
idx = 3900
list(zip(all_preds[idx:idx+20], all_labels[idx:idx+20], all_terms[idx:idx+20]))

[('библиографическая запись, аудиозапись, запись на магнитофонном носителе, звукозаписывающая студия, фонограмма, видеозап',
  'лексическая единица, nolle prosequi, ноутбук запись',
  "Predict hyponyms for the word 'за́пись.n.1'.  Answer:"),
 ('ацетилглицеролы (синтетический полиненасыщенный углеводный эфир) в таблетках и капсулах формата',
  'бутырин',
  "Predict the hypernym for the word 'трибутирин.n.1' which is hyponyms for the word 'глицеринового эфира.n.1' at the same time. Answer:"),
 ('Orchis virginica variegata fimbriata uvariae-floridae, Orchis macrocarpa, orchis papilionacea, орх',
  'Ванда (растение)',
  "Predict the hypernym for the word 'blue orchid.n.1' which is hyponyms for the word 'орхидея.n.1' at the same time. Answer:"),
 ('хризолиты (хрусталь) изумруды и бриллианты из хризобериллитовой группы минералов и самоцветами',
  'авантюрин',
  "Predict the hypernym for the word 'авантюриновое стекло.n.1' which is hyponyms for the word 'прозрачные жемчужины.n.1' at the same 

In [36]:

idx = 3950
list(zip(all_preds[idx:idx+20], all_labels[idx:idx+20], all_terms[idx:idx+20]))

[('копейка, ливийский динар, фунт Ливии, пшеничный тулуп, шиллинг либерийских денежных единиц,',
  'Ливийский динар, ливийская дирхам',
  "Predict hyponyms for the word 'ливийская денежной единицы.n.1'.  Answer:"),
 ('негра́жанин США́йской нацио́наль\xadной па́зиции (США) (национальный состав) человек по происхожд',
  'испаноамериканцы',
  "Predict the hypernym for the word 'criollo людей.n.1' which is hyponyms for the word 'американцы США.n.1' at the same time. Answer:"),
 ('кхмерский фунт тайланда (квота) и тайский тайваньский донг (доллар США) равны по отношению к',
  'бат',
  "сатангов.n.1are hyponyms for the word 'тайский денежной единицы.n.1'. Predict other hyponyms for the word 'тайский денежной единицы.n.1'. Answer:"),
 ('кость (биология) животных твёрдой оболочки обонятельных тканей человека и древесины твердых оболочек мозга',
  'дентин',
  "Predict the hypernym for the word 'слоновая кость.n.1' which is hyponyms for the word 'материал животного происхождения.n.1' at the same

In [10]:
get_hypernyms(all_labels[3954]), get_hypernyms(all_preds[3954], is_gold=False)
#азрег нет такого слова
#куру тоже нет, есть только болезнь 

(['азрег', 'куру', 'турецкая лира'],
 ['фунт стерлингов',
  'турецкая лира',
  'динар (турецкий денежный знак)',
  'дирхам',
  'тюркские денежные знаки'])

In [19]:
print(all_terms[3957])
get_hypernyms(all_labels[3957]), get_hypernyms(all_preds[3957], is_gold=False)


зарегистрируйтесь языке.n.1are hyponyms for the word 'тональный язык.n.1'. Predict other hyponyms for the word 'тональный язык.n.1'. Answer:


(['контур языке'],
 ['англи́йский тональ́ный язы́к (языки программирования) фреймворк-кода языка c++ с открытым исходным'])

In [13]:

idx = 1550
list(zip(all_preds[idx:idx+20], all_labels[idx:idx+20], all_terms[idx:idx+20]))

[('flexibility, girdle, hip, knee, leg, neck, shoulder, stomach, torso, waist, wrist, foot, hand, head,',
  'address',
  "attention.n.6, erectness.n.1are hyponyms for the word 'stance.n.1'. Predict other hyponyms for the word 'stance.n.1'. Answer:"),
 ('potato salad dish with tomato sauce and mozzarella di bufala di pane raffermo fritto in olio di semi di girasole',
  'kabob',
  "Predict the hypernym for the word 'souvlaki.n.1' which is hyponyms for the word 'dish.n.2' at the same time. Answer:"),
 ('citronade, cranberry juice, ice-cream, kool-aid, mocha, soda, vanilla ice cream, watermelon ju',
  'lemonade',
  "limeade.n.1, orangeade.n.1are hyponyms for the word 'fruit drink.n.1'. Predict other hyponyms for the word 'fruit drink.n.1'. Answer:"),
 ("vajrayana sage of India's Mahayana tradition of Buddhism and Hinduism of the 11th and 12th centuries CE to the 6th century CE CEZone",
  'common sage',
  "Predict the hypernym for the word 'Bharadvaja.n.1' which is hyponyms for the word 'sa

In [25]:
idx = 1505

print(all_terms[idx])
get_hypernyms(all_labels[idx]), get_hypernyms(all_preds[idx], is_gold=False)


Predict common hyponyms for the words 'writer.n.1' and 'performer.n.1'. Answer:


(['matthew richter'],
 ['paul shaffer (musician)\nbutch hartman (musical director) (born august 23',
  '1948) is an american musician',
  'singer-songwriter',
  'bass guitarist',
  ''])

In [18]:
idx = 1552

print(all_terms[idx])
get_hypernyms(all_labels[idx]), get_hypernyms(all_preds[idx], is_gold=False)


limeade.n.1, orangeade.n.1are hyponyms for the word 'fruit drink.n.1'. Predict other hyponyms for the word 'fruit drink.n.1'. Answer:


(['lemonade'],
 ['citronade',
  'cranberry juice',
  'ice-cream',
  'kool-aid',
  'mocha',
  'soda',
  'vanilla ice cream',
  'watermelon ju'])