In [44]:
domain = 'laptop'

In [45]:
import os
from collections import Counter
from concurrent.futures import ThreadPoolExecutor, as_completed
from glob import glob
from operator import itemgetter
from typing import List, Tuple

import fasttext
import nltk
import numpy as np
import torch
from nltk.corpus import stopwords
from sklearn.metrics.pairwise import cosine_similarity
from tag_utils import Annotate
from tqdm import tqdm

nltk.download('stopwords')
model = fasttext.load_model('../cc.en.300.bin')
def process_ann(idx: int, sentence: str):
    result = {}
    for score, mention, entity_title, entity_id, uri in Annotate(sentence, theta=0.05).values():
        if entity_title in result:
            if len(result[entity_title]) < len(mention):
                result[entity_title] = mention
        else:
            result[entity_title] = mention
    return idx, result



entities_tuple_list = []
lines = []
# get all sentences in a domain
for file in glob(os.path.join("../data", f"{domain}.*.txt")):
    lines.extend([line.split("***") for line in open(file).read().splitlines()])
with ThreadPoolExecutor(max_workers=256) as t:
    for future in tqdm(as_completed(
        [t.submit(process_ann, idx, line[0]) for idx, line in enumerate(lines)]),
                       total=len(lines),
                       desc=f"Extracting eitities in {domain} domain"):
        entities_tuple_list.append(future.result())
# entities_tuple_list = sorted(entities_tuple_list, key=lambda item: item[0])
entities_dict_list = [entities_dict for _, entities_dict in entities_tuple_list]
# remove stopwords
sets = stopwords.words('english')
entities = [word for entities_dict in entities_dict_list for word in entities_dict.values() if word not in sets]
counters = Counter(entities)
sorted_entities: List[Tuple[str, int]] = sorted(filter(lambda item: item[1] > 0, counters.items()),
                                                key=lambda item: item[1],
                                                reverse=True)
vec_dict = {}
for entity in tqdm(sorted_entities, total=len(counters)):
    e = entity[0]
    vec_dict[e] = model.get_word_vector(e)
getter = itemgetter(*[entity for entity, _ in sorted_entities[:10]])
mean_vec = np.average(getter(vec_dict),
                      axis=0,
                      weights=[count for _, count in sorted_entities[:10]])
res = np.array([
        cosine_similarity(mean_vec.reshape(1, -1), vec_dict[k].reshape(1, -1)) for k in vec_dict
    ]).squeeze()

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
Extracting eitities in laptop domain: 100%|██████████| 3845/3845 [00:42<00:00, 91.38it/s] 
100%|██████████| 3685/3685 [00:00<00:00, 38780.70it/s]


In [56]:
aspects = []
for line in lines:
    text, labels = line
    tokens = text.split()
    labels = labels.split()
    aspect = []
    for idx, label in enumerate(labels):
        if label != 'O':
            aspect.append(tokens[idx])
            continue
        if aspect:
            aspects.append(' '.join(aspect))
            aspect.clear()
len(set(aspects))

1275

In [57]:
res.shape[0]

3685

In [55]:
topk = torch.topk(torch.from_numpy(res), int(0.6 * res.shape[0])).indices.tolist()
k = len([sim for sim in itemgetter(*topk)(res) if sim > 0.1])
topk = torch.topk(torch.from_numpy(res), k).indices.tolist()
len(set(aspects).intersection(set([item[0] for item in itemgetter(*topk)(sorted_entities)]))) / len(
    set(aspects))

0.37176470588235294

(0.13363008265910653,
 0.1336102307362878,
 0.13357236294255548,
 0.13354637634459182,
 0.13354420182916996,
 0.13353221020005662,
 0.1335151878377501,
 0.13340458987022863,
 0.13337646690644944,
 0.13328786582825397,
 0.13316617281089727,
 0.1330690857457708,
 0.13300306251251612,
 0.13296010141207815,
 0.13295992325112432,
 0.13291151311877294,
 0.13288638409244163,
 0.13288321483997764,
 0.13271120979978165,
 0.1326476574912433,
 0.1325895697942457,
 0.13253645800368508,
 0.13249097367669269,
 0.1324701927441511,
 0.13243493946764562,
 0.1323926639258604,
 0.13235951970584223,
 0.13212959008538655,
 0.132072589003778,
 0.13206533705279075,
 0.132023038897007,
 0.1319693406335541,
 0.13192982158114153,
 0.13188123941858815,
 0.13187393095376168,
 0.13186263672039683,
 0.13182613452411665,
 0.13181609581451298,
 0.13170296035181173,
 0.1316353726835832,
 0.13160814433666035,
 0.13160472829731387,
 0.13152143551994322,
 0.1314968746247112,
 0.1313857076582732,
 0.13127572653388025,
 0.

In [53]:
len(set(aspects).intersection(set(entities)))/len(set(aspects))

0.49411764705882355

In [19]:
data = open(f"../data/{domain}.test.txt", "r").read().splitlines()
with open("/root/autodl-tmp/out/bert_base/laptop-rest/predict.txt", "r") as f:
    pre_aspects = []
    for i, line in enumerate(f):
        _, pre, labels = line.split("***")
        tokens = data[i].split("***")[0]
        pres = pre.split()
        labels = labels.strip().split()
        pre_aspect = []
        for idx, label in enumerate(pres):
            if label != 'O':
                pre_aspect.append(tokens[idx])
                continue
            if pre_aspect:
                pre_aspects.append(' '.join(pre_aspect))
                pre_aspect.clear()
len(set(pre_aspects))

282

In [21]:
len(set(pre_aspects).intersection(set(entities)))/len(set(pre_aspects))

0.0