In [1]:
import json
import pickle
from os.path import join
from random import sample
import numpy as np
import pandas as pd
import shap
from s2and.consts import CONFIG
from s2and.data import PDData
from s2and.featurizer import featurize, FeaturizationInfo, many_pairs_featurize
from s2and.model import PairwiseModeler, Clusterer, FastCluster
from s2and.eval import pairwise_eval, cluster_eval, facet_eval
from s2and.consts import FEATURIZER_VERSION, DEFAULT_CHUNK_SIZE, PROJECT_ROOT_PATH
from s2and.file_cache import cached_path
from sklearn.model_selection import GroupKFold, train_test_split
from sklearn.metrics import average_precision_score

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(context='talk')

  from .autonotebook import tqdm as notebook_tqdm


RuntimeError: module compiled against API version 0xe but this version of numpy is 0xd



In [13]:
# load the data
with open('/net/nfs2.s2-research/shaurya/projects/s2pac/final_formatting/data/papers.v2.json', 'r') as f:
    papers = json.load(f)

with open('/net/nfs2.s2-research/shaurya/projects/s2pac/final_formatting/data/clusters.v2.json', 'r') as f:
    clusters = json.load(f)

In [3]:
print(next(iter(papers.items())), next(iter(clusters.items())))

('43808411', {'authors': [{'position': 0, 'first': 'Ursula', 'middle': [], 'last': 'Hofer', 'suffix': None, 'affiliations': [], 'email': None}], 'abstract': None, 'references': [], 'paper_id': 43808411, 'source': 'Medline', 'doi': '10.1038/nrmicro3092', 'pmid': 23872943, 'title': 'Viral evolution: Variation in the gut virome.', 'year': 2013, 'venue': 'Nature reviews. Microbiology', 'publicationtypes': ['LettersAndComments', 'JournalArticle'], 'fieldsofstudy': ['Medicine'], 'journal_name': '\n          596\n        ', 'block': 'viralevolution', 'corpus_paper_id': 19963}) ('thefboxproteinslmbrestrictstheactiv_6551263', {'cluster_id': 'thefboxproteinslmbrestrictstheactiv_6551263', 'paper_ids': [3673658920, 3503591660, 2051795599], 'model_version': -1})


In [16]:
set([type(key) for key in papers.keys()])

{str}

In [14]:
# step 1: we need to make sure that every paper is in a cluster that is inside the *correct* block
# if not, we delete the paper from the cluster in the incorrect block
blocks = {str(k): v['block'] for k, v  in papers.items()}
blocks.update({int(k): v['block'] for k, v  in papers.items()})

times_trimmed = 0
for cluster_id in clusters.keys():
    papers_loop = clusters[cluster_id]['paper_ids']
    this_block = cluster_id.split('_')[0]
    papers_for_this_block = [p for p in papers_loop if blocks[p] == this_block]
    clusters[cluster_id]['paper_ids'] = papers_for_this_block
    if len(papers_for_this_block) < len(papers_loop):
        times_trimmed += 1
print(times_trimmed)

190


In [15]:
# step 2: a bunch of papers are in clusters AND in orphans. we need to remove them from the orphans
from collections import defaultdict

papers_to_clusters = defaultdict(set)
for cluster_id, cluster in clusters.items():
    for paper_id in cluster['paper_ids']:
        papers_to_clusters[paper_id].add(cluster_id)

In [25]:
total_removed = 0
for paper_id, cluster_ids in papers_to_clusters.items():
    if len(cluster_ids) > 1:
        orphan_clusters = [c for c in cluster_ids if c.endswith('_orphans')]
        orphan_cluster_blocks = [c.split('_')[0] for c in orphan_clusters]
        nonorphan_clusters = [c for c in cluster_ids if not c.endswith('_orphans')]
        nonorphan_clusters_blocks = [c.split('_')[0] for c in nonorphan_clusters]
        # if the paper is in an orphan cluster, and there is a non-orphan cluster in the same block, remove it from the orphan cluster
        to_remove = [c for b, c in zip(orphan_cluster_blocks, orphan_clusters) if b in nonorphan_clusters_blocks]
        total_removed += len(to_remove)
        for cluster_id in to_remove:
            try: # i really don't know the types of these keys
                clusters[cluster_id]['paper_ids'].remove(str(paper_id))
            except:
                clusters[cluster_id]['paper_ids'].remove(int(paper_id))
total_removed

11092

In [28]:
# step 3: a bunch of papers are in multiple clusters in the SAME block. probably easiest to just remove these blocks entirely
papers_to_clusters = defaultdict(set)
for cluster_id, cluster in clusters.items():
    for paper_id in cluster['paper_ids']:
        papers_to_clusters[paper_id].add(cluster_id)
        
papers_in_multiple_clusters = set([int(p) for p, c in papers_to_clusters.items() if len(c) > 1] + [str(p) for p, c in papers_to_clusters.items() if len(c) > 1] )

In [31]:
for cluster_id in clusters.keys():
    paper_ids = clusters[cluster_id]['paper_ids']
    paper_ids_trimmed = [p for p in paper_ids if p not in papers_in_multiple_clusters]
    clusters[cluster_id]['paper_ids'] = paper_ids_trimmed

In [35]:
# keep only papers that are in a cluster
papers_in_clusters = set()
for cluster_id in clusters.keys():
    papers_in_clusters.update(set([str(i) for i in clusters[cluster_id]['paper_ids']]))

papers = {k: v for k, v in papers.items() if str(k) in papers_in_clusters}
assert papers_in_clusters == set(papers.keys())

In [38]:
# save next version
with open('/net/nfs2.s2-research/shaurya/projects/s2pac/final_formatting/data/papers.v3.json', 'w') as f:
    json.dump(papers, f)

with open('/net/nfs2.s2-research/shaurya/projects/s2pac/final_formatting/data/clusters.v3.json', 'w') as f:
    json.dump(clusters, f)

In [2]:
# now downsample the blocks that have 100% accuracy
# load the data
with open('/net/nfs2.s2-research/shaurya/projects/s2pac/final_formatting/data/papers.v3.json', 'r') as f:
    papers = json.load(f)

with open('/net/nfs2.s2-research/shaurya/projects/s2pac/final_formatting/data/clusters.v3.json', 'r') as f:
    clusters = json.load(f)

In [3]:
# now let's drop a subset of the "easy" blocks
df = pd.read_csv(join(PROJECT_ROOT_PATH, "data", "block_removal_candidates.csv"))
easy_blocks = df['block'][df.accuracy == 1.0].values
not_easy_blocks = df['block'][df.accuracy != 1.0].values

In [4]:
# choose a random half of the easy blocks and keep all the hard blocks
keep_blocks = set(np.random.choice(easy_blocks, size=len(easy_blocks)//2, replace=False))
print(len(keep_blocks))
keep_blocks.update(not_easy_blocks)
print(len(keep_blocks))
keep_blocks.update(df['block'][df.keep_count > 0].values)
print(len(keep_blocks))
keep_blocks -= set(df['block'][df.remove_count > 0].values)
print(len(keep_blocks))

34204
47147
47334
47121


In [5]:
# keep only the papers in the blocks we want to keep
papers_sub = {k: v for k, v in papers.items() if v['block'] in keep_blocks}

# keep only clusters where all of the papers are papers2
clusters_sub = {k: v for k, v in clusters.items() if all([str(p) in papers_sub for p in v['paper_ids']])}

# keep only papers that are in a cluster
papers_in_clusters = set()
for cluster_id in clusters_sub.keys():
    papers_in_clusters.update(set([str(i) for i in clusters[cluster_id]['paper_ids']]))
    
assert papers_in_clusters == set(papers_sub.keys())

In [6]:
# save next version
with open('/net/nfs2.s2-research/shaurya/projects/s2pac/final_formatting/data/papers.v3_hard.json', 'w') as f:
    json.dump(papers_sub, f)

with open('/net/nfs2.s2-research/shaurya/projects/s2pac/final_formatting/data/clusters.v3_hard.json', 'w') as f:
    json.dump(clusters_sub, f)