In [0]:
raw_df = spark.read.table('airbnb.raw.listings')
display(raw_df)

In [0]:
# Imports
from pyspark.sql import functions as F, types as T
from pyspark.sql import DataFrame
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
from typing import List, Tuple

In [0]:
def is_missing_expr(col_name, dtype, include_empty_str):
    c = F.col(col_name)
    base_null = c.isNull()
    if isinstance(dtype, (T.FloatType, T.DoubleType)):
        cond = base_null | F.isnan(c)
    elif isinstance(dtype, T.StringType):
        cond = base_null | (F.length(F.trim(c)) == 0 if include_empty_str else F.lit(False))
    else:
        cond = base_null
    return F.when(cond, 1).otherwise(0)

def missing_summary(df, columns=None, include_empty_str=True):
    if columns is None:
        columns = df.columns
    columns = [c for c in columns if c in df.columns]
    if not columns:
        raise ValueError("No valid columns were given")

    total_rows = df.count()
    schema_map = {f.name: f.dataType for f in df.schema.fields}

    exprs = [F.sum(is_missing_expr(c, schema_map[c], include_empty_str)).alias(c) for c in columns]
    null_counts_row = df.select(*exprs).collect()[0].asDict()

    rows = []
    for c in columns:
        nulls = int(null_counts_row[c])
        present = total_rows - nulls
        null_pct = (nulls / total_rows * 100.0) if total_rows else 0.0
        rows.append({"column": c, "present_count": present, "null_count": nulls, "null_pct": round(null_pct, 4)})

    return pd.DataFrame(rows).sort_values("null_pct", ascending=False, ignore_index=True)

def missing_matrix(df, columns, max_rows = 5000, include_empty_str = True):
    if columns is None:
        columns = df.columns
    columns = [c for c in columns if c in df.columns]
    if not columns:
        raise ValueError("No valid columns given")

    with_id = df.select(F.monotonically_increasing_id().alias("_rid"), *columns).orderBy("_rid").limit(max_rows)

    schema_map = {f.name: f.dataType for f in df.schema.fields}
    miss_cols = [is_missing_expr(c, schema_map[c], include_empty_str).alias(c) for c in columns]

    bin_df = with_id.select("_rid", *miss_cols).orderBy("_rid").drop("_rid")
    pdf = bin_df.toPandas()
    return pdf

def visualize_missing(df, columns, max_rows = 2000, include_empty_str = True):
    """
    Visualisation that gives basic understanding of amount of pressent and missing data (with % of missing data), plots heatmap showing where missing data is (displays only columns where there are missing values) and bar chart of null % in the data 
    """
    summary_pdf = missing_summary(df, columns=columns, include_empty_str=include_empty_str)
    # Keeping columns that have some missing values
    cols_with_nulls = summary_pdf.loc[summary_pdf["null_count"] > 0, "column"].tolist()

    if not cols_with_nulls:
        print("No missing values detected")
        display(spark.createDataFrame(summary_pdf))
        return summary_pdf

    # Ploting heatmap: grenn - data is present, red - missing data (null value)
    mat_pdf = missing_matrix(df, columns=cols_with_nulls, max_rows=max_rows, include_empty_str=include_empty_str)
    if not mat_pdf.empty:
        mat = mat_pdf.values.astype(np.int8)
        fig1, ax1 = plt.subplots(figsize=(12, 10))
        cmap = ListedColormap(["#3ec245", "#f21818"])
        norm = BoundaryNorm([-0.5, 0.5, 1.5], cmap.N)
        ax1.imshow(mat, aspect="auto", interpolation="nearest", cmap=cmap, norm=norm)
        ax1.set_title(f"Missingness heatmap")
        ax1.set_xlabel("Columns")
        ax1.set_ylabel("Row number")
        ax1.set_xticks(range(len(mat_pdf.columns)))
        ax1.set_xticklabels(mat_pdf.columns, rotation=90)
        plt.tight_layout()
        plt.show()

    # Ploting bar charts with null % per column
    filtered_summary = summary_pdf[summary_pdf["column"].isin(cols_with_nulls)]
    if not filtered_summary.empty:
        fig2, ax2 = plt.subplots(figsize=(12, 8))
        ax2.bar(filtered_summary["column"], filtered_summary["null_pct"])
        ax2.set_title("Null % per column")
        ax2.set_xlabel("Column")
        ax2.set_ylabel("Null percentage (%)")
        ax2.set_xticklabels(filtered_summary["column"], rotation=90)
        plt.tight_layout()
        plt.show()

    display(spark.createDataFrame(summary_pdf))
    return summary_pdf

In [0]:
summary = visualize_missing(raw_df, columns=None, max_rows=15000, include_empty_str=True)

In [0]:
raw_df.count()  

In [0]:
def missing_rate_expr(col: str):
    return (F.count(F.when(F.col(col).isNull(), 1)) / F.count(F.lit(1))).alias(col)

def missing_indicator(col: str):
    return F.when(F.col(col).isNull(), 1).otherwise(0).alias(col + "_miss")

def to_pandas(df, limit=None):
    return (df.limit(limit) if isinstance(limit, int) else df).toPandas()

SAMPLE_FRAC = 0.20  
SEED = 42

In [0]:
schema = raw_df.schema

type_groups = {
    "numeric": [f.name for f in schema if isinstance(f.dataType, (T.IntegerType, T.LongType, T.FloatType, T.DoubleType, T.ShortType, T.DecimalType))],
    "string":  [f.name for f in schema if isinstance(f.dataType, T.StringType)],
    "boolean": [f.name for f in schema if isinstance(f.dataType, T.BooleanType)],
    "date_ts": [f.name for f in schema if isinstance(f.dataType, (T.DateType, T.TimestampType))]
}

stats_rows: List[Tuple[str, float, int]] = []
for tname, cols in type_groups.items():
    if not cols:
        continue
    exprs = [missing_rate_expr(c) for c in cols]
    row = raw_df.agg(*exprs).collect()[0].asDict()
    mean_missing = sum(row.values()) / len(cols)
    stats_rows.append((tname, float(mean_missing), len(cols)))

missing_by_type_df = spark.createDataFrame(stats_rows, ["type", "avg_missing_rate", "n_cols"])
display(missing_by_type_df.orderBy(F.desc("avg_missing_rate")))


In [0]:
pdf = to_pandas(missing_by_type_df.orderBy(F.desc("avg_missing_rate")))
plt.figure(figsize=(6,4))
plt.bar(pdf["type"], pdf["avg_missing_rate"])
plt.title("Average Missing Rate by Feature Type")
plt.xlabel("Type")
plt.ylabel("Avg Missing Rate")
plt.ylim(0, 1)
plt.show()


In [0]:
null_counts = raw_df.select([
    F.count(F.when(F.col(c).isNull(), 1)).alias(c) for c in raw_df.columns
]).collect()[0].asDict()

cols_with_nulls = [c for c, n in null_counts.items() if n and n > 0]

miss_df = raw_df.select([missing_indicator(c) for c in cols_with_nulls])
miss_df_sample = miss_df.sample(withReplacement=False, fraction=SAMPLE_FRAC, seed=SEED)

pdf = miss_df_sample.toPandas()

corr = pdf.corr(numeric_only=True)
corr.index.name = "col"
corr.columns.name = "col"


In [0]:
N = 30
miss_rates = miss_df_sample.agg(*[F.mean(c).alias(c) for c in miss_df_sample.columns]).collect()[0].asDict()
top_cols = [k for k,_ in sorted(miss_rates.items(), key=lambda kv: kv[1], reverse=True)[:N]]

corr_top = corr.loc[top_cols, top_cols].to_numpy()

plt.figure(figsize=(8,7))
im = plt.imshow(corr_top, aspect='auto')
plt.colorbar(im, fraction=0.046, pad=0.04)
plt.xticks(ticks=np.arange(len(top_cols)), labels=top_cols, rotation=90)
plt.yticks(ticks=np.arange(len(top_cols)), labels=top_cols)
plt.title("Co-missingness Correlation (Top N columns)")
plt.tight_layout()
plt.show()


In [0]:
geo_col = "area" if "area" in raw_df.columns else "neighbourhood_cleansed"

geo_missing = (
    raw_df.groupBy(geo_col)
    .agg(
        (F.count(F.when(F.col("license").isNull(), 1)) / F.count(F.lit(1))).alias("license_missing_rate"),
        (F.count(F.when(F.col("review_scores_rating").isNull(), 1)) / F.count(F.lit(1))).alias("review_missing_rate"),
        F.count(F.lit(1)).alias("n")
    )
    .filter(F.col(geo_col).isNotNull())
)

top_geo = geo_missing.orderBy(F.desc("n")).limit(20)
display(top_geo)


In [0]:
pdf_geo = to_pandas(top_geo.orderBy(F.desc("n")))
x = range(len(pdf_geo))

plt.figure(figsize=(10,4))
plt.bar(x, pdf_geo["license_missing_rate"], width=0.4)
plt.xticks(x, pdf_geo[geo_col], rotation=90)
plt.title("License Missing Rate by Area (Top 20 by volume)")
plt.ylabel("Missing Rate")
plt.tight_layout()
plt.show()

plt.figure(figsize=(10,4))
plt.bar(x, pdf_geo["review_missing_rate"], width=0.4)
plt.xticks(x, pdf_geo[geo_col], rotation=90)
plt.title("Review Score Missing Rate by Area (Top 20 by volume)")
plt.ylabel("Missing Rate")
plt.tight_layout()
plt.show()
