# **Libraries and Data**

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import round
from pyspark.sql.functions import col, count, isnan, isnull, countDistinct, sum, when

spark = SparkSession.builder.getOrCreate()

num = spark.read.option("header", True).option("delimiter", "\t").option("inferSchema", True).csv("dbfs:/FileStore/Data/2024q1/num.txt")

tag = spark.read.option("header", True).option("delimiter", "\t").option("inferSchema", True).csv("dbfs:/FileStore/Data/2024q1/tag.txt")

pre = spark.read.option("header", True).option("delimiter", "\t").option("inferSchema", True).csv("dbfs:/FileStore/Data/2024q1/pre.txt")

sub = spark.read.option("header", True).option("delimiter", "\t").option("inferSchema", True).csv("dbfs:/FileStore/Data/2024q1/sub.txt")


## Defining the generate_df_info function for df summary generation

### Resulting DF contains
* **`cols`**

  * **Description**: The name of each column in the original DataFrame.

* **`data_types`**

  * **Description**: The Spark‐inferred data type for that column (e.g., “string,” “int,” “timestamp”).

* **`null_percentage`**

  * **Description**: The percentage of rows in which that column is null (empty).
  * **Importance**: Shows us how incomplete a column is; if it’s mostly null, we might drop it or treat it differently in your pipeline.

* **`distinct_count`**

  * **Description**: The number of unique, non‐null values that column contains.
  * **Importance**: Helps us understand cardinality. A low distinct count often means a small set of categories (good for dimension tables).

* **`top1_value`**

  * **Description**: The single most frequent (non‐null) value that appears in that column.
  * **Importance**: Quickly tells you if one category dominates. For example, if “USP” appears 95% of the time, that column may not be very informative for analysis.

* **`top1_value_count`**

  * **Description**: How many times the `top1_value` appears.
  * **Importance**: Paired with `distinct_count`, it shows us how skewed the distribution is. A high count relative to total rows indicates low variability.

* **`top1_value_percentage`**

  * **Description**: The `top1_value_count` expressed as a percentage of total rows.
  * **Importance**: Gives immediate context for how dominant the top value is (e.g., 80% means four out of five rows share that same value, suggesting low diversity in that column).


In [0]:
from pyspark.sql.functions import *
from pyspark.sql.types import *


def generate_df_info(df):
    """
    Given any Spark DataFrame `df`, returns a summary DataFrame with:
      - cols
      - data_types
      - null_percentage
      - distinct_count
      - top1_value
      - top1_value_count
      - top1_value_percentage
    """
    # 1. Total row count
    total_rows = df.count()

    # 2. Computing null counts per column (only isNull)
    null_exprs = [
        sum(when(col(c).isNull(), 1).otherwise(0)).alias(f"{c}_nullCount")
        for c in df.columns
    ]
    null_counts_dict = df.select(*null_exprs).collect()[0].asDict()

    # 3. Build initial rows: (col_name, dtype_str, null_pct)
    log_rows = []
    for col_name, dtype_str in df.dtypes:
        null_count = null_counts_dict[f"{col_name}_nullCount"]
        null_pct = float(null_count) * 100.0 / total_rows
        log_rows.append((col_name, dtype_str, null_pct))

    # 4. Defining schema for core df_info
    core_schema = StructType([
        StructField("cols", StringType(), nullable=False),
        StructField("data_types", StringType(), nullable=False),
        StructField("null_percentage", DoubleType(), nullable=False)
    ])
    df_info = spark.createDataFrame(log_rows, schema=core_schema)

    # 5. Computing distinct (non-null) count per column
    distinct_exprs = [
        countDistinct(col(c)).alias(f"{c}_distinctCount")
        for c in df.columns
    ]
    distinct_counts_dict = df.select(*distinct_exprs).collect()[0].asDict()

    distinct_rows = [
        (column.replace("_distinctCount", ""), distinct_counts_dict[column])
        for column in distinct_counts_dict
    ]
    distinct_schema = StructType([
        StructField("cols", StringType(), nullable=False),
        StructField("distinct_count", IntegerType(), nullable=False)
    ])
    distinct_df = spark.createDataFrame(distinct_rows, schema=distinct_schema)

    # 6. Computing top-1 value, count, and percentage for string columns
    string_cols = [name for name, dtype in df.dtypes if dtype == "string"]
    top1_rows = []
    for c in string_cols:
        top_row = (
            df.filter(col(c).isNotNull())
              .groupBy(c).count()
              .orderBy(desc("count"))
              .limit(1)
              .collect()
        )
        if top_row:
            val = top_row[0][c]
            cnt = top_row[0]["count"]
            pct = float(cnt) * 100.0 / total_rows
        else:
            val, cnt, pct = None, 0, 0.0
        top1_rows.append((c, val, cnt, pct))

    top1_schema = StructType([
        StructField("cols", StringType(), nullable=False),
        StructField("top1_value", StringType(), nullable=True),
        StructField("top1_value_count", IntegerType(), nullable=False),
        StructField("top1_value_percentage", DoubleType(), nullable=False)
    ])
    top1_df = spark.createDataFrame(top1_rows, schema=top1_schema)

    # 7. Joining everything together on "cols"
    df_info = df_info \
        .join(distinct_df, on="cols", how="left") \
        .join(top1_df, on="cols", how="left")

    return df_info

# Understanding the Data

## 1. Submissions DF

In [0]:
print("sub DF row count = ", sub.count())

##### INSPECT SCHEMA & DATA TYPES

In [0]:
print("---- Schema of sub ----")
sub.printSchema()

##### Sub_Info Summary table

In [0]:
sub_info = generate_df_info(sub)
sub_info.display(truncate=False)

sub_file:
"adsh", "cik", "name", "sic", "countryba", "stprba", "cityba", "zipba", "baph",
"fye", "form", "period", "fy", "fp", "filed", "delay_days", "accepted", "instance"

1. **Column Selection for Gold**

   * If a column’s `distinct_count` is extremely high and we only need a few metrics, so we might drop that column from the Gold table (too many distinct values can bloat your fact tables).
   * If a column’s `null_percentage` > 50%, might be that it’s too sparse to include as a Gold attribute and either fill it or drop it.

2. **Data‐Type Corrections**

   * The `note` column already flags “Convert to datetime” or “Convert to bool.” Once converted, we can set `is_good = True` for those rows in the next run.

3. **Dimension Strategy**

   * Low‐cardinality string columns (distinct\_count < 50) are good candidates for dimension tables.

4. **Automated Alerts**

   * If `top1_count / total_rows` is > 90% for a column (i.e., one value dominates 90% of records), we might want to examine whether the column is actually useful or should be dropped.


## 2. Numbers DF

In [0]:
print("sub DF row count = ", num.count())

In [0]:
print("---- Schema of sub ----")
num.printSchema()

In [0]:
num_info = generate_df_info(num)
num_info.display(truncate=False)

num_file:
adsh
tag
version
ddate
qtrs
uom
value


In [0]:
num.where(col("adsh") == "0001161697-24-000084").display()

### One to many connection of SUB and NUM
Each SUB row is like a “folder” for one filing, and that folder can contain **hundreds or even thousands** of individual numbers. For example, a single 10-K report might tag and report:

* **Balance sheet line items**: cash, receivables, inventory, property, debt…
* **Income statement items**: revenue, cost of goods sold, operating expenses, net income…
* **Cash flow items**: operating cash flow, investing cash flow, financing cash flow…
* **Footnotes and segments**: interest expense by segment, tax footnotes, etc.

Each of those tagged numbers becomes **one row** in NUM.

#### **Quarter**

In [0]:
num.groupBy("qtrs").count().orderBy(col("count").desc()).show()

- Values are mostly recorded as "Point-in-time"(0) or "for a year"(4)
- in vary rare cases we see the use of values other than 0,1,2,3 and 4 which might be an anomaly.

#### Units of measurement

In [0]:
num.groupBy("uom").agg(count("*").alias("count"), (count("*")*100/num.count()).alias("percentage rows covered")).orderBy(col("count").desc()).show()

Description for each of the units 

* **USD**: United States dollars (the standard currency unit for most financial facts).
* **shares**: Number of equity shares outstanding or issued.
* **pure**: A unitless, “pure” numeric value (no currency or other unit).
* **CAD**: Canadian dollars.
* **EUR**: Euros (the common currency of the Eurozone).
* **GBP**: British pounds sterling.
* **CNY**: Chinese yuan renminbi.
* **BRL**: Brazilian reais.
* **CHF**: Swiss francs.
* **CLP**: Chilean pesos.
* **COP**: Colombian pesos.
* **Rate**: A proportion or percentage rate (e.g., interest or growth rate).
* **AUD**: Australian dollars.
* **JPY**: Japanese yen.
* **DKK**: Danish kroner.
* **SEK**: Swedish kronor.
* **PHP**: Philippine pesos.
* **HKD**: Hong Kong dollars.
* **ILS**: Israeli new shekels.
* **SGD**: Singapore dollars.


In [0]:
summary_table = num.summary()
summary_table.select("summary", "value").display()

We see values in both positives and negatives, depicting that '-' negative sign is being used to measure cretid from the repective company

### Inferences on NUM

1. **Tag cardinality vs. “hot” tags**

   * **`tag`** has **65 705** distinct values—an extremely long tail of seldom-used tags. The single most common tag, **`StockholdersEquity`**, appears \~102 770 times (3 %).
   * **Inference**: You’ll want to **filter** in Silver to a curated list of “Gold” tags (e.g., top 20 by frequency) rather than ingest every one of the \~65 000 tags into your main fact table.

4. **Quarter vs. date granularity**

   * **`qtrs`** shows 67 distinct values—more than the expected {0,1,2,3,4}. Any other numbers are not valid under the XBRL definition
   * **Inference**: We have to implement a cleaning transformation step in Silver to enforce that `qtrs ∈ {0,1,2,3,4}` and that `ddate` falls on end-of-quarter.

5. **Unit of measure concentration**

   * **`uom`** has 89 distinct units, but **84 %** of rows are **`USD`**; the rest include “shares,” “pure,” and dozens of other currencies.
   * **Inference**:

     * **Filter** out non-USD if our client only care about dollar amounts.

6. **Sparse dimension columns**

   * **`segments`** is \~45 % null.
   * **`coreg`** is \~99 % null (only a handful of co-registrant cases).
   * **`footnote`** is \~99.8 % null.
   * **Inference**:

     * Treat `coreg` and `footnote` as **“detail”** fields that belong in a secondary table (e.g., only join when the user asks for footnotes).

7. **Value coverage and cleaning**

   * **`value`** (cast to `double`) has a small null rate (3.6 % of rows failed to convert).
   * **Inference**:

     * Add a Silver step to **clean** the `value` field (strip non-numeric characters, catch “NM” or “—” cases) before final cast.
     * We may Consoder dropping rows where `value` remains null after cleaning.

## A Look on NULLS in num

### Nulls vs Tags

In [0]:
# Filtering the null value rows from the df
null_vals = num.filter(col("value").isNull())

# Total number of null-value rows
total_nulls = null_vals.count()
print(f"Total null-value rows: {total_nulls}")

# TAGs have with most nulls
null_by_tag = (
    null_vals.groupBy("tag")
             .agg(count("*").alias("null_count"))
)

total_by_tag = num.groupBy("tag").agg(count("*").alias("total_count"))

tag_null_summary = (
    null_by_tag
      .join(total_by_tag, on="tag", how="inner")
      .withColumn("null_pct_of_tag", round(col("null_count") / col("total_count") * 100, 2))
      .orderBy(col("null_count").desc())
)

print("=================== Top 10 tags by # of null values ==================")
tag_null_summary.select("tag", "null_count", "total_count", "null_pct_of_tag") \
                .display(10, truncate=False)

#### Taking tags which have more than 50 percent nulls and ordering them by null count (desc)

In [0]:
tag_null_summary.select("tag", "null_count", "total_count", "null_pct_of_tag").where(col("null_pct_of_tag") >= 50).orderBy(col("null_count").desc()).display()

### Nulls vs Units of Measure

In [0]:
null_by_uom = (
    null_vals.groupBy("uom")
             .agg(count("*").alias("null_count"))
)
total_by_uom = num.groupBy("uom").agg(count("*").alias("total_count"))

uom_null_summary = (
    null_by_uom
      .join(total_by_uom, on="uom", how="inner")
      .withColumn("null_pct_of_uom",
                  round(col("null_count") / col("total_count") * 100, 2))
      .orderBy(col("null_count").desc())
)

print("================ Top 10 UOMs by # of null values =============")
uom_null_summary.select("uom", "null_count", "total_count", "null_pct_of_uom").display(10, truncate=False)


### Nulls vs Quarters

In [0]:

null_by_qtrs = (
    null_vals.groupBy("qtrs")
             .agg(count("*").alias("null_count"))
)
total_by_qtrs = num.groupBy("qtrs").agg(count("*").alias("total_count"))

qtrs_null_summary = (
    null_by_qtrs
      .join(total_by_qtrs, on="qtrs", how="inner")
      .withColumn("null_pct_of_qtrs",
                  round(col("null_count") / col("total_count") * 100, 2))
      .orderBy(col("null_count").desc())
)

print("================ Null-value breakdown by qtrs =================")
qtrs_null_summary.select("qtrs", "null_count", "total_count", "null_pct_of_qtrs").display(truncate=False)


## 3. Tag DF

In [0]:
print("sub DF row count = ", tag.count())

In [0]:
print("---- Schema of sub ----")
tag.printSchema()

In [0]:
tag_info = generate_df_info(tag)
tag_info.display(truncate=False)

tag_file:
tag_id
tag
version
datatype
iord
tlabel
doc

### Custom tag proportion in TAG table VS. Custom tags usage in NUM table

In [0]:
tag.groupby("custom").agg(count("*").alias("count"), (count("*")*100/tag.count()).alias("percentage")).display()

- This 79993 non-custom tags were filed

In [0]:
num_tag_joined = num.join(
    tag.select("tag", "version", "custom", "abstract"),
    on=["tag", "version"],
    how="left"
)
num_tag_joined.groupby("custom").agg(count("*").alias("count"), (count("*")*100/num_tag_joined.count()).alias("percentage")).display()

Although Custom tags are 92.4 percent of the definitions but only use 8.8 percent of the numbers table

## 4. Presentation DF

In [0]:
print("sub DF row count = ", pre.count())

In [0]:
print("---- Schema of sub ----")
pre.printSchema()

In [0]:
pre_info = generate_df_info(pre)
pre_info.display(truncate=False)

pre_file:
adsh
report
line
stmt
tag
version
plabel


In [0]:
sub.select("form").distinct().show()

In [0]:
tag.groupBy("datatype").count().display()

In [0]:
tag.where(col("datatype") != "monetary").display()

## Sub nulls

#### sic and countryba
- Per-CIK mode: filling nulls with each company’s most common value for that field.

- Global fallback mode: for any CIK that has no non-null values, we fill with the overall most frequent value in the entire table.

In [0]:
from pyspark.sql import DataFrame, Window
from pyspark.sql.functions import col, when, row_number

def impute_sub_categorical_modes(df_sub: DataFrame) -> DataFrame:
    """
    Imputes nulls in the 'sic' and 'countryba' columns of a SUB DataFrame by:
      1. Per-CIK mode (most frequent) fill
      2. Global mode fallback for any remaining nulls

    Parameters:
        df_sub (DataFrame): Cleaned SUB DataFrame containing at least 'cik', 'sic', and 'countryba'.

    Returns:
        DataFrame: A new DataFrame with no nulls in 'sic' or 'countryba'.
    """
    df = df_sub
    to_impute = ["sic", "countryba"]

    # 1) Per-CIK mode fill
    for c in to_impute:
        mode_per_cik = (
            df
              .filter(col(c).isNotNull())
              .groupBy("cik", c)
              .count()
              .withColumn(
                  "rn",
                  row_number().over(
                      Window.partitionBy("cik")
                            .orderBy(col("count").desc())
                  )
              )
              .filter(col("rn") == 1)
              .select("cik", col(c).alias(f"{c}_cik_mode"))
        )
        df = (
            df
              .join(mode_per_cik, on="cik", how="left")
              .withColumn(
                  c,
                  when(col(c).isNull(), col(f"{c}_cik_mode"))
                  .otherwise(col(c))
              )
              .drop(f"{c}_cik_mode")
        )

    # 2) Global mode fallback for any remaining nulls
    global_modes = {}
    for c in to_impute:
        mode_val = (
            df
              .filter(col(c).isNotNull())
              .groupBy(c)
              .count()
              .orderBy(col("count").desc())
              .limit(1)
              .collect()[0][c]
        )
        global_modes[c] = mode_val

    # 3) Apply global fallback
    for c, mode_val in global_modes.items():
        df = df.withColumn(
            c,
            when(col(c).isNull(), mode_val)
            .otherwise(col(c))
        )

    return df

#### fye, fy, period, fp

##### following are default

- fye → 1231 (Dec 31)

- fy → the filing’s calendar year (extracted from the filed date)

- period → fiscal-year end date in YYYYMMDD form, i.e. fy * 10000 + fye

- fp → "FY"

In [0]:
from pyspark.sql import DataFrame, Window
from pyspark.sql.functions import (
    col, when, row_number,
    percentile_approx, to_date,
    year, lit
)

def impute_sub_date_fields(df_sub: DataFrame) -> DataFrame:
    """
    Impute nulls in the SUB DataFrame for the following columns:
      - fy, fp       : categorical (mode per CIK, then fallback)
      - period, fye  : numeric dates (median per CIK, then fallback)

    Parameters:
        df_sub (DataFrame): Cleaned SUB DataFrame containing at least
                            'cik', 'filed', 'fy', 'fp', 'period', and 'fye'.

    Returns:
        DataFrame: New DataFrame with no nulls in 'fy', 'fp', 'period', or 'fye'.
    """
    df = df_sub

    # Convert 'filed' (int YYYYMMDD) to date for extracting year
    df = df.withColumn(
        "filed_dt",
        to_date(col("filed").cast("string"), "yyyyMMdd")
    )

    # 1) Impute categorical 'fy' and 'fp' by mode per CIK
    for c in ["fy", "fp"]:
        mode_df = (
            df.filter(col(c).isNotNull())
              .groupBy("cik", c)
              .count()
              .withColumn(
                  "rn",
                  row_number().over(
                      Window.partitionBy("cik")
                            .orderBy(col("count").desc())
                  )
              )
              .filter(col("rn") == 1)
              .select("cik", col(c).alias(f"{c}_mode"))
        )
        df = (
            df.join(mode_df, on="cik", how="left")
              .withColumn(
                  c,
                  when(col(c).isNull(), col(f"{c}_mode"))
                  .otherwise(col(c))
              )
              .drop(f"{c}_mode")
        )

    # 2) Impute numeric 'period' and 'fye' by median per CIK
    for c in ["period", "fye"]:
        med_df = (
            df.filter(col(c).isNotNull())
              .groupBy("cik")
              .agg(
                  percentile_approx(col(c), 0.5).alias(f"{c}_med")
              )
        )
        df = (
            df.join(med_df, on="cik", how="left")
              .withColumn(
                  c,
                  when(col(c).isNull(), col(f"{c}_med"))
                  .otherwise(col(c))
              )
              .drop(f"{c}_med")
        )

    # 3) Final fallback defaults
    df = (
        df
          .withColumn("fye",
              when(col("fye").isNull(), lit(1231))
              .otherwise(col("fye"))
          )
          .withColumn("fy",
              when(col("fy").isNull(), year(col("filed_dt")))
              .otherwise(col("fy"))
          )
          .withColumn("period",
              when(col("period").isNull(), col("fy") * 10000 + col("fye"))
              .otherwise(col("period"))
          )
          .withColumn("fp",
              when(col("fp").isNull(), lit("FY"))
              .otherwise(col("fp"))
          )
          .drop("filed_dt")
    )

    return df

We pick **mode** for the categorical fields (`fy` and `fp`) and **median** for the numeric‐date fields (`period` and `fye`) because of the nature of those columns:

1. **fy (Fiscal Year) and fp (Fiscal Period)**

   * Those are **labels**, not numbers you’d average.
   * You want the single year or period code a company most often uses, so filling with the **most frequent** value (the mode) makes sense.
   * Example: if Acme Corp. has five filings and four of them say `fp = "Q2"`, then any missing `fp` should almost certainly be `"Q2"`, not `"Q1"` or `"FY"`.

2. **period (Reporting Date) and fye (Fiscal Year-End Month/Day)**

   * These are **numeric dates** stored as integers. Averaging them can produce nonsensical halfway dates (e.g. a “.5” day), and means are easily skewed by one outlier.
   * The **median** finds the exact middle of a company’s historical dates, so if one quarter was filed very late or early, it won’t drag your fill date to that extreme.
   * Example: if your company usually files around June 30 but had one odd March 31 filing, the median will stay at June 30 rather than shift toward March.

In short:

* **Mode** → best for categorical or discrete codes you want “most common.”
* **Median** → best for continuous or ordered data (like dates) where you need a robust center point.


## Num file 

In [0]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, when, count, percentile_approx

def impute_num_value_medians(df_num: DataFrame, df_sub: DataFrame) -> DataFrame:
    """
    Impute nulls in the NUM DataFrame’s `value` column by:
      1. Joining to SUB on `adsh` to bring in `cik`
      2. Grouping by (cik, tag, uom) to compute:
         - grp_count  : number of non-null values per group
         - grp_median : median of value per group
      3. Flagging groups with no non-null values as `is_incomplete`
      4. Filling null `value` entries with their group’s median where available

    Parameters:
        df_num (DataFrame): Cleaned NUM DataFrame containing `adsh`, `tag`, `uom`, and `value`.
        df_sub (DataFrame): Cleaned SUB DataFrame containing `adsh` and `cik`.

    Returns:
        DataFrame: A new DataFrame with:
          - `value` nulls imputed by group median
          - `is_incomplete` flag set True for rows in groups with zero non-null values
    """
    # 1) Join NUM → SUB to get CIK
    num_joined = df_num.join(
        df_sub.select("adsh", "cik"),
        on="adsh",
        how="left"
    )

    # 2) Compute per-group non-null count & median
    group_cols = ["cik", "tag", "uom"]
    group_stats = (
        num_joined
          .filter(col("value").isNotNull())
          .groupBy(*group_cols)
          .agg(
              count(col("value")).alias("grp_count"),
              percentile_approx(col("value"), 0.5).alias("grp_median")
          )
    )

    # 3) Join stats back
    df_with_stats = num_joined.join(group_stats, on=group_cols, how="left")

    # 4) Flag empty groups & impute medians
    df_result = (
        df_with_stats
          .withColumn(
              "is_incomplete",
              when(col("grp_count").isNull(), True).otherwise(False)
          )
          .withColumn(
              "value",
              when(
                  col("value").isNull() & col("grp_count").isNotNull(),
                  col("grp_median")
              )
              .otherwise(col("value"))
          )
          .drop("grp_count", "grp_median")
    )

    return df_result


# PRE

#### stmt 
1. Compute each tag’s most common stmt (its mode) across all rows.

2. Join that back onto PRE and fill any null stmt with the tag’s mode.

3. Fallback to "UN" (Unclassifiable) for any tag that never appeared with a non-null stmt.

In [0]:
from pyspark.sql import DataFrame, Window
from pyspark.sql.functions import col, when, row_number

def impute_pre_stmt(df_pre: DataFrame) -> DataFrame:
    """
    Impute nulls in the PRE DataFrame’s `stmt` column by:
      1. Computing each tag’s most frequent statement type (mode) across filings.
      2. Filling null `stmt` values with that per-tag mode.
      3. As a final fallback, assigning 'UN' (Unclassifiable) to any remaining nulls.

    Parameters:
        df_pre (DataFrame): Cleaned PRE DataFrame containing at least 'tag' and 'stmt'.

    Returns:
        DataFrame: A new DataFrame with no nulls in `stmt`.
    """
    # 1) Build per-tag mode lookup for stmt
    tag_mode_stmt = (
        df_pre
          .filter(col("stmt").isNotNull())
          .groupBy("tag", "stmt")
          .count()
          .withColumn(
              "rn",
              row_number().over(
                  Window.partitionBy("tag")
                        .orderBy(col("count").desc())
              )
          )
          .filter(col("rn") == 1)
          .select("tag", col("stmt").alias("stmt_mode"))
    )

    # 2) Left-join and fill with per-tag mode
    df_filled = (
        df_pre
          .join(tag_mode_stmt, on="tag", how="left")
          .withColumn(
              "stmt",
              when(col("stmt").isNull(), col("stmt_mode"))
              .otherwise(col("stmt"))
          )
          .drop("stmt_mode")
    )

    # 3) Final fallback: assign 'UN' to any still-null stmt
    df_result = df_filled.withColumn(
        "stmt",
        when(col("stmt").isNull(), "UN")
        .otherwise(col("stmt"))
    )

    return df_result


In [0]:
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, when, regexp_replace, initcap

def impute_plabel_from_tag(df_pre: DataFrame) -> DataFrame:
    """
    Fill null 'plabel' values in a PRE DataFrame by deriving a label from the 'tag':
      - Inserts spaces before capital letters that follow lowercase letters
      - Converts the result to title case

    Parameters:
        df_pre (DataFrame): Cleaned PRE DataFrame containing 'tag' and 'plabel'.

    Returns:
        DataFrame: A new DataFrame with no nulls in 'plabel'.
    """
    return (
        df_pre
          .withColumn(
              "plabel",
              when(
                  col("plabel").isNull(),
                  initcap(
                      regexp_replace(col("tag"), "([a-z])([A-Z])", "$1 $2")
                  )
              )
              .otherwise(col("plabel"))
          )
    )
