In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf, col, lower, trim, year, to_date, current_date, datediff, regexp_extract, when
from pyspark.sql.types import StringType

# Step 1: Initialize Spark Session
spark = SparkSession.builder \
    .appName("CompanyDataMart") \
    .getOrCreate()

joined_df = spark.read_csv('/joined_df.csv')
# Transformation 1: Filter for Active Companies Only
active_companies_df = joined_df.filter(col("Entity Status") == "Active")

# Transformation 2: Extract Start Year for Trend Analysis
active_companies_df = active_companies_df.withColumn(
    "Start Year", year(to_date(col("Entity Start Date"), "yyyy-MM-dd"))
)

# Transformation 3: Group by State and Count Companies
state_counts_df = active_companies_df.groupBy("Entity State").count().orderBy("count", ascending=False)

# Transformation 4: Industry-Wise Company Distribution
industry_counts_df = active_companies_df.groupBy("industry").count().orderBy("count", ascending=False)

# Transformation 5: Add Domain Extension (.com, .org, etc.) as a New Column
active_companies_df = active_companies_df.withColumn(
    "domain_extension",
    regexp_extract(col("url"), r"\.([a-z]{2,3})(?:\/|$)", 1)
)

# Transformation 6: Companies Older than 10 Years
active_companies_df = active_companies_df.withColumn(
    "Company Age (days)", datediff(current_date(), to_date(col("Entity Start Date"), "yyyy-MM-dd"))
)

older_companies_df = active_companies_df.filter(col("Company Age (days)") > 3650)

# Transformation 7: Normalize Industry Labels
active_companies_df = active_companies_df.withColumn(
    "normalized_industry",
    when(col("industry").rlike("tech|software|it"), "Technology")
    .when(col("industry").rlike("consult"), "Consulting")
    .when(col("industry").rlike("finance|bank|fund"), "Finance")
    .otherwise("Other")
)

# Step 7: Show final transformations
active_companies_df.show(truncate=False)

# Step 8: Store transformed DataFrames into the Data Mart (Delta format)

# Store active companies data
active_companies_df.write.format("delta").mode("overwrite").saveAsTable("data_mart.active_companies")

# Store state counts data
state_counts_df.write.format("delta").mode("overwrite").saveAsTable("data_mart.state_counts")

# Store industry counts data
industry_counts_df.write.format("delta").mode("overwrite").saveAsTable("data_mart.industry_counts")

# Store older companies data (older than 10 years)
older_companies_df.write.format("delta").mode("overwrite").saveAsTable("data_mart.older_companies")
