In [1]:
import re
import json
import argparse
import os
import sys
import torch
import pandas as pd
import numpy as np

from tqdm import tqdm
from nltk.stem import PorterStemmer
# from transformers import GPT2Tokenizer, GPT2Model
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login

from swisscom_ai.research_keyphrase.preprocessing.postagging import PosTaggingCoreNLP
from swisscom_ai.research_keyphrase.model.input_representation import InputTextObj
from swisscom_ai.research_keyphrase.model.extractor import extract_candidates

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
host = 'localhost'
port = 9000
pos_tagger = PosTaggingCoreNLP(host, port)

# load stopwords
stopwords = []
with open('UGIR_stopwords.txt', "r") as f:
    for line in f:
        if line:
            stopwords.append(line.replace('\n', ''))

stemmer = PorterStemmer()

def read_jsonl(path):
    data = []
    with open(path, 'r') as f:
        for line in f:
            item = json.loads(line.strip())
            data.append(item)
    return data

In [2]:
def get_candidates(core_nlp, text):
    tagged = core_nlp.pos_tag_raw_text(text)
    text_obj = InputTextObj(tagged, 'en')
    candidates = extract_candidates(text_obj)
    return candidates

def get_phrase_indices(text_tokens, phrase, prefix):
    text_tokens = [t.replace(prefix, '') for t in text_tokens]

    phrase = phrase.replace(' ', '')

    matched_indices = []
    matched_index = []
    target = phrase
    for i in range(len(text_tokens)):
        cur_token = text_tokens[i]
        sub_len = min(len(cur_token), len(phrase))
        if cur_token[:sub_len].lower() == target[:sub_len]:
            matched_index.append(i)
            target = target[sub_len:]
            if len(target) == 0:
                matched_indices.append([matched_index[0], matched_index[-1] + 1])
                target = phrase
        else:
            matched_index = []
            target = phrase
            if cur_token[:sub_len].lower() == target[:sub_len]:
                matched_index.append(i)
                target = target[sub_len:]
                if len(target) == 0:
                    matched_indices.append([matched_index[0], matched_index[-1] + 1])
                    target = phrase

    return matched_indices

def remove_repeated_sub_word(candidates_pos_dict):
    for phrase in candidates_pos_dict.keys():
        split_phrase = re.split(r'\s+|-', phrase)
        split_phrase = list(filter(None, split_phrase))
        if len(split_phrase) > 1:
            for word in split_phrase:
                if word in candidates_pos_dict:
                    single_word_positions = candidates_pos_dict[word]
                    phrase_positions = candidates_pos_dict[phrase]
                    single_word_alone_positions = [pos for pos in single_word_positions if not any(
                        pos[0] >= phrase_pos[0] and pos[1] <= phrase_pos[1] for phrase_pos in phrase_positions)]
                    candidates_pos_dict[word] = single_word_alone_positions

    return candidates_pos_dict

def get_all_indices(candidates_pos_dict, window_end):
    all_indices = []
    for phrase in candidates_pos_dict.values():
        for element in phrase:
            start_index = element[0]
            end_index = element[1]
            if(start_index>window_end):
                all_indices.extend(range(start_index, end_index))
    all_indices = sorted(all_indices)
    all_indices = list(dict.fromkeys(all_indices))

    return all_indices

def aggregate_phrase_scores(index_list, tokens_scores):
    total_score = 0.0

    for p_index in index_list:
        part_sum = tokens_scores[p_index[0]:p_index[1]].sum()
        total_score += part_sum

    return total_score

def get_score_full(candidates, references, maxDepth=15):
    precision = []
    recall = []
    reference_set = set(references)
    referencelen = len(reference_set)
    true_positive = 0
    for i in range(maxDepth):
        if len(candidates) > i:
            kp_pred = candidates[i]
            if kp_pred in reference_set:
                true_positive += 1
            precision.append(true_positive / float(i + 1))
            recall.append(true_positive / float(referencelen))
        else:
            precision.append(true_positive / float(len(candidates)))
            recall.append(true_positive / float(referencelen))
    return precision, recall


def evaluate_document(candidates, ground_truth):
    results = {}
    precision_scores, recall_scores, f1_scores = {5: [], 10: [], 15: []}, \
                                                 {5: [], 10: [], 15: []}, \
                                                 {5: [], 10: [], 15: []}
    for candidate, gt in zip(candidates, ground_truth):
        p, r = get_score_full(candidate, gt)
        for i in [5, 10, 15]:
            precision = p[i - 1]
            recall = r[i - 1]
            if precision + recall > 0:
                f1_scores[i].append((2 * (precision * recall)) / (precision + recall))
            else:
                f1_scores[i].append(0)
            precision_scores[i].append(precision)
            recall_scores[i].append(recall)

    print("########################\nMetrics")
    for i in precision_scores:
        print("@{}".format(i))
        print("F1:{}".format(np.mean(f1_scores[i])))
        print("P:{}".format(np.mean(precision_scores[i])))
        print("R:{}".format(np.mean(recall_scores[i])))

        top_n_p = 'precision@' + str(i)
        top_n_r = 'recall@' + str(i)
        top_n_f1 = 'f1@' + str(i)
        results[top_n_p] = np.mean(precision_scores[i])
        results[top_n_r] = np.mean(recall_scores[i])
        results[top_n_f1] = np.mean(f1_scores[i])
    print("#########################")

    return results

def evaluate_dataset(predicted_top, dataset, score_type, dataset_name):
    experiment_results = []
    gt_keyphrase_list = []
    predicted_keyphrase_list = []

    for i in range(len(dataset)):
        predicted_keyphrase = predicted_top[i]
        predicted_keyphrase = [phrase.lower() for phrase in predicted_keyphrase]
        predicted_keyphrase_list.append(predicted_keyphrase)

        gt_keyphrase = [key.lower() for key in dataset[i]['keyphrases']]
        gt_keyphrase_list.append(gt_keyphrase)

    total_score = evaluate_document(predicted_keyphrase_list, gt_keyphrase_list)
    experiment_results.append(total_score)

    df = pd.DataFrame(experiment_results)

    path = f'experiment_results/{dataset_name}/'
    os.makedirs(path, exist_ok=True)
    df.to_csv(f'{path}score_type_{score_type}.csv', index=False)

    top3_f1_5 = df.nlargest(3, 'f1@5').reset_index(drop=True)
    top3_f1_10 = df.nlargest(3, 'f1@10').reset_index(drop=True)
    top3_f1_15 = df.nlargest(3, 'f1@15').reset_index(drop=True)

    return top3_f1_5, top3_f1_10, top3_f1_15

<img src="hugging_token.png" width=1000px></img>

In [4]:
# You need to request permission to use the LLAMA 3 model with your Huggingface account
login(token="YOUR_TOKEN_HERE")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to C:\Users\user0\.cache\huggingface\token
Login successful


In [5]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B", attn_implementation="eager", output_attentions=True)
prefix = 'Ġ'

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 4/4 [00:39<00:00,  9.85s/it]


In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device: {device}")

device: cuda


<h2>DATASET: INSPEC</h2>

In [7]:
dataset_name = "inspec"
dataset = read_jsonl("KEYWORD_DATA/{}.jsonl".format(dataset_name))

SAMRANK BASE

In [9]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["text"]

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)

        ###############################################################
        # ATTENTION MEASSUREMENT
        attentions = sum(attentions)/len(attentions)
        attentions = attentions.squeeze(0)
        att_scores = attentions.mean(0).sum(0)
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)
            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "SAMRANK_BASE", dataset_name)

100%|██████████| 500/500 [01:52<00:00,  4.45it/s]

########################
Metrics
@5
F1:0.3424774238136697
P:0.4784
R:0.29396447932697767
@10
F1:0.3818415536037738
P:0.38055952380952385
R:0.4318405058729329
@15
F1:0.381147979499465
P:0.32708260073260065
R:0.5150753282616488
#########################





BASE + RELEVANCE $S^{lh}$

In [10]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["text"]

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_S", dataset_name)

100%|██████████| 500/500 [02:47<00:00,  2.98it/s]

########################
Metrics
@5
F1:0.3498333697425259
P:0.4888
R:0.30044397737152145
@10
F1:0.3929725131711594
P:0.3909595238095238
R:0.4453400418288929
@15
F1:0.38994231332502227
P:0.33468260073260075
R:0.5273217643447722
#########################





BASE + RELEVANCE $R^{lh}$

In [12]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["text"]

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens,len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += lh_weight.mean(0)*crrn_att_map
        att_scores = attentions.mean(0)
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_R", dataset_name)

  0%|          | 0/500 [00:00<?, ?it/s]

100%|██████████| 500/500 [02:57<00:00,  2.82it/s]

########################
Metrics
@5
F1:0.34436109229542
P:0.48079999999999995
R:0.29599482024210116
@10
F1:0.38538516030017866
P:0.3837595238095238
R:0.43613373026991953
@15
F1:0.3823815661162578
P:0.32828260073260074
R:0.5166368833732039
#########################





BASE + $S^{lh}$ FILTERING

In [14]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["text"]

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens,len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                crrn_att_map[~mask] = 0
                attentions += crrn_att_map
        att_scores = attentions.mean(0)
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_S_FILT", dataset_name)

  0%|          | 0/500 [00:00<?, ?it/s]

100%|██████████| 500/500 [02:39<00:00,  3.13it/s]

########################
Metrics
@5
F1:0.3479964132361266
P:0.4852
R:0.2993481196632496
@10
F1:0.39542657022266414
P:0.3927595238095238
R:0.44848639637733734
@15
F1:0.3908234559524791
P:0.3356159340659341
R:0.5280488252985057
#########################





BASE + RELEVANCES $S^{lh}*R^{lh}$

In [15]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["text"]

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += lh_weight.mean(0)*torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_SR", dataset_name)

100%|██████████| 500/500 [03:14<00:00,  2.57it/s]

########################
Metrics
@5
F1:0.3506669211059857
P:0.488
R:0.30176278832791825
@10
F1:0.3956080429173188
P:0.39355952380952386
R:0.44838545588049894
@15
F1:0.3904686107651664
P:0.33508260073260077
R:0.5282595081511222
#########################





BASE + RELEVANCES $S^{lh}*R^{lh}$ + $S^{lh}$ FILTERING

In [16]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["text"]

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                crrn_att_map[~mask] = 0
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += lh_weight.mean(0)*torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_SR_S_FILT", dataset_name)

100%|██████████| 500/500 [03:48<00:00,  2.18it/s]

########################
Metrics
@5
F1:0.3523239070864082
P:0.48960000000000004
R:0.30404242617373256
@10
F1:0.4021599576060729
P:0.39935952380952383
R:0.4561671422998015
@15
F1:0.3948088224297586
P:0.33934926739926735
R:0.5334289439326322
#########################





FINAL ATTENTIONSEEKER

In [17]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["text"]

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                crrn_att_map[~mask] = 0
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += lh_weight.mean(0)*torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        # LHC-SEEKER
        f_att_scores = torch.zeros_like(att_scores)
        f_att_scores[mask] = att_scores[mask]
        # NEW ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                crrn_att_map[~mask] = 0
                lh_weight = torch.matmul(crrn_att_map,f_att_scores)
                attentions += lh_weight.mean(0)*torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "ATTENTION_SEEKER", dataset_name)

100%|██████████| 500/500 [05:31<00:00,  1.51it/s]

########################
Metrics
@5
F1:0.35490195392637314
P:0.49440000000000006
R:0.30500898385626474
@10
F1:0.40136623714102004
P:0.3987595238095238
R:0.45497844248364167
@15
F1:0.39220319615918514
P:0.33681593406593413
R:0.5299187073369938
#########################





<h2>DATASET: SEMEVAL 2017</h2>

In [18]:
dataset_name = "semeval2017"
dataset = read_jsonl("KEYWORD_DATA/{}.jsonl".format(dataset_name))

SAMRANK BASE

In [19]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["text"]

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)

        ###############################################################
        # ATTENTION MEASSUREMENT
        attentions = sum(attentions)/len(attentions)
        attentions = attentions.squeeze(0)
        att_scores = attentions.mean(0).sum(0)
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)
            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "SAMRANK_BASE", dataset_name)

100%|██████████| 493/493 [02:26<00:00,  3.36it/s]

########################
Metrics
@5
F1:0.24742282955536773
P:0.5168356997971603
R:0.1700399499230244
@10
F1:0.3350504558633974
P:0.4480730223123732
R:0.2851300385057317
@15
F1:0.3701308557000764
P:0.3989408495493688
R:0.3718292969857578
#########################





BASE + RELEVANCE $S^{lh}$

In [20]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["text"]

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_S", dataset_name)

100%|██████████| 493/493 [03:27<00:00,  2.37it/s]

########################
Metrics
@5
F1:0.25020787276333684
P:0.5237322515212981
R:0.1716076394829049
@10
F1:0.34384281706024733
P:0.459026369168357
R:0.29326315596933095
@15
F1:0.38135343215622675
P:0.4105703289273269
R:0.383159126693229
#########################





BASE + RELEVANCE $R^{lh}$

In [21]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["text"]

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens,len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += lh_weight.mean(0)*crrn_att_map
        att_scores = attentions.mean(0)
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_R", dataset_name)

  0%|          | 0/493 [00:00<?, ?it/s]

100%|██████████| 493/493 [03:42<00:00,  2.22it/s]

########################
Metrics
@5
F1:0.24854724252405688
P:0.5204868154158215
R:0.17053791835533402
@10
F1:0.33788381149682867
P:0.45152129817444225
R:0.2877299963257747
@15
F1:0.37105733717899136
P:0.3994817555669482
R:0.373022554882714
#########################





BASE + $S^{lh}$ FILTERING

In [22]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["text"]

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens,len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                crrn_att_map[~mask] = 0
                attentions += crrn_att_map
        att_scores = attentions.mean(0)
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_S_FILT", dataset_name)

  0%|          | 0/493 [00:00<?, ?it/s]

100%|██████████| 493/493 [03:23<00:00,  2.42it/s]

########################
Metrics
@5
F1:0.24998754161570974
P:0.5221095334685598
R:0.17157882654792983
@10
F1:0.34298342658445324
P:0.45862068965517244
R:0.2919147327160832
@15
F1:0.37628145434871557
P:0.40516126875153247
R:0.3782410630340131
#########################





BASE + RELEVANCES $S^{lh}*R^{lh}$

In [23]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["text"]

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += lh_weight.mean(0)*torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_SR", dataset_name)

  0%|          | 0/493 [00:00<?, ?it/s]

100%|██████████| 493/493 [03:51<00:00,  2.13it/s]

########################
Metrics
@5
F1:0.2525391606695459
P:0.5286004056795133
R:0.17325493270907552
@10
F1:0.34574394921860424
P:0.46227180527383366
R:0.2944868684142313
@15
F1:0.3816535964336744
P:0.41084078193611656
R:0.3835753747327744
#########################





BASE + RELEVANCES $S^{lh}*R^{lh}$ + $S^{lh}$ FILTERING

In [24]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["text"]

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                crrn_att_map[~mask] = 0
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += lh_weight.mean(0)*torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_SR_S_FILT", dataset_name)

100%|██████████| 493/493 [04:18<00:00,  1.91it/s]

########################
Metrics
@5
F1:0.2526717018835288
P:0.5286004056795133
R:0.1734006263042143
@10
F1:0.347774377022352
P:0.4645030425963489
R:0.2959501223321395
@15
F1:0.3832744924906041
P:0.4125987264932498
R:0.3853198784209179
#########################





FINAL ATTENTIONSEEKER

In [25]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["text"]

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                crrn_att_map[~mask] = 0
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += lh_weight.mean(0)*torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        # ATTENTION SEEKER (LHC-SEEKER)
        f_att_scores = torch.zeros_like(att_scores)
        f_att_scores[mask] = att_scores[mask]
        # NEW ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                crrn_att_map[~mask] = 0
                lh_weight = torch.matmul(crrn_att_map,f_att_scores)
                attentions += lh_weight.mean(0)*torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "ATTENTION_SEEKER", dataset_name)

100%|██████████| 493/493 [06:09<00:00,  1.33it/s]

########################
Metrics
@5
F1:0.25397632889673977
P:0.5330628803245435
R:0.1741208725412833
@10
F1:0.345326367943574
P:0.46186612576064906
R:0.2942509747499267
@15
F1:0.3849912207748741
P:0.4150328035723573
R:0.38667026905240437
#########################





<h2>DATASET: SEMEVAL 2010</h2>

In [30]:
def evaluate_dataset(predicted_top, dataset, score_type, dataset_name):
    experiment_results = []
    gt_keyphrase_list = []
    predicted_keyphrase_list = []

    for i in range(len(dataset)):
        predicted_keyphrase = predicted_top[i]
        predicted_keyphrase = [phrase.lower() for phrase in predicted_keyphrase]
        predicted_keyphrase_list.append(predicted_keyphrase)

        stemmed_gt_keyphrases = [" ".join(stemmer.stem(word) for word in phrase.split()) for phrase in dataset[i]['keywords'].split(";")]
        gt_keyphrase = [key.lower() for key in stemmed_gt_keyphrases]
        gt_keyphrase_f = list(dict.fromkeys(gt_keyphrase))
        gt_keyphrase_list.append(gt_keyphrase_f)

    total_score = evaluate_document(predicted_keyphrase_list, gt_keyphrase_list)
    experiment_results.append(total_score)

    df = pd.DataFrame(experiment_results)

    path = f'experiment_results/{dataset_name}/'
    os.makedirs(path, exist_ok=True)
    df.to_csv(f'{path}score_type_{score_type}.csv', index=False)

    top3_f1_5 = df.nlargest(3, 'f1@5').reset_index(drop=True)
    top3_f1_10 = df.nlargest(3, 'f1@10').reset_index(drop=True)
    top3_f1_15 = df.nlargest(3, 'f1@15').reset_index(drop=True)

    return top3_f1_5, top3_f1_10, top3_f1_15

In [27]:
dataset_name = "semeval_test"
dataset = read_jsonl("KEYWORD_DATA/{}.json".format(dataset_name))

SAMRANK BASE

In [31]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["title"] + ". " + data["abstract"] 

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)

        ###############################################################
        # ATTENTION MEASSUREMENT
        attentions = sum(attentions)/len(attentions)
        attentions = attentions.squeeze(0)
        att_scores = attentions.mean(0).sum(0)
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)
            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "SAMRANK_BASE", dataset_name)

100%|██████████| 100/100 [00:27<00:00,  3.62it/s]

########################
Metrics
@5
F1:0.16804899041141827
P:0.336
R:0.11333652014382292
@10
F1:0.20259379173698572
P:0.255
R:0.1714168653644571
@15
F1:0.20942092450444147
P:0.212
R:0.21200038582863637
#########################





BASE + RELEVANCE $S^{lh}$

In [32]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["title"] + ". " + data["abstract"] 

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_S", dataset_name)

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:40<00:00,  2.48it/s]

########################
Metrics
@5
F1:0.15770701531527706
P:0.31400000000000006
R:0.10651850827715621
@10
F1:0.20369017365807
P:0.256
R:0.17241906637981608
@15
F1:0.21432769066704746
P:0.21733333333333335
R:0.21672575528676447
#########################





BASE + RELEVANCE $R^{lh}$

In [33]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["title"] + ". " + data["abstract"] 

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens,len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += lh_weight.mean(0)*crrn_att_map
        att_scores = attentions.mean(0)
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_R", dataset_name)

100%|██████████| 100/100 [00:43<00:00,  2.32it/s]

########################
Metrics
@5
F1:0.1678945892570171
P:0.336
R:0.11319719133537648
@10
F1:0.206035820722493
P:0.259
R:0.17443884338643514
@15
F1:0.21259616283383803
P:0.21533333333333335
R:0.21514075997354773
#########################





BASE + $S^{lh}$ FILTERING

In [34]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["title"] + ". " + data["abstract"] 

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens,len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                crrn_att_map[~mask] = 0
                attentions += crrn_att_map
        att_scores = attentions.mean(0)
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_S_FILT", dataset_name)

100%|██████████| 100/100 [00:39<00:00,  2.52it/s]

########################
Metrics
@5
F1:0.15902666582300806
P:0.31600000000000006
R:0.10749493522997637
@10
F1:0.19803418204244352
P:0.24900000000000003
R:0.16761874784729935
@15
F1:0.21003373021537095
P:0.2126666666666667
R:0.21280103433970443
#########################





BASE + RELEVANCES $S^{lh}*R^{lh}$

In [35]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["title"] + ". " + data["abstract"] 

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += lh_weight.mean(0)*torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_SR", dataset_name)

100%|██████████| 100/100 [00:45<00:00,  2.21it/s]

########################
Metrics
@5
F1:0.16634798367821343
P:0.332
R:0.11222278499459083
@10
F1:0.20407429046718678
P:0.257
R:0.1726155325235439
@15
F1:0.21361158046817866
P:0.21666666666666667
R:0.2159473670083762
#########################





BASE + RELEVANCES $S^{lh}*R^{lh}$ + $S^{lh}$ FILTERING

In [37]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["title"] + ". " + data["abstract"] 

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                crrn_att_map[~mask] = 0
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += lh_weight.mean(0)*torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_SR_S_FILT", dataset_name)

100%|██████████| 100/100 [00:50<00:00,  1.97it/s]

########################
Metrics
@5
F1:0.1582586832335152
P:0.31400000000000006
R:0.10702061216741802
@10
F1:0.20404853688011532
P:0.256
R:0.17290295309218237
@15
F1:0.21523352216200578
P:0.218
R:0.2179959278964401
#########################





FINAL ATTENTIONSEEKER

In [39]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["title"] + ". " + data["abstract"] 

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                crrn_att_map[~mask] = 0
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += lh_weight.mean(0)*torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        # ATTENTION SEEKER (LHC-SEEKER)
        f_att_scores = torch.zeros_like(att_scores)
        f_att_scores[mask] = att_scores[mask]
        # NEW ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                crrn_att_map[~mask] = 0
                lh_weight = torch.matmul(crrn_att_map,f_att_scores)
                attentions += lh_weight.mean(0)*torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "ATTENTION_SEEKER", dataset_name)

100%|██████████| 100/100 [01:14<00:00,  1.34it/s]

########################
Metrics
@5
F1:0.16740832744073106
P:0.332
R:0.1131913534058837
@10
F1:0.2025983409213736
P:0.255
R:0.17138670111243687
@15
F1:0.21728691363138247
P:0.22
R:0.22018888532215616
#########################





<h2>DATASET: KRAPIVIN</h2>

In [41]:
dataset_name = "krapivin_test"
dataset = read_jsonl("KEYWORD_DATA/{}.json".format(dataset_name))

SAMRANK BASE

In [42]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["title"] + " " + data["abstract"] 

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)

        ###############################################################
        # ATTENTION MEASSUREMENT
        attentions = sum(attentions)/len(attentions)
        attentions = attentions.squeeze(0)
        att_scores = attentions.mean(0).sum(0)
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)
            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "SAMRANK_BASE", dataset_name)

100%|██████████| 460/460 [02:02<00:00,  3.75it/s]

########################
Metrics
@5
F1:0.16375814823871937
P:0.1782608695652174
R:0.1752103975895031
@10
F1:0.16440910422311278
P:0.13726363008971706
R:0.24998524105865272
@15
F1:0.1478211358912224
P:0.1101181306616089
R:0.2826974604044197
#########################





BASE + RELEVANCE $S^{lh}$

In [43]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["title"] + " " + data["abstract"] 

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_S", dataset_name)

  0%|          | 0/460 [00:00<?, ?it/s]

100%|██████████| 460/460 [03:11<00:00,  2.41it/s]


########################
Metrics
@5
F1:0.1669871489645905
P:0.18130434782608698
R:0.17880221426955512
@10
F1:0.1655392655643261
P:0.13878536922015186
R:0.2506132239625139
@15
F1:0.15072540189562078
P:0.11229204370508718
R:0.2880556867789335
#########################


BASE + RELEVANCE $R^{lh}$

In [44]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["title"] + " " + data["abstract"] 

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens,len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += lh_weight.mean(0)*crrn_att_map
        att_scores = attentions.mean(0)
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_R", dataset_name)

  0%|          | 0/460 [00:00<?, ?it/s]

100%|██████████| 460/460 [03:10<00:00,  2.42it/s]


########################
Metrics
@5
F1:0.169861965918624
P:0.18478260869565216
R:0.18121893011977477
@10
F1:0.16475416194317052
P:0.1376984126984127
R:0.25033163982896456
@15
F1:0.14910143006310697
P:0.11098769587900022
R:0.2851054126397376
#########################


BASE + $S^{lh}$ FILTERING

In [45]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["title"] + " " + data["abstract"] 

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens,len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                crrn_att_map[~mask] = 0
                attentions += crrn_att_map
        att_scores = attentions.mean(0)
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_S_FILT", dataset_name)

100%|██████████| 460/460 [02:55<00:00,  2.62it/s]

########################
Metrics
@5
F1:0.16587551892437125
P:0.18130434782608698
R:0.1769954060592942
@10
F1:0.1667251893112521
P:0.13965493443754312
R:0.2525436628047175
@15
F1:0.15244591514409064
P:0.11374131906740603
R:0.2904122775565769
#########################





BASE + RELEVANCES $S^{lh}*R^{lh}$

In [46]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["title"] + " " + data["abstract"] 

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += lh_weight.mean(0)*torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_SR", dataset_name)

  0%|          | 0/460 [00:00<?, ?it/s]

100%|██████████| 460/460 [03:24<00:00,  2.25it/s]

########################
Metrics
@5
F1:0.16919854130335382
P:0.1839130434782609
R:0.18099855657459304
@10
F1:0.1657912395423231
P:0.13900276052449967
R:0.2512477524774772
@15
F1:0.15249178937857105
P:0.11345146399494226
R:0.2917047551019149
#########################





BASE + RELEVANCES $S^{lh}*R^{lh}$ + $S^{lh}$ FILTERING

In [47]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["title"] + " " + data["abstract"] 

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        #ATTENTION-SEEKER
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                crrn_att_map[~mask] = 0
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += lh_weight.mean(0)*torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "BASE_SR_S_FILT", dataset_name)

  0%|          | 0/460 [00:00<?, ?it/s]

100%|██████████| 460/460 [03:48<00:00,  2.01it/s]

########################
Metrics
@5
F1:0.16635494097052358
P:0.18173913043478263
R:0.1776042616710102
@10
F1:0.16724658727444183
P:0.14052449965493444
R:0.25259002512981854
@15
F1:0.15453340498692653
P:0.11548044950218862
R:0.2936556284086234
#########################





FINAL ATTENTIONSEEKER

In [48]:
model.to(device)
model.eval()

dataset_att_scores_overall = []

for data in tqdm(dataset):
    with torch.no_grad():
        input_text = data["title"] + " " + data["abstract"] 

        tokenized_content = tokenizer(input_text, return_tensors='pt')
        outputs = model(**tokenized_content.to(device))
        content_tokens = tokenizer.convert_ids_to_tokens(tokenized_content['input_ids'].squeeze(0))

        all_attentions = outputs.attentions
        del outputs

        candidates = get_candidates(pos_tagger, input_text)
        candidates = [phrase for phrase in candidates if phrase.split(' ')[0] not in stopwords]
        candidates_indices = {}
        for phrase in candidates:
            matched_indices = get_phrase_indices(content_tokens, phrase, prefix)
            if len(matched_indices) == 0:
                continue
            candidates_indices[phrase] = matched_indices
        candidates_indices = remove_repeated_sub_word(candidates_indices)
        all_indices = get_all_indices(candidates_indices,0)

        ###############################################################
        len_t_tokens = all_attentions[0].squeeze(0)[0].shape[0]
        all_indices_tensor = torch.arange(len_t_tokens)
        mask = torch.isin(all_indices_tensor, torch.tensor(all_indices)).to(device)
        mask_1 = mask*1.0
        # ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                crrn_att_map[~mask] = 0
                lh_weight = torch.matmul(crrn_att_map,mask_1)
                attentions += lh_weight.mean(0)*torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        # ATTENTION SEEKER (LHC-SEEKER)
        f_att_scores = torch.zeros_like(att_scores)
        f_att_scores[mask] = att_scores[mask]
        # NEW ATTENTION MEASSUREMENT
        attentions = torch.zeros(len_t_tokens).to(device)
        for layer in range(len(all_attentions)):
            for head in range(32):
                crrn_att_map = all_attentions[layer].squeeze(0)[head].clone()
                crrn_att_map[~mask] = 0
                lh_weight = torch.matmul(crrn_att_map,f_att_scores)
                attentions += lh_weight.mean(0)*torch.matmul(lh_weight,crrn_att_map)
        att_scores = attentions
        att_scores[0] = 0
        ###############################################################

        phrase_score_dict = {}
        for phrase in candidates_indices.keys():
            try:
                phrase_indices = candidates_indices[phrase]
                if len(phrase_indices) == 0:
                    continue
            except KeyError:
                continue

            final_phrase_score = aggregate_phrase_scores(phrase_indices, att_scores)

            if len(phrase.split()) == 1:
                final_phrase_score = final_phrase_score / len(phrase_indices)

            phrase_score_dict[phrase] = final_phrase_score

    sorted_scores_att_o_s = sorted(phrase_score_dict.items(), key=lambda item: item[1], reverse=True)
    stemmed_sorted_scores_att_o_s = [(" ".join(stemmer.stem(word) for word in phrase.split()), score) for
                                phrase, score in sorted_scores_att_o_s]

    set_stemmed_scores_list_att_o_s = []
    for phrase, score in stemmed_sorted_scores_att_o_s:
        if phrase not in set_stemmed_scores_list_att_o_s:
            set_stemmed_scores_list_att_o_s.append(phrase)

    pred_stemmed_phrases_att_o_s = set_stemmed_scores_list_att_o_s[:15]
    dataset_att_scores_overall.append(pred_stemmed_phrases_att_o_s)

att_o_s_top3_f1_5, att_o_s_top3_f1_10, att_o_s_top3_f1_15 = evaluate_dataset(dataset_att_scores_overall, dataset, "ATTENTION_SEEKER", dataset_name)

  0%|          | 0/460 [00:00<?, ?it/s]

100%|██████████| 460/460 [05:31<00:00,  1.39it/s]

########################
Metrics
@5
F1:0.1714048900614592
P:0.18739130434782614
R:0.1826724472040127
@10
F1:0.16890060982697894
P:0.1416114561766736
R:0.2555353248085278
@15
F1:0.15270308531482443
P:0.11374131906740603
R:0.29166143276639306
#########################



