In [0]:
def get_cik_lookup_df():
    # Get CIK lookup table through company tickers endpoint
    cik_map = requests.get(
        "https://www.sec.gov/files/company_tickers.json", headers=header
    )
    # print(cik_map.status_code)
    cik_map_df = spark.createDataFrame(
        cik_map.json().values(),
        schema=StructType(
            [
                StructField("cik_str", StringType(), True),
                StructField("ticker", StringType(), True),
                StructField("title", StringType(), True),
            ]
        ),
    )
    # padding to get 10 digit cik
    cik_map_df = cik_map_df.withColumn("cik_str", F.lpad(F.col("cik_str"), 10, "0"))
    return cik_map_df


def get_company_facts(cik):
    # company facts endpoint
    URL = f"https://data.sec.gov/api/xbrl/companyfacts/CIK{cik}.json"
    comp_facts = requests.get(URL, headers=header)
    time.sleep(1)
    comp_facts_dict = json.loads(comp_facts.text)
    return comp_facts_dict


def create_stmnts_combined_df(comp_facts_dict, history_years=3):
    # stmnts_to_scan= [stmnt for stmnt in comp_facts_dict['facts']['us-gaap'].keys() if re.search("revenue", stmnt.lower())]
    if "us-gaap" not in comp_facts_dict["facts"].keys():
        return None
    else:
        stmnts_to_scan = [
            "RevenueFromContractWithCustomerExcludingAssessedTax",
            "RevenueFromContractWithCustomerIncludingAssessedTax",
            # 'RevenueFromContractsWithCustomers',
            "Revenues",
        ]
        stmnts_combined_df_op = spark.createDataFrame(
            [],
            schema=StructType(
                [
                    StructField("end", StringType(), True),
                    StructField("fp", StringType(), True),
                    StructField("start", StringType(), True),
                    StructField("filed", StringType(), True),
                    StructField("time_period_months", FloatType(), True),
                    StructField("val", LongType(), True),
                    StructField("orig_stmnt", StringType(), True),
                ]
            ),
        )
        for stmnt in stmnts_to_scan:
            try:
                stmnt_df = spark.createDataFrame(
                    comp_facts_dict["facts"]["us-gaap"][stmnt]["units"]["USD"]
                )
                # check must have columns -- if any statements lacks these skip it
                if not (
                    {"start", "end", "fp", "val", "filed"}.issubset(
                        set(stmnt_df.columns)
                    )
                ):
                    continue
                # basic transformations and filtering
                else:
                    stmnt_df = (
                        stmnt_df.withColumn("start", F.col("start").cast(DateType()))
                        .withColumn("end", F.col("end").cast(DateType()))
                        .withColumn("filed", F.col("filed").cast(DateType()))
                        .withColumn("val", F.col("val").cast(LongType()))
                        .withColumn(
                            "time_period_months",
                            F.round(F.months_between(F.col("end"), F.col("start")), 2),
                        )
                        .withColumn("orig_stmnt", F.lit(stmnt))
                    )
                    stmnt_df = (
                        stmnt_df
                        # filtering out records that are not within the last history_years --IMPORTANT
                        .filter(
                            f" start > date_sub(current_date(), {history_years*366}) "
                        )
                        # filter out rows that are not quarterly/yearly
                        .filter(
                            "  (time_period_months between 2.7 and 4) or (time_period_months between 11 and 13) "
                        )
                    )
                    stmnts_combined_df_op = stmnts_combined_df_op.unionByName(
                        stmnt_df.select(
                            "start",
                            "end",
                            "fp",
                            "time_period_months",
                            "val",
                            "filed",
                            "orig_stmnt",
                        )
                    )
            except KeyError as ke:
                continue
    return stmnts_combined_df_op.select(
        "start", "end", "fp", "time_period_months", "val", "filed", "orig_stmnt"
    )


def remove_inconsistencies(stmnts_combined_df):
    stmnts_combined_df_cleansed = (
        stmnts_combined_df.withColumn(
            "max_rev", F.max("val").over(Window.partitionBy("start", "end", "fp"))
        )
        .withColumn(
            "rank_rev",
            F.row_number().over(
                Window.partitionBy("start", "end", "fp", "max_rev").orderBy(
                    F.desc("val")
                )
            ),
        )
        .filter(F.col("rank_rev") == 1)
        .drop("max_rev", "rank_rev")
        # Filtering out ambigous cases
        .filter(
            r""" fp is not null and not(time_period_months < 11 and fp == "FY") and not(time_period_months > 11 and fp rlike "^Q\\d{1}$") """
        )
        .groupBy("start", "end", "time_period_months", "fp")
        .agg(
            F.count(F.lit(1)).alias("count"),
            F.max("val").alias("val"),
            F.collect_list(F.col("filed")).alias("list_filed"),
        )
        # filtering out rows with less count when start and end are same but fp is different
        .withColumn(
            "rn",
            F.rank().over(Window.partitionBy("start", "end").orderBy(F.desc("count"))),
        )
        .filter(F.col("rn") == 1)
        # filtering out rows with less count when fp is same and the respective start and end dates are very close in terms of days(inconsistent data)
        .withColumn("start_month", F.trunc(F.col("start"), "month"))
        .withColumn("end_month", F.trunc(F.col("end"), "month"))
        .withColumn(
            "rnk_m",
            F.rank().over(
                Window.partitionBy("start_month", "end_month", "fp").orderBy(
                    F.col("count").desc()
                )
            ),
        )
        .filter(F.col("rnk_m") == 1)
        .withColumn("week_end", F.date_trunc("week", F.col("end")))
        .withColumn("week_start", F.date_trunc("week", F.col("start")))
        .withColumn(
            "rnk_w",
            F.rank().over(
                Window.partitionBy("week_end", "week_start", "fp").orderBy(
                    F.col("count").desc()
                )
            ),
        )
        .filter(F.col("rnk_w") == 1)
        .withColumn("filed", F.explode("list_filed"))
        .drop(
            "rn",
            "rnk_m",
            "rnk_w",
            "week_end",
            "week_start",
            "start_month",
            "end_month",
            "count",
            "list_filed",
        )
    )
    return stmnts_combined_df_cleansed


def save_latest_q1_q3(stmnts_combined_df_cleansed):
    # save the latest q1-q3 for union later
    max_end_fy = (
        stmnts_combined_df_cleansed.filter(" fp == 'FY' ")
        .select(F.max("end"))
        .first()[0]
    )
    latest_q1_q3 = stmnts_combined_df_cleansed.filter(F.col("end") > max_end_fy)
    return latest_q1_q3


def assign_relation_between_q_and_fy(stmnts_combined_df_cleansed):
    fy_rows = stmnts_combined_df_cleansed.filter(" fp == 'FY' ").selectExpr(
        "start as fy_start", "end as fy_end"
    )

    stmnts_combined_df_with_rel = (
        stmnts_combined_df_cleansed.crossJoin(fy_rows)
        .filter(" start >= fy_start and end <= fy_end")
        .withColumn(
            "q1_q3_fy_rel", F.concat_ws(" ", F.col("fy_start"), F.col("fy_end"))
        )
    )

    return stmnts_combined_df_with_rel


def get_q4_rows(stmnts_combined_df_with_rel):
    q4_rows_df = (
        stmnts_combined_df_with_rel.withColumn(
            "q1_q3_fy_rev", F.sum("val").over(Window.partitionBy("q1_q3_fy_rel"))
        )
        .withColumn(
            "q1_q3_rev",
            F.when(F.col("fp") == "FY", F.col("q1_q3_fy_rev") - F.col("val")).otherwise(
                None
            ),
        )
        .withColumn(
            "q4_rev",
            F.when(F.col("fp") == "FY", F.col("val") - F.col("q1_q3_rev")).otherwise(
                None
            ),
        )
        .withColumn("q4_rev", F.max("q4_rev").over(Window.partitionBy("q1_q3_fy_rel")))
        .withColumn(
            "rnk_q1_q3_fy_rel",
            F.rank().over(
                Window.partitionBy("q1_q3_fy_rel").orderBy(F.col("end").desc())
            ),
        )
        .withColumn(
            "q4_start",
            F.when(
                F.col("rnk_q1_q3_fy_rel") == 2, F.date_add(F.col("end"), 1)
            ).otherwise(None),
        )
        .withColumn(
            "q4_end",
            F.when(F.col("rnk_q1_q3_fy_rel") == 1, F.col("end")).otherwise(None),
        )
        .withColumn(
            "q4_filed",
            F.when(F.col("rnk_q1_q3_fy_rel") == 1, F.col("filed")).otherwise(None),
        )
        .withColumn(
            "running_max_q4_start",
            F.max(F.col("q4_start")).over(
                Window.rowsBetween(Window.unboundedPreceding, Window.currentRow)
            ),
        )
        .withColumn(
            "running_max_q4_end",
            F.max(F.col("q4_end")).over(
                Window.rowsBetween(Window.unboundedPreceding, Window.currentRow)
            ),
        )
        .withColumn(
            "running_max_q4_filed",
            F.max(F.col("q4_filed")).over(
                Window.rowsBetween(Window.unboundedPreceding, Window.currentRow)
            ),
        )
        .withColumn(
            "time_period_months_q4",
            F.round(
                F.months_between(
                    F.col("running_max_q4_end"), F.col("running_max_q4_start")
                ),
                2,
            ),
        )
        .withColumn(
            "time_period_months_q4",
            F.round(
                F.months_between(
                    F.col("running_max_q4_end"), F.col("running_max_q4_start")
                ),
                2,
            ),
        )
        .filter(" time_period_months_q4 between 2.5 and 4 ")
        .withColumn("fp", F.lit("Q4"))
        .selectExpr(
            "running_max_q4_start as start",
            "running_max_q4_end as end",
            "fp",
            "q4_rev as val",
            "running_max_q4_filed as filed",
        )
        .distinct()
    )
    return q4_rows_df


def get_fy_all_q_union_df(stmnts_combined_df_with_rel, latest_q1_q3, q4_rows_df):
    return (
        stmnts_combined_df_with_rel.select("start", "end", "fp", "val", "filed")
        .unionByName(q4_rows_df)
        .unionByName(latest_q1_q3.select("start", "end", "fp", "val", "filed"))
    )


def period_corrected(fy_all_q_union_df):
    fy_all_q_union_df_fy_rows = fy_all_q_union_df.filter(
        F.col("fp") == "FY"
    ).selectExpr(["start as fy_start", "end as fy_end"])
    fy_all_q_union_latest_q_rows = save_latest_q1_q3(fy_all_q_union_df)
    upto_latest_fy_pc_df = (
        fy_all_q_union_df.crossJoin(fy_all_q_union_df_fy_rows)
        .filter(" start >= fy_start and end <= fy_end")
        .withColumn(
            "Period",
            F.when(
                F.col("fp").isin(["FY"]),
                F.concat_ws(
                    "",
                    F.col("fp"),
                    F.regexp_extract(
                        F.year(F.date_trunc("year", F.col("fy_end"))).cast(
                            StringType()
                        ),
                        "^\\d{2}(\\d{2})$",
                        1,
                    ),
                ),
            ).otherwise(
                F.concat_ws(
                    "",
                    F.reverse(F.col("fp")),
                    F.regexp_extract(
                        F.year(F.date_trunc("year", F.col("fy_end"))).cast(
                            StringType()
                        ),
                        "^\\d{2}(\\d{2})$",
                        1,
                    ),
                )
            ),
        )
        .select("start", "end", "fp", "Period", "val", "filed")
    )
    max_end_fy = fy_all_q_union_df_fy_rows.selectExpr(" max(fy_end) as max_end_fy ")
    try:
        period_val_max_end_fy = (
            upto_latest_fy_pc_df.join(
                max_end_fy, on=F.col("end") == F.col("max_end_fy"), how="inner"
            )
            .selectExpr("regexp_extract(Period, r'^.*(\\d{2,}).*$', 1)")
            .first()[0]
        )
        period_for_latest_q_rows = str(int(period_val_max_end_fy) + 1)
        fy_all_q_union_latest_q_pc_rows = fy_all_q_union_latest_q_rows.withColumn(
            "Period",
            F.concat_ws("", F.reverse(F.col("fp")), F.lit(period_for_latest_q_rows)),
        ).select("start", "end", "fp", "Period", "val", "filed")
    except TypeError:
        return upto_latest_fy_pc_df
    return upto_latest_fy_pc_df.unionByName(fy_all_q_union_latest_q_pc_rows)


def get_raw_table_format(fy_all_q_union_df, Ticker):
    # raw format
    fy_all_q_union_raw_table_df = (
        fy_all_q_union_df.withColumn("Ticker", F.lit(f"{Ticker}" + " " + "US Equity"))
        .withColumn("Period_Startdate", F.date_format(F.col("start"), "M/d/y"))
        .withColumn("Period_Reportdate", F.date_format(F.col("filed"), "M/d/y"))
        .withColumn("Period_Enddate", F.date_format(F.col("end"), "M/d/y"))
        .withColumn("Value", F.format_number(F.col("val"), "###,###"))
        .withColumn("KPI", F.lit("rev_Topline"))
        .withColumn(
            "Source",
            F.when(
                F.col("fp").isin(["FY", "Q4"]),
                F.concat_ws(" ", F.lit("10-K"), F.col("Period")),
            ).otherwise(F.concat_ws(" ", F.lit("10-Q"), F.col("Period"))),
        )
    )

    fy_all_q_union_raw_table_df = fy_all_q_union_raw_table_df.select(
        [
            "Ticker",
            "Period",
            "Period_Startdate",
            "Period_Enddate",
            "Period_Reportdate",
            "KPI",
            "Value",
            "Source"
        ]
    ).filter((F.col("Period").rlike("Q|F")) & (F.col("val") > 0))
    return fy_all_q_union_raw_table_df


def get_not_in_raw_data_lake(fy_all_q_union_df, Ticker):
    try:
        raw_df = spark.sql("""  select * from delta.`/Volumes/revenue_benchmarking/sec_gov_api_data/raw_data_lake` where ticker = f"{Ticker} US Equity" """)
    except:
        return fy_all_q_union_df
    not_in_raw_df = (
        F.broadcast(
            fy_all_q_union_df.withColumn(
                "period_sub_str", F.col("period").substr(0, 2)
            ).alias("a")
        )
        .join(raw_df
            # .filter(F.col("kpi") == "rev_Topline")
            # .filter(F.col("value").isNotNull())
            .select("period", "period_startdate", "period_enddate")
            .withColumn("period_sub_str", F.col("period").substr(0, 2))
            .alias("b"),
            on=(F.col("a.Period_Startdate") == F.col("b.period_startdate"))
            & (F.col("a.Period_Enddate") == F.col("b.period_enddate"))
            & (F.col("a.period_sub_str") == F.col("b.period_sub_str")),
            how="left_anti",
        )
        .drop("period_sub_str")
    )
    return not_in_raw_df