In [None]:
import pandas as pd
import ast
import itertools
from collections import Counter
import matplotlib.pyplot as plt
from collections import defaultdict
from utils import calculate_bertscore_f1, get_synonyms, preprocess_text
import numpy as np

In [None]:
gt_labels = pd.read_csv("postprocessed_labels.csv")
gt_labels["most_granular_concept_synonyms"] = gt_labels["most_granular_concept_synonyms"].apply(ast.literal_eval)
merged = list(itertools.chain(*gt_labels["most_granular_concept_synonyms"].map(lambda row: [col[0] for col in row if col]).to_list()))

In [None]:
total_columns = 0
for table in gt_labels["most_granular_concept_synonyms"]:
    for column in table:
        if len(column) > 0:
            total_columns += 1
print(total_columns)

In [None]:
def identify_common_by_bertscore(current_word: str, all_words: list[str]) -> str | None:
    bertscores = calculate_bertscore_f1(current_word, all_words)
    max_index = np.argmax(bertscores)
    if bertscores[max_index] > 0.8:
        return all_words[max_index]
    return None

def identify_common_by_nltk(current_word: str, all_words: list[str]) -> str | None:
    print(f"attempting to match {current_word} with {all_words}")
    synonym_sets_of_current_word = get_synonyms(preprocess_text(current_word, True))
    for word in all_words:
        if len(word.split(" ")) != 1:
            continue
        if any(
            synonym_sets_of_current_word.intersection(get_synonyms(preprocess_text(word, True)))
        ):
            return word
    return None
    
def group_common_words(labels: list[str]) -> dict[str, list[str]]:
    tracker = defaultdict(list)
    for label in labels:
        # Initial start
        if len(tracker) == 0:
            tracker[label].append(label)
        else:
            current_root_concepts = list(tracker.keys())
            if len(label.split(" ")) == 1 and (
                most_similar_word := identify_common_by_nltk(label, current_root_concepts)
            ):
                print(f"identify match using nltk {label} with {most_similar_word}")
                tracker[most_similar_word].append(label)
            elif most_similar_word := identify_common_by_bertscore(label, current_root_concepts):
                print(f"identify match using bertscore {label} with {most_similar_word}")
                tracker[most_similar_word].append(label)
            else:
                # Set label as a new root word if
                tracker[label].append(label)
    return tracker

In [None]:
tracker = group_common_words(merged)

In [None]:
sorted_concept_counts = sorted([(key, len(val)) for key, val in tracker.items()], key=lambda x: -x[1])
keys, values = list(zip(*sorted_concept_counts))
print(keys[:20], values[:20])
ticks = range(20)
plt.bar(ticks,values[:20], align='center')
plt.xticks(ticks, keys[:20], rotation=90)

In [None]:
len(keys)

In [None]:
tracker