In [1]:
#set the pyspark environment variables
import os
os.environ['SPARK_HOME']= r"D:\spark"
os.environ['PYSPARK_DRIVER_PYTHON']='jupiter'
os.environ['PYSPARK_DRIVER_PYTHON_OPTS']='lab'
os.environ['PYSPARK_PYTHON']='python'


In [2]:
#import pyspark
from pyspark.sql import SparkSession

In [3]:
#create spark session
spark = SparkSession.builder\
     .appName("pyspark-get-started")\
     .getOrCreate()

In [4]:
#test the setup
data=[("Alice",25), ("Bob", 30),("Charlie",35)]
df=spark.createDataFrame(data,["Name", "Age"])
df.show()

+-------+---+
|   Name|Age|
+-------+---+
|  Alice| 25|
|    Bob| 30|
|Charlie| 35|
+-------+---+



In [5]:
import os
import zipfile

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lower, trim, collect_set


In [6]:
ZIP_FILES = [
faers_ascii_2024Q1.zipfaers_ascii_2024Q2.zipfaers_ascii_2024Q3.zipfaers_ascii_2024Q4.zipfaers_ascii_2025q1 (1).zipfaers_ascii_2025q2 (2).zipfaers_ascii_2025q3 (2).zip]

DRUG_LIST = [
    "Lamotrigine","Levetiracetam","Topiramate","Gabapentin","Pregabalin",
    "Oxcarbazepine","Zonisamide","Lacosamide","Clobazam","Phenytoin",
    "Carbamazepine","Phenobarbital","Valproic acid","Sodium valproate",
    "Ethosuximide","Levodopa + Carbidopa","Bromocriptine","Pramipexole",
    "Ropinirole","Rotigotine","Apomorphine","Selegiline","Rasagiline",
    "Safinamide","Entacapone","Tolcapone","Trihexyphenidyl","Benztropine",
    "Amantadine","Donepezil","Rivastigmine","Galantamine","Memantine",
    "Fingolimod","Natalizumab","Ocrelizumab","Baclofen","Modafinil",
    "Sumatriptan","Ibuprofen","Dihydroergotamine"
]

DRUG_LIST_LOWER = [d.lower().strip() for d in DRUG_LIST]


In [7]:
import tempfile
from functools import reduce
from pyspark.sql import DataFrame
from pyspark.sql import SparkSession

# Make sure Spark session is running
spark = SparkSession.builder.appName("FAERS Reader").getOrCreate()

def read_faers_tables_spark(table_keyword):
    dfs = []
    
    for zip_path in ZIP_FILES:
        if not os.path.exists(zip_path):
            continue
        
        with zipfile.ZipFile(zip_path, "r") as z:
            files = [
                f for f in z.namelist()
                if table_keyword.upper() in f.upper() and f.endswith(".txt")
            ]
            for file in files:
                # Use a temporary file safely
                with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as temp_file:
                    temp_file.write(z.read(file))
                    temp_file_path = temp_file.name

                # Read the extracted file with Spark
                df = spark.read.option("delimiter", "$") \
                               .option("header", True) \
                               .option("encoding", "ISO-8859-1") \
                               .csv(temp_file_path)
                
                dfs.append(df)
    
    if dfs:
        # Combine all Spark DataFrames
        combined_df = reduce(DataFrame.unionByName, dfs)
        return combined_df
    else:
        # Return empty DataFrame if no tables found
        return spark.createDataFrame([], schema=None)




In [8]:
# --- 3️⃣ Read each FAERS table ---
drug_raw = read_faers_tables_spark("DRUG")
reac_raw = read_faers_tables_spark("REAC")
demo_raw = read_faers_tables_spark("DEMO")
outc_raw = read_faers_tables_spark("OUTC")


In [9]:
from pyspark.sql import functions as F

def clean_drug_table_spark(df):
    drop_cols = [
        'val_vbm','dose_vbm','cum_dose_chr','cum_dose_unit',
        'dechal','rechal','lot_num','exp_dt','nda_num','primaryid'
    ]
    # Drop columns that exist
    df = df.drop(*[c for c in drop_cols if c in df.columns])

    # Lowercase and strip drugname
    df = df.withColumn("drugname", F.lower(F.trim(F.col("drugname"))))

    # Filter by DRUG_LIST_LOWER and role_cod
    df = df.filter(
        (F.col("drugname").isin(DRUG_LIST_LOWER)) &
        (F.col("role_cod").isin("PS", "SS"))
    )

    # Drop duplicates on caseid + drugname
    df = df.dropDuplicates(["caseid", "drugname"])

    # Replace 'Unknown' with null
    for c in df.columns:
        df = df.withColumn(c, F.when(F.col(c) == "Unknown", None).otherwise(F.col(c)))

    # Drop columns with >60% missing values
    total_count = df.count()
    missing_thresh = 0.6 * total_count

    # Calculate missing per column
    missing_counts = df.select([F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in df.columns]).collect()[0].asDict()
    drop_missing_cols = [c for c, cnt in missing_counts.items() if cnt > missing_thresh]
    
    df = df.drop(*drop_missing_cols)

    return df

# --- Usage ---
drug_clean = clean_drug_table_spark(drug_raw)
drug_clean.show(5)


+--------+--------+--------+-------------+--------------------+
|  caseid|drug_seq|role_cod|     drugname|             prod_ai|
+--------+--------+--------+-------------+--------------------+
|10182318|       2|      SS|  pramipexole|PRAMIPEXOLE\PRAMI...|
|10365773|       5|      SS|oxcarbazepine|       OXCARBAZEPINE|
|10462505|       1|      PS|carbamazepine|       CARBAMAZEPINE|
|10497080|       1|      PS|    donepezil|           DONEPEZIL|
|10609166|       3|      SS|   lacosamide|          LACOSAMIDE|
+--------+--------+--------+-------------+--------------------+
only showing top 5 rows


In [10]:
from pyspark.sql import functions as F

def clean_reac_table_spark(df):
    # Take only relevant columns
    reac_df = df.select("caseid", "pt")
    
    # Aggregate reactions per caseid, sorted and unique
    reac_agg = (
        reac_df.groupBy("caseid")
        .agg(
            F.concat_ws(", ", F.collect_list(F.col("pt"))).alias("pt_list_raw")
        )
    )
    
    # Remove duplicates and sort within each list
    # Convert the comma string to array, remove duplicates, sort, then back to string
    reac_agg = reac_agg.withColumn(
        "pt",
        F.array_join(
            F.sort_array(F.array_distinct(F.split(F.col("pt_list_raw"), ", "))),
            ", "
        )
    ).drop("pt_list_raw")
    
    return reac_agg

# Usage
reac_clean = clean_reac_table_spark(reac_raw)
reac_clean.show(5, truncate=False)


+--------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|caseid  |pt                                                                                                                                                                                                                                                                                                                                                                                           

In [11]:
from pyspark.sql import functions as F

def clean_demo_table_spark(df):
    drop_cols = [
        'rept_cod','to_mfr','caseversion','i_f_code','event_dt',
        'mfr_dt','init_fda_dt','fda_dt','rept_dt','occp_cod',
        'reporter_country','e_sub',
        'primaryid', 'mfr_num', 'mfr_sndr' 
    ]
    
    # Drop columns that exist
    df = df.drop(*[c for c in drop_cols if c in df.columns])

    # Drop columns with >60% missing values (less than 40% non-null)
    total_count = df.count()
    thresh_count = 0.4 * total_count  # at least 40% non-null required
    
    # Calculate non-null counts per column
    non_null_counts = df.select([F.count(F.when(F.col(c).isNotNull(), c)).alias(c) for c in df.columns]).collect()[0].asDict()
    
    drop_missing_cols = [c for c, cnt in non_null_counts.items() if cnt < thresh_count]
    
    df = df.drop(*drop_missing_cols)
    
    return df

# Usage
demo_clean = clean_demo_table_spark(demo_raw)
demo_clean.show(5)


+--------+---+-------+---+------------+
|  caseid|age|age_cod|sex|occr_country|
+--------+---+-------+---+------------+
|10016781| 56|     YR|  F|          CA|
|10028721| 57|     YR|  F|          CA|
|10029366| 32|     YR|  M|          AU|
|10054507| 68|     YR|  F|          US|
|10057621| 57|     YR|  M|          CA|
+--------+---+-------+---+------------+
only showing top 5 rows


In [12]:
from pyspark.sql import functions as F

def clean_outc_table_spark(df):
    # Take only relevant columns
    outc_df = df.select("caseid", "outc_cod")
    
    # Aggregate outcomes per caseid, collecting all values into a list
    outc_agg = (
        outc_df.groupBy("caseid")
        .agg(F.collect_list(F.col("outc_cod")).alias("outc_list_raw"))
    )
    
    # Remove duplicates within each case, sort, and convert back to string
    outc_agg = outc_agg.withColumn(
        "outc_cod",
        F.array_join(
            F.sort_array(F.array_distinct(F.col("outc_list_raw"))),
            ","
        )
    ).drop("outc_list_raw")
    
    return outc_agg

# Usage
outc_clean = clean_outc_table_spark(outc_raw)
outc_clean.show(5, truncate=False)


+--------+--------+
|caseid  |outc_cod|
+--------+--------+
|10032297|HO,OT   |
|10142803|DE,OT   |
|10167865|DE,OT   |
|10182318|OT      |
|10201017|HO      |
+--------+--------+
only showing top 5 rows


In [13]:
# --- Step 1: Ensure caseid is string type ---
from pyspark.sql import functions as F

drug_clean = drug_clean.withColumn("caseid", F.col("caseid").cast("string"))
reac_clean = reac_clean.withColumn("caseid", F.col("caseid").cast("string"))
demo_clean = demo_clean.withColumn("caseid", F.col("caseid").cast("string"))
outc_clean = outc_clean.withColumn("caseid", F.col("caseid").cast("string"))

# --- Step 2: Merge drug_clean and reac_clean (inner join) ---
final_df = drug_clean.join(reac_clean, on="caseid", how="inner")

# --- Step 3: Merge with demo_clean (inner join) ---
final_df = final_df.join(demo_clean, on="caseid", how="inner")

# --- Step 4: Merge with outc_clean (left join) ---
final_df = final_df.join(outc_clean, on="caseid", how="left")

# Show first 20 rows, full text without truncation
final_df.show(20, truncate=False)



+--------+--------+--------+-------------+---------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----+-------+----+------------+--------+
|caseid  |drug_seq|role_cod|drugname     |prod_ai                                |pt                                                                                                                                                                                                                                                                                                         

In [14]:
pip install numpy


Note: you may need to restart the kernel to use updated packages.


In [15]:
from pyspark.sql import functions as F

df_tx = final_df.withColumn("pt_list", F.split(F.col("pt"), ",\\s*"))

df_tx = df_tx.groupBy("caseid").agg(
    F.collect_set("drugname").alias("drug_list"),
    F.first("pt_list").alias("pt_list")
)

df_tx = df_tx.withColumn(
    "items",
    F.array_distinct(F.concat(F.col("drug_list"), F.col("pt_list")))
)



In [16]:
transactions = (
    df_tx
    .filter(F.col("items").isNotNull())
    .filter(F.size("items") > 1)
    .select("items")
)


In [17]:
transactions.printSchema()
transactions.filter(F.size("items") <= 1).count()


root
 |-- items: array (nullable = true)
 |    |-- element: string (containsNull = false)



0

In [18]:
from pyspark.ml.fpm import FPGrowth

fp = FPGrowth(itemsCol="items", minSupport=0.01, minConfidence=0.1)
model = fp.fit(transactions)


In [19]:
rules = model.associationRules

rules_ddi = (
    rules
    .filter(F.size("antecedent") == 2)
    .filter(F.size("consequent") == 1)
    .select(
        F.col("antecedent")[0].alias("DrugA"),
        F.col("antecedent")[1].alias("DrugB"),
        F.col("consequent")[0].alias("ADR"),
        F.col("lift").alias("Lift_2Drugs"),
        "support"
    )
)


In [20]:
rules_ddi = rules_ddi.withColumn(
    "Severity",
    F.when((F.col("Lift_2Drugs") >= 10) & (F.col("support") < 0.002), "Severe (rare)")
     .when(F.col("Lift_2Drugs") >= 2, "Moderate")
     .otherwise("Mild")
)


In [24]:
pip install streamlit


Note: you may need to restart the kernel to use updated packages.
