In [None]:
import os
from pyspark import SparkContext, SQLContext
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import StringType, FloatType, ArrayType

In [None]:
class py_or_udf:
    def __init__(self, returnType=StringType()):
        self.spark_udf_type = returnType
        
    def __call__(self, func):
        def wrapped_func(*args, **kwargs):
            if any([isinstance(arg, F.Column) for arg in args]) or \
                any([isinstance(vv, F.Column) for vv in kwargs.values()]):
                return F.udf(func, self.spark_udf_type)(*args, **kwargs)
            else:
                return func(*args, **kwargs)
            
        return wrapped_func

def format_name(author):
    middle_name = " ".join(author['middle'])
    
    if author['middle']:
        return " ".join([author['first'], middle_name, author['last']])
    else:
        return " ".join([author['first'], author['last']])

def format_affiliation(affiliation):
    text = []
    location = affiliation.get('location')
    if location:
        text.extend(list(affiliation['location'].values()))
    
    institution = affiliation.get('institution')
    if institution:
        text = [institution] + text
    return ", ".join(text)

@py_or_udf()
def format_authors(authors, with_affiliation=False):
    name_ls = []
    
    for author in authors:
        name = format_name(author)
        if with_affiliation:
            affiliation = format_affiliation(author['affiliation'])
            if affiliation:
                name_ls.append(f"{name} ({affiliation})")
            else:
                name_ls.append(name)
        else:
            name_ls.append(name)
    
    return ", ".join(name_ls)

@py_or_udf()
def format_body(body_text):
    texts = [(di['section'], di['text']) for di in body_text]
    texts_di = {di['section']: "" for di in body_text}
    
    for section, text in texts:
        texts_di[section] += text

    body = ""

    for section, text in texts_di.items():
        body += section
        body += "\n\n"
        body += text
        body += "\n\n"
    
    return body

@py_or_udf()
def format_bib(bibs):
    if type(bibs) == dict:
        bibs = list(bibs.values())
    formatted = []
    for bib in bibs:
        title_str = str(bib['title'])
        authors_str = format_authors(bib['authors'], False)
        formatted_ls = [title_str, authors_str] + [str(bib[k]) for k in ['venue', 'year']]
        formatted.append(", ".join(formatted_ls))

    return "; ".join(formatted)

In [None]:
spark = SparkSession \
    .builder \
    .master('local[*]') \
    .getOrCreate()

sc = spark.sparkContext

In [None]:
files_path_bio = "/data/biorxiv_medrxiv/biorxiv_medrxiv/"
list_files = os.listdir(files_path_bio)

In [None]:
input_files = spark.read.option("multiline",True).json(files_path_bio)

In [None]:
input_files.printSchema()

In [None]:
input_files.count()

In [None]:
input_files.describe()

In [None]:
input_files.columns

In [None]:
input_files_processed = (input_files.withColumn('body_text', format_body(F.col('body_text')))
                                    .withColumn('abstract', format_body(F.col('abstract')))
                                    .withColumn('title', F.col('metadata.title'))
                                    .withColumn('authors', format_authors(F.col('metadata.authors')))
                                    #.withColumn('bib_entries', format_bib(F.col('bib_entries')))
                                    .drop('back_matter','metadata','ref_entries', 'bib_entries'))

In [None]:
input_files_processed.columns

In [None]:
input_files_processed.printSchema()

In [None]:
input_files_processed.show()