#Defining reusuable codes for Clinical Trial datasets
--------------------------------------------------
#Instructions: Upload the clinical trial and pharma zip file. 
#Change the year in the fileroot (in cmd 2) to the uploaded clinicaltrial file name year

In [0]:
#Defining reusuable codes for Clinical Trial datasets 
fileroot = "clinicaltrial_2023"

In [0]:
#Extracting the year of clinical trial
fileName = fileroot.split("_")
trial_Year = fileName[1]
trial_Year

In [0]:
#Copying the file to /tmp directory on driver node
dbutils.fs.cp("/FileStore/tables/" + fileroot + ".zip", "file:/tmp/")

In [0]:
#Checking the fileroot.zip in DBFS directory
dbutils.fs.ls("/FileStore/tables/")

In [0]:
#Checking the fileroot.zip in local tmp directory
dbutils.fs.ls("file:/tmp/")

In [0]:
#Making 'fileroot' accessible by the command line
import os
os.environ['fileroot'] = fileroot

In [0]:
%sh
# Use the '-o' option to overwrite existing files without prompting
unzip -o -d /tmp /tmp/$fileroot.zip

In [0]:
#Moving the unzipped file to the DBFS directory
dbutils.fs.mv("file:/tmp/" + fileroot +".csv" , "/FileStore/tables/"+ fileroot +".csv", True )

In [0]:
#Checking the unzipped file in DBFS directory
dbutils.fs.ls("/FileStore/tables/")

In [0]:
#Checking the DBFS directory
dbutils.fs.ls ("/FileStore/tables/" )

In [0]:
file_path = "/FileStore/tables/"+ fileroot +".csv"
print(dbutils.fs.head(file_path))

In [0]:
dbutils.fs.head("/FileStore/tables/"+fileroot+".csv")

In [0]:
from pyspark import SparkContext
from pyspark.sql import SparkSession

# Create a Spark session
spark = SparkSession.builder \
    .appName("Create RDD from CSV") \
    .getOrCreate()

# Get the SparkContext
sc = spark.sparkContext

# File path of the CSV file
file_path = "/FileStore/tables/" + fileroot + ".csv"

# Preview the contents of the CSV file
file_preview = dbutils.fs.head(file_path)
print("Preview of the CSV file:")
print(file_preview)

# Create an RDD from the CSV file
clinicaltrial_rdd = sc.textFile(file_path)

# Display the first few lines of the RDD
print("First few lines of the RDD:")
for line in clinicaltrial_rdd.take(5):
    print(line)

# You can continue with further processing of the RDD


In [0]:
from pyspark import SparkContext
from pyspark.sql.functions import trim, regexp_replace
import re

# Assuming sc is your SparkContext and the RDD is already created
# sc = SparkContext.getOrCreate()

# Function to clean the first field in a line
def clean_first_field(line):
    # Split the line into fields using tab as the delimiter
    fields = line.split('\t')
    
    # Clean the first field (first column) by removing leading and trailing quotation marks and trimming spaces
    first_field = fields[0]
    first_field = re.sub(r'^"|"$', '', first_field)  # Remove leading and trailing quotation marks
    first_field = first_field.strip()  # Trim leading and trailing spaces
    
    # Replace the cleaned first field in the list of fields
    fields[0] = first_field
    
    # Reconstruct the line by joining the fields with the tab delimiter
    cleaned_line = '\t'.join(fields)
    
    return cleaned_line

# Function to clean the 14th field in a line
def clean_14th_field(line):
    # Split the line into fields using tab as the delimiter
    fields = line.split('\t')
    
    # Check if the line has at least 14 fields
    if len(fields) < 14:
        # Return the line unchanged if there aren't enough fields
        return line
    
    # Clean the 14th field (14th column) by removing trailing commas and quotation marks
    # `fields[13]` represents the 14th column since indexing starts at 0
    fields[13] = re.sub(r'("|,)+$', '', fields[13])  # Remove trailing quotation marks and commas
    fields[13] = fields[13].strip()  # Trim leading and trailing spaces
    
    # Reconstruct the line by joining the fields with the tab delimiter
    cleaned_line = '\t'.join(fields)
    
    return cleaned_line

# Perform operations based on the value of fileroot
def process_clinicaltrial_rdd(clinicaltrial_rdd, fileroot):
    if fileroot == "clinicaltrial_2023":
        # Clean the first field in each line
        clinicaltrial2023RDD = clinicaltrial_rdd.map(clean_first_field)
        
        # Clean the 14th field in each line
        clinicaltrial2023RDD = clinicaltrial2023RDD.map(clean_14th_field)
        
        # Discard the header from the RDD
        clinicaltrial2023RDD_no_header = clinicaltrial2023RDD.zipWithIndex().filter(lambda line_index: line_index[1] != 0).map(lambda line_index: line_index[0])
        
        # Display the first few cleaned lines in the RDD (for demonstration purposes)
        for line in clinicaltrial2023RDD_no_header.take(5):
            print(line)

    elif fileroot == "clinicaltrial_2020" or "clinicaltrial_2021":
        # Extract the header row from the RDD
        header = clinicaltrial_rdd.first()
        
        # Remove the header row from the RDD
        clinicaltrial_rdd_no_header = clinicaltrial_rdd.filter(lambda line: line != header)
        
        # Display the first few lines without header (for demonstration purposes)
        for line in clinicaltrial_rdd_no_header.take(5):
            print(line)


process_clinicaltrial_rdd(clinicaltrial_rdd, fileroot)


#Question 1.The number of studies in the dataset. You must ensure that you explicitly check distinct studies.

In [0]:
from pyspark import SparkContext
from pyspark.sql import SQLContext

# Assuming sc is your SparkContext and the RDDs are already created
# sc = SparkContext.getOrCreate()
sqlContext = SQLContext(sc)

def process_rdd(clinicaltrial_rdd, fileroot):
    """
    Process the RDD based on the fileroot and perform the respective operations.

    Args:
        clinicaltrial_rdd (RDD): The RDD containing clinical trial data.
        fileroot (str): The file root string to determine which operations to perform.

    Returns:
        None
    """
    if fileroot == "clinicaltrial_2023":
        # Assuming clinicaltrial2023RDD_no_header is created by removing the header line from the RDD
        clinicaltrial2023RDD_no_header = clinicaltrial_rdd.zipWithIndex().filter(lambda line_index: line_index[1] != 0).map(lambda line_index: line_index[0])
        
        # Count the distinct number of studies based on the 'Id' column
        distinct_studies_id = clinicaltrial2023RDD_no_header.map(lambda line: line.split('\t')[0]).distinct().count()

        # Count the distinct number of studies based on the 'Study Title' column
        distinct_studies_title = clinicaltrial2023RDD_no_header.map(lambda line: line.split('\t')[1]).distinct().count()

        # Display the results
        print(f"Number of distinct studies based on 'Id': {distinct_studies_id}")
        print(f"Number of distinct studies based on 'Study Title': {distinct_studies_title}")

    elif fileroot == "clinicaltrial_2020" or "clinicaltrial_2021":
        # Discard the header line from the RDD
        clinicaltrial_rdd_no_header = clinicaltrial_rdd.zipWithIndex().filter(lambda line_index: line_index[1] != 0).map(lambda line_index: line_index[0])

        # Count the distinct studies by extracting the first element from each line
        count = clinicaltrial_rdd_no_header.map(lambda line: line.split("|")[0]).distinct().count()

        # Print the number of distinct studies found in the RDD
        print(f"The number of distinct studies in {fileroot} dataset is {count}.")

process_rdd(clinicaltrial_rdd, fileroot)


In [0]:
from pyspark import SparkContext

def process_rdd(clinicaltrial_rdd, fileroot, trial_year=None):
    """
    Process the RDD based on the fileroot and perform the respective operations.

    Args:
        clinicaltrial_rdd (RDD): The RDD containing clinical trial data.
        fileroot (str): The file root string to determine which operations to perform.
        trial_year (int, optional): The trial year for clinicaltrial_2020 or clinicaltrial_2021. Defaults to None.

    Returns:
        None
    """
    if fileroot == "clinicaltrial_2023":
        # Assuming clinicaltrial2023RDD is already the RDD to be processed

        # Get the distinct Ids
        distinct_ids_rdd = clinicaltrial_rdd.map(lambda line: line.split('\t')[0]).distinct()

        # Get the first 20 distinct Ids
        first_20_distinct_ids = distinct_ids_rdd.take(20)

        # Display the first 20 distinct Ids
        print("First 20 distinct Ids:")
        for id in first_20_distinct_ids:
            print(id)

        # Get the distinct study titles
        distinct_study_titles_rdd = clinicaltrial_rdd.map(lambda line: line.split('\t')[1]).distinct()

        # Get the first 20 distinct study titles
        first_20_distinct_study_titles = distinct_study_titles_rdd.take(20)

        # Display the first 20 distinct study titles
        print("First 20 distinct study titles:")
        for title in first_20_distinct_study_titles:
            print(title)

    elif fileroot == "clinicaltrial_2020" or "clinicaltrial_2021":
        # List down the first 100 distinct studies
        distinct_studies = clinicaltrial_rdd.map(lambda line: line.split("|")[0]).distinct().take(20)
        print(f"The ID of first 100 distinct studies in the clinical trial year {trial_year}:")
        for study in distinct_studies:
            print(study)


# For clinicaltrial_2023
process_rdd(clinicaltrial_rdd, fileroot)

In [0]:
from pyspark import SparkContext
from pyspark.sql import SQLContext
import matplotlib.pyplot as plt

# Assuming sc is your SparkContext and the RDDs are already created
# sc = SparkContext.getOrCreate()
sqlContext = SQLContext(sc)

def process_rdd(clinicaltrial_rdd, fileroot):
    """
    Process the RDD based on the fileroot, perform respective operations, and plot results.

    Args:
        clinicaltrial_rdd (RDD): The RDD containing clinical trial data.
        fileroot (str): The file root string to determine which operations to perform.

    Returns:
        None
    """
    if fileroot == "clinicaltrial_2023":
        # Assuming clinicaltrial2023RDD_no_header is created by removing the header line from the RDD
        clinicaltrial2023RDD_no_header = clinicaltrial_rdd.zipWithIndex().filter(lambda line_index: line_index[1] != 0).map(lambda line_index: line_index[0])
        
        # Count the distinct number of studies based on the 'Id' column
        distinct_studies_id = clinicaltrial2023RDD_no_header.map(lambda line: line.split('\t')[0]).distinct().count()

        # Count the distinct number of studies based on the 'Study Title' column
        distinct_studies_title = clinicaltrial2023RDD_no_header.map(lambda line: line.split('\t')[1]).distinct().count()

        # Print the results
        print(f"Number of distinct studies based on 'Id': {distinct_studies_id}")
        print(f"Number of distinct studies based on 'Study Title': {distinct_studies_title}")

        # Plot the data
        data = [
            ('Distinct Studies (Id)', distinct_studies_id),
            ('Distinct Studies (Study Title)', distinct_studies_title)
        ]

        # Unzip the data
        categories, counts = zip(*data)

        # Create a bar plot
        plt.figure(figsize=(10, 6))
        bars = plt.bar(categories, counts, color=['blue', 'green'])

        # Add labels and title
        plt.xlabel('Category')
        plt.ylabel('Count')
        plt.title('Distinct Number of Studies Based on "Id" and "Study Title"')

        # Add hover-over data: annotate bars
        for bar in bars:
            yval = bar.get_height()
            plt.text(bar.get_x() + bar.get_width() / 2, yval, round(yval),
                     ha='center', va='bottom')

        # Display the plot
        plt.show()

    elif fileroot == "clinicaltrial_2020" or "clinicaltrial_2021":
        # Remove header line
        clinicaltrial_rdd_no_header = clinicaltrial_rdd.zipWithIndex().filter(lambda line_index: line_index[1] != 0).map(lambda line_index: line_index[0])

        # Count the distinct studies by extracting the first element from each line
        count = clinicaltrial_rdd_no_header.map(lambda line: line.split("|")[0]).distinct().count()

        # Print the number of distinct studies found in the RDD
        print(f"The number of distinct studies in {fileroot} dataset is {count}.")

        # Plot the results for the count of distinct studies
        plt.figure(figsize=(10, 6))
        plt.bar(f"{fileroot}", [count], color='blue')
        plt.xlabel('Dataset')
        plt.ylabel('Count')
        plt.title(f'Distinct Studies in {fileroot} Dataset')
        plt.show()


# Example usage of the process_rdd function
process_rdd(clinicaltrial_rdd, fileroot)

#Question 2. You should list all the types (as contained in the Type column) of studies in the dataset along with the frequencies of each type. These should be ordered from most frequent to least frequent.

In [0]:
from pyspark import SparkContext
from pyspark.sql import SQLContext
import matplotlib.pyplot as plt
import plotly.graph_objects as go

# Initialize SparkContext and SQLContext
# sc = SparkContext.getOrCreate()
sqlContext = SQLContext(sc)

def process_rdd(clinicaltrial_rdd, fileroot):
    """
    Process the RDD based on the fileroot, perform respective operations, and plot results.

    Args:
        clinicaltrial_rdd (RDD): The RDD containing clinical trial data.
        fileroot (str): The file root string to determine which operations to perform.

    Returns:
        None
    """
    if fileroot == "clinicaltrial_2023":
        # Filter out the header line from the RDD using zipWithIndex
        clinicaltrial2023RDD_no_header = clinicaltrial_rdd.zipWithIndex().filter(lambda line_index: line_index[1] > 0).map(lambda line_index: line_index[0])

        # Define a function to extract the "Type" column
        def extract_type(line):
            fields = line.split('\t')
            if len(fields) > 10:  # Ensure there are at least 11 fields
                return fields[10]  # Return the "Type" column (index 10)
            else:
                return None  # Return None if the line does not have enough fields

        # Get the "Type" column while filtering out invalid lines
        type_rdd = clinicaltrial2023RDD_no_header.map(extract_type).filter(lambda x: x is not None)

        # Perform the counting and sorting
        type_counts_rdd = type_rdd.map(lambda type_: (type_, 1)).reduceByKey(lambda a, b: a + b)
        type_counts_ordered_rdd = type_counts_rdd.sortBy(lambda x: -x[1])
        results = type_counts_ordered_rdd.collect()

        # Display the results
        print("Type frequencies ordered from most frequent to least frequent:")
        for type_, count in results:
            print(f"Type: {type_}, Count: {count}")

        # Prepare data for plotting
        type_names = []
        type_counts = []
        for type_, count in results:
            type_names.append(type_)
            type_counts.append(count)

        # Plotting the type frequencies
        plt.figure(figsize=(12, 8))
        plt.bar(type_names, type_counts, color='steelblue')
        plt.xlabel('Type')
        plt.ylabel('Frequency')
        plt.title('Type Frequencies Ordered from Most Frequent to Least Frequent')
        plt.xticks(rotation=45, ha='right')  # Rotate x-axis labels for readability
        plt.show()

    elif fileroot == "clinicaltrial_2020" or "clinicaltrial_2021":
        # Remove the header line from the RDD
        clinicaltrial_rdd_no_header = clinicaltrial_rdd.zipWithIndex().filter(lambda line_index: line_index[1] != 0).map(lambda line_index: line_index[0])

        # Extract the type column from each row of the RDD
        type_rdd = clinicaltrial_rdd_no_header.map(lambda row: row.split("|")[5])

        # Retrieve all distinct types
        type_distinct = type_rdd.distinct().collect()

        # Creating key-value pairs with type as key and value as 1
        type_keyvaluePair = type_rdd.map(lambda types: (types, 1))

        # Perform the counting and sorting
        all_Type = type_keyvaluePair.reduceByKey(lambda a, b: a + b).takeOrdered(10, key=lambda x: -x[1])

        # Display the obtained result
        print("All types with their frequencies:")
        for types, count in all_Type:
            print(types, count)

        # Data preparation for Plotly plot
        types = [types for types, _ in all_Type]
        counts = [count for _, count in all_Type]

        # Creating the horizontal bar plot
        fig = go.Figure(go.Bar(
            x=counts,
            y=types,
            orientation='h',  # Horizontal orientation
            marker=dict(color='skyblue')  # Bar color
        ))

        # Adding hover-over tooltips
        fig.update_traces(hovertemplate='<b>Type</b>: %{y}<br><b>Frequency</b>: %{x}')

        # Setting layout
        fig.update_layout(
            title='Frequency of Each Type',
            xaxis_title='Frequency',
            yaxis_title='Types',
            height=600,
            width=800
        )

        # Show the plot
        fig.show()

# Process the RDD based on the file root
process_rdd(clinicaltrial_rdd, fileroot)

Question 3. The top 5 conditions (from Conditions) with their frequencies.

In [0]:
from pyspark import SparkContext
from pyspark.sql import SQLContext
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import pandas as pd

# Initialize SparkContext and SQLContext
# sc = SparkContext.getOrCreate()
sqlContext = SQLContext(sc)

def process_conditions(clinicaltrial_rdd, fileroot):
    """
    Process the "Conditions" column of the RDD based on the fileroot,
    perform respective operations, and plot results.

    Args:
        clinicaltrial_rdd (RDD): The RDD containing clinical trial data.
        fileroot (str): The file root string to determine which operations to perform.

    Returns:
        None
    """
    if fileroot == "clinicaltrial_2023":
        # Filter out the header line from the RDD using zipWithIndex
        clinicaltrial2023RDD_no_header = clinicaltrial_rdd.zipWithIndex().filter(lambda line_index: line_index[1] > 0).map(lambda line_index: line_index[0])

        # Define a function to extract the "Conditions" column and handle missing values
        def extract_condition(line):
            fields = line.split('\t')
            if len(fields) > 4:  # Ensure there are at least 5 fields
                condition = fields[4]  # 'Conditions' column is at index 4
                return condition if condition else None  # Return None if the condition is empty
            else:
                return None  # Return None if the line does not have enough fields

        # Get the "Conditions" column from the RDD and filter out invalid lines
        condition_rdd = clinicaltrial2023RDD_no_header.map(extract_condition).filter(lambda x: x is not None)

        # Create key-value pairs where the key is the condition and the value is 1
        condition_counts_rdd = condition_rdd.map(lambda condition: (condition, 1))

        # Reduce the RDD by key (condition) to count occurrences
        condition_frequencies_rdd = condition_counts_rdd.reduceByKey(lambda a, b: a + b)

        # Sort the results from most frequent to least frequent
        sorted_condition_frequencies_rdd = condition_frequencies_rdd.sortBy(lambda x: -x[1])

        # Take the top 5 conditions
        top_5_conditions = sorted_condition_frequencies_rdd.take(5)

        # Display the top 5 conditions with their frequencies
        print("Top 5 conditions and their frequencies:")
        for condition, count in top_5_conditions:
            print(f"Condition: {condition}, Count: {count}")

        # Convert top 5 conditions and their frequencies to a pandas DataFrame
        data = pd.DataFrame(top_5_conditions, columns=['Condition', 'Count'])

        # Plot a scatter plot
        plt.figure(figsize=(10, 6))
        plt.scatter(data['Condition'], data['Count'], color='blue', alpha=0.7)

        # Add labels and title
        plt.xlabel('Condition')
        plt.ylabel('Frequency')
        plt.title('Top 5 Conditions and Their Frequencies')

        # Add text labels to each point in the plot
        for i, (condition, count) in enumerate(top_5_conditions):
            plt.text(data['Condition'][i], data['Count'][i], f'{count}', ha='right', va='bottom')

        # Display the plot
        plt.show()

    elif fileroot == "clinicaltrial_2020" or "clinicaltrial_2021":
        # Discard the header line from the RDD
        clinicaltrial_rdd_no_header = clinicaltrial_rdd.zipWithIndex().filter(lambda line_index: line_index[1] != 0).map(lambda line_index: line_index[0])

        # Extract the type column from each row of the RDD
        condition_rdd = clinicaltrial_rdd_no_header.map(lambda row: row.split("|")[7])

        # Create key-value pairs with the condition as the key and value as 1
        condition_frequency = condition_rdd.flatMap(lambda conditions: [(condition.strip(), 1) for condition in conditions.split(",")]).filter(lambda x: x[0] != "")

        # Add up the frequencies for each condition and get the top 5 conditions
        condition_top5 = condition_frequency.reduceByKey(lambda a, b: a + b).takeOrdered(5, key=lambda x: -x[1])

        # Display the corresponding result
        print("Top 5 conditions and their frequencies:")
        for condition, count in condition_top5:
            print(condition, count)

        # Data preparation for Plotly plot
        conditions = [condition for condition, _ in condition_top5]
        counts = [count for _, count in condition_top5]

        # Creating the bar plot
        fig = go.Figure(go.Bar(
            x=conditions,
            y=counts,
            marker=dict(color='skyblue'),  # Bar color
            hoverinfo='x+y',  # Show both x and y on hover
        ))

        # Setting layout
        fig.update_layout(
            title='Top 5 Conditions and Their Frequencies',
            xaxis_title='Conditions',
            yaxis_title='Frequency',
            height=600,  # Set the height of the graph
        )

        # Show the plot
        fig.show()


# Process the RDD based on the file root
process_conditions(clinicaltrial_rdd, fileroot)

In [0]:
from pyspark import SparkContext
from pyspark.sql import SQLContext
import plotly.express as px

# Initialize SparkContext and SQLContext
# sc = SparkContext.getOrCreate()
sqlContext = SQLContext(sc)

def process_conditions(clinicaltrial_rdd, fileroot):
    """
    Process the 'Conditions' column of the RDD based on the fileroot,
    perform respective operations, and plot results.

    Args:
        clinicaltrial_rdd (RDD): The RDD containing clinical trial data.
        fileroot (str): The file root string to determine which operations to perform.

    Returns:
        None
    """
    if fileroot == "clinicaltrial_2023":
        # Step 1: Define a function to split the 'Conditions' column by pipe ('|')
        def split_conditions(line):
            # Split the line by tab character to get the fields
            fields = line.split('\t')
            # Ensure the line has at least 5 fields and get the 'Conditions' field
            if len(fields) > 4:
                conditions = fields[4]  # 'Conditions' column is at index 4
                # Split the 'Conditions' field by pipe ('|') and return as a list
                return [(field,) for field in conditions.split('|')]
            else:
                return []

        # Step 2: Use flatMap to split and explode the 'Conditions' column
        clinicaltrial2023RDD_no_header = clinicaltrial_rdd.zipWithIndex().filter(lambda line_index: line_index[1] > 0).map(lambda line_index: line_index[0])
        exploded_conditions_rdd = clinicaltrial2023RDD_no_header.flatMap(split_conditions)

        
        # Step 1: Group the exploded RDD by 'Conditions' and count occurrences
        condition_counts_rdd = exploded_conditions_rdd.map(lambda condition: (condition, 1)).reduceByKey(lambda a, b: a + b)

        # Step 2: Sort the RDD by count in descending order
        condition_counts_sorted_rdd = condition_counts_rdd.sortBy(lambda x: -x[1])

        # Step 3: Take the top 5 conditions
        top_5_conditions_rdd = condition_counts_sorted_rdd.take(5)

        # Display the top 5 conditions with their frequencies
        print("Top 5 conditions and their frequencies:")
        for condition, count in top_5_conditions_rdd:
            print(f"Condition: {condition[0]}, Count: {count}")

        # Convert conditions and counts to a DataFrame
        data = {
            'Condition': [condition[0] for condition, count in top_5_conditions_rdd],
            'Count': [count for condition, count in top_5_conditions_rdd]
        }

        # Create a plotly bar chart with hover data and different colors for each bar
        fig = px.bar(
            data,
            x='Count',
            y='Condition',
            orientation='h',
            color='Condition',
            color_discrete_sequence=px.colors.qualitative.Plotly,
            labels={'Count': 'Frequency', 'Condition': 'Condition'},
            title='Top 5 Conditions and Their Frequencies',
            hover_data={'Condition': True, 'Count': True}
        )

        # Invert y-axis to display conditions from top to bottom
        fig.update_yaxes(autorange='reversed')

        # Show the plot
        fig.show()

    elif fileroot == "clinicaltrial_2020" or "clinicaltrial_2021":
        # This operation is not applicable for these file roots
        print(f"This operation is not applicable to {fileroot} file.")


# Process the RDD based on the file root
process_conditions(clinicaltrial_rdd, fileroot)

In [0]:
pip install mplcursors

In [0]:
from pyspark import SparkContext
from pyspark.sql import SparkSession
import re
import matplotlib.pyplot as plt
import mplcursors

# Initialize SparkContext and SparkSession
sc = SparkContext.getOrCreate()
spark = SparkSession.builder.appName("ClinicalTrialAnalysis").getOrCreate()

def process_clinicaltrial_data(file_path, fileroot):
    """
    Process the clinical trial data based on the fileroot value.

    Args:
        file_path (str): The file path to the clinical trial data file.
        fileroot (str): The file root string to determine which operations to perform.

    Returns:
        None
    """
    # Load the data file into an RDD
    clinicaltrial_rdd = sc.textFile(file_path)

    # Filter out the header line from the RDD using zipWithIndex
    clinicaltrial_no_header = clinicaltrial_rdd.zipWithIndex().filter(lambda line_index: line_index[1] > 0).map(lambda line_index: line_index[0])

    if fileroot == "clinicaltrial_2023":
        # Step 1: Define a function to clean and trim the 'Conditions' field
        def clean_conditions(line):
            fields = line.split('\t')
            if len(fields) > 4:  # Ensure there are at least 5 fields
                conditions = fields[4]
                if conditions:
                    # Remove quotes and trim leading and trailing whitespaces
                    conditions = conditions.replace('"', '').strip()
                    # Replace square brackets and parentheses
                    conditions = re.sub(r'[\[\]\(\)]', '', conditions)
                    # Split the conditions by pipe ('|') and flatten nested conditions
                    conditions = re.split(r'[|,]', conditions)
                    # Trim leading and trailing whitespaces from each condition
                    conditions = [condition.strip() for condition in conditions]
                    # Filter out conditions that contain 'e.g.' (case-insensitive)
                    conditions = [condition for condition in conditions if 'e.g.' not in condition.lower()]
                    return conditions
            return []

        # Create an RDD from the cleaned data
        cleaned_conditions_rdd = clinicaltrial_no_header.flatMap(clean_conditions)

        # Step 10: Calculate the top 5 conditions with their frequencies in RDD
        # Map each condition to a key-value pair (condition, 1)
        condition_counts_rdd = cleaned_conditions_rdd.map(lambda condition: (condition, 1))

        # Reduce by key to count occurrences of each condition
        condition_counts_rdd = condition_counts_rdd.reduceByKey(lambda a, b: a + b)

        # Sort the RDD by count in descending order
        sorted_condition_counts_rdd = condition_counts_rdd.sortBy(lambda x: -x[1])

        # Take the top 5 conditions with the highest frequencies
        top_5_conditions_rdd = sorted_condition_counts_rdd.take(5)

        # Display the top 5 conditions with their frequencies
        print("Top 5 conditions and their frequencies:")
        for condition, count in top_5_conditions_rdd:
            print(f"Condition: {condition}, Count: {count}")

        # Extract conditions and counts from the top_5_conditions_rdd list
        conditions = [condition for condition, count in top_5_conditions_rdd]
        counts = [count for condition, count in top_5_conditions_rdd]

        # Calculate cumulative sum of counts
        cumulative_counts = [sum(counts[:i + 1]) for i in range(len(counts))]

        # Create a waterfall chart
        fig, ax = plt.subplots(figsize=(10, 6))

        # Initialize previous value to the initial value
        previous_value = 0

        # Iterate through conditions and their counts to plot the waterfall chart
        for i, (condition, count) in enumerate(zip(conditions, counts)):
            if count > 0:
                # Plot a positive change
                bar = ax.barh(condition, count, left=previous_value, color='steelblue', edgecolor='black')
            else:
                # Plot a negative change (if there are negative values)
                bar = ax.barh(condition, count, left=previous_value, color='red', edgecolor='black')

            # Update the previous value for the next iteration
            previous_value += count

            # Add hover-over tooltips
            mplcursors.cursor(bar).connect(
                "add", lambda sel: sel.annotation.set_text(f"Count: {sel.target[1]}")
            )

        # Plot the final value as a horizontal line
        ax.axvline(previous_value, color='grey', linestyle='--', linewidth=1.5)

        # Add labels and title
        ax.set_xlabel('Frequency')
        ax.set_title('Top 5 Conditions and Their Frequencies (Waterfall Chart)')

        # Show the plot
        plt.show()

    elif fileroot == "clinicaltrial_2020" or "clinicaltrial_2021":
        # This operation is not applicable for these file roots
        print(f"This operation is not applicable to {fileroot} file.")


process_clinicaltrial_data(file_path, fileroot)


In [0]:
fileroot1= "pharma"

In [0]:
#Copying the file to /tmp directory on driver node
dbutils.fs.cp("/FileStore/tables/" + fileroot1 + ".zip", "file:/tmp/")

In [0]:
#Checking the fileroot.zip in DBFS directory
dbutils.fs.ls("/FileStore/tables/")

In [0]:
#Checking the fileroot.zip in local tmp directory
dbutils.fs.ls("file:/tmp/")

In [0]:
#Making 'fileroot' accessible by the command line
import os
os.environ['fileroot1'] = fileroot1

In [0]:
#Unzipping the file in the local directory

In [0]:
%sh
unzip -o -d /tmp /tmp/$fileroot1.zip

In [0]:
#Moving the unzipped file to the DBFS directory
dbutils.fs.mv("file:/tmp/" + fileroot1 +".csv" , "/FileStore/tables/" , True )

In [0]:
#Checking the unzipped file in DBFS directory
dbutils.fs.ls("/FileStore/tables/")

In [0]:
#Checking the DBFS directory
dbutils.fs.ls ("/FileStore/tables/" )

In [0]:
#Listing the contents of the csv file
dbutils.fs.head("/FileStore/tables/" + fileroot1 + ".csv" )

In [0]:
# Read the CSV file and create an RDD
import re
pharma_RDD = sc.textFile("/FileStore/tables/" + fileroot1 + ".csv")
header = pharma_RDD.first()
pharma_RDD = pharma_RDD.filter(lambda row: row != header)
 
# Split each row based on commas by excluding commas enclosed within double quotes
pharma_RDD = pharma_RDD.map(lambda x: re.split(',(?=(?:[^"]*"[^"]*")*[^"]*$)', x))

In [0]:
#Extract the contents of parent company field and store it in a new variable
pharma_Parent_RDD = pharma_RDD.map(lambda x: x[1])

In [0]:
pharma_companies_group = set(pharma_Parent_RDD.distinct().collect())

In [0]:
#Remove the double quotes
parent_companies_group = {company.replace('"', '') for company in pharma_companies_group}
print(parent_companies_group)

In [0]:
#converting it back to rdd using parallelize()
pharma_Parent_RDD = sc.parallelize(parent_companies_group).collect()

#Question 4. Find the 10 most common sponsors that are not pharmaceutical companies, along with the number of clinical trials they have sponsored. Hint: For a basic implementation, you can assume that the Parent Company column contains all possible pharmaceutical companies.

In [0]:
from pyspark import SparkContext
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Initialize SparkContext
sc = SparkContext.getOrCreate()

def process_clinicaltrial_data(file_path, fileroot, pharma_Parent_RDD):
    """
    Process the clinical trial data based on the fileroot value.

    Args:
        file_path (str): The file path to the clinical trial data file.
        fileroot (str): The file root string to determine which operations to perform.
        pharma_Parent_RDD (RDD): An RDD of pharmaceutical parent companies for filtering.

    Returns:
        None
    """
    
    # Filter out the header line from the RDD using zipWithIndex
    clinicaltrial_no_header = clinicaltrial_rdd.zipWithIndex().filter(lambda line_index: line_index[1] > 0).map(lambda line_index: line_index[0])

    if fileroot == "clinicaltrial_2023":
        # Step 1: Extract the "Sponsor" column from the RDD
        sponsor2023RDD = clinicaltrial_no_header.map(lambda line: line.split('\t')[6])

        # Filter out all the sponsors that are not pharmaceutical companies
        sponsor_non_pharma_RDD = sponsor2023RDD.filter(lambda x: x not in pharma_Parent_RDD)

        # Count the number of clinical trials sponsored by each non-pharmaceutical company
        sponsor_non_pharma_count_RDD = sponsor_non_pharma_RDD.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y)

        # Get the top 10 non-pharma sponsors by number of clinical trials sponsored
        sponsor_non_pharma_top10_RDD = sponsor_non_pharma_count_RDD.takeOrdered(10, key=lambda x: -x[1])

        # Display the corresponding results with header
        print("Top 10 Non-Pharma Sponsors by Number of Clinical Trials Sponsored (2023):")
        print("+---------------------------------------+-----+")
        print("|{:40s}|{:5s}|".format("Sponsor", "Count"))
        print("+---------------------------------------+-----+")
        for sponsor, count in sponsor_non_pharma_top10_RDD:
            print("|{:40s}|{:5d}|".format(sponsor, count))
        print("+---------------------------------------+-----+")

        # Plotting a heatmap with the top 10 non-pharma sponsors and the number of clinical trials they sponsored
        # Extract sponsors and counts from the list of tuples
        sponsors = [item[0] for item in sponsor_non_pharma_top10_RDD]
        counts = [item[1] for item in sponsor_non_pharma_top10_RDD]

        # Convert the counts list to a 2D array (reshape it as a row with all counts)
        counts_array = np.array(counts).reshape(1, -1)

        # Create a heatmap
        plt.figure(figsize=(12, 3))  # Adjust the size as per your preference
        sns.heatmap(counts_array, annot=True, cmap='coolwarm', fmt='d', xticklabels=sponsors, yticklabels=['Counts'])

        # Add labels and title
        plt.title('Top 10 Non-Pharma Sponsors by Number of Clinical Trials Sponsored (2023)')
        plt.xlabel('Sponsor')
        plt.ylabel('')

        # Display the heatmap
        plt.show()

    elif fileroot == "clinicaltrial_2020" or fileroot == "clinicaltrial_2021":
        # Step 1: Extract the "Sponsor" column from the RDD
        sponsor_RDD = clinicaltrial_no_header.map(lambda line: line.split('|')[1])

        # Filter out all the sponsors that are not pharmaceutical companies
        sponsor_non_pharma_RDD = sponsor_RDD.filter(lambda x: x not in pharma_Parent_RDD)

        # Count the number of clinical trials sponsored by each non-pharmaceutical company
        sponsor_non_pharma_count_RDD = sponsor_non_pharma_RDD.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y)

        # Get the top 10 non-pharma sponsors by number of clinical trials sponsored
        sponsor_non_pharma_top10_RDD = sponsor_non_pharma_count_RDD.takeOrdered(10, key=lambda x: -x[1])

        # Display the corresponding results with header
        trial_Year2 = fileroot.split("_")[1]  # Extract the year from the fileroot string
        print(f"{trial_Year2}:")
        print("+---------------------------------------+-----+")
        print("|{:40s}|{:5s}|".format("Sponsor", "Count"))
        print("+---------------------------------------+-----+")
        for sponsor, count in sponsor_non_pharma_top10_RDD:
            print("|{:40s}|{:5d}|".format(sponsor, count))
        print("+---------------------------------------+-----+")

# Process the data
process_clinicaltrial_data(file_path, fileroot, pharma_Parent_RDD)


#Question 5.Plot number of completed studies for each month in 2023. You need to include your visualization as well as a table of all the values you have plotted for each month.

In [0]:
import re
from collections import defaultdict
from datetime import datetime
import matplotlib.pyplot as plt
import mplcursors

# Function to process and plot data for clinicaltrial_2023
def process_clinicaltrial_2023(clinicaltrial_rdd):
    # Split each line of the RDD
    split_rdd = clinicaltrial_rdd.map(lambda line: line.split('\t'))

    # Filter RDD to include only completed studies in 2023
    complete_2023RDD = split_rdd.filter(lambda row: len(row) > 13 and row[3] == 'COMPLETED' and row[13].startswith('2023'))

    # Initialize a defaultdict with a default value of 0
    completed_month_status = defaultdict(int)

    # Updating the count for each month
    for row in complete_2023RDD.collect():
        completion_date_status = row[13]  # Accessing the 13th column

        # Remove all special characters from completion_date_status except '-'
        clean_completion_date_status = re.sub(r"[^0-9-]", "", completion_date_status)

        # Extract the month from the clean completion date
        completion_month = clean_completion_date_status.split('-')[1]  # Extracting the month

        # Increment the count for the month
        completed_month_status[completion_month] += 1

    # Sort the completed_month_status dictionary by month
    sorted_completed_month_status = sorted(completed_month_status.items(), key=lambda x: int(x[0]))

    # Print results in a tabular format in ascending order of months
    print("+---------------+--------------------------------")
    print("| Month        | Number of completed studies of the month  |")
    print("+---------------+--------------------------------")
    for month, count in sorted_completed_month_status:
        print(f"|  {month:<10}  | {count:<25}  |")
    print("+---------------+--------------------------------")

    # Extract months and counts from the sorted dictionary
    months = [item[0] for item in sorted_completed_month_status]
    counts = [item[1] for item in sorted_completed_month_status]

    # Create a bar chart
    plt.figure(figsize=(10, 6))

    # Use a colormap to change hues based on count
    cmap = plt.get_cmap('viridis')
    colors = cmap([count / max(counts) for count in counts])

    # Plot bar chart
    bars = plt.bar(months, counts, color=colors)

    # Add labels and title
    plt.xlabel('Month')
    plt.ylabel('Number of Completed Studies')
    plt.title('Number of Completed Studies per Month in 2023')

    # Add a color bar to represent the hue scale
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=min(counts), vmax=max(counts)))
    sm.set_array([])
    plt.colorbar(sm, label='Count')

    # Use mplcursors to display hover-over data
    cursor = mplcursors.cursor(bars, hover=True)
    cursor.connect(
        "add", lambda sel: sel.annotation.set_text(
            f"Month: {months[sel.index]}\nCount: {counts[sel.index]}"
        )
    )

    # Show the plot
    plt.show()

    # Create a line graph
    plt.figure(figsize=(10, 6))

    # Plot the line graph
    line, = plt.plot(months, counts, color='blue', linewidth=2, marker='o', markersize=8, linestyle='-', label='Completed Studies')

    # Add labels and title
    plt.xlabel('Month')
    plt.ylabel('Number of Completed Studies')
    plt.title('Number of Completed Studies per Month in 2023')
    plt.legend()

    # Use mplcursors to display hover-over data
    cursor = mplcursors.cursor([line], hover=True)
    cursor.connect(
        "add", lambda sel: sel.annotation.set_text(
            f"Month: {months[sel.index]}\nCount: {counts[sel.index]}"
        )
    )

    # Show the plot
    plt.show()

# Function to process and plot data for clinicaltrial_2020 or clinicaltrial_2021
def process_clinicaltrial_2020_2021(clinicaltrial_rdd, trial_year):
    # Filter out the rows where the study status is "Completed" and the study completion date is in the trial year
    status_filter_RDD = clinicaltrial_rdd.filter(lambda x: "Completed" in x.split("|")[2] and trial_year in x.split("|")[4])
    
    # Map each row to a tuple containing the month and year of the study completion date
    def parse_date(row):
        date_str = row.split("|")[4]
        # Parse the date
        date_obj = datetime.strptime(date_str, "%b %Y")
        # Return tuple with month and year
        return date_obj.strftime("%b"), date_obj.strftime("%Y")
    
    MM_YYYY_RDD = status_filter_RDD.map(parse_date)
    
    # Count the number of studies completed for each month and year
    month_wise_count_RDD = MM_YYYY_RDD.countByValue()
    
    # Sorting the counts by month
    month_wise_count_sort = sorted(month_wise_count_RDD.items(), key=lambda x: datetime.strptime(x[0][0], "%b").month)
    
    # Extract the counts for each month and the months list
    months = [month for (month, year), count in month_wise_count_sort if year == trial_year]
    num_counts = [count for (month, year), count in month_wise_count_sort if year == trial_year]
    
    # Print months and counts in two columns
    print("Month   Count")
    for month, count in zip(months, num_counts):
        print(f"{month:<6} {count:<6}")

    # Plotting the data as a line plot
    plt.figure(figsize=(10, 6))
    plt.plot(months, num_counts, marker='o', linestyle='-')
    plt.title(f'Completed Studies Month-wise (Year: {trial_year})')
    plt.xlabel(f'Months (Year: {trial_year})')
    plt.ylabel('Count')
    plt.xticks(rotation=45)
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    # Plotting the data as a bar plot
    plt.figure(figsize=(10, 6))
    plt.bar(months, num_counts, color='blue')
    plt.title(f'Completed Studies Month-wise (Year: {trial_year})')
    plt.xlabel(f'Months (Year: {trial_year})')
    plt.ylabel('Count')
    plt.xticks(rotation=45)
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# Main code to process files based on the fileroot value
def process_files(fileroot, clinicaltrial_rdd):
    # Split the filename to get the year
    fileName = fileroot.split("_")
    trial_Year = fileName[1]
    
    if fileroot == "clinicaltrial_2023":
        # Process and plot data for clinicaltrial_2023
        process_clinicaltrial_2023(clinicaltrial_rdd)
    elif fileroot == "clinicaltrial_2020" or fileroot == "clinicaltrial_2021":
        # Process and plot data for clinicaltrial_2020 or clinicaltrial_2021
        process_clinicaltrial_2020_2021(clinicaltrial_rdd, trial_Year)
        

process_files(fileroot, clinicaltrial_rdd)


#Additional Analyses on Clinicaltrial datasets

#1. Top 5 studies with the highest enrollments (2023)  and  Distribution of study statuses (Historical datasets)

In [0]:
import matplotlib.pyplot as plt
import mplcursors
import re
from collections import defaultdict
from datetime import datetime

# Function to process and visualize data for clinicaltrial_2023
def process_clinicaltrial_2023(clinicaltrial_rdd):
    # Split each line of the RDD
    split_rdd = clinicaltrial_rdd.map(lambda line: line.split('\t'))

    # Function to parse the Enrollment column and handle non-numeric and null values
    def parse_enrollment(row):
        try:
            # Check if the enrollment value exists
            if len(row) < 9 or not row[8]:
                # Return None for missing or null values in the Enrollment column
                return None

            enrollment = row[8]

            # Convert enrollment to a float
            enrollment_float = float(enrollment)

            # Return a tuple with the study ID, study title, enrollment (string), and enrollment (float)
            return (row[0], row[1], row[8], enrollment_float)

        except ValueError:
            # Return None for non-numeric values or parsing errors
            return None

    # Parse the Enrollment column and filter out None values
    enrollment_rdd = split_rdd.map(parse_enrollment).filter(lambda x: x is not None)

    # Sort the RDD based on enrollment in descending order
    sorted_enrollment_rdd = enrollment_rdd.sortBy(lambda x: x[3], ascending=False)

    # Take the top 5 studies with the highest enrollments
    top_5_enrollments = sorted_enrollment_rdd.take(5)

    # Display the top 5 studies with the highest enrollments
    print("Top 5 studies with highest enrollments:")
    for study in top_5_enrollments:
        print(f"Study ID: {study[0]}, Study Title: {study[1]}, Enrollment: {study[2]}")

    # Extracting data from the top 5 studies with the highest enrollments
    study_titles = [study[1] for study in top_5_enrollments]
    enrollments = [study[3] for study in top_5_enrollments]

    # Calculate the total enrollment of the top 5 studies
    total_enrollment = sum(enrollments)

    # Create a pie chart
    plt.figure(figsize=(10, 8))

    # Plotting the pie chart using enrollments as sizes and study titles as labels
    plt.pie(enrollments, labels=study_titles, autopct=lambda p: f'{p:.2f}%\n({p * total_enrollment / 100:.0f})', startangle=140)

    # Add a title
    plt.title('Top 5 Studies with Highest Enrollments')

    # Display the plot
    plt.show()

# Function to process and visualize data for clinicaltrial_2020 or clinicaltrial_2021
def process_clinicaltrial_2020_2021(clinicaltrial_rdd, trial_year):
    # Filter RDD for the given year
    filtered_RDD = clinicaltrial_rdd.filter(lambda line: trial_year in line.split("|")[4])

    # Problem Statement 4: Distribution of study statuses
    status_distribution = filtered_RDD.map(lambda line: (line.split("|")[2], 1)).reduceByKey(lambda x, y: x + y).sortBy(lambda x: x[1], ascending=False)

    print(f"Distribution of study statuses for year {trial_year}:")
    print(status_distribution.collect())

    # Extracting status labels and counts
    labels = status_distribution.map(lambda x: x[0]).collect()
    counts = status_distribution.map(lambda x: x[1]).collect()

    # Plotting the pie chart
    plt.figure(figsize=(8, 8))
    plt.pie(counts, labels=labels, autopct='%1.1f%%', startangle=140)
    plt.title(f'\nDistribution of Study Statuses (Year: {trial_year})\n', pad=20)  # Add padding to the title
    plt.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
    plt.tight_layout()  # Adjust layout to avoid overlap
    plt.show()

    # Plotting the bar chart
    plt.figure(figsize=(10, 6))  # Set figure size
    plt.barh(labels, counts, color=['blue', 'green', 'red', 'gold', 'coral', 'lightskyblue', 'lightgreen', 'salmon', 'gold', 'lightcoral'])
    plt.xlabel('Count')
    plt.ylabel('Study Status')
    plt.title(f'Distribution of Study Statuses (Year: {trial_year})')
    plt.tight_layout()  # Adjust layout to prevent overlapping elements
    plt.show()

# Main code to process files based on the fileroot value
def process_files(fileroot, clinicaltrial_rdd):
    if fileroot == "clinicaltrial_2023":
        # Process and visualize data for clinicaltrial_2023
        process_clinicaltrial_2023(clinicaltrial_rdd)
    elif fileroot == "clinicaltrial_2020" or fileroot == "clinicaltrial_2021":
        # Determine the trial year based on fileroot
        trial_year = fileroot.split('_')[1]
        # Process and visualize data for clinicaltrial_2020 or clinicaltrial_2021
        process_clinicaltrial_2020_2021(clinicaltrial_rdd, trial_year)


process_files(fileroot, clinicaltrial_rdd)


#2. Comparing the sponsorship of Mayo Clinic and Massachusetts General Hospital (2023) and Top 5 sponsors with the most completed studies

In [0]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# Define a function to process and visualize data for clinicaltrial_2023
def process_clinicaltrial_2023(clinicaltrial_rdd):
    # Split each line of the RDD
    split_rdd = clinicaltrial_rdd.map(lambda line: line.split('\t'))

    # Filter the RDD to include only studies sponsored by "Mayo Clinic"
    Mayo_Clinic_studies = split_rdd.filter(lambda row: row[6].strip() == "Mayo Clinic")

    # Count the number of studies that satisfy the condition
    num_Mayo_Clinic_studies = Mayo_Clinic_studies.count()

    # Print the result
    print(f"Number of studies sponsored by Mayo Clinic: {num_Mayo_Clinic_studies}")

    # Filter the RDD to include only studies sponsored by "Massachusetts General Hospital"
    M_G_studies = split_rdd.filter(lambda row: row[6].strip() == "Massachusetts General Hospital")

    # Count the number of studies that satisfy the condition
    num_M_G_studies = M_G_studies.count()

    # Print the result
    print(f"Number of studies sponsored by Massachusetts General Hospital: {num_M_G_studies}")

    # Define the data
    data = {
        'Institution': ['Mayo Clinic', 'Massachusetts General Hospital'],
        'Number of Studies': [num_Mayo_Clinic_studies, num_M_G_studies]
    }

    # Create a pandas DataFrame from the data
    df = pd.DataFrame(data)

    # Set the index to the Institution column
    df.set_index('Institution', inplace=True)

    # Create a heatmap
    plt.figure(figsize=(8, 6))
    sns.heatmap(df, annot=True, fmt='d', cmap='coolwarm', cbar=True, linewidths=.5, linecolor='black')

    # Add title and labels
    plt.title('Heatmap of Number of Studies Sponsored by Each Institution')
    plt.xlabel('')
    plt.ylabel('Institution')

    # Display the heatmap
    plt.show()

# Define a function to process and visualize data for clinicaltrial_2020 or clinicaltrial_2021
def process_clinicaltrial_2020_2021(clinicaltrial_rdd, trial_year):
    # Filter the RDD for studies with a "Completed" status
    completed_studies = clinicaltrial_rdd.filter(lambda line: "Completed" in line)

    # Calculate the top 5 sponsors with the most completed studies
    top_sponsors_completed = completed_studies.map(lambda line: (line.split("|")[1], 1)).reduceByKey(lambda x, y: x + y).sortBy(lambda x: x[1], ascending=False).take(5)

    # Print the top 5 sponsors with the most completed studies
    print(f"Top 5 sponsors with the most completed studies (Year: {trial_year}):")
    print(top_sponsors_completed)

    # Extract sponsor names and their respective counts
    sponsors = [item[0] for item in top_sponsors_completed]
    counts = [item[1] for item in top_sponsors_completed]

    # Plotting the bar chart
    plt.figure(figsize=(10, 6))  # Set figure size
    plt.bar(sponsors, counts, color='skyblue')
    plt.xlabel('Sponsor')
    plt.ylabel('Number of Completed Studies')
    plt.title(f'Top 5 Sponsors with the Most Completed Studies (Year: {trial_year})')
    plt.xticks(rotation=45)  # Rotate x-axis labels for better readability
    plt.tight_layout()  # Adjust layout to prevent overlapping elements
    plt.show()

# Define the main function to process the clinical trial data based on fileroot
def process_files(fileroot, clinicaltrial_rdd):
    if fileroot == "clinicaltrial_2023":
        # Process and visualize data for clinicaltrial_2023
        process_clinicaltrial_2023(clinicaltrial_rdd)
    elif fileroot == "clinicaltrial_2020" or fileroot == "clinicaltrial_2021":
        # Determine the trial year based on fileroot
        trial_year = fileroot.split('_')[1]
        # Process and visualize data for clinicaltrial_2020 or clinicaltrial_2021
        process_clinicaltrial_2020_2021(clinicaltrial_rdd, trial_year)


process_files(fileroot, clinicaltrial_rdd)


#3.  Count of studies where the condition contains the keywords "heart," "alcohol," or "liver," (2023) and list studies started in a specific year (historical datasets)

In [0]:
import matplotlib.pyplot as plt

def count_studies_with_keywords(clinicaltrial_rdd):
    # Define keywords to search for
    keywords = {
        "heart": "heart",
        "alcohol": "alcohol",
        "liver": "liver"
    }

    # Split each line of the RDD based on tab delimiter (adjust if your dataset uses a different delimiter)
    split_rdd = clinicaltrial_rdd.map(lambda line: line.split('\t'))

    # Initialize a dictionary to hold the counts for each keyword
    keyword_counts = {}
    
    # Calculate the counts for each keyword
    for keyword_name, keyword in keywords.items():
        # Filter the RDD to include only rows where the condition contains the keyword
        filtered_rdd = split_rdd.filter(lambda row: keyword in row[4].lower())
        
        # Count the number of studies in the filtered RDD
        keyword_counts[keyword_name] = filtered_rdd.count()

    # Print the count of studies where the condition contains each of the keywords
    for keyword_name, count in keyword_counts.items():
        print(f"Count of studies where the condition contains '{keyword_name}': {count}")

    # Plot the counts as a bar chart
    plot_study_keyword_counts(keyword_counts)


def plot_study_keyword_counts(keyword_counts):
    # Extract conditions and counts from the dictionary
    conditions = list(keyword_counts.keys())
    counts = list(keyword_counts.values())

    # Colors for the bar chart
    colors = ['red', 'orange', 'green']

    # Create a bar plot
    plt.figure(figsize=(10, 6))
    bars = plt.bar(conditions, counts, color=colors)

    # Add labels and title
    plt.xlabel('Condition')
    plt.ylabel('Count')
    plt.title('Number of Studies with Conditions Containing "Heart," "Alcohol," or "Liver"')

    # Add hover-over data: annotate bars
    for bar in bars:
        # Get the height of the bar (count value)
        yval = bar.get_height()
        # Add text annotation above the bar to display the count value
        plt.text(bar.get_x() + bar.get_width() / 2, yval, f'{int(yval)}',
                 ha='center', va='bottom', fontsize=10, fontweight='bold')

    # Show the plot
    plt.show()


# Define a function to list studies started in a particular year
def list_studies_by_year(year, clinicaltrial_rdd):
    # Filter RDD to get studies started in the specified year
    studies_year = clinicaltrial_rdd.filter(lambda line: year in line.split("|")[3])

    # Map each study to a tuple containing study ID, sponsor, and start date
    studies_year_info = studies_year.map(lambda line: (line.split("|")[0], line.split("|")[1], line.split("|")[3]))

    # Print the header
    print("Study ID\t\t Sponsor\t\t Start Date")
    print("-" * 40)  # Print a separator line

    # Print the study information
    for study_id, sponsor, start_date in studies_year_info.collect():
        print(f"{study_id}\t {sponsor}\t {start_date}")


# Main function to process the clinical trial data based on fileroot
def process_files(fileroot, clinicaltrial_rdd):
    if fileroot == "clinicaltrial_2023":
        # Perform tasks for clinicaltrial_2023
        count_studies_with_keywords(clinicaltrial_rdd)
    elif fileroot == "clinicaltrial_2020" or fileroot == "clinicaltrial_2021":
        # Specify the year from the fileroot (e.g., '2020' or '2021')
        trial_year = fileroot.split('_')[1]
        # Perform tasks for clinicaltrial_2020 or clinicaltrial_2021
        list_studies_by_year(trial_year, clinicaltrial_rdd)


process_files(fileroot, clinicaltrial_rdd)

In [0]:
# Removing the files from local directory if present

In [0]:
dbutils.fs.rm("/FileStore/tables/" + fileroot + ".zip")
dbutils.fs.rm("/FileStore/tables/" + fileroot1 + ".zip")

In [0]:
# Define the directory path
directory_path = "/FileStore/tables/"

# Use dbutils.fs.rm to delete files in the directory
# The recursive parameter is set to True to delete all files in the directory
dbutils.fs.rm(directory_path, recurse=True)

# Print a message indicating the deletion
print(f"All files in {directory_path} have been deleted.")