In [42]:
from pyspark.sql import SparkSession
from pyspark import SparkConf
from pyspark.sql.functions import avg, col, count, desc

# =========================================
# 0. START SPARK SESSION
# =========================================

# Configuration
project_id = "dejadsgl"
bq_dataset = "netflix"
temp_bucket = "netflix-group5-temp_gl"
gcs_data_bucket = "netflix_data_25"

# Spark configuration
sparkConf = SparkConf()
sparkConf.setMaster("spark://spark-master:7077")
sparkConf.setAppName("SparkCleanDataset")
sparkConf.set("spark.driver.memory", "2g")
sparkConf.set("spark.executor.cores", "1")
sparkConf.set("spark.driver.cores", "1")

# Create the Spark session
spark = SparkSession.builder.config(conf=sparkConf).getOrCreate()

# Use the Cloud Storage bucket for temporary BigQuery export data used by the connector
spark.conf.set('temporaryGcsBucket', temp_bucket)

# Setup hadoop fs configuration for schema gs://
conf = spark.sparkContext._jsc.hadoopConfiguration()
conf.set("fs.gs.impl", "com.google.cloud.hadoop.fs.gcs.GoogleHadoopFileSystem")
conf.set("fs.AbstractFileSystem.gs.impl", "com.google.cloud.hadoop.fs.gcs.GoogleHadoopFS")

print("Spark session started.")

# =========================================
# 1. LOAD ALL TABLES
# =========================================

# Load data from BigQuery
tables = {}
titles = [
    "Movies",
    "Users",
    "Watch_history",
    "Reviews"
]

for title in titles:
    df = spark.read \
            .format("bigquery") \
            .load(f"{project_id}.{bq_dataset}.{title}")

    df.cache()
    tables[title] = df   # store in dictionary

    print(f"\nLoaded table: {title}")
    df.printSchema()

print("Done.")


Spark session started.

Loaded table: Movies
root
 |-- movie_id: string (nullable = true)
 |-- title: string (nullable = true)
 |-- content_type: string (nullable = true)
 |-- genre_primary: string (nullable = true)
 |-- genre_secondary: string (nullable = true)
 |-- release_year: long (nullable = true)
 |-- duration_minutes: double (nullable = true)
 |-- rating: string (nullable = true)
 |-- language: string (nullable = true)
 |-- country_of_origin: string (nullable = true)
 |-- imdb_rating: double (nullable = true)
 |-- production_budget: double (nullable = true)
 |-- box_office_revenue: double (nullable = true)
 |-- number_of_seasons: double (nullable = true)
 |-- number_of_episodes: double (nullable = true)
 |-- is_netflix_original: boolean (nullable = true)
 |-- added_to_platform: date (nullable = true)


Loaded table: Users
root
 |-- user_id: string (nullable = true)
 |-- email: string (nullable = true)
 |-- first_name: string (nullable = true)
 |-- last_name: string (nullable = 

In [45]:
from pyspark.sql import functions as F
from pyspark.sql.types import (IntegerType, LongType, FloatType, DoubleType, DecimalType, StringType)
from pyspark.ml.feature import Imputer


def choose_duplicate_keys(df):
    """
    Try to infer a reasonable key for dropDuplicates.
    """
    candidates = [
        ("CustomerID", "InvoiceDate"),            # retail example
        ("user_id", "timestamp"),
        ("userId", "timestamp"),
        ("userId", "movieId", "timestamp"),
    ]
    for keys in candidates:
        if all(k in df.columns for k in keys):
            return list(keys)
    return None


def remove_outliers_iqr(df, cols, k=1.5):
    for c in cols:
        print(f"   - Processing outliers for numeric column '{c}'")

        # skip if column is all nulls
        non_null = df.select(F.count(F.col(c))).first()[0]
        if non_null == 0:
            print(f"     Skipping '{c}' (no non-null values).")
            continue

        try:
            q1, q3 = df.approxQuantile(c, [0.25, 0.75], 0.01)
        except Exception as e:
            print(f"     Skipping '{c}' (approxQuantile error: {e})")
            continue

        iqr = q3 - q1
        lower = q1 - k * iqr
        upper = q3 + k * iqr

        before = df.count()
        df = df.filter((F.col(c) >= lower) & (F.col(c) <= upper))
        after = df.count()

        print(f"     Removed {before - after} outliers from '{c}'")
        print(f"     New row count: {after}")

    return df


def clean_table(name, df):
    print(f"\n========== Cleaning table: {name} ==========\n")
    print("STEP 0: Starting data-cleaning pipeline...")
    print("Initial row count:", df.count())

    # Detect numeric and categorical columns for this table
    numeric_cols = [
        f.name for f in df.schema.fields
        if isinstance(f.dataType, (IntegerType, LongType, FloatType, DoubleType, DecimalType))
    ]
    categorical_cols = [
        f.name for f in df.schema.fields
        if isinstance(f.dataType, StringType)
    ]

    # =========================================
    # 1. Handling missing values (imputation)
    # =========================================
    print("\nSTEP 1: Handling missing values")

    # a) Numerical columns
    if numeric_cols:
        print(" - Imputing numerical columns:", numeric_cols)

        imputer = Imputer(
            inputCols=numeric_cols,
            outputCols=[c + "_imputed" for c in numeric_cols]
        ).setStrategy("median")

        df = imputer.fit(df).transform(df)

        for c in numeric_cols:
            df = df.drop(c).withColumnRenamed(c + "_imputed", c)

        print("   Completed numerical imputation.")
        print("   Row count after numeric imputation:", df.count())
    else:
        print(" - No numerical columns found for imputation.")

    # b) Categorical columns
    if categorical_cols:
        print(" - Imputing categorical columns:", categorical_cols)

        for c in categorical_cols:
            mode_row = (
                df.groupBy(c)
                  .count()
                  .orderBy(F.desc("count"))
                  .first()
            )
            mode_value = mode_row[0] if mode_row else None
            if mode_value is not None:
                df = df.fillna({c: mode_value})
                print(f"   Filled missing values in '{c}' with mode='{mode_value}'")

        print("   Completed categorical imputation.")
        print("   Row count after categorical imputation:", df.count())
    else:
        print(" - No categorical (string) columns found for imputation.")

    # =========================================
    # 2. Removing duplicates
    # =========================================
    print("\nSTEP 2: Removing duplicates")

    dup_keys = choose_duplicate_keys(df)
    if dup_keys:
        print(f" - Using keys for duplicates: {dup_keys}")
        before = df.count()
        df = df.dropDuplicates(dup_keys)
        after = df.count()
        print(f"   Removed {before - after} duplicates")
        print("   Row count after deduplication:", after)
    else:
        print(" - No suitable duplicate keys found, skipping deduplication.")

    # =========================================
    # 3. Outlier filtering
    # =========================================
    print("\nSTEP 3: Filtering outliers using IQR method")

    if numeric_cols:
        df = remove_outliers_iqr(df, numeric_cols)
    else:
        print(" - No numeric columns, skipping outlier filtering.")

    # =========================================
    # 4. Final output
    # =========================================
    print("\nSTEP 4: Pipeline completed for table:", name)
    print("Final schema:")
    df.printSchema()

    # print("\nFinal preview:")
    # df.show(10, truncate=False)
    return df

# ---------------------------------------------------
# Run cleaning for all opened tables in `tables`
# ---------------------------------------------------
cleaned_tables = {}
for table_name, table_df in tables.items():
    cleaned_tables[table_name] = clean_table(table_name, table_df)
print("Done.\n")




STEP 0: Starting data-cleaning pipeline...
Initial row count: 1040

STEP 1: Handling missing values
 - Imputing numerical columns: ['release_year', 'duration_minutes', 'imdb_rating', 'production_budget', 'box_office_revenue', 'number_of_seasons', 'number_of_episodes']
   Completed numerical imputation.
   Row count after numeric imputation: 1040
 - Imputing categorical columns: ['movie_id', 'title', 'content_type', 'genre_primary', 'genre_secondary', 'rating', 'language', 'country_of_origin']
   Filled missing values in 'movie_id' with mode='movie_0823'
   Filled missing values in 'title' with mode='A Adventure'
   Filled missing values in 'content_type' with mode='Movie'
   Filled missing values in 'genre_primary' with mode='Adventure'
   Filled missing values in 'rating' with mode='TV-Y'
   Filled missing values in 'language' with mode='English'
   Filled missing values in 'country_of_origin' with mode='USA'
   Completed categorical imputation.
   Row count after categorical imputa

In [56]:
# Write aggregated results to BigQuery
for cleaned_table, df in cleaned_tables.items():
    print(f"Writing '{cleaned_table}_cleaned' aggregation to BigQuery...")
    
    df.write.format('bigquery') \
        .option('table', f'{project_id}.{bq_dataset}.{cleaned_table}_cleaned') \
        .mode("overwrite") \
        .save()

    print(f"\nLoaded table: {cleaned_table}_cleaned")
    df.printSchema()

print("\nAll aggregations written to BigQuery successfully!")

Writing 'Movies_cleaned' aggregation to BigQuery...

Loaded table: Movies_cleaned
root
 |-- movie_id: string (nullable = false)
 |-- title: string (nullable = false)
 |-- content_type: string (nullable = false)
 |-- genre_primary: string (nullable = false)
 |-- genre_secondary: string (nullable = true)
 |-- rating: string (nullable = false)
 |-- language: string (nullable = false)
 |-- country_of_origin: string (nullable = false)
 |-- is_netflix_original: boolean (nullable = true)
 |-- added_to_platform: date (nullable = true)
 |-- release_year: long (nullable = true)
 |-- duration_minutes: double (nullable = true)
 |-- imdb_rating: double (nullable = true)
 |-- production_budget: double (nullable = true)
 |-- box_office_revenue: double (nullable = true)
 |-- number_of_seasons: double (nullable = true)
 |-- number_of_episodes: double (nullable = true)

Writing 'Users_cleaned' aggregation to BigQuery...

Loaded table: Users_cleaned
root
 |-- user_id: string (nullable = false)
 |-- email

In [57]:
spark.stop()