In [None]:
import os
import io
import json
import time
import math
import string 
import pickle
import datetime
import itertools
import numpy as np
import pandas as pd
from pprint import pprint 
from tqdm.notebook import tqdm
from collections import Counter

import matplotlib.pyplot as plt

from scipy.sparse import csr_matrix as sparse_matrix

from sklearn.svm import SVR
from sklearn.neural_network import MLPRegressor
from sklearn.linear_model import SGDClassifier, SGDRegressor, LogisticRegression, LinearRegression
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer, TfidfTransformer
from sklearn.preprocessing import LabelBinarizer, MultiLabelBinarizer

from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer 
from nltk.tokenize import RegexpTokenizer, word_tokenize  

from gensim.test import utils
from gensim.models import KeyedVectors, nmf
from gensim.corpora.dictionary import Dictionary
from gensim.scripts.glove2word2vec import glove2word2vec
from gensim.parsing.preprocessing import preprocess_documents
from gensim.models.doc2vec import Doc2Vec, TaggedDocument

from sentence_transformers import models, SentenceTransformer

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="3"

In [None]:
df_talks = pd.read_csv('../Data/talks_data.csv')
df_users = pd.read_csv('../users_data.csv')

df_train = pd.read_csv('../Data/TED_train.csv')
df_test  = pd.read_csv('../Data/TED_test.csv')
talks_ids = {k:i for i,k in pickle.load(open('../Data/dict_talks_idx.pickle', 'rb')).items()}
users_ids = {k:i for i,k in pickle.load(open('../Data/dict_users_idx.pickle', 'rb')).items()}

df_train['user'] = df_train['user_id'].apply(lambda u: users_ids[u])
df_train['talk'] = df_train['talk_id'].apply(lambda u: talks_ids[u])
df_test['user'] = df_test['user_id'].apply(lambda u: users_ids[u])
df_test['talk'] = df_test['talk_id'].apply(lambda u: talks_ids[u])

df_train.head()

In [None]:
df_talks.head()

In [None]:
import spacy

In [None]:
# ! pip install --upgrade spacy

In [None]:
# ! python -m spacy download en_core_web_trf

In [None]:
nlp = spacy.load("en_core_web_trf")

In [None]:
from spacy.lang.en.examples import sentences 
len(sentences)

In [None]:
sentences

In [None]:
entities = {}
for i, talk in tqdm(df_talks.iterrows(), total=len(df_talks)):
    content = talk['title'] + ' ' + talk['description'] + ' ' + talk['transcript'] 
    talk_hash = talk['id']
    entities[talk_hash] = []
    doc = nlp(content)

    for ent in doc.ents:
        entities[talk_hash].append((ent.text, ent.label_))

In [None]:
entities['062dd0f773cd5999a09714a371e1f8017163e2a1']

In [None]:
pickle.dump(entities, open('entities.pickle', 'wb'))

In [None]:
enamex_entities = {}
all_types = set()
all_entities = []
for talk in entities:
    enamex = set()
    for entity, etype in entities[talk]:
        all_types.add(etype)
        if etype in ['PERSON', 'LOC', 'ORG', 'GPE', 'ORG', 'FAC', 'PRODUCT', 'WORK_OF_ART']:
            all_entities.append(entity)
            enamex.add((entity, etype))
        enamex_entities[talk] = enamex

In [None]:
entity_counter = Counter(all_entities)

In [None]:
entity_counter.most_common()[-20:]

In [None]:
all_types

In [None]:
enamex_entities['062dd0f773cd5999a09714a371e1f8017163e2a1']

In [None]:
with open('../kg_embeddings/ted-ner-all.txt', 'w') as f:
    for talk in enamex_entities:
        for entity, etype in enamex_entities[talk]:
            h = talk
            r = 'mentions'
            t = entity
            if entity_counter[entity] > 1:
                f.write(f'{h}\t{r}\t{t}\n')
            else:
                print(entity)

In [None]:
with open('../kg_embeddings/metadata-interactions-ner/ted-ner-min10.txt', 'w') as f:
    for talk in enamex_entities:
        for entity, etype in enamex_entities[talk]:
            h = talk
            r = 'mentions'
            t = entity
            if entity_counter[entity] > 9:
                f.write(f'{h}\t{r}\t{t}\n')
            else:
                print(entity)