In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pip install pyspark



In [2]:
# Environment setup
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("opioid").getOrCreate()

In [None]:
# Load data
carrier = spark.read.option("header", "true") \
                    .option("delimiter", "|") \
                    .csv("/content/drive/MyDrive/carrier.csv", header=True)


pde = spark.read.option("header", "true") \
                .option("delimiter", "|") \
                .csv("/content/drive/MyDrive/pde.csv")

In [None]:
carrier.write.mode("overwrite").parquet("/content/drive/MyDrive/carrier_parquet")
pde.write.mode("overwrite").parquet("/content/drive/MyDrive/pde_parquet")

carrier_p     = spark.read.parquet("/content/drive/MyDrive/carrier_parquet")
pde_p         = spark.read.parquet("/content/drive/MyDrive/pde_parquet")

In [4]:
# Standardize IDs and parse dates
def normalize(df):
    for c in df.columns:
        df = df.withColumnRenamed(c, c.lower())
    return df

carrier_p     = normalize(carrier_p)
pde_p         = normalize(pde_p)

In [5]:
from pyspark.sql.functions import col, regexp_replace, to_date

carrier_silver = carrier_p.select("BENE_ID", "CLM_ID", "PRNCPAL_DGNS_CD", "PRF_PHYSN_NPI", "CLM_FROM_DT", "CLM_THRU_DT")
carrier_silver = (carrier_silver.withColumn("BENE_ID", regexp_replace(col("BENE_ID").cast("string"), "-", ""))
                                .withColumn("CLM_ID",  regexp_replace(col("CLM_ID").cast("string"), "-", ""))
                                .withColumn("CLM_FROM_DT", to_date(col("CLM_FROM_DT"), "dd-MMM-yyyy"))
                                .withColumn("CLM_THRU_DT", to_date(col("CLM_THRU_DT"), "dd-MMM-yyyy")))
carrier_silver.show(5)

+--------------+--------------+---------------+-------------+-----------+-----------+
|       BENE_ID|        CLM_ID|PRNCPAL_DGNS_CD|PRF_PHYSN_NPI|CLM_FROM_DT|CLM_THRU_DT|
+--------------+--------------+---------------+-------------+-----------+-----------+
|10000010273042|10000930854995|           Z733|   9999992693| 2019-10-22| 2019-10-22|
|10000010273042|10000930854995|           Z733|   9999992693| 2019-10-22| 2019-10-22|
|10000010273042|10000930854995|           Z733|   9999992693| 2019-10-22| 2019-10-22|
|10000010273042|10000930854995|           Z733|   9999992693| 2019-10-22| 2019-10-22|
|10000010273042|10000930854995|           Z733|   9999992693| 2019-10-22| 2019-10-22|
+--------------+--------------+---------------+-------------+-----------+-----------+
only showing top 5 rows


In [6]:
pde_silver = pde_p.select("PDE_ID", "BENE_ID", "SRVC_DT", "PROD_SRVC_ID", "QTY_DSPNSD_NUM", "DAYS_SUPLY_NUM", "BRND_GNRC_CD", "PRSCRBR_ID")
pde_silver = (pde_silver.withColumn("PDE_ID", regexp_replace(col("PDE_ID").cast("string"),"-", ""))
                        .withColumn("BENE_ID", regexp_replace(col("BENE_ID").cast("string"), "-", ""))
                        .withColumn("SRVC_DT", to_date(col("SRVC_DT"), "dd-MMM-yyyy")))
pde_silver.show(5)

+-----------+--------------+----------+------------+--------------+--------------+------------+----------+
|     PDE_ID|       BENE_ID|   SRVC_DT|PROD_SRVC_ID|QTY_DSPNSD_NUM|DAYS_SUPLY_NUM|BRND_GNRC_CD|PRSCRBR_ID|
+-----------+--------------+----------+------------+--------------+--------------+------------+----------+
|10602140347|10000010254618|2015-03-25| 68115025030|            63|            63|           G|9999987089|
|10602140348|10000010254618|2016-05-27| 53978010903|             7|             7|           B|9999999569|
|10602140349|10000010254618|2016-10-03| 55154010000|            90|            90|           B|9999997109|
|10602140350|10000010254618|2017-09-20| 13107021199|            40|            10|           B|9999999569|
|10602140351|10000010254618|2017-09-30| 13107021199|            40|            10|           G|9999999569|
+-----------+--------------+----------+------------+--------------+--------------+------------+----------+
only showing top 5 rows


In [7]:
opioid_ndc = (spark.read.parquet("/content/drive/MyDrive/Parquet")
              .withColumn("ndc", regexp_replace(col("ndc"), "'", "")).select("ndc").distinct())

In [8]:
# Flag opioid fills
from pyspark.sql.functions import lit

opioid_flag = opioid_ndc.withColumn("is_opioid", lit(1))

pde_flagged = (pde_silver.join(opioid_flag, pde_silver.PROD_SRVC_ID == opioid_flag.ndc, "left")
               .drop(opioid_flag.ndc)
               .fillna({"is_opioid": 0}))
pde_flagged.show(5)

+-----------+--------------+----------+------------+--------------+--------------+------------+----------+---------+
|     PDE_ID|       BENE_ID|   SRVC_DT|PROD_SRVC_ID|QTY_DSPNSD_NUM|DAYS_SUPLY_NUM|BRND_GNRC_CD|PRSCRBR_ID|is_opioid|
+-----------+--------------+----------+------------+--------------+--------------+------------+----------+---------+
|10602140347|10000010254618|2015-03-25| 68115025030|            63|            63|           G|9999987089|        0|
|10602140348|10000010254618|2016-05-27| 53978010903|             7|             7|           B|9999999569|        0|
|10602140349|10000010254618|2016-10-03| 55154010000|            90|            90|           B|9999997109|        0|
|10602140350|10000010254618|2017-09-20| 13107021199|            40|            10|           B|9999999569|        0|
|10602140351|10000010254618|2017-09-30| 13107021199|            40|            10|           G|9999999569|        0|
+-----------+--------------+----------+------------+------------

In [9]:
import pandas as pd

a = pd.read_excel("/content/drive/MyDrive/section-111_excluded_icd9_october2025.xlsx")
a1 = spark.createDataFrame(a)

b = pd.read_excel("/content/drive/MyDrive/section111_excluded_icd10_october2025.xlsx")
b1 = spark.createDataFrame(b)

c = pd.read_excel("/content/drive/MyDrive/section111_valid_icd10_october2025.xlsx")
c1 = spark.createDataFrame(c)

d = pd.read_excel("/content/drive/MyDrive/section111_valid_icd9_october2025.xlsx")
d1 = spark.createDataFrame(d)

In [10]:
from pyspark.sql.functions import col, lit, trim, upper, regexp_replace

def norm_code(c):
    return upper(regexp_replace(trim(c), r"[.\s]", ""))

# Excluded ICD-9
icd9_excl_std = (
    a1.select(
        norm_code(col("CODE")).alias("icd_code"),
        trim(col("LONG DESCRIPTION (EXCLUDED FY2026 ICD-9 ALL TYPES E, L, D)")).alias("icd_desc_long"))
    .withColumn("icd_desc_short", lit(None).cast("string"))
    .withColumn("nf_excl", lit(None).cast("string"))
    .withColumn("icd_version", lit("ICD9"))
    .withColumn("is_valid", lit(0)))

# Excluded ICD-10
icd10_excl_std = (
    b1.select(
        norm_code(col("CODE")).alias("icd_code"),
        trim(col("LONG DESCRIPTION (EXCLUDED FY2026 ICD-10 ALL TYPES E, L, D)")).alias("icd_desc_long"))
    .withColumn("icd_desc_short", lit(None).cast("string"))
    .withColumn("nf_excl", lit(None).cast("string"))
    .withColumn("icd_version", lit("ICD10"))
    .withColumn("is_valid", lit(0)))

# Valid ICD-10
icd10_val_std = (
    c1.select(
        norm_code(col("CODE")).alias("icd_code"),
        trim(col("SHORT DESCRIPTION (VALID ICD-10 FY2026)")).alias("icd_desc_short"),
        trim(col("LONG DESCRIPTION (VALID ICD-10 FY2026)")).alias("icd_desc_long"),
        trim(col("NF EXCL")).alias("nf_excl"))
    .withColumn("icd_version", lit("ICD10"))
    .withColumn("is_valid", lit(1)))

# Valid ICD-9
icd9_val_std = (
    d1.select(
        norm_code(col("CODE")).alias("icd_code"),
        trim(col("LONG DESCRIPTION (VALID ICD-9 FY2026)")).alias("icd_desc_long"),
        trim(col("NF EXCL")).alias("nf_excl"))
    .withColumn("icd_desc_short", lit(None).cast("string"))
    .withColumn("icd_version", lit("ICD9"))
    .withColumn("is_valid", lit(1)))

In [11]:
# ICD descriptions
icd_ref = (icd9_val_std
           .unionByName(icd9_excl_std)
           .unionByName(icd10_val_std)
           .unionByName(icd10_excl_std)
           .dropDuplicates(["icd_code", "icd_version", "is_valid"]))

icd_ref.show(5, truncate=False)

+--------+-------------------------------------------+-------+--------------+-----------+--------+
|icd_code|icd_desc_long                              |nf_excl|icd_desc_short|icd_version|is_valid|
+--------+-------------------------------------------+-------+--------------+-----------+--------+
|0022    |Paratyphoid fever B                        |NaN    |NULL          |ICD9       |1       |
|0023    |Paratyphoid fever C                        |NaN    |NULL          |ICD9       |1       |
|0030    |Salmonella gastroenteritis                 |NaN    |NULL          |ICD9       |1       |
|00320   |Localized salmonella infection, unspecified|NaN    |NULL          |ICD9       |1       |
|00324   |Salmonella osteomyelitis                   |NaN    |NULL          |ICD9       |1       |
+--------+-------------------------------------------+-------+--------------+-----------+--------+
only showing top 5 rows


In [12]:
icd_ref.write.mode("overwrite").parquet("/content/drive/MyDrive/icd_reference_parquet")

In [13]:
from pyspark.sql.functions import col, upper, lower, regexp_replace, trim

# Standardize diagnosis codes to improve join match rate with ICD reference
carrier_clean = (carrier_silver.withColumn("PRNCPAL_DGNS_CD", upper(regexp_replace(trim(col("PRNCPAL_DGNS_CD")), r"[.\s]", ""))))

# Attach ICD long description for interpretability
final = (carrier_clean
         .join(icd_ref, carrier_clean.PRNCPAL_DGNS_CD == icd_ref.icd_code, "left")
         .select("BENE_ID", "CLM_ID", "PRNCPAL_DGNS_CD", "icd_desc_long", "PRF_PHYSN_NPI", "CLM_FROM_DT", "CLM_THRU_DT")
         .withColumnRenamed("icd_desc_long", "ICD_Description"))

# Define pain-related cohort using keywords in ICD description
final_data = final.filter(lower(col("ICD_Description")).rlike(r"\b(back pain|neck pain|sprain|headache|headaches)\b"))
final_data.show(5, truncate=False)

+--------------+--------------+---------------+----------------------------------------------------------+-------------+-----------+-----------+
|BENE_ID       |CLM_ID        |PRNCPAL_DGNS_CD|ICD_Description                                           |PRF_PHYSN_NPI|CLM_FROM_DT|CLM_THRU_DT|
+--------------+--------------+---------------+----------------------------------------------------------+-------------+-----------+-----------+
|10000010260050|10000930282679|G43C1          |Periodic headache syndromes in child or adult, intractable|9999952390   |2022-08-18 |2022-08-18 |
|10000010260050|10000930282679|G43C1          |Periodic headache syndromes in child or adult, intractable|9999952390   |2022-08-18 |2022-08-18 |
|10000010260050|10000930282679|G43C1          |Periodic headache syndromes in child or adult, intractable|9999952390   |2022-08-18 |2022-08-18 |
|10000010260050|10000930282679|G43C1          |Periodic headache syndromes in child or adult, intractable|9999952390   |2022-08-18

In [14]:
from pyspark.sql.functions import col, date_add

# Remove duplicate claim rows to avoid over-counting when joining to PDEs
final_data_d = final_data.dropDuplicates(["BENE_ID", "CLM_ID", "CLM_FROM_DT", "CLM_THRU_DT", "PRNCPAL_DGNS_CD"])

c = final_data_d.alias("c")
p = pde_flagged.alias("p")

# Link fills within claim window through +7 days after claim end
df = (c.join(p,
        (col("c.BENE_ID") == col("p.BENE_ID")) &
        (col("p.SRVC_DT").between(col("c.CLM_FROM_DT"), date_add(col("c.CLM_THRU_DT"), 7))), "left"))

total_claims = final_data.count()
matched_rows = df.filter(col("PDE_ID").isNotNull()).count()
opioid_matches = df.filter((col("PDE_ID").isNotNull()) & (col("is_opioid") == 1)).count()

print("Total carrier claims:", total_claims)
print("Claims with >=1 PDE match in 0-7 days:", matched_rows)
print("Opioid matches (within matched rows):", opioid_matches)

df.show(10, truncate=False)

Total carrier claims: 25463
Claims with >=1 PDE match in 0-7 days: 5906
Opioid matches (within matched rows): 769
+--------------+--------------+---------------+--------------------------+-------------+-----------+-----------+-----------+--------------+----------+------------+--------------+--------------+------------+----------+---------+
|BENE_ID       |CLM_ID        |PRNCPAL_DGNS_CD|ICD_Description           |PRF_PHYSN_NPI|CLM_FROM_DT|CLM_THRU_DT|PDE_ID     |BENE_ID       |SRVC_DT   |PROD_SRVC_ID|QTY_DSPNSD_NUM|DAYS_SUPLY_NUM|BRND_GNRC_CD|PRSCRBR_ID|is_opioid|
+--------------+--------------+---------------+--------------------------+-------------+-----------+-----------+-----------+--------------+----------+------------+--------------+--------------+------------+----------+---------+
|10000010254667|10000930038319|M5450          |Low back pain, unspecified|9999968891   |2015-03-27 |2015-03-27 |10602140490|10000010254667|2015-03-27|60951091005 |3             |10            |B        

In [15]:
from pyspark.sql.functions import col, max as spark_max, countDistinct, sum as spark_sum

# Collapse the join output (many rows per claim due to multiple PDE fills) into 1 row per claim
claim_level = (df                          # One row per unique carrier claim
    .groupBy("c.BENE_ID", "c.CLM_ID", "c.CLM_FROM_DT", "c.CLM_THRU_DT", "c.PRNCPAL_DGNS_CD", "c.ICD_Description", "c.PRF_PHYSN_NPI")
    .agg(
        spark_max((col("p.PDE_ID").isNotNull()).cast("int")).alias("has_any_pde_0_7"),
        spark_max((col("p.is_opioid") == 1).cast("int")).alias("has_opioid_pde_0_7"),
        countDistinct("p.PDE_ID").alias("num_pde_fills_0_7"),
        spark_sum((col("p.is_opioid") == 1).cast("int")).alias("num_opioid_fills_0_7"))
    )

claim_level.show(5, truncate=False)

+--------------+--------------+-----------+-----------+---------------+--------------------------+-------------+---------------+------------------+-----------------+--------------------+
|BENE_ID       |CLM_ID        |CLM_FROM_DT|CLM_THRU_DT|PRNCPAL_DGNS_CD|ICD_Description           |PRF_PHYSN_NPI|has_any_pde_0_7|has_opioid_pde_0_7|num_pde_fills_0_7|num_opioid_fills_0_7|
+--------------+--------------+-----------+-----------+---------------+--------------------------+-------------+---------------+------------------+-----------------+--------------------+
|10000010254667|10000930038319|2015-03-27 |2015-03-27 |M5450          |Low back pain, unspecified|9999968891   |1              |0                 |3                |0                   |
|10000010254667|10000930038320|2016-04-01 |2016-04-01 |M5450          |Low back pain, unspecified|9999968891   |1              |1                 |3                |3                   |
|10000010254667|10000930038321|2017-04-07 |2017-04-07 |M5450     

In [16]:
# Provider-level prescribing metrics
from pyspark.sql.functions import count, avg

provider_metrics = (claim_level
    .groupBy("PRF_PHYSN_NPI")
    .agg(
        count("*").alias("pain_claims"),
        spark_sum("has_any_pde_0_7").alias("claims_with_any_pde"), # How many of those claims had any PDE fill within 0–7 days
        spark_sum("has_opioid_pde_0_7").alias("claims_with_opioid"), # How many of those claims had an opioid fill within 0–7 days
        avg("has_opioid_pde_0_7").alias("opioid_rate_per_pain_claim"), # Rate share of pain claims that resulted in an opioid fill within 0–7 days
        spark_sum("num_opioid_fills_0_7").alias("total_opioid_fills_0_7")) # Total opioid fills linked in the window
    )

provider_metrics.orderBy(col("opioid_rate_per_pain_claim").desc()).show(10, truncate=False)

+-------------+-----------+-------------------+------------------+--------------------------+----------------------+
|PRF_PHYSN_NPI|pain_claims|claims_with_any_pde|claims_with_opioid|opioid_rate_per_pain_claim|total_opioid_fills_0_7|
+-------------+-----------+-------------------+------------------+--------------------------+----------------------+
|9999999698   |5          |5                  |5                 |1.0                       |6                     |
|9999974790   |8          |3                  |3                 |1.0                       |4                     |
|9999998898   |2          |2                  |2                 |1.0                       |2                     |
|9999965996   |4          |4                  |4                 |1.0                       |4                     |
|9999974592   |2          |1                  |1                 |1.0                       |3                     |
|9999925396   |7          |1                  |1                

In [17]:
# Top N providers by opioid rate
provider_metrics.orderBy(col("opioid_rate_per_pain_claim").desc()).show(10, truncate=False)

+-------------+-----------+-------------------+------------------+--------------------------+----------------------+
|PRF_PHYSN_NPI|pain_claims|claims_with_any_pde|claims_with_opioid|opioid_rate_per_pain_claim|total_opioid_fills_0_7|
+-------------+-----------+-------------------+------------------+--------------------------+----------------------+
|9999999698   |5          |5                  |5                 |1.0                       |6                     |
|9999974790   |8          |3                  |3                 |1.0                       |4                     |
|9999998898   |2          |2                  |2                 |1.0                       |2                     |
|9999965996   |4          |4                  |4                 |1.0                       |4                     |
|9999974592   |2          |1                  |1                 |1.0                       |3                     |
|9999925396   |7          |1                  |1                

In [18]:
# Z-score outliers

from pyspark.sql.functions import mean, stddev

stats = provider_metrics.select(
    mean("opioid_rate_per_pain_claim").alias("mu"),
    stddev("opioid_rate_per_pain_claim").alias("sigma")
).collect()[0]

mu = stats["mu"]
sigma = stats["sigma"]

provider_outliers = (provider_metrics
    .withColumn("z_opioid_rate", (col("opioid_rate_per_pain_claim") - mu) / sigma)
    .filter(col("pain_claims") >= 30)
    .orderBy(col("z_opioid_rate").desc())
)

provider_outliers.show(10, truncate=False)

+-------------+-----------+-------------------+------------------+--------------------------+----------------------+-------------------+
|PRF_PHYSN_NPI|pain_claims|claims_with_any_pde|claims_with_opioid|opioid_rate_per_pain_claim|total_opioid_fills_0_7|z_opioid_rate      |
+-------------+-----------+-------------------+------------------+--------------------------+----------------------+-------------------+
|9999994897   |35         |21                 |18                |0.8571428571428571        |31                    |1.0399313228114184 |
|9999642397   |54         |46                 |39                |0.8478260869565217        |79                    |1.0169116525506419 |
|9999777292   |46         |34                 |23                |0.6764705882352942        |62                    |0.5935302661857687 |
|9999838995   |93         |91                 |6                 |0.06593406593406594       |23                    |-0.9149699054883872|
|9999996595   |62         |21            

In [26]:
# Compute provider-level opioid fill counts and the % of fills with longer day-supply (>=3/7/10 days),
# then rank providers to highlight potential overprescribing patterns
from pyspark.sql.functions import avg, countDistinct

opioid_fills = df.filter(col("is_opioid") == 1)


provider_opioid_fill_metrics2 = (opioid_fills
    .groupBy("c.PRF_PHYSN_NPI")
    .agg(
        countDistinct("p.PDE_ID").alias("opioid_fills"),
        avg((col("p.DAYS_SUPLY_NUM") >= 3).cast("int")).alias("pct_3plus"),
        avg((col("p.DAYS_SUPLY_NUM") >= 7).cast("int")).alias("pct_7plus"),
        avg((col("p.DAYS_SUPLY_NUM") >= 10).cast("int")).alias("pct_10plus"),)
    )

provider_opioid_fill_metrics2.orderBy(col("opioid_fills").desc(),col("pct_10plus").desc()).show(10, truncate=False)

+-------------+------------+---------+---------+-------------------+
|PRF_PHYSN_NPI|opioid_fills|pct_3plus|pct_7plus|pct_10plus         |
+-------------+------------+---------+---------+-------------------+
|9999642397   |71          |1.0      |1.0      |0.6329113924050633 |
|9999777292   |51          |1.0      |1.0      |0.3548387096774194 |
|9999994897   |23          |1.0      |1.0      |0.5161290322580645 |
|9999939298   |17          |1.0      |1.0      |0.7647058823529411 |
|9999838995   |17          |1.0      |1.0      |0.43478260869565216|
|9999912295   |16          |1.0      |1.0      |0.5625             |
|9999874495   |15          |1.0      |1.0      |0.8666666666666667 |
|9999908897   |14          |1.0      |1.0      |1.0                |
|9999968891   |14          |1.0      |1.0      |1.0                |
|9999970095   |12          |1.0      |1.0      |1.0                |
+-------------+------------+---------+---------+-------------------+
only showing top 10 rows
