Creating User-Defined Functions

In [0]:
# Defining a function for splitting datasets
 
def splitting_data(dataset, delimiter):
    return dataset.map(lambda line: line.split(delimiter))

In [0]:
# Defining a function for removing headers 

def removing_header(rdd):
    header = rdd.first()
    return rdd.filter(lambda row: row != header)

### Loading Clinical data

In [0]:
# Loading the CSV file as an RDD
clinicaltrialRDD = sc.textFile("/FileStore/tables/clinicaltrial_2021.csv")

# Removing the header
clinicaltrialRDD = removing_header(clinicaltrialRDD)

# Removing the delimiter
clinicaltrialRDD = splitting_data(clinicaltrialRDD, "|")

In [0]:
from pyspark.sql.types import *

# Defining the schema for clinical trial data
mySchema = StructType([
    StructField("Id", StringType(), True),
    StructField("Sponsor", StringType(), True),
    StructField("Status", StringType(), True),
    StructField("Start", StringType(), True),
    StructField("Completion", StringType(), True),
    StructField("Type", StringType(), True),
    StructField("Submission",StringType(), True),
    StructField("Conditions", StringType(), True),
    StructField("Interventions", StringType(), True)])

In [0]:
# Converting the RDD to a DataFrame with the specified schema
clinicaltrialDF = spark.createDataFrame(clinicaltrialRDD, mySchema)

In [0]:
# Printing the schema
clinicaltrialDF.printSchema()

In [0]:
# Displaying the first 10 rows of the clinical trial data
clinicaltrialDF.display(10)

### Loading Pharma data

In [0]:
# Loading the pharma dataset
pharmaDF = spark.read.csv("/FileStore/tables/pharma.csv", header=True, inferSchema=True)

In [0]:
# Printing the schema of the Pharma DataFrame
pharmaDF.printSchema()

In [0]:
# Displaying the first 10 rows of the Pharma data
pharmaDF.display(10)

## Question 1

In [0]:
# Selecting distinct values of the 'Id' column and counting it
studiesCountDF = clinicaltrialDF\
                    .select("Id")\
                    .distinct()\
                    .count()

In [0]:
# Printing the result 
print("The number of distinct studies conducted were:", studiesCountDF)

## Question 2

In [0]:
from pyspark.sql.functions import count

# Grouping the DataFrame by the 'Type' column and aggregating the count of each type
studytypesDF = clinicaltrialDF\
                    .groupBy("Type")\
                    .agg(count("Type").alias("Frequency"))

In [0]:
# Sorting the resulting DataFrame by the 'Frequency' column in descending order
studytypesDF = studytypesDF.orderBy("Frequency", ascending=False)

In [0]:
# Displaying the result
studytypesDF.display()

##### Visualization

In [0]:
import matplotlib.pyplot as plt
import pandas as pd

# Converting the DataFrame to a Pandas DataFrame
studytypesPandasDF = studytypesDF.toPandas()

# Creating a bar chart using the resulting Pandas DataFrame
fig = plt.figure(figsize=(10, 5))
studytypesPandasDF.plot(kind='barh', x='Type', y='Frequency')

# Setting the title and axes labels
plt.title("Number of Clinical Trials by Type in 2021")
plt.xlabel("Type")
plt.ylabel("Frequency")

# Rotating the x-axis labels for better visibility
plt.xticks(rotation=90)

# Displaying the bar chart
plt.show()

## Question 3

In [0]:
# Importing necessary functions
from pyspark.sql.functions import split, explode, count

# Splitting the 'Conditions' column 
conditionsDF = clinicaltrialDF \
               .withColumn("condition", explode(split("Conditions", ",")))

In [0]:
# Filtering out any empty conditions
conditionsDF = conditionsDF.filter("condition != ''")

In [0]:
# Grouping by the "condition" column and counting the frequency of each condition
topConditionsDF = conditionsDF \
                  .groupBy("condition") \
                  .agg(count("*").alias("frequency"))

In [0]:
# Ordering the resulting DataFrame by frequency and selecting the top 5 rows
topConditionsDF = topConditionsDF.orderBy("frequency", ascending=False).limit(5)

In [0]:
# Displaying the resulting DataFrame
topConditionsDF.display()

##### Visualization

In [0]:
import seaborn as sns
import matplotlib.pyplot as plt

# Converted the PySpark DataFrame to a Pandas DataFrame
topConditionsPandasDF = topConditionsDF.toPandas()

# Created a bar chart using Seaborn
sns.set_style("whitegrid")
plt.figure(figsize=(10, 5))
ax = sns.barplot(x="frequency", y="condition", data=topConditionsPandasDF, color="b")

# Set the chart title and axes labels
plt.title("Top 5 Conditions in Clinical Trials in 2021")
plt.xlabel("Frequency")
plt.ylabel("Condition")

# Displayed the chart
plt.show()


## Question 4

In [0]:
from pyspark.sql.functions import count

# Joining the Clinicaltrial and Pharma DataFrames using a left join 
nonPharmaSponsors = clinicaltrialDF \
                    .join(pharmaDF, clinicaltrialDF.Sponsor == pharmaDF["Parent_Company"], "left_anti")

In [0]:
# Grouping and counting the clinical trials sponsored by non-pharma companies
topSponsors = nonPharmaSponsors\
               .groupBy("Sponsor")\
               .agg(count("Sponsor").alias("sponsored_trials"))\
               .orderBy("sponsored_trials", ascending=False)\
               .take(10)

In [0]:
# Printing the top 10 sponsors 
print("These are the top 10 sponsors that are not pharmaceutical companies:")
for sponsor, count_col in topSponsors:
    print(sponsor, count_col)

##### Visualization

In [0]:
import pandas as pd
import matplotlib.pyplot as plt

# Convert topSponsors list to a pandas DataFrame
topSponsorsDF = pd.DataFrame(topSponsors, columns=["Sponsor", "sponsored_trials"])

# Create a line chart of sponsor counts
plt.plot(topSponsorsDF["Sponsor"], topSponsorsDF["sponsored_trials"])
plt.xticks(rotation=90)
plt.title("Top 10 Non-Pharmaceutical Sponsors")
plt.xlabel("Sponsors")
plt.ylabel("Number of sponsored trials")
plt.show()


## Question 5

In [0]:
# Importing necessary functions
from pyspark.sql.functions import to_date, year, month, date_format

# Filtering to include only trials with a status of "Completed" and a completion year of 2021.
completedTrialsDF = clinicaltrialDF \
                        .filter((clinicaltrialDF.Status == "Completed") \
                                & (year(to_date(clinicaltrialDF.Completion, "MMM yyyy")) == 2021))

In [0]:
# Grouping tby month of completion and counting the number of trials for each month.
completedCountsDF = completedTrialsDF \
                    .groupBy(date_format(to_date(completedTrialsDF.Completion, "MMM yyyy"), "MMM").alias("month")) \
                    .count() \
                    .orderBy(month("month"))

In [0]:
# Creating a dictionary mapping month abbreviations
monthDictionary = {"Jan": 1, "Feb": 2, "Mar": 3, "Apr": 4, "May": 5, "Jun": 6,\
                   "Jul": 7, "Aug": 8, "Sep": 9, "Oct": 10, "Nov": 11, "Dec": 12}

In [0]:
# Sorting the results using the creatd month dictionary.
results = completedCountsDF.collect()
results = sorted(results, key=lambda x: monthDictionary[x[0]])

In [0]:
# Printed the results
for (month, count) in results:
    print("{:<3} {}".format(month, count))

##### Visualization

In [0]:
import matplotlib.pyplot as plt

# Extract the month names and counts into separate lists
months = [x[0] for x in results]
counts = [x[1] for x in results]

# Plot the data as a bar chart
plt.bar(months, counts)

# Set the x-label and y-label
plt.xlabel("Month")
plt.ylabel("Number of Completed Trials")

# Set the title
plt.title("Completed Clinical Trials in 2021")

# Show the plot
plt.show()


#### Further Analysis 2 (DataFrame)

Investigate the distribution of trial statuses across different types of studies. 

What are the top 10 Type-Status combinations and their respective counts in the clinical trial dataset?

In [0]:
from pyspark.sql.functions import count

# grouping by Type/Status and counting the number of trials for each combination
statusDF = clinicaltrialDF.groupBy("Type", "Status").agg(count("*").alias("Count"))

# sorting the count in descending order
statusCountDF = statusDF.orderBy("Count", ascending=False)

# Selecting the top 10 Type/Status combinations
statusCountDF.show(10)

##### Visualization

In [0]:
import pandas as pd
import matplotlib.pyplot as plt

# Converting statusCountDF to a pandas DataFrame
statusCountPandasDF = statusCountDF.toPandas()

# Extracting the top 10 rows
top10 = statusCountPandasDF.head(10)

# Pivoting the data to create a grouped bar chart
groupedData = top10.pivot(index="Type", columns="Status", values="Count")

# Creating the chart
groupedData.plot(kind="bar")
plt.title("Top 10 Type-Status Combinations")
plt.xlabel("Type")
plt.ylabel("Count")
plt.show()


#### Extra Feature

Column Count Function

In [0]:
from pyspark.sql.functions import count

def columnCount(df, col_name):
    
    # Grouping the data by input column and counting the rows
    countDF = df.groupBy(col_name).agg(count("*").alias("count"))

    # Sorting the data in descending order
    countDF = countDF.orderBy("count", ascending=False)

    return countDF


#### Question 1
Identify the relationship between the origin of a parent company (HQ_Country_of_Parent) and the number of violations?

In [0]:
# Calling the function on the Pharma dataframe
violationsDF = columnCount(pharmaDF, "HQ_Country_of_Parent")

# Displaying the results
violationsDF.show()

In [0]:
import seaborn as sns

sns.set_style('darkgrid')

plt.figure(figsize=(14, 8))
sns.barplot(data=violationsDF.toPandas(), x='HQ_Country_of_Parent', y='count')
plt.title('Number of Violations by Country of Parent Compamy')
plt.xlabel('Country of Parent Company')
plt.ylabel('Number of Violations')
plt.show()


### Question 2
Investigate the distribution of violations during the years. 
Which year experienced the biggest spike in violations?

In [0]:
# Calling the function ColumnCount function on the dataframe
yearly_violations = columnCount(pharmaDF, "Penalty_Year")
yearly_violations.show()

In [0]:
# converting the DataFrame to a Pandas DataFrame
yearly_violationspd = yearly_violations.toPandas()

# plotting the line chart with a trend line
fig, ax = plt.subplots(figsize=(12,6))
sns.regplot(data=yearly_violationspd, x='Penalty_Year', y='count', ax=ax)
plt.title("Trend of violations made per year")
plt.xlabel("Year")
plt.ylabel("Number of Violations")
plt.show()
