In [None]:
from pyspark.sql import SparkSession
from pyspark import SparkConf
from pyspark.sql.functions import avg, col, count, desc
from pyspark.sql import functions as F
from pyspark.sql.types import (IntegerType, LongType, FloatType, DoubleType, DecimalType, StringType)
from pyspark.ml.feature import Imputer
import os

# =========================================
# START SPARK SESSION
# =========================================

# Configuration
project_id = os.environ.get("PROJECT_ID", "dejadsgl")
bq_dataset = os.environ.get("BQ_DATASET", "netflix")
temp_bucket = os.environ.get("TEMP_BUCKET", "netflix-group5-temp_gl")
gcs_data_bucket = os.environ.get("GCS_DATA_BUCKET", "netflix_data_25")

# Spark configuration
sparkConf = SparkConf()
sparkConf.setMaster(os.getenv("SPARK_MASTER", "local[*]"))
sparkConf.setAppName("CleanDataset")
sparkConf.set("spark.driver.memory", "2g")
sparkConf.set("spark.executor.cores", "1")
sparkConf.set("spark.driver.cores", "1")

# Create the Spark session
spark = SparkSession.builder.config(conf=sparkConf).getOrCreate()

# Use the Cloud Storage bucket for temporary BigQuery export data used by the connector
spark.conf.set('temporaryGcsBucket', temp_bucket)

# Setup hadoop fs configuration for schema gs://
conf = spark.sparkContext._jsc.hadoopConfiguration()
conf.set("fs.gs.impl", "com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem")
conf.set("fs.AbstractFileSystem.gs.impl", "com.google.cloud.hadoop.fs.gcs.GoogleHadoopFS")

print("Spark session started.")

# =========================================
# LOAD TABLES
# =========================================

# Load data from BigQuery
tables = {}
titles = [
    "movies.csv",
    "users.csv",
    "watch_history.csv",
    "reviews.csv"
]

for title in titles:
    #  Google Storage File Path
    gsc_file_path = f"gs://{gcs_data_bucket}/{title}" 
    print(f"Importing CSV from: {gsc_file_path}")
    
    # Create data frame
    df = spark.read.format("csv").option("header", "true").option("delimiter", ",") \
           .load(gsc_file_path)
    
    df.cache()
    
    # update the table title name starting with capital and without .csv
    title = title.replace(".csv", "").capitalize()

    # store in dictionary
    tables[title] = df   

    print(f"\nLoaded table: {title}")
    df.printSchema()

print("DONE: loading tables from CSV.")

# =========================================
# HELPERS
# =========================================

def choose_duplicate_keys(df):
    """
    Try to infer a reasonable key for dropDuplicates.
    """
    candidates = [
        ("CustomerID", "InvoiceDate"),            # retail example
        ("user_id", "timestamp"),
        ("userId", "timestamp"),
        ("userId", "movieId", "timestamp"),
    ]
    for keys in candidates:
        if all(k in df.columns for k in keys):
            return list(keys)
    return None


def remove_outliers_iqr(df, cols, k=1.5):
    for c in cols:
        print(f"   - Processing outliers for numeric column '{c}'")

        # skip if column is all nulls
        non_null = df.select(F.count(F.col(c))).first()[0]
        if non_null == 0:
            print(f"     Skipping '{c}' (no non-null values).")
            continue

        try:
            q1, q3 = df.approxQuantile(c, [0.25, 0.75], 0.01)
        except Exception as e:
            print(f"     Skipping '{c}' (approxQuantile error: {e})")
            continue

        iqr = q3 - q1
        lower = q1 - k * iqr
        upper = q3 + k * iqr

        before = df.count()
        df = df.filter((F.col(c) >= lower) & (F.col(c) <= upper))
        after = df.count()

        print(f"     Removed {before - after} outliers from '{c}'")
        print(f"     New row count: {after}")

    return df


def clean_table(name, df):
    print(f"\n========== Cleaning table: {name} ==========\n")
    print("STEP 0: Starting data-cleaning pipeline...")
    print("Initial row count:", df.count())

    # Detect numeric and categorical columns for this table
    numeric_cols = [
        f.name for f in df.schema.fields
        if isinstance(f.dataType, (IntegerType, LongType, FloatType, DoubleType, DecimalType))
    ]
    categorical_cols = [
        f.name for f in df.schema.fields
        if isinstance(f.dataType, StringType)
    ]

    # =========================================
    # Handling missing values (imputation)
    # =========================================
    print("\nSTEP 1: Handling missing values")

    # a) Numerical columns
    if numeric_cols:
        print(" - Imputing numerical columns:", numeric_cols)

        imputer = Imputer(
            inputCols=numeric_cols,
            outputCols=[c + "_imputed" for c in numeric_cols]
        ).setStrategy("median")

        df = imputer.fit(df).transform(df)

        for c in numeric_cols:
            df = df.drop(c).withColumnRenamed(c + "_imputed", c)

        print("   Completed numerical imputation.")
        print("   Row count after numeric imputation:", df.count())
    else:
        print(" - No numerical columns found for imputation.")

    # b) Categorical columns
    if categorical_cols:
        print(" - Imputing categorical columns:", categorical_cols)

        for c in categorical_cols:
            mode_row = (
                df.groupBy(c)
                  .count()
                  .orderBy(F.desc("count"))
                  .first()
            )
            mode_value = mode_row[0] if mode_row else None
            if mode_value is not None:
                df = df.fillna({c: mode_value})
                print(f"   Filled missing values in '{c}' with mode='{mode_value}'")

        print("   Completed categorical imputation.")
        print("   Row count after categorical imputation:", df.count())
    else:
        print(" - No categorical (string) columns found for imputation.")

    # =========================================
    # Removing duplicates
    # =========================================
    print("\nSTEP 2: Removing duplicates")

    dup_keys = choose_duplicate_keys(df)
    if dup_keys:
        print(f" - Using keys for duplicates: {dup_keys}")
        before = df.count()
        df = df.dropDuplicates(dup_keys)
        after = df.count()
        print(f"   Removed {before - after} duplicates")
        print("   Row count after deduplication:", after)
    else:
        print(" - No suitable duplicate keys found, skipping deduplication.")

    # =========================================
    # Outlier filtering
    # =========================================
    print("\nSTEP 3: Filtering outliers using IQR method")

    if numeric_cols:
        df = remove_outliers_iqr(df, numeric_cols)
    else:
        print(" - No numeric columns, skipping outlier filtering.")

    # =========================================
    # Final output
    # =========================================
    print("\nSTEP 4: Pipeline completed for table:", name)
    print("Final schema:")
    df.printSchema()

    # print("\nFinal preview:")
    # df.show(10, truncate=False)
    return df

# =========================================
# Run cleaning for all opened tables in `tables`
# =========================================
cleaned_tables = {}
for table_name, table_df in tables.items():
    cleaned_tables[table_name] = clean_table(table_name, table_df)
print("DONE: cleaning tables.\n")

# =========================================
# Save the cleaned table to BigQuery
# =========================================
for cleaned_table, df in cleaned_tables.items():
    bq_table_name = cleaned_table.lower() + "_cleaned"
    full_table_id = f"{project_id}:{bq_dataset}.{bq_table_name}"  # dejadsgl:netflix.movies_cleaned

    print(f"Writing '{cleaned_table}' to BigQuery table {full_table_id} ...")

    (
        df.write
          .format("bigquery")
          .option("table", full_table_id)
          .mode("overwrite")
          .save()
    )

    print(f"Loaded table: {bq_table_name}")
    df.printSchema()
    print("-" * 60)
    
print("DONE: writing tables to BigQuery.\n")

In [None]:
spark.stop()