In [0]:
%sh
pip install habanero
pip install networkx

In [0]:
import pyspark.sql.functions as F
import graphframes as gf
import networkx as nx
import matplotlib.pyplot as plt
import requests
import json
from habanero import Crossref
from pyspark.sql.types import StringType, IntegerType, ArrayType
import time


spark.conf.set("spark.sql.shuffle.partitions", sc.defaultParallelism)

In [0]:
raw_df = spark.read.format("delta").load("/user/hive/warehouse/scientific_publications")
display(raw_df.limit(25))

In [0]:
filtered_df = raw_df.withColumn("title_word_count", F.size(F.split("title", " "))).filter("title_word_count > 1")
display(filtered_df.limit(25))

In [0]:
authors_exp = (filtered_df.limit(10000)
            .select("doi", "authors", "title", F.posexplode(F.col("authors")).alias("rank", "authors_exp"))
            .withColumn("rank", F.col("rank") + 1)
            .select("authors_exp.*","*")
            .select("rank", "name", "title", "org")
            .withColumn("name", F.initcap(F.col("name")))
            )

authors_id = (authors_exp.select("name").dropDuplicates()
           .withColumn("id", F.monotonically_increasing_id())
           .withColumn("id", F.lpad("id", 8, "0"))
           .withColumn("id", F.lpad("id", 9, "A"))
          )

authors = authors_exp.join(authors_id, ["name"])
display(authors.limit(100))

In [0]:
vertices = authors_id.select("id", "name").withColumn("type", F.lit("author"))

display(vertices.limit(100))

In [0]:
df1 = authors.select("id", "title")
df2 = df1.alias("a").join(df1.alias("b"), ["title"])
df3 = df2.filter(F.col("a.id") < F.col("b.id"))
co_authors = df3.select(F.col("a.id").alias("src"), F.col("b.id").alias("dst"), "title")

display(co_authors.limit(100))

In [0]:
co_authors_graph = gf.GraphFrame(vertices, co_authors)

vertices.cache()
co_authors.cache()

display(co_authors_graph.triplets.limit(100))

In [0]:
def get_organization(name, country):
    try:
        URL = "https://api.ror.org/organizations?query=" + name.replace(" ", "+")
        r = requests.get(url = URL)
        data = r.json()
        if data['number_of_results'] == 0:
            return "No results"
        for i in range(data['number_of_results']):
            found_country = data['items'][i]["country"]["country_name"]
            if country == "" or country == found_country:
                return data['items'][i]["name"]+";"+data['items'][i]["addresses"][0]["city"]+";"+found_country
    except:
        return "No results"

In [0]:
countries = "(,|\s)(Afghanistan|Albania|Algeria|Andorra|Angola|Antigua and Barbuda|Argentina|Armenia|Australia|Austria|Azerbaijan|Bahamas|Bahrain|Bangladesh|Barbados|Belarus|Belgium|Belize|Benin|Bhutan|Bolivia|Bosnia and Herzegovina|Botswana|Brazil|Brunei|Bulgaria|Burkina Faso|Burundi|Cabo Verde|Cambodia|Cameroon|Canada|Central African Republic|Chad|Chile|China|Colombia|Comoros|Democratic Republic of the Congo|Republic of the Congo|Costa Rica|Cote d'Ivoire|Croatia|Cuba|Cyprus|Czech Republic|Denmark|Djibouti|Dominica|Dominican Republic|Ecuador|Egypt|El Salvador|Equatorial Guinea|Eritrea|Estonia|Ethiopia|Fiji|Finland|France|Gabon|Gambia|Georgia|Germany|Ghana|Greece|Grenada|Guatemala|Guinea|Guinea-Bissau|Guyana|Haiti|Honduras|Hungary|Iceland|India|Indonesia|Iran|Iraq|Ireland|Israel|Italy|Jamaica|Japan|Jordan|Kazakhstan|Kenya|Kiribati|Kosovo|Kuwait|Kyrgyzstan|Laos|Latvia|Lebanon|Lesotho|Liberia|Libya|Liechtenstein|Lithuania|Luxembourg|North Macedonia|Madagascar|Malawi|Malaysia|Maldives|Mali|Malta|Marshall Islands|Mauritania|Mauritius|Mexico|Micronesia|Moldova|Monaco|Mongolia|Montenegro|Morocco|Mozambique|Myanmar|Namibia|Nauru|Nepal|Netherlands|New Zealand|Nicaragua|Niger|Nigeria|North Korea|Norway|Oman|Pakistan|Palau|Palestine|Panama|Papua New Guinea|Paraguay|Peru|Philippines|Poland|Portugal|Qatar|Romania|Russia|Rwanda|Saint Kitts and Nevis|Saint Lucia|Saint Vincent and the Grenadines|Samoa|San Marino|Sao Tome and Principe|Saudi Arabia|Senegal|Serbia|Seychelles|Sierra Leone|Singapore|Slovakia|Slovenia|Solomon Islands|Somalia|South Africa|South Korea|South Sudan|Spain|Sri Lanka|Sudan|Suriname|Swaziland|Sweden|Switzerland|Syria|Taiwan|Tajikistan|Tanzania|Thailand|Timor-Leste|Togo|Tonga|Trinidad and Tobago|Tunisia|Turkey|Turkmenistan|Tuvalu|Uganda|Ukraine|United Arab Emirates|UAE|United Kingdom|UK|United States of America|USA|United States|US|Uruguay|Uzbekistan|Vanuatu|Vatican City|Venezuela|Vietnam|Yemen|Zambia|Zimbabwe)(,|\s|$)"


organization = (filtered_df.limit(100).select("authors", "title", "year", F.explode(F.col("authors")).alias("authors_exp"))
               .select("authors_exp.*","*")
               .withColumn("author", F.col("name"))
               .select("author", "org")
               .filter(F.col("org").isNotNull())
               .dropDuplicates()
                # ., +, *, ?, ^, $, (, ), [, ], {, }, |, \
               .withColumn("strip_org", F.regexp_replace(F.col("org"), r'[\+-=#&\|><!\(\)\{\}\[\]\^"~\*\?:\\/]', " "))
               .withColumn("county", F.regexp_extract(F.col("org"), countries, 2))
               .withColumn("county", F.regexp_replace("county", "United States of America|USA|US", "United States"))
               .withColumn("county", F.regexp_replace("county", "UK", "United Kingdom"))
               .withColumn("county", F.regexp_replace("county", "UAE", "United Arab Emirates"))
               )

organization_rdd = organization.rdd.map(lambda x: (x[0], x[1], x[2], x[3], get_organization(x[2], x[3])))
organization_raw = (organization_rdd.toDF(["author", "org", "strip_org", "country", "api_org"])
                    .withColumn("api_name", F.regexp_extract(F.col("api_org"), r"^(.+);.+;.+$", 1))
                    .withColumn("match", F.expr(r"regexp_extract(org, concat('(', api_name,')'), 0)"))
                    .withColumn("name", F.when(F.col("match") == "", F.col("org"))
                               .otherwise(F.col("api_name")))
                    .withColumn("city", F.when(F.col("match") == "", "")
                               .otherwise(F.regexp_extract(F.col("api_org"), r"^.+;(.+);.+$", 1)))
                    .withColumn("country", F.when(F.col("match") == "", F.col("country"))
                               .otherwise(F.regexp_extract(F.col("api_org"), r"^.+;.+;(.+)$", 1)))
                    .filter(F.col("city") != "")
                   )

organization = (organization_raw.select("name", "city", "country")
                .dropDuplicates()
                .withColumn("id", F.monotonically_increasing_id())
                .withColumn("id", F.lpad("id", 8, "0"))
                .withColumn("id", F.lpad("id", 9, "O"))
               )

vertices = vertices.unionByName(organization.withColumn("type", F.lit("organization")), allowMissingColumns=True)

organization_full = organization_raw.join(organization, ["name", "city", "country"])

auth = authors.select(F.col("id").alias("author_id"), F.col("name").alias("author"))
works_for = (organization_full.join(auth, ["author"])
             .select(F.col("author_id").alias("src"), F.col("id").alias("dst"))
             .dropDuplicates()
            )

display(works_for)

In [0]:
works_for_graph = gf.GraphFrame(vertices, works_for)

vertices.cache()
works_for.cache()

display(works_for_graph.triplets)

## Authorship

In [0]:
def check_return_data(check_type, data, cur):
    if data is not None and check_type in data:
        return data[check_type]
    else:
        return cur
    
def update_df(df, doi_list, data_list, data_str, is_int=False):
    data_dict = dict(zip(doi_list, data_list))
    update_data = udf(lambda x: data_dict[x], IntegerType() if is_int else StringType())
    return df.withColumn(data_str, update_data(F.col('doi')))

def get_publication_data(df):
    doi_list = df.select(F.col("doi")).rdd.flatMap(lambda x: x).collect()
    url_list = df.select(F.col("url")).rdd.flatMap(lambda x: x).collect()
    title_list = df.select(F.col("title")).rdd.flatMap(lambda x: x).collect()
    citation_list = df.select(F.col("n_citation")).rdd.flatMap(lambda x: x).collect()
    volume_list = df.select(F.col("volume")).rdd.flatMap(lambda x: x).collect()
    new_volume_list = []
    n_citation_list = []
    series_list = []
    new_doi_list = []
    for i, doi in enumerate(doi_list):
        try:
            if doi == "" or doi is None:
                if "doi" in url_list[i][0]:
                    doi_req = url_list[i][0].split("org/")[-1]
                else:
                    raise Exception
            else:
                doi_req = doi
            
            new_doi_list.append(doi_req)
            response = requests.get(f"https://api.crossref.org/works/{doi_req}")
            data = response.json()['message']
        except Exception:
            new_doi_list.append(None)
            data = None

        n_citation_list.append(int(check_return_data('is-referenced-by-count', data, citation_list[i])))
        new_volume_list.append(check_return_data('volume', data, volume_list[i]))
        temp = check_return_data('container-title', data, None)
        series_list.append(None if temp == None or len(temp) == 0 else temp[0])

        time.sleep(0.05)
    
    df = update_df(df, doi_list, n_citation_list, 'n_citation')
    df = update_df(df, doi_list, new_volume_list, 'volume')
    df = update_df(df, doi_list, series_list, 'series')
    df = update_df(df, doi_list, new_doi_list, 'doi')
    
    return df

publication_raw = filtered_df.limit(100).select("_id", "title", "volume", "n_citation", "doi", "url")

publications_raw = get_publication_data(publication_raw).filter("n_citation > 2")
# publication = get_publication_dataget_crossref_data(publication_raw).select("_id", "title", "volume", "series", "n_citation").filter("n_citation > 2")

publications = (publications_raw.select("title", "volume", "series", "n_citation")
            .dropDuplicates()
            .withColumn("id", F.monotonically_increasing_id())
            .withColumn("id", F.lpad("id", 8, "0"))
            .withColumn("id", F.lpad("id", 9, "P"))   
          )

display(publications)

In [0]:
v_authorship = authors_id.select("id", "name").withColumn("type", F.lit("author"))
v_authorship = vertices.unionByName(publications.withColumn("type", F.lit("organization")), allowMissingColumns=True)

display(v_authorship.limit(100))

In [0]:
publication_full = publications_raw.join(publications, ["title", "volume", "series", "n_citation"])

auth = authors.select(F.col("id").alias("author_id"), F.col("name").alias("author"), F.col("title"))

e_authorship = (publication_full.join(auth, ["title"])
             .select(F.col("author_id").alias("src"), F.col("id").alias("dst"))
             .dropDuplicates()
            )

display(e_authorship.limit(100))


In [0]:
authorship = gf.GraphFrame(v_authorship, e_authorship.limit(100))

v_authorship.cache()
e_authorship.cache()

display(authorship.triplets)

## Cites

In [0]:
references_exp = (filtered_df.limit(10000).select("doi", "references", "title", F.posexplode(F.col("references")).alias("rank", "reference"))
            .withColumn("rank", F.col("rank") + 1)
            .select("rank", "reference", "title")
            )

references_id = (references_exp.select("reference").dropDuplicates()
           .withColumn("id", F.monotonically_increasing_id())
           .withColumn("id", F.lpad("id", 8, "0"))
           .withColumn("id", F.lpad("id", 9, "R"))
          )

references = references_exp.join(references_id, ["reference"])

display(references.limit(100))

In [0]:
v_cites = references_id.select("id", "reference").withColumn("type", F.lit("reference"))

display(v_cites.limit(100))

In [0]:
df1 = references.select("id", "title")
df2 = df1.alias("a").join(df1.alias("b"), ["title"])
df3 = df2.filter(F.col("a.id") < F.col("b.id"))
e_cites = df3.select(F.col("a.id").alias("src"), F.col("b.id").alias("dst"), "title")

display(e_cites.limit(100))

In [0]:
cites = gf.GraphFrame(v_cites, e_cites)

v_cites.cache()
e_cites.cache()

display(cites.triplets)

## Specialisations

In [0]:
# Domains according to https://images.webofknowledge.com/images/help/WOK/hs_research_domains.html
domains_list = ['Architecture', 'Art', 'Asian Studies', 'Classics', 'Dance', 'Film, Radio & Television', 'History', 'History & Philosophy of Science', 'Literature', 'Music', 'Philosophy', 'Religion', 'Theater', 'Archaeology', 'Area Studies', 'Biomedical Social Sciences', 'Business & Economics', 'Communication', 'Criminology & Penology', 'Cultural Studies', 'Demography', 'Education & Educational Research', 'Ethnic Studies', 'Family Studies', 'Geography', 'Government & Law', 'International Relations', 'Linguistics', 'Mathematical Methods In Social Sciences', 'Psychology', 'Public Administration', 'Social Issues', 'Social Work', 'Sociology', 'Urban Studies', "Women's Studies", 'Agriculture', 'Allergy', 'Anatomy & Morphology', 'Anesthesiology', 'Anthropology', 'Behavioral Sciences', 'Biochemistry & Molecular Biology', 'Biodiversity & Conservation', 'Biophysics', 'Biotechnology & Applied Microbiology', 'Cardiovascular System & Cardiology', 'Cell Biology', 'Critical Care Medicine', 'Dentistry, Oral Surgery & Medicine', 'Dermatology', 'Developmental Biology', 'Emergency Medicine', 'Endocrinology & Metabolism', 'Entomology', 'Environmental Sciences & Ecology', 'Evolutionary Biology', 'Fisheries', 'Food Science & Technology', 'Forestry', 'Gastroenterology & Hepatology', 'General & Internal Medicine', 'Genetics & Heredity', 'Geriatrics & Gerontology', 'Health Care Sciences & Services', 'Hematology', 'Immunology', 'Infectious Diseases', 'Integrative & Complementary Medicine', 'Legal Medicine', 'Marine & Freshwater Biology', 'Mathematical & Computational Biology', 'Medical Ethics', 'Medical Informatics', 'Medical Laboratory Technology', 'Microbiology', 'Mycology', 'Neurosciences & Neurology', 'Nursing', 'Nutrition & Dietetics', 'Obstetrics & Gynecology', 'Oncology', 'Ophthalmology', 'Orthopedics', 'Otorhinolaryngology', 'Paleontology', 'Parasitology', 'Pathology', 'Pediatrics', 'Pharmacology & Pharmacy', 'Physiology', 'Plant Sciences', 'Psychiatry', 'Public, Environmental & Occupational Health', 'Radiology, Nuclear Medicine & Medical Imaging', 'Rehabilitation', 'Reproductive Biology', 'Research & Experimental Medicine', 'Respiratory System', 'Rheumatology', 'Sport Sciences', 'Substance Abuse', 'Surgery', 'Toxicology', 'Transplantation', 'Tropical Medicine', 'Urology & Nephrology', 'Veterinary Sciences', 'Virology', 'Zoology', 'Astronomy & Astrophysics', 'Chemistry', 'Crystallography', 'Electrochemistry', 'Geochemistry & Geophysics', 'Geology', 'Mathematics', 'Meteorology & Atmospheric Sciences', 'Mineralogy', 'Mining & Mineral Processing', 'Oceanography', 'Optics', 'Physical Geography', 'Physics', 'Polymer Science', 'Thermodynamics', 'Water Resources', 'Acoustics', 'Automation & Control Systems', 'Computer Science', 'Construction & Building Technology', 'Energy & Fuels', 'Engineering', 'Imaging Science & Photographic Technology', 'Information Science & Library Science', 'Instruments & Instrumentation', 'Materials Science', 'Mechanics', 'Metallurgy & Metallurgical Engineering', 'Microscopy', 'Nuclear Science & Technology', 'Operations Research & Management Science', 'Remote Sensing', 'Robotics', 'Spectroscopy', 'Telecommunications', 'Transportation']

cr = Crossref()

def getDomain(doi, fos):
    result = [] if not fos else fos
    subj, pos_dom = [], None
    try:
        query = cr.works(ids = doi)['message']
        subj = query['subject']
        pos_dom = query['short-container-title'][0]
    except:
        pass
    for dom in domains_list:
        if dom in result or dom in subj:
            return dom
    if pos_dom:
        return pos_dom
    return None

getDomainUDF = udf(getDomain, StringType())

In [0]:
domains_df = (filtered_df
              .limit(100)
              .select('title', "doi", 'fos')
              .withColumn("domain", getDomainUDF(F.col("doi"), F.col("fos")))
             )

domains_id = (domains_df.select("domain").dropDuplicates()
           .withColumn("id", F.monotonically_increasing_id())
           .withColumn("id", F.lpad("id", 8, "0"))
           .withColumn("id", F.lpad("id", 9, "D"))
          )

In [0]:
domains = domains_df.join(domains_id, ["domain"]).select('id', 'domain', 'title')

v_specs = authors_id.select("id", "name").withColumn("type", F.lit("author"))
v_specs = v_specs.unionByName(domains_id.withColumn("type", F.lit("domain")), allowMissingColumns=True)
display(v_specs.limit(100))

In [0]:
e_specs = (domains.join(auth, ["title"])
             .select(F.col("author_id").alias("src"), F.col("id").alias("dst"))
             .dropDuplicates()
            )

display(e_specs.limit(100))

In [0]:
specialisations_graph = gf.GraphFrame(v_specs, e_specs)

v_specs.cache()
e_specs.cache()

display(specialisations_graph.triplets)

## Publishes

Counts how many papers did each author publish on each year.

In [0]:
years_df = (filtered_df
              #.limit(100)
              .select('title', "year")
           )

years_id = (years_df.select("year").dropDuplicates()
            .withColumn("id", F.col('year'))
            .withColumn("id", F.lpad("id", 5, "Y"))
           )

In [0]:
years = years_df.join(years_id, ["year"]).select('id', 'year', 'title')

v_publishes = vertices.unionByName(years_id.withColumn("type", F.lit("year")), allowMissingColumns=True)

display(v_publishes)

In [0]:
e_publishes = (years.join(auth, ["title"])
             .select(F.col("author_id").alias("src"), F.col("id").alias("dst"))
             .groupBy('src', 'dst').agg(F.count('src').alias('count'))
            )

display(e_publishes)

In [0]:
publishes_graph = gf.GraphFrame(v_publishes, e_publishes)

v_publishes.cache()
e_publishes.cache()

display(publishes_graph.triplets)

## Graph

In [0]:
import numpy as np

def PlotGraph(edge_list, labels, parms, colors=None):
    Gplot=nx.Graph()
    for row in edge_list.select('src','dst').take(1000):
        Gplot.add_edge(row['src'],row['dst'])
    
    fig = plt.figure(1, figsize=(parms[0], parms[1]), dpi=parms[2])
    pos = nx.spring_layout(Gplot, k=parms[5]*1/np.sqrt(len(Gplot.nodes())), iterations=20)
    plt.subplot(111)
    if colors is None:
        nx.draw(Gplot, with_labels=True, font_size=parms[3], node_size=parms[4], font_weight='normal', labels=labels, pos=pos)
    else:
        cols = []
        for node in Gplot.nodes:
            cols.append(colors[node])
        nx.draw(Gplot, with_labels=True, font_size=parms[3], node_size=parms[4], font_weight='normal', labels=labels, node_color=cols, pos=pos)

### Queries

In [0]:
#Co-authorship network of a given author, up to 3 hops
name = "Peter B. Luh"
co_authors_3_df = (co_authors_graph
                       .find("(a)-[e1]->(b); (b)-[e2]->(c); (c)-[e3]->(d)")
                       .filter("a.name == 'Peter B. Luh'")
                       .filter("b.type == 'author'").filter("c.type == 'author'").filter("d.type == 'author'")
                      )

res_vrtx = (co_authors_3_df
              .select("a")
              .union(co_authors_3_df.select("b"))
              .union(co_authors_3_df.select("c"))
              .union(co_authors_3_df.select("d"))
              .distinct()
              .select("a.*", "*")
              .select("id", "name", "type")
             )

res_edges = (co_authors_3_df
              .select("e1")
              .union(co_authors_3_df.select("e2"))
              .union(co_authors_3_df.select("e3"))
              .distinct()
              .select("e1.*", "*")
              .select("src", "dst", "title")
             )

df = res_vrtx.toPandas()
labels = {}
for i, row in df.iterrows():
    labels[row.id] = row["name"]

co_authors_3_graph = gf.GraphFrame(res_vrtx, res_edges)
PlotGraph(co_authors_3_graph.edges, labels, [100, 50, 30, 50, 1500, 0.8])

In [0]:
works_for_query_df = (works_for_graph
                      .find("(a)-[e1]->(b)")
                      #.filter("a.name == 'Peter B. Luh'")
                      .filter("a.type == 'author'").filter("b.type == 'organization'")
                    )

res_vrtx2 = (works_for_query_df
              .select("a")
              .union(works_for_query_df.select("b"))
              .distinct()
              .select("a.*", "*")
              .select("id", "name", "type")
             )

res_edges2 = (works_for_query_df
              .select("e1")
              .distinct()
              .select("e1.*", "*")
              .select("src", "dst")
             )

df2 = res_vrtx2.toPandas()
labels2 = {}
for i, row in df2.iterrows():
    labels2[row.id] = row["name"]
    
colors2 = {}
res_vrtx2_collect = res_vrtx2.collect()
for row in res_vrtx2_collect:
    if row.type == "author":
        colors2[row.id] = "#1f78b4"
    elif row.type == "organization":
        colors2[row.id] = "#ff0000"

works_for_query = gf.GraphFrame(res_vrtx2, res_edges2)
PlotGraph(works_for_query.edges, labels2, [200, 100, 60, 100, 3000, 0.9], colors2)

In [0]:
#Co-authorship network of a given author, up to 3 hops
reference = "53e9bbadb7602d97047f860b"
cites_3_df = (cites
                       .find("(a)-[e1]->(b); (b)-[e2]->(c); (c)-[e3]->(d)")
                       .filter("a.reference == '53e9bbadb7602d97047f860b'")
                       .filter("b.type == 'reference'").filter("c.type == 'reference'").filter("d.type == 'reference'")
                      )

res_vrtx = (cites_3_df
              .select("a")
              .union(cites_3_df.select("b"))
              .union(cites_3_df.select("c"))
              .union(cites_3_df.select("d"))
              .distinct()
              .select("a.*", "*")
              .select("id", "reference", "type")
             )

res_edges = (cites_3_df
              .select("e1")
              .union(cites_3_df.select("e2"))
              .union(cites_3_df.select("e3"))
              .distinct()
              .select("e1.*", "*")
              .select("src", "dst", "title")
             )

df = res_vrtx.toPandas()
labels = {}
for i, row in df.iterrows():
    labels[row.id] = row["reference"]

cite_3_graph = gf.GraphFrame(res_vrtx, res_edges)
PlotGraph(cite_3_graph.edges, labels, [200, 100, 60, 100, 3000, 1.5])

### PageRank

In [0]:
# Run PageRank algorithm
cite_page_rank = cites.pageRank(resetProbability=0.15, maxIter=10)

In [0]:
display(cite_page_rank.vertices.select("id", "pagerank").limit(100))

### Graph analytics

In [0]:
#strongly connected components for co-authorship graph
result = co_authors_graph.stronglyConnectedComponents(maxIter=10)
display(result.select("id", "component").orderBy("component"))