In [None]:
import requests
from xml.etree import ElementTree
import networkx as nx
import matplotlib.pyplot as plt
from pyvis.network import Network
from alive_progress import alive_bar
import colour

In [None]:
def remove_uninformative(graph):
    nodes_to_remove = [node for node in graph.nodes() if graph.out_degree(node) == 0 and graph.in_degree(node) == 1]
    print(f"Removing {len(nodes_to_remove)} nodes")
    graph.remove_nodes_from(nodes_to_remove)
    return graph

def remove_orphans(graph):
    nodes_to_remove = [node for node in graph.nodes() if graph.out_degree(node) == 0 and graph.in_degree(node) == 0]
    print(f"Removing {len(nodes_to_remove)} nodes")
    graph.remove_nodes_from(nodes_to_remove)
    return graph

In [None]:
# Get the most cited paper in the graph
def get_most_cited(g, n=0, all_counts=False):
    if all_counts:
        return sorted(list(dg.in_degree()), key=lambda x: x[1], reverse=True)
    if (isinstance(n, int)):
        if n == 0:
            return sorted(list(dg.in_degree()), key=lambda x: x[1], reverse=True)[0]
        else:
            return sorted(list(dg.in_degree()), key=lambda x: x[1], reverse=True)[n]
    elif (isinstance(n, list)):
        ret = []
        s = sorted(list(dg.in_degree()), key=lambda x: x[1], reverse=True)
        for i in n:
            ret.append(s[i])
        return ret

In [None]:
def get_papers(keyword, pages=1, perPage=15):
    # Build the URL for the search endpoint with the keyword query.
    url = f"https://www.ebi.ac.uk/europepmc/webservices/rest/search?query={keyword}&page={pages}&pageSize={perPage}&format=xml"

    # Get the XML response from Europe PMC.
    response = requests.get(url)

    # Parse the XML into an ElementTree.
    pmids = []
    tree = ElementTree.fromstring(response.content)
    for result in tree.findall('.//result'):
        pmid_elem = result.find('pmid')
        if pmid_elem is not None:
            pmids.append(pmid_elem.text)
    #print(pmids)
    hit_count_elem = tree.find('.//hitCount')
    #if hit_count_elem is not None:
    #    print("Total hits:", hit_count_elem.text)
    return pmids

In [None]:
def get_citation_dict(pmids, pages=1, perPage=100):
    citation_dict = {}
    with alive_bar(len(pmids), force_tty=True) as bar:
        for i, pmid in enumerate(pmids):
            url = f"https://www.ebi.ac.uk/europepmc/webservices/rest/MED/{pmid}/references?page={pages}&pageSize={perPage}&format=xml"
            response = requests.get(url)
            tree = ElementTree.fromstring(response.content)

            # List to hold PMIDs that the current paper cites.
            cited_pmids = []
            # The XML is assumed to have <reference> elements each containing a <pmid> element.
            for ref in tree.findall('.//reference'):
                cited_pmid_elem = ref.find('id')
                if cited_pmid_elem is not None:
                    cited_pmids.append(cited_pmid_elem.text)

            citation_dict[pmid] = cited_pmids
            bar()
    return citation_dict

In [None]:
def remove_no_refs(citation_dict):
    return {key: value for key, value in citation_dict.items() if value}

In [None]:
keyword = "bioinformatics"
pmids = []
pages = 20
with alive_bar(pages, force_tty=True) as bar:
    for i in range(1, pages):
        res = get_papers(keyword, pages=i, perPage=100)
        pmids.append(res)
        bar()
    pmids = list(set(sum(pmids, [])))

In [None]:
citation_dict = get_citation_dict(pmids)

In [None]:
dg = nx.DiGraph(citation_dict)

In [None]:
dg = remove_uninformative(dg)
dg = remove_orphans(dg)

In [None]:
get_most_cited(dg, [0,1,2,3,4,5])

In [None]:
# Create a PyVis network; set notebook=True if you're in a Jupyter notebook.
net = Network(height="750px", width="100%", bgcolor="#222222", font_color="white", directed=True, notebook=True)
net.show_buttons(filter_=['physics'])
net.from_nx(dg)

# Calculate indegree using networkx
indegree_centrality = nx.in_degree_centrality(dg)

min_indegree = min(indegree_centrality.values())
max_indegree = max(indegree_centrality.values())

# Color mapping function (linear interpolation from blue to red)
def get_color(indegree):
    normalized_indegree = (indegree - min_indegree) / (max_indegree - min_indegree) if (max_indegree - min_indegree) != 0 else 0
    r = int(255 * normalized_indegree)
    b = int(255 * (1 - normalized_indegree))
    return f'rgb({r},20,{b})'

# Apply colors to nodes
for node_id, indegree in indegree_centrality.items():
    node = net.get_node(node_id)
    if node: # Check if node exists
        node['color'] = get_color(indegree)
        #node['color'] = 'blue'
             
        
net.show('indegree_gradient_network.html')