In [2]:
from config.config_utils import set_seed
from config.parse_config import ConfigParser
from experiment.runner.nc_run import run, init_data_loader


In [4]:
from config.configuration import Configuration
from config.default_config import arch_default_config

arch_type = "BiAttentionClassifyModel"
kwargs = {"arch_config": arch_default_config(arch_type)}
config_parser = ConfigParser(Configuration(**kwargs),
                             {"arch_config": {"head_num": 100}, "optimizer_config": {"lr": 1e-4},
                              "trainer_config": {"epochs": 10}})
config = config_parser.config
config.update_sub_config("data_config", name="MIND15/keep_all", max_length=512)
config.set("seed", 42)
data_loader = init_data_loader(config_parser)
log = {"arch_type": config.arch_config["type"], "seed": config.seed,
       "variant_name": config.arch_config.get("variant_name", None), "#Voc": len(data_loader.word_dict)}
set_seed(log["seed"])
trainer = run(config_parser, data_loader)

BiAttentionClassifyModel(
  (embedding): Embedding(251810, 300)
  (classifier): Linear(in_features=300, out_features=15, bias=True)
  (final): Linear(in_features=300, out_features=300, bias=True)
  (topic_layer): Sequential(
    (0): Linear(in_features=300, out_features=2000, bias=True)
    (1): Tanh()
    (2): Linear(in_features=2000, out_features=100, bias=True)
  )
  (projection): AttLayer(
    (attention): Sequential(
      (0): Linear(in_features=300, out_features=128, bias=True)
      (1): Tanh()
      (2): Linear(in_features=128, out_features=1, bias=True)
      (3): Flatten(start_dim=1, end_dim=-1)
      (4): Softmax(dim=-1)
    )
  )
)
Trainable params: 76,478,572
Freeze params: 0
load device cuda:0


Train Epoch: 1 Loss: 1.0754542350769043: 100%|██████████| 3208/3208 [03:29<00:00, 15.30it/s] 
100%|██████████| 401/401 [00:05<00:00, 66.94it/s]


    epoch          : 1
    loss           : 1.073741
    accuracy       : 0.679257
    macro_f        : 0.396104
    doc_entropy    : 3.618213
    val_loss       : 0.800217
    val_accuracy   : 0.738659
    val_macro_f    : 0.51498
    val_doc_entropy: 4.015196
    monitor_best   : 0.738659
    seed           : 42
                   : None
    run_name       : News26/keep_all/BiAttentionClassifyModel


Train Epoch: 2 Loss: 0.8115770220756531: 100%|██████████| 3208/3208 [03:34<00:00, 14.96it/s] 
100%|██████████| 401/401 [00:06<00:00, 64.96it/s]


    epoch          : 2
    loss           : 0.722682
    accuracy       : 0.762809
    macro_f        : 0.560846
    doc_entropy    : 4.134907
    val_loss       : 0.699033
    val_accuracy   : 0.766485
    val_macro_f    : 0.577824
    val_doc_entropy: 4.268788
    monitor_best   : 0.766485
    seed           : 42
                   : None
    run_name       : News26/keep_all/BiAttentionClassifyModel


Train Epoch: 3 Loss: 0.7625419497489929: 100%|██████████| 3208/3208 [03:35<00:00, 14.85it/s] 
100%|██████████| 401/401 [00:07<00:00, 55.44it/s]


    epoch          : 3
    loss           : 0.63884
    accuracy       : 0.786016
    macro_f        : 0.597991
    doc_entropy    : 4.177661
    val_loss       : 0.656384
    val_accuracy   : 0.779501
    val_macro_f    : 0.601163
    val_doc_entropy: 4.085182
    monitor_best   : 0.779501
    seed           : 42
                   : None
    run_name       : News26/keep_all/BiAttentionClassifyModel


Train Epoch: 4 Loss: 0.5409938097000122: 100%|██████████| 3208/3208 [03:36<00:00, 14.80it/s] 
100%|██████████| 401/401 [00:06<00:00, 64.59it/s]


    epoch          : 4
    loss           : 0.581721
    accuracy       : 0.804205
    macro_f        : 0.63022
    doc_entropy    : 4.076777
    val_loss       : 0.629332
    val_accuracy   : 0.787841
    val_macro_f    : 0.611674
    val_doc_entropy: 4.066385
    monitor_best   : 0.787841
    seed           : 42
                   : None
    run_name       : News26/keep_all/BiAttentionClassifyModel


Train Epoch: 5 Loss: 0.5164729356765747: 100%|██████████| 3208/3208 [03:36<00:00, 14.84it/s] 
100%|██████████| 401/401 [00:06<00:00, 65.86it/s]


    epoch          : 5
    loss           : 0.530042
    accuracy       : 0.821147
    macro_f        : 0.658462
    doc_entropy    : 3.95445
    val_loss       : 0.61431
    val_accuracy   : 0.794232
    val_macro_f    : 0.620323
    val_doc_entropy: 3.952786
    monitor_best   : 0.794232
    seed           : 42
                   : None
    run_name       : News26/keep_all/BiAttentionClassifyModel


Train Epoch: 6 Loss: 0.3570220172405243: 100%|██████████| 3208/3208 [03:36<00:00, 14.85it/s] 
100%|██████████| 401/401 [00:06<00:00, 65.38it/s]


    epoch          : 6
    loss           : 0.484072
    accuracy       : 0.836345
    macro_f        : 0.683934
    doc_entropy    : 3.819045
    val_loss       : 0.620367
    val_accuracy   : 0.794934
    val_macro_f    : 0.62788
    val_doc_entropy: 3.789298
    monitor_best   : 0.794934
    seed           : 42
                   : None
    run_name       : News26/keep_all/BiAttentionClassifyModel


Train Epoch: 7 Loss: 0.15787436068058014: 100%|██████████| 3208/3208 [03:35<00:00, 14.86it/s]
100%|██████████| 401/401 [00:06<00:00, 65.45it/s]

    epoch          : 7
    loss           : 0.442091
    accuracy       : 0.850375
    macro_f        : 0.707219
    doc_entropy    : 3.686661
    val_loss       : 0.62613
    val_accuracy   : 0.793998
    val_macro_f    : 0.626108
    val_doc_entropy: 3.695529
    monitor_best   : 0.794934
    seed           : 42
                   : None
    run_name       : News26/keep_all/BiAttentionClassifyModel



Train Epoch: 8 Loss: 0.5448232889175415: 100%|██████████| 3208/3208 [03:36<00:00, 14.85it/s] 
100%|██████████| 401/401 [00:06<00:00, 64.84it/s]

    epoch          : 8
    loss           : 0.399342
    accuracy       : 0.864589
    macro_f        : 0.732997
    doc_entropy    : 3.559433
    val_loss       : 0.636995
    val_accuracy   : 0.794232
    val_macro_f    : 0.625663
    val_doc_entropy: 3.526581
    monitor_best   : 0.794934
    seed           : 42
                   : None
    run_name       : News26/keep_all/BiAttentionClassifyModel



Train Epoch: 9 Loss: 0.4767472743988037: 100%|██████████| 3208/3208 [03:35<00:00, 14.89it/s] 
100%|██████████| 401/401 [00:06<00:00, 63.98it/s]

    epoch          : 9
    loss           : 0.362917
    accuracy       : 0.878433
    macro_f        : 0.754278
    doc_entropy    : 3.466064
    val_loss       : 0.65797
    val_accuracy   : 0.792907
    val_macro_f    : 0.622639
    val_doc_entropy: 3.46579
    monitor_best   : 0.794934
    seed           : 42
                   : None
    run_name       : News26/keep_all/BiAttentionClassifyModel



Train Epoch: 10 Loss: 0.5627778768539429: 100%|██████████| 3208/3208 [03:37<00:00, 14.75it/s] 
100%|██████████| 401/401 [00:06<00:00, 64.50it/s]

    epoch          : 10
    loss           : 0.32689
    accuracy       : 0.890484
    macro_f        : 0.776173
    doc_entropy    : 3.371288
    val_loss       : 0.679327
    val_accuracy   : 0.792673
    val_macro_f    : 0.632346
    val_doc_entropy: 3.381072
    monitor_best   : 0.794934
    seed           : 42
                   : None
    run_name       : News26/keep_all/BiAttentionClassifyModel
Validation performance did not improve for 3 epochs. Training stops.





In [5]:
from utils import get_topic_dist

topic_dist = get_topic_dist(trainer, list(data_loader.word_dict.values())).transpose(1, 0)

In [6]:
id2word = {k: w for w, k in data_loader.word_dict.items()}

In [5]:
from utils import load_docs, filter_tokens
dataset_name, method = config.data_config["name"].split("/")
ref_texts = load_docs(dataset_name, method)
topic_dict = filter_tokens(ref_texts, 30, 0.5)
word_index = [data_loader.word_dict[w] for w in topic_dict.values() if w in data_loader.word_dict]


adding document #0 to Dictionary(0 unique tokens: [])
adding document #10000 to Dictionary(76864 unique tokens: ["'", '(', ')', ',', '-']...)
adding document #20000 to Dictionary(103270 unique tokens: ["'", '(', ')', ',', '-']...)
adding document #30000 to Dictionary(122908 unique tokens: ["'", '(', ')', ',', '-']...)
adding document #40000 to Dictionary(139266 unique tokens: ["'", '(', ')', ',', '-']...)
adding document #50000 to Dictionary(153076 unique tokens: ["'", '(', ')', ',', '-']...)
adding document #60000 to Dictionary(165969 unique tokens: ["'", '(', ')', ',', '-']...)
adding document #70000 to Dictionary(177578 unique tokens: ["'", '(', ')', ',', '-']...)
adding document #80000 to Dictionary(188413 unique tokens: ["'", '(', ')', ',', '-']...)
adding document #90000 to Dictionary(198849 unique tokens: ["'", '(', ')', ',', '-']...)
adding document #100000 to Dictionary(208342 unique tokens: ["'", '(', ')', ',', '-']...)
adding document #110000 to Dictionary(217766 unique toke

KeyError: '@'

In [16]:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

word_graph = []
threshold = 0.8
for i in range(1, topic_dist.shape[0]):
    cos_simi = cosine_similarity(topic_dist[i].reshape(1, -1), topic_dist[i:])[0]
    for j in range(len(cos_simi)):
        if cos_simi[j] > threshold:
            word_graph.append((i, i+j))



KeyboardInterrupt



In [12]:
from pathlib import Path
from utils import get_project_root, read_json
from utils.graph_untils import load_entities

entity2id = load_entities()
default_config = read_json(Path(get_project_root()) / "config" / "mind_rs_default.json")
from collections import defaultdict
import json


def load_news_entity_from_file(news_file):
    news_entity = defaultdict()
    with open(news_file, "r", encoding="utf-8") as rd:
        for text in rd:
            # news id, category, subcategory, title, abstract, url
            nid, vert, subvert, title, abstract, url, title_entity, abs_entity = text.strip("\n").split("\t")
            title_entity_json, abstract_entity_json = json.loads(title_entity), json.loads(abs_entity)
            for entity in title_entity_json + abstract_entity_json:
                news_entity[entity["WikidataId"]] = entity["SurfaceForms"]
    return news_entity


root_path = Path(default_config["data_config"]["data_path"])
news_entities = defaultdict()
phases = ["train", "valid", "test"]
mind_type = "large"
for phase in phases:
    file = root_path / mind_type / phase / "news.tsv"
    if file.exists():
        news_entities.update(load_news_entity_from_file(file))