In [1]:
def calculate_vsi(training_data, out_of_sample_data, feature_column, num_bins=10):
    """
    Calculate the Variable Stability Index (VSI) of the dataset.

    Parameters:
    - conf (dict): Dictionary containing configuration parameters.
    - spark (SparkSession): Spark session object.
    - training_dataset (DataFrame): DataFrame containing the training dataset.
    - out_of_sample_dataset (DataFrame): DataFrame containing the out-of-sample dataset.
    - remove_column (str): Name of the column to be removed from the feature list.

    Returns:
    - df_vsi (DataFrame): DataFrame containing VSI values for each feature.
    """
    # Calculate the bin ranges for the variable
    min_value = training_data.agg(f.min(feature_column)).collect()[0][0]
    max_value = training_data.agg(f.max(feature_column)).collect()[0][0]
    bin_size = (max_value - min_value) / num_bins
    # bins = [min_value + i * bin_size for i in range(num_bins)] + [max_value]

    # Calculate the distribution of the variable in each dataset
    training_data = training_data.select(feature_column).withColumn(
        "bin",
        f.when(f.col(feature_column) == max_value, num_bins - 1).otherwise(
            f.floor((f.col(feature_column) - min_value) / bin_size)
        ),
    )

    training_data_dist = training_data.groupBy("bin").agg(
        f.count("*").alias("count_bin_train")
    )

    out_of_sample_data = out_of_sample_data.select(feature_column).withColumn(
        "bin",
        f.when(f.col(feature_column) == max_value, num_bins - 1).otherwise(
            f.floor((f.col(feature_column) - min_value) / bin_size)
        ),
    )

    out_of_sample_data_dist = out_of_sample_data.groupBy("bin").agg(
        f.count("*").alias("count_bin_oos")
    )

    # Calculate the percentage of total observations for each datas
    total_count_training = training_data.count()
    total_count_out_of_sample = out_of_sample_data.count()
    training_data_dist = training_data_dist.withColumn(
        "percent_bin_train", f.col("count_bin_train") / total_count_training
    )
    out_of_sample_data_dist = out_of_sample_data_dist.withColumn(
        "percent_bin_oos", f.col("count_bin_oos") / total_count_out_of_sample
    )

    final_df = training_data_dist.join(out_of_sample_data_dist, on="bin", how="left")

    # Calculate the PSI for each bin of the variable
    final_df = final_df.withColumn(
        "psi",
        f.when(
            (f.col("percent_bin_train") == 0) | (f.col("percent_bin_oos") == 0), 0
        ).otherwise(
            (f.col("percent_bin_oos") - f.col("percent_bin_train"))
            * f.log(f.col("percent_bin_oos") / f.col("percent_bin_train"))
        ),
    )
    vsi = final_df.select(f.sum("psi")).collect()[0][0]

    return vsi
