In [1]:
import pandas as pd
import numpy as np
import pickle
from collections import defaultdict, deque
from itertools import combinations
from tqdm.notebook import tqdm

In [2]:
COLUMNS_TO_DROP = ['n_citation', 'keywords', 'fos', 'title', 'abstract', 'venue']
NUM_PARTS = 3
RANDOM_STATE = 42
TRAIN_SIZE = 0.8

def get_data(file_path):
    data = pd.read_json(file_path).drop(COLUMNS_TO_DROP, axis=1)
    return data

def save_pickle(obj, filename):
    with open(filename, 'wb') as f:
        pickle.dump(obj, f)

def load_pickle(filename):
    with open(filename, 'rb') as f:
        return pickle.load(f)

In [3]:
# articles = pd.concat(get_data(f'data/part_{i+1}_clean.json') for i in range(NUM_PARTS))
# articles.reset_index(drop=True, inplace=True)
# articles.to_json('articles_tts.json')
articles = pd.read_json('articles_tts.json')
N = articles.shape[0]
articles.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 1084405 entries, 0 to 1084404
Data columns (total 4 columns):
 #   Column      Non-Null Count    Dtype 
---  ------      --------------    ----- 
 0   _id         1084405 non-null  object
 1   year        1084405 non-null  int64 
 2   references  1084405 non-null  object
 3   authors     1084405 non-null  object
dtypes: int64(1), object(3)
memory usage: 41.4+ MB


In [4]:
articles.sort_values('year', inplace=True)
articles

Unnamed: 0,_id,year,references,authors
58561,53e99ad7b7602d9702357302,0,"[53e9a53ab7602d9702e5fda9, 53e9b350b7602d9703e...","[{'_id': '53f431b5dabfaee4dc75054f', 'name': '..."
1070925,5c2c7a6717c44a4e7cf30f13,0,"[558c6cb4e4b0cfb70a1d9a18, 599c785a601a182cd25...","[{'_id': '53f42bffdabfaeb22f3f31e4', 'name': '..."
1071185,5c38b5463a55acc9e54e1b8a,0,"[558b6436e4b037c0875c9f57, 53e99fe9b7602d97028...","[{'_id': '53f437aedabfaedce553aec1', 'name': '..."
1072180,5c5ce4fd17c44a400fc380e6,0,"[53e99b43b7602d97023e4719, 53e99de2b7602d97026...","[{'_id': '53f42bffdabfaeb22f3f31e4', 'name': '..."
152205,53e99eaeb7602d9702777569,0,"[53e9bdd5b7602d9704a86144, 53e9af19b7602d97039...","[{'_id': '53f3aaebdabfae4b34af69b9', 'name': '..."
...,...,...,...,...
567655,53e9b2f5b7602d9703db387c,2300,"[53e9a37ab7602d9702c8997f, 53e9ad63b7602d97037...","[{'_id': '56cb189bc35f4f3c65654efd', 'name': '..."
152867,53e99eb5b7602d970277e3f2,2300,"[53e99f78b7602d970284e318, 53e99837b7602d97020...","[{'_id': '53f434dcdabfaeb22f463a4d', 'name': '..."
275647,53e9a4b2b7602d9702dd3641,2300,"[53e999c3b7602d97022073a4, 5c78dbc94895d9cbc6f...","[{'_id': '53f444fadabfaedf435c6f95', 'name': '..."
139098,53e99e28b7602d97026ebec5,2300,"[557d34e66feeaa8086da8181, 558ac89fe4b0b32fcb3...","[{'_id': '53f43265dabfaee1c0a74818', 'name': '..."


In [5]:
# article_to_authors_id, author_to_articles_id = defaultdict(set), defaultdict(set)

# for _, article in tqdm(articles.iterrows(), total=articles.shape[0]):
#     for author in article.authors:
#         article_to_authors_id[article._id].add(author['_id'])
#         author_to_articles_id[author['_id']].add(article._id)

# save_pickle(article_to_authors_id, 'article_to_authors_id.pickle')
# save_pickle(author_to_articles_id, 'author_to_articles_id.pickle')

In [6]:
article_to_authors_id = load_pickle('article_to_authors_id.pickle')
author_to_articles_id = load_pickle('author_to_articles_id.pickle')

In [7]:
assert len(article_to_authors_id) == N
len(author_to_articles_id)

1199286

# 1) Search for connected components by authors using BFS

## Starting from atricle, iterate over all authors, add their papers to deque and to current connected component

In [20]:
# this approach doesn't work because there exist only one connected component in space of articles
def bfs1(article_id: str) -> set[str]:
    """
    return: all article_ids from connected componend with given @article_id
    """
    connected_component = set([article_id])
    deq = deque([article_id])
    while deq:
        _article_id = deq.popleft()
        connected_component.add(_article_id)
        for author_id in article_to_authors_id[_article_id]:
            for article in author_to_articles_id[author_id]:
                if article not in deq and article not in connected_component:
                    deq.append(article)

    return connected_component

## This doesn't work because bfs1 above iterates over all articles. Seems that whole graph is single component.

## Next we populating deque by articles, that include two authors of current article.

In [9]:
def bfs2(article_id: str, num_common_articles) -> set[str]:
    """
    for current article we add all other articles with more than one common authors
    return: all article_ids from connected componend with given @article_id
    """
    connected_component = set()
    visited = set()
    deq = deque([article_id])
    i = 0
    while deq:
        _article_id = deq.popleft()
        visited.add(_article_id)
        connected_component.add(_article_id)
        authors = sorted(article_to_authors_id[_article_id])
        for author1, author2 in combinations(authors, 2):
            if (author1, author2) not in visited:
                visited.add((author1, author2))
                atricles_intersection = author_to_articles_id[author1] & author_to_articles_id[author2]
                if len(atricles_intersection) >= num_common_articles:
                    for article in atricles_intersection:
                        if article not in deq and article not in visited:
                            deq.append(article)

    return connected_component

In [10]:
%%time
train_article_ids2 = set()

for _, article in tqdm(articles.iterrows(), total=articles.shape[0]):
    if article._id not in train_article_ids2:
        cc = bfs2(article._id, num_common_articles=2)
        train_article_ids2 |= cc
        if len(train_article_ids2) > TRAIN_SIZE * N:
            break

test_article_ids2 = set(articles._id.values) - train_article_ids2
assert len(train_article_ids2) + len(test_article_ids2) == N

save_pickle(train_article_ids2, 'train_article_ids2.pickle')
save_pickle(test_article_ids2, 'test_article_ids2.pickle')

  0%|          | 0/1084405 [00:00<?, ?it/s]

CPU times: total: 5min 39s
Wall time: 5min 38s


In [12]:
len(train_article_ids2), len(test_article_ids2)

(867525, 216880)

In [15]:
cc = bfs2(articles.iloc[0]._id, num_common_articles=2)
s = articles.loc[articles._id.isin(cc)]
print(len(cc))
for _, row in s.iterrows():
    print('\t'.join(author['name'] for author in row.authors))

3
Klaus Didrich	Carola Gerke	Wolfgang Grieskamp	Christian Maeder	Peter Pepper
Klaus Didrich	Andreas Fett	Carola Gerke	Wolfgang Grieskamp	Peter Pepper
Klaus Didrich	Carola Gerke	Wolfgang Grieskamp	Christian Maeder	Peter Pepper


### Let set `num_common_articles=3` in `bfs2` so that for each two authors we will select those articles, that lie in intersection of articles of these authors and size of intersection is at least 3.

In [11]:
%%time
train_article_ids3 = set()

for _, article in tqdm(articles.iterrows(), total=articles.shape[0]):
    if article._id not in train_article_ids3:
        cc = bfs2(article._id, num_common_articles=3)
        train_article_ids3 |= cc
        if len(train_article_ids3) > TRAIN_SIZE * N:
            break

test_article_ids3 = set(articles._id.values) - train_article_ids3
assert len(train_article_ids3) + len(test_article_ids3) == N

save_pickle(train_article_ids3, 'train_article_ids3.pickle')
save_pickle(test_article_ids3, 'test_article_ids3.pickle')

  0%|          | 0/1084405 [00:00<?, ?it/s]

CPU times: total: 2min 38s
Wall time: 2min 38s


In [13]:
len(train_article_ids3), len(test_article_ids3)

(867525, 216880)

In [16]:
cc = bfs2(articles.iloc[0]._id, num_common_articles=3)
s = articles.loc[articles._id.isin(cc)]
print(len(cc))
for _, row in s.iterrows():
    print('\t'.join(author['name'] for author in row.authors))

3
Klaus Didrich	Carola Gerke	Wolfgang Grieskamp	Christian Maeder	Peter Pepper
Klaus Didrich	Andreas Fett	Carola Gerke	Wolfgang Grieskamp	Peter Pepper
Klaus Didrich	Carola Gerke	Wolfgang Grieskamp	Christian Maeder	Peter Pepper


### Now let iterate over triplets of authors of current article and populate deque from common articles.

In [17]:
def bfs3(article_id: str, num_common_articles=2) -> set[str]:
    """
    for current article we add all other articles with more than one common authors
    return: all article_ids from connected componend with given @article_id
    """
    connected_component = set()
    visited = set()
    deq = deque([article_id])
    while deq:
        _article_id = deq.popleft()
        visited.add(_article_id)
        connected_component.add(_article_id)
        authors = sorted(article_to_authors_id[_article_id])
        for author1, author2, author3 in combinations(authors, 3):
            if (author1, author2, author3) not in visited:
                visited.add((author1, author2, author3))
                atricles_intersection = author_to_articles_id[author1] & author_to_articles_id[author2] & author_to_articles_id[author3]
                if len(atricles_intersection) >= num_common_articles:
                    for article in atricles_intersection:
                        if article not in deq and article not in visited:
                            deq.append(article)

    return connected_component

In [18]:
train_article_ids_by_triplets = set()

for _, article in tqdm(articles.iterrows(), total=articles.shape[0]):
    if article._id not in train_article_ids_by_triplets:
        cc = bfs3(article._id)
        train_article_ids_by_triplets |= cc
        if len(train_article_ids_by_triplets) > TRAIN_SIZE * N:
            break

  0%|          | 0/1084405 [00:00<?, ?it/s]

In [116]:
len(train_article_ids_by_triplets)

867525

In [103]:
len(train_article_ids_by_triplets) / N

0.8000009221646894

In [202]:
cc = bfs3(articles.iloc[33]._id, num_common_articles=2)
s = articles.loc[articles._id.isin(cc)]
print(len(cc), articles.iloc[33]._id)
for _, row in s.iterrows():
    print('\t'.join(author['name'] for author in row.authors))

9 5c2c7a6717c44a4e7cf30b18
Zoe L. Jiang	Jin Yabin	Wang Xuan	Fang Junbin
Liu Xiaoyan	Zoe L. Jiang	Wang Xuan	Li Ye	Liu Zechao	Jin Yabin	Fang Junbin
Wang Xuan	Fang Junbin	Li Jin	Jin Yabin	Huang Jiajun
Liu Zechao	Zoe L. Jiang	Wang Xuan	Zhang Chunkai	Zhao Xiaomeng
Li Ye	Zoe L. Jiang	Wang Xuan	Liao Qing
Li Ye	Zoe L. Jiang	Wang Xuan
Li Ye	Zoe L. Jiang	Wang Xuan	Fang Junbin
Liu Meng	Wang Xuan	Yang Chi	Jiang Zoe Lin	Li Ye
Li Ye	Zoe L. Jiang	Wang Xuan


# 2) Search for connected components by references

In [174]:
def bfs_by_ref(article_id: str) -> set[str]:
    """
    return: all article_ids from connected componend with given @article_id
    """
    connected_component = set([article_id])
    deq = deque([article_id])
    i = 0
    while deq:
        _article_id = deq.popleft()
        connected_component.add(_article_id)
        references = articles.query(f"_id == '{_article_id}'").references.values.tolist()
        if references:
            for ref in references[0]:
                if ref not in connected_component:
                    deq.append(ref)
        i += 1
        if i % 100 == 0:
            print(f'len(cc)={len(connected_component)}, len(deq)={len(deq)}')
        if i > 1e3:
            break
    return connected_component

In [175]:
cc = bfs_by_ref(articles.iloc[0]._id)

len(cc)=78, len(deq)=104
len(cc)=147, len(deq)=190
len(cc)=220, len(deq)=232
len(cc)=271, len(deq)=289
len(cc)=339, len(deq)=317
len(cc)=371, len(deq)=338
len(cc)=428, len(deq)=375
len(cc)=494, len(deq)=362
len(cc)=525, len(deq)=383
len(cc)=569, len(deq)=373
