In [None]:
# Libraries
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql import Row

# 1.1: Calculate proportions for the original dataset
def calculate_proportions_org(regitra_transeksta_final):
    """
    Calculate proportions based on vehicle type and explanation.

    :param regitra_transeksta_final: PySpark DataFrame with vehicle data.
    :return: DataFrame with total counts and proportions.
    """
    total_count = regitra_transeksta_final.count()
    count_with_rida = regitra_transeksta_final.filter(F.col('rida_per_metus').isNotNull()).count()

    # Group by vehicle type and explanation, then calculate total and valid counts that has milleage
    proportions_df = regitra_transeksta_final.groupBy('transporto_priemones_tipas', 'transporto_priemones_paaiskinimas') \
        .agg(
            F.count('*').alias('total_count'),
            F.count(F.when(F.col('rida_per_metus').isNotNull(), 1)).alias('count_with_rida')
        )

    # Calculate proportions as percentages of total and valid counts
    proportions_df = proportions_df.withColumn(
        'proportion_total',
        F.round((F.col('total_count') / total_count) * 100, 2)
    ).withColumn(
        'proportion_with_rida',
        F.round((F.col('count_with_rida') / count_with_rida) * 100, 2)
    )

    return proportions_df

# 1.2: Extract unique combinations of bus makes and models
def autobusu_markes_org(regitra_transeksta_final):
    """
    Filter vehicle type 'K4' (buses) and retrieve unique combinations of 'cleaned_marke' and 'modelis',
    along with record counts and percentages.

    :param regitra_transeksta_final: PySpark DataFrame with vehicle data.
    :return: Sorted DataFrame with unique values of 'cleaned_marke' and 'modelis', including counts and percentages.
    """

    filtered_df = regitra_transeksta_final.filter(F.col('transporto_priemones_tipas') == 'K4')

    # Group by 'cleaned_marke' and 'cleaned_modelis' and count occurrences
    grouped_df = filtered_df.groupBy('cleaned_marke', 'cleaned_modelis') \
                            .agg(F.count('*').alias('count'))

    total_count = filtered_df.count()

    # Add a percentage column to represent the proportion of each combination
    result_df = grouped_df.withColumn('percentage', F.round((F.col('count') / total_count) * 100, 2))

    sorted_df = result_df.orderBy(F.col('count').desc())

    return sorted_df



# 1.3: Calculate proportions based on fuel type
def calculate_proportions_with_fuel_org(regitra_transeksta_final):
    """
    Calculate proportions considering fuel types by combining 'degalai' and 'papildomi_degalai_1'.

    :param regitra_transeksta_final: PySpark DataFrame with vehicle data.
    :return: DataFrame with proportions grouped by vehicle type, explanation, and fuel combination.
    """

    # Filter out rows where 'transporto_priemones_tipas' is NULL
    regitra_transeksta_final = regitra_transeksta_final.filter(F.col('transporto_priemones_tipas').isNotNull())

    # Create a new column combining 'degalai' and 'papildomi_degalai_1' if both exist
    regitra_transeksta_final = regitra_transeksta_final.withColumn(
        'fuel_combined',
        F.when(F.col('papildomi_degalai_1').isNotNull(), F.concat(F.col('degalai'), F.lit(', '), F.col('papildomi_degalai_1')))
        .otherwise(F.col('degalai'))
    )

    # Get total record count and count of non-NULL 'rida_per_metus' (milleage per year)
    total_count = regitra_transeksta_final.count()
    count_with_rida = regitra_transeksta_final.filter(F.col('rida_per_metus').isNotNull()).count()

    # Group by vehicle type, explanation, and fuel combination; calculate total and valid counts
    proportions_df = regitra_transeksta_final.groupBy('transporto_priemones_tipas', 'transporto_priemones_paaiskinimas', 'fuel_combined') \
        .agg(
            F.count('*').alias('total_count'),
            F.count(F.when(F.col('rida_per_metus').isNotNull(), 1)).alias('count_with_rida')
        )

    # Add columns for proportions as percentages
    proportions_df = proportions_df.withColumn(
        'proportion_total',
        F.round((F.col('total_count') / total_count) * 100, 4)
    ).withColumn(
        'proportion_with_rida',
        F.round((F.col('count_with_rida') / count_with_rida) * 100, 4)
    )

    # Add a summary row for each fuel type
    total_rows = regitra_transeksta_final.groupBy('fuel_combined') \
        .agg(
            F.count('*').alias('total_count'),
            F.count(F.when(F.col('rida_per_metus').isNotNull(), 1)).alias('count_with_rida')
        ) \
        .withColumn('transporto_priemones_tipas', F.lit('K0')) \
        .withColumn('transporto_priemones_paaiskinimas', F.lit('Iš viso'))

    # Add proportion columns to the summary rows
    total_rows = total_rows.withColumn(
        'proportion_total',
        F.round((F.col('total_count') / total_count) * 100, 4)
    ).withColumn(
        'proportion_with_rida',
        F.round((F.col('count_with_rida') / count_with_rida) * 100, 4)
    )

    # Combine the detailed proportions with the summary rows
    proportions_df = proportions_df.unionByName(total_rows)

    # Sort by vehicle type and fuel combination
    proportions_df = proportions_df.orderBy('transporto_priemones_tipas', 'fuel_combined')

    return proportions_df


# 2.0: Outlier detection using IQR 1.5 method
def outlier_detection(regitra_transeksta_final):
    """
    Detects outliers in 'rida_per_metus' column using the IQR method.

    :param regitra_transeksta_final: PySpark DataFrame with vehicle data and milleage information.
    :return: DataFrame with an additional column 'outlier_rida_per_tipa' marking outliers (1 for outlier, 0 otherwise).
    """
    # Filter rows where 'rida_per_metus' is not NULL
    filtered_df = regitra_transeksta_final.filter(F.col('rida_per_metus').isNotNull())

    # Define a window specification partitioned by 'tp_pavadinimas'
    window_spec = Window.partitionBy('tp_pavadinimas')

    # Calculate Q1 (25th percentile)
    q1_df = filtered_df.withColumn(
        'Q1',
        F.expr('percentile_approx(rida_per_metus, 0.25)').over(window_spec)
    )

    # Calculate Q3 (75th percentile)
    q3_df = q1_df.withColumn(
        'Q3',
        F.expr('percentile_approx(rida_per_metus, 0.75)').over(window_spec)
    )

    # Calculate the IQR (Q3 - Q1)
    iqr_df = q3_df.withColumn(
        'IQR',
        F.col('Q3') - F.col('Q1')
    )

    # Calculate lower and upper bounds for outlier detection
    iqr_boundaries_df = iqr_df.withColumn(
        'lower_bound',
        F.col('Q1') - 1.5 * F.col('IQR')
    ).withColumn(
        'upper_bound',
        F.col('Q3') + 1.5 * F.col('IQR')
    )

    # Mark rows as outliers (1) or not (0) in the 'outlier_rida_per_tipa' column
    outlier_df = iqr_boundaries_df.withColumn(
        'outlier_rida_per_tipa',
        F.when(
            (F.col('rida_per_metus') < F.col('lower_bound')) |
            (F.col('rida_per_metus') > F.col('upper_bound')), 1
        ).otherwise(0)
    )

    return outlier_df


# 2.1 Count outliers grouped by vehicle type
def outliers_results(outlier_detection):
    """
    Groups data by 'tp_pavadinimas' and 'outlier_rida_per_tipa' and counts occurrences.

    :param outlier_detection: DataFrame with outlier detection results.
    :return: DataFrame with grouped counts of outliers and non-outliers.
    """
    # Group by vehicle type and outlier flag, then count occurrences
    outliers_count_df = outlier_detection.groupBy('tp_pavadinimas', 'outlier_rida_per_tipa') \
        .agg(F.count('outlier_rida_per_tipa').alias('count')) \
        .orderBy('tp_pavadinimas', 'outlier_rida_per_tipa')

    return outliers_count_df



# 3.0 Set outliers to NULL
def set_outliers_to_null(outlier_detection):
    """
    Sets 'rida_per_diena' and 'rida_per_metus' values to NULL for rows marked as outliers.

    :param outlier_detection: DataFrame with outlier detection results.
    :return: Updated DataFrame with outlier values set to NULL.
    """
    # Replace values with NULL for rows where 'outlier_rida_per_tipa' equals 1
    updated_df = outlier_detection.withColumn(
        'rida_per_diena',
        F.when(F.col('outlier_rida_per_tipa') == 1, None).otherwise(F.col('rida_per_diena'))
    ).withColumn(
        'rida_per_metus',
        F.when(F.col('outlier_rida_per_tipa') == 1, None).otherwise(F.col('rida_per_metus'))
    )

    return updated_df


# 3.1 Calculate proportions after removing outliers
def calculate_proportions(set_outliers_to_null):
    """
    Calculates proportions based on vehicle type and explanation after removing outliers.

    :param set_outliers_to_null: DataFrame with outlier values set to NULL.
    :return: DataFrame with calculated proportions.
    """
    # Calculate the total number of records
    total_count = set_outliers_to_null.count()

    # Count records where 'rida_per_metus' is not NULL
    count_with_rida = set_outliers_to_null.filter(F.col('rida_per_metus').isNotNull()).count()

    # Group by vehicle type and explanation, then calculate proportions
    proportions_df = set_outliers_to_null.groupBy('transporto_priemones_tipas', 'transporto_priemones_paaiskinimas') \
        .agg(
            F.count('*').alias('total_count'),
            F.count(F.when(F.col('rida_per_metus').isNotNull(), 1)).alias('count_with_rida')
        )

    # Calculate proportions as percentages
    proportions_df = proportions_df.withColumn(
        'proportion_total',
        F.round((F.col('total_count') / total_count) * 100, 2)
    ).withColumn(
        'proportion_with_rida',
        F.round((F.col('count_with_rida') / count_with_rida) * 100, 2)
    )

    return proportions_df


# 3.2 Extract bus makes and models with mileage
def autobusu_markes_su_rida(set_outliers_to_null):
    """
    Filters vehicle type 'K4' (buses) and retrieves unique combinations of 'cleaned_marke', 'modelis', and 'degalai',
    including record counts and percentages.

    :param set_outliers_to_null: DataFrame with outlier values set to NULL.
    :return: Sorted DataFrame with unique bus make-model-fuel combinations and percentages.
    """
    # Filter for vehicle type 'K4' (buses)
    filtered_df = set_outliers_to_null.filter(F.col('transporto_priemones_tipas') == 'K4')

    # Group by 'cleaned_marke', 'cleaned_modelis', and 'degalai', then count occurrences
    grouped_df = filtered_df.groupBy('cleaned_marke', 'cleaned_modelis', 'degalai') \
                            .agg(F.count('*').alias('count'))

    # Calculate the total number of records in the filtered dataset
    total_count = filtered_df.count()

    # Add a percentage column to represent the proportion of each combination
    result_df = grouped_df.withColumn('percentage', F.round((F.col('count') / total_count) * 100, 2))

    # Sort by count in descending order
    sorted_df = result_df.orderBy(F.col('count').desc())

    return sorted_df

# 3.3  Trolleybuses unique models and makers with fuel type and record counts and percentages
def troleibusu_markes_su_rida(set_outliers_to_null):
    """
    Retrieve unique combinations of 'cleaned_marke', 'modelis', and 'degalai' for vehicle type 'K5'
    (trolleybuses), including record counts and percentages. Sort results by count in descending order.

    :param set_outliers_to_null: PySpark DataFrame with processed vehicle data.
    :return: Sorted DataFrame with unique combinations of 'cleaned_marke', 'modelis', and 'degalai',
             including record counts and percentages.
    """
    # Filter the dataset for vehicle type 'K5' (trolleybuses)
    filtered_df = set_outliers_to_null.filter(F.col('transporto_priemones_tipas') == 'K5')

    # Group by 'cleaned_marke', 'cleaned_modelis', and 'degalai', calculating counts
    grouped_df = filtered_df.groupBy('cleaned_marke', 'cleaned_modelis', 'degalai') \
                            .agg(F.count('*').alias('count'))

    # Calculate the total number of records in the filtered dataset
    total_count = filtered_df.count()

    # Add a column for percentage of each unique combination
    result_df = grouped_df.withColumn('percentage', F.round((F.col('count') / total_count) * 100, 2))

    # Sort by count in descending order
    sorted_df = result_df.orderBy(F.col('count').desc())

    return sorted_df

#3.4 Proportions for records based on vehicle type, fuel type

def calculate_proportions_with_fuel(set_outliers_to_null):
    """
    Calculate proportions for records based on vehicle type, explanation, and fuel type combination
    ('degalai' and 'papildomi_degalai_1'). Includes handling of missing values.

    :param set_outliers_to_null: PySpark DataFrame with processed vehicle data.
    :return: DataFrame with proportions grouped by vehicle type, explanation, and combined fuel types.
    """
    # Remove records where 'transporto_priemones_tipas' is NULL
    set_outliers_to_null = set_outliers_to_null.filter(F.col('transporto_priemones_tipas').isNotNull())

    # Create a new column combining 'degalai' and 'papildomi_degalai_1' if both exist
    set_outliers_to_null = set_outliers_to_null.withColumn(
        'fuel_combined',
        F.when(F.col('papildomi_degalai_1').isNotNull(), F.concat(F.col('degalai'), F.lit(', '), F.col('papildomi_degalai_1')))
        .otherwise(F.col('degalai'))
    )

    # Calculate total record count and count of non-NULL 'rida_per_metus'
    total_count = set_outliers_to_null.count()
    count_with_rida = set_outliers_to_null.filter(F.col('rida_per_metus').isNotNull()).count()

    # Group by vehicle type, explanation, and fuel combination; calculate counts
    proportions_df = set_outliers_to_null.groupBy('transporto_priemones_tipas', 'transporto_priemones_paaiskinimas', 'fuel_combined') \
        .agg(
            F.count('*').alias('total_count'),
            F.count(F.when(F.col('rida_per_metus').isNotNull(), 1)).alias('count_with_rida')
        )

    # Add proportions as percentages
    proportions_df = proportions_df.withColumn(
        'proportion_total',
        F.round((F.col('total_count') / total_count) * 100, 4)
    ).withColumn(
        'proportion_with_rida',
        F.round((F.col('count_with_rida') / count_with_rida) * 100, 4)
    )

    # Add summary row ('K0') for each fuel type
    total_rows = set_outliers_to_null.groupBy('fuel_combined') \
        .agg(
            F.count('*').alias('total_count'),
            F.count(F.when(F.col('rida_per_metus').isNotNull(), 1)).alias('count_with_rida')
        ) \
        .withColumn('transporto_priemones_tipas', F.lit('K0')) \
        .withColumn('transporto_priemones_paaiskinimas', F.lit('Total'))

    # Add proportions to the summary rows
    total_rows = total_rows.withColumn(
        'proportion_total',
        F.round((F.col('total_count') / total_count) * 100, 4)
    ).withColumn(
        'proportion_with_rida',
        F.round((F.col('count_with_rida') / count_with_rida) * 100, 4)
    )

    # Combine detailed proportions with summary rows
    proportions_df = proportions_df.unionByName(total_rows)

    # Sort by vehicle type and combined fuel type
    proportions_df = proportions_df.orderBy('transporto_priemones_tipas', 'fuel_combined')

    return proportions_df


# 3.5 Check for missing values
def check_na_values(set_outliers_to_null):
    """
    Check for missing (NULL) values in selected columns. Also calculates the percentage of missing
    values in each column.

    :param set_outliers_to_null: PySpark DataFrame with processed vehicle data after handling outliers.
    :return: DataFrame with NA counts and percentages for each selected column.
    """
    # Columns to check for NA values
    columns_to_check = ['cleaned_marke', 'cleaned_modelis', 'transporto_priemones_tipas', 'variklio_turis', 'galia', 'degalai', 'nuosava_mase', 'rida_per_metus']

    # Ensure the columns exist in the dataset
    existing_columns = [col for col in columns_to_check if col in set_outliers_to_null.columns]

    # Total record count
    total_count = set_outliers_to_null.count()

    # Generate NA count and percentage for each column
    na_count_expressions = []
    for col in existing_columns:
        na_count_expressions.append(F.count(F.when(F.col(col).isNull(), col)).alias(f'{col}_na_count'))
        na_count_expressions.append(
            (F.round((F.count(F.when(F.col(col).isNull(), col)) / total_count) * 100, 2)).alias(f'{col}_na_percentage')
        )

    # Create a DataFrame with the NA results
    na_count_df = set_outliers_to_null.select(*na_count_expressions)

    return na_count_df

# 3.6 Remove rows with missing data in specified columns
def delete_na_values(set_outliers_to_null):
    """
    Remove rows with missing (NULL) values in specified columns.

    :param set_outliers_to_null: PySpark DataFrame with processed vehicle data after handling outliers.
    :return: Cleaned DataFrame with rows containing NULL values in specified columns removed.
    """

    columns_to_check = ['cleaned_marke', 'cleaned_modelis', 'transporto_priemones_tipas', 'variklio_turis', 'galia', 'degalai', 'nuosava_mase', 'rida_per_metus', 'pag_metai']

    # Drop rows with NULL values in the specified columns
    df_cleaned = set_outliers_to_null.dropna(subset=columns_to_check)

    return df_cleaned

# 3.7 Correlation analysis of milleage and vehicle owner age
def calculate_correlation_amzius_rida(set_outliers_to_null):
    """
    Calculate correlation between 'rida_per_metus' and 'amzius'. Filters out NA values before calculation.

    :param set_outliers_to_null: PySpark DataFrame with processed vehicle data after handling outliers.
    :return: Spark DataFrame with correlation result.
    """
    # Filter rows with non-NULL 'rida_per_metus' and 'amzius'
    filtered_df = set_outliers_to_null.filter(F.col('rida_per_metus').isNotNull() & F.col('amzius').isNotNull())

    # Calculate correlation
    correlation_value = filtered_df.stat.corr('rida_per_metus', 'amzius')

    # Return correlation as a Spark DataFrame
    result_df = set_outliers_to_null.sparkSession.createDataFrame([Row(correlation_rida_amzius=correlation_value)])

    return result_df


# 3.8 Correlation analysis of mileage, age for each municipality
def calculate_correlation_with_savivaldybe(set_outliers_to_null):
    """
    Calculates the correlation between 'rida_per_metus' (mileage per year) and 'amzius' (age)
    separately for each municipality ('savivaldybe'). It removes NA values before calculation.

    :param set_outliers_to_null: PySpark DataFrame with vehicle data.
    :return: DataFrame with correlation results by municipality.
    """
    # Columns to check for correlation
    corr_columns = ['rida_per_metus', 'amzius']

    # Municipality column
    savivaldybe_col = 'savivaldybe'

    # Filter out rows with null values in the columns used for correlation
    filtered_df = set_outliers_to_null.filter(F.col(corr_columns[0]).isNotNull() & F.col(corr_columns[1]).isNotNull())

    # Get all unique municipalities
    unique_savivaldybes = filtered_df.select(savivaldybe_col).distinct().collect()

    # Create an empty list to store correlation results
    correlation_results = []

    # Loop through each municipality and calculate the correlation
    for savivaldybe in unique_savivaldybes:
        savivaldybe_name = savivaldybe[savivaldybe_col]

        # Filter data for the specific municipality
        filtered_savivaldybe_df = filtered_df.filter(F.col(savivaldybe_col) == savivaldybe_name)

        # Calculate correlation between 'rida_per_metus' and 'amzius'
        correlation_value = filtered_savivaldybe_df.stat.corr(corr_columns[0], corr_columns[1])

        # Append the results to the list
        correlation_results.append((savivaldybe_name, correlation_value))

    # Convert the results to a DataFrame
    results_df = set_outliers_to_null.sparkSession.createDataFrame(correlation_results, [savivaldybe_col, 'correlation_rida_per_metus_amzius'])

    return results_df

#3.6.1
def categorize_savivaldybe(delete_na_values):
    """
    Categorizes 'savivaldybe' (municipality) into three categories: 'Rajono' (rural),
    'Miesto' (city), and 'Didmiesčio' (metropolitan).

    :param delete_na_values: PySpark DataFrame.
    :return: DataFrame with an additional column 'savivaldybe_category'.
    """

    # Define a mapping function for categorization
    def savivaldybe_category(savivaldybe):
        if 'R. SAV.' in savivaldybe:
            return 'Rajono'
        elif savivaldybe in ['VILNIAUS M. SAV.', 'KAUNO M. SAV.']:
            return 'Didmiesčio'
        else:
            return 'Miesto'

    # Register the function as a UDF
    categorize_udf = F.udf(savivaldybe_category, T.StringType())

    # Apply the UDF to the dataframe
    categorized_df = delete_na_values.withColumn('savivaldybe_category', categorize_udf(F.col('savivaldybe')))

    return categorized_df



# 3.6.1.1 Correlation analysis by municipality category
def calculate_correlation_with_savivaldybe_category(categorize_savivaldybe):
    """
    Calculates the correlation between 'rida_per_metus' and 'amzius' (age),
    separately for each municipality category.

    :param categorize_savivaldybe: PySpark DataFrame with categorized municipality data.
    :return: DataFrame with correlation results by municipality category.
    """

    # Columns to check for correlation
    corr_columns = ['rida_per_metus', 'amzius']

    # Municipality category column
    savivaldybe_col = 'savivaldybe_category'

    # Filter out rows with null values in the columns used for correlation
    filtered_df = categorize_savivaldybe.filter(F.col(corr_columns[0]).isNotNull() & F.col(corr_columns[1]).isNotNull())

    # Get all unique municipality categories
    unique_savivaldybes = filtered_df.select(savivaldybe_col).distinct().collect()

    # Create an empty list to store correlation results
    correlation_results = []

    # Loop through each category and calculate the correlation
    for savivaldybe in unique_savivaldybes:
        savivaldybe_name = savivaldybe[savivaldybe_col]

        # Filter data for the specific category
        filtered_savivaldybe_df = filtered_df.filter(F.col(savivaldybe_col) == savivaldybe_name)

        # Calculate correlation between 'rida_per_metus' and 'amzius'
        correlation_value = filtered_savivaldybe_df.stat.corr(corr_columns[0], corr_columns[1])

        # Append the results to the list
        correlation_results.append((savivaldybe_name, correlation_value))

    # Convert the results to a DataFrame
    results_df = categorize_savivaldybe.sparkSession.createDataFrame(correlation_results, [savivaldybe_col, 'correlation_rida_per_metus_amzius'])

    return results_df

# 3.6.1.2 Categorize age
def categorize_amzius(categorize_savivaldybe):
    """
    Function to categorize the 'amzius' column into different age groups,
    create a numerical category for each group, and add explanations for each category.

    :param categorize_savivaldybe: Input PySpark DataFrame with an 'amzius' column.
    :return: Updated PySpark DataFrame with 'amzius_categories', 'amzius_numerical',
             and 'amzius_categories_paaiskinimas' columns.
    """

    # Define the conditions for age categories
    categorize_savivaldybe = categorize_savivaldybe.withColumn(
        'amzius_categories',
        F.when((F.col('amzius') >= 18) & (F.col('amzius') < 25), '18 to 25')
         .when((F.col('amzius') >= 25) & (F.col('amzius') < 30), '25 to 30')
         .when((F.col('amzius') >= 30) & (F.col('amzius') < 40), '30 to 40')
         .when((F.col('amzius') >= 40) & (F.col('amzius') < 50), '40 to 50')
         .when((F.col('amzius') >= 50) & (F.col('amzius') < 60), '50 to 60')
         .when((F.col('amzius') >= 60) & (F.col('amzius') < 70), '60 to 70')
         .otherwise('more than 70')
    )

    # Map the categories to numerical values
    categorize_savivaldybe = categorize_savivaldybe.withColumn(
        'amzius_numerical',
        F.when(F.col('amzius_categories') == '18 to 25', 1)
         .when(F.col('amzius_categories') == '25 to 30', 2)
         .when(F.col('amzius_categories') == '30 to 40', 3)
         .when(F.col('amzius_categories') == '40 to 50', 4)
         .when(F.col('amzius_categories') == '50 to 60', 5)
         .when(F.col('amzius_categories') == '60 to 70', 6)
         .otherwise(7)
    )

    # Add the explanation (_paaiskinimas) for each category
    categorize_savivaldybe = categorize_savivaldybe.withColumn(
        'amzius_categories_paaiskinimas',
        F.when(F.col('amzius_categories') == '18 to 25', 'Young Adults')
         .when(F.col('amzius_categories') == '25 to 30', 'Adults starting career')
         .when(F.col('amzius_categories') == '30 to 40', 'Established Adults')
         .when(F.col('amzius_categories') == '40 to 50', 'Middle-aged Adults')
         .when(F.col('amzius_categories') == '50 to 60', 'Pre-retirement Adults')
         .when(F.col('amzius_categories') == '60 to 70', 'Early Retirement')
         .otherwise('Seniors')
    )

    return categorize_savivaldybe

# 3.6.1.3 Correlation analysis of mileage and age
def calculate_correlation_with_savivaldybe(categorize_amzius):
    """
    Calculate correlation between 'rida_per_metus' and 'amzius' for each municipality.

    :param categorize_amzius: PySpark DataFrame containing 'rida_per_metus', 'amzius', and 'savivaldybe'.
    :return: DataFrame with correlations for each municipality.
    """
    # Columns to analyze correlation
    corr_columns = ['rida_per_metus', 'amzius']
    # Municipality column
    savivaldybe_col = 'savivaldybe'

    # Filter rows where 'rida_per_metus' or 'amzius' is NULL
    filtered_df = categorize_amzius.filter(
        F.col(corr_columns[0]).isNotNull() & F.col(corr_columns[1]).isNotNull()
    )

    # Get all unique municipalities
    unique_savivaldybes = filtered_df.select(savivaldybe_col).distinct().collect()

    # List to store correlation results
    correlation_results = []

    # Iterate over each municipality and calculate correlation
    for savivaldybe in unique_savivaldybes:
        savivaldybe_name = savivaldybe[savivaldybe_col]
        # Filter for the current municipality
        filtered_savivaldybe_df = filtered_df.filter(F.col(savivaldybe_col) == savivaldybe_name)
        # Calculate correlation
        correlation_value = filtered_savivaldybe_df.stat.corr(corr_columns[0], corr_columns[1])
        # Append the results
        correlation_results.append((savivaldybe_name, correlation_value))

    # Convert results to a DataFrame
    results_df = categorize_amzius.sparkSession.createDataFrame(
        correlation_results, [savivaldybe_col, 'correlation_rida_per_metus_amzius']
    )

    return results_df



# 3.6.1.4 Categorizing ages into groups
def categorize_age(categorize_savivaldybe):
    """
    Categorize 'amzius' into age groups and add numerical categories.

    :param categorize_savivaldybe: PySpark DataFrame containing 'amzius'.
    :return: Updated DataFrame with 'amzius_categories' and 'amzius_numerical'.
    """
    # Define age group conditions
    categorize_savivaldybe = categorize_savivaldybe.withColumn(
        'amzius_categories',
        F.when((F.col('amzius') >= 18) & (F.col('amzius') < 25), '18 to 25')
        .when((F.col('amzius') >= 25) & (F.col('amzius') < 30), '25 to 30')
        .when((F.col('amzius') >= 30) & (F.col('amzius') < 40), '30 to 40')
        .when((F.col('amzius') >= 40) & (F.col('amzius') < 50), '40 to 50')
        .when((F.col('amzius') >= 50) & (F.col('amzius') < 60), '50 to 60')
        .when((F.col('amzius') >= 60) & (F.col('amzius') < 70), '60 to 70')
        .otherwise('70 and above')
    )

    # Assign numerical values to categories
    categorize_savivaldybe = categorize_savivaldybe.withColumn(
        'amzius_numerical',
        F.when(F.col('amzius_categories') == '18 to 25', 1)
        .when(F.col('amzius_categories') == '25 to 30', 2)
        .when(F.col('amzius_categories') == '30 to 40', 3)
        .when(F.col('amzius_categories') == '40 to 50', 4)
        .when(F.col('amzius_categories') == '50 to 60', 5)
        .when(F.col('amzius_categories') == '60 to 70', 6)
        .otherwise(7)
    )

    return categorize_savivaldybe


# Clustering method:
# 4.1. K means clustering method for grouped millage and municipality category to get k and silhouette scores:
from pyspark.ml.clustering import KMeans
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.evaluation import ClusteringEvaluator
from pyspark.sql import functions as F
from pyspark.ml import Pipeline
from pyspark.sql.window import Window
from pyspark.sql import SparkSession

def train_kmeans_on_grouped_ridas(sample_dataset_final, k_min=2, k_max=10):
    """
    Trains optimized K-means clustering on grouped mileage data by vehicle type and municipality category,
    calculating average mileage for each cluster and average fuel consumption.

    :param sample_dataset_final: PySpark DataFrame with columns 'rida_per_metus', 'transporto_priemones_tipas',
                                  'savivaldybe_category', and others.
    :param k_min: Minimum number of clusters (default 2).
    :param k_max: Maximum number of clusters (default 10).
    :return: PySpark DataFrame with Silhouette scores for each k, average mileage, fuel consumption, and grouping information by clusters.
    """

    # Check if required columns are present
    required_columns = ['rida_per_metus', 'amzius', 'transporto_priemones_tipas', 'savivaldybe_category']
    for col in required_columns:
        if col not in sample_dataset_final.columns:
            raise ValueError(f"Column '{col}' is missing from the dataset.")

    # Filter rows with missing values in required columns
    sample_dataset_final = sample_dataset_final.filter(
        (F.col('rida_per_metus').isNotNull()) &
        (F.col('amzius').isNotNull()) &
        (F.col('transporto_priemones_tipas').isNotNull()) &
        (F.col('savivaldybe_category').isNotNull())
    )

    # Group data by vehicle type and municipality category, calculate mean values for clustering features
    grouped_data = sample_dataset_final.groupBy('transporto_priemones_tipas', 'savivaldybe_category')\
        .agg(F.mean('rida_per_metus').alias('avg_rida_per_metus'),
             F.mean('amzius').alias('avg_amzius'))

    # Create a feature vector from the grouped average values
    assembler = VectorAssembler(inputCols=['avg_rida_per_metus', 'avg_amzius'], outputCol='features', handleInvalid='skip')
    dataset = assembler.transform(grouped_data)

    # Normalize features
    scaler = StandardScaler(inputCol='features', outputCol='scaled_features')

    # Optimize K-means clustering
    silhouette_scores = []
    best_silhouette = float('-inf')
    best_k = None
    best_model = None
    best_train_predictions = None

    for k in range(k_min, k_max + 1):
        kmeans = KMeans(k=k, seed=1, featuresCol='scaled_features', predictionCol='cluster')
        pipeline = Pipeline(stages=[scaler, kmeans])
        model = pipeline.fit(dataset)

        # Apply the model to the data
        predictions = model.transform(dataset)

        # Calculate Silhouette score
        evaluator = ClusteringEvaluator(predictionCol='cluster', featuresCol='scaled_features')
        silhouette_train = evaluator.evaluate(predictions)
        silhouette_scores.append((k, silhouette_train))

        # Select the best k
        if silhouette_train > best_silhouette:
            best_silhouette = silhouette_train
            best_k = k
            best_model = model
            best_train_predictions = predictions

    # Create PySpark DataFrame with Silhouette scores
    spark = SparkSession.builder.getOrCreate()
    silhouette_df = spark.createDataFrame(silhouette_scores, schema=["k", "silhouette_score"])

    print(f"Best Silhouette Score: {best_silhouette} with {best_k} clusters.")
    return silhouette_df


from pyspark.sql import functions as F
from pyspark.ml.feature import VectorAssembler, MinMaxScaler
from pyspark.ml.regression import RandomForestRegressor
from pyspark.sql import SparkSession
from pyspark.sql.window import Window

spark = SparkSession.builder.getOrCreate()

# 5.1. Calculate fuel consumption
def calculate_fuel_consumption(sample_dataset_final_with_predictions):
    """
    Calculate fuel consumption by multiplying 'rida_per_metus' (annual mileage) with the predicted fuel consumption per 100km ('pred_kuro_sunaudojimas_l100km').

    :param sample_dataset_final_with_predictions: PySpark DataFrame containing the 'prediction' column.
    :return: PySpark DataFrame with renamed 'prediction' column to 'pred_kuro_sunaudojimas_l100km'
             and added 'kuro_suvartojimas_lt' column for total fuel consumption.
    """

    # Rename 'prediction' column to 'pred_kuro_sunaudojimas_l100km'
    sample_dataset_final_with_predictions = sample_dataset_final_with_predictions.withColumnRenamed(
        'prediction', 'pred_kuro_sunaudojimas_l100km'
    )

    # Calculate total fuel consumption by multiplying mileage by predicted fuel consumption per 100km
    sample_dataset_final_with_predictions = sample_dataset_final_with_predictions.withColumn(
        'kuro_suvartojimas_lt',
        (F.col('rida_per_metus') / 100) * F.col('pred_kuro_sunaudojimas_l100km')
    )
    return sample_dataset_final_with_predictions

# 5.1.1 Total fuel consumption on sample. that has all information needed
def Results_sample(calculate_fuel_consumption):
    """
    Aggregate total fuel consumption by vehicle type, vehicle description, and fuel type,
    and sort by vehicle type in correct numerical order.

    :param calculate_fuel_consumption: PySpark DataFrame with calculated fuel consumption ('kuro_suvartojimas_lt').
    :return: Aggregated PySpark DataFrame with total fuel consumption per 'transporto_priemones_tipas',
             'transporto_priemones_paaiskinimas', and 'degalai'.
    """

    # Add a numeric column derived from the vehicle type for sorting
    kuro_suvartojimas = kuro_suvartojimas.withColumn(
        'transporto_priemones_tipas_skaitmenine',
        F.regexp_extract(F.col('transporto_priemones_tipas'), '\d+', 0).cast(T.IntegerType())
    )

    # Aggregate fuel consumption by vehicle type, description, and fuel type
    agreguoti_rezultatai = kuro_suvartojimas.groupBy(
        'transporto_priemones_tipas', 'transporto_priemones_paaiskinimas', 'degalai'
    ).agg(
        F.sum('kuro_suvartojimas_lt').alias('total_kuro_suvartojimas_lt')
    )

    # Round total fuel consumption to zero decimal places
    agreguoti_rezultatai = agreguoti_rezultatai.withColumn(
        'total_kuro_suvartojimas_lt', F.round(F.col('total_kuro_suvartojimas_lt'), 0)
    )

    # Sort results by numeric vehicle type and fuel type
    agreguoti_rezultatai = agreguoti_rezultatai.join(
        kuro_suvartojimas.select('transporto_priemones_tipas', 'transporto_priemones_tipas_skaitmenine').distinct(),
        on='transporto_priemones_tipas',
        how='left'
    ).orderBy('transporto_priemones_tipas_skaitmenine', 'degalai')

    return agreguoti_rezultatai.drop('transporto_priemones_tipas_skaitmenine')

# 5.1.2 Rezults on sample with additional information
from pyspark.sql import functions as F
from pyspark.sql import types as T

def results_with_unit_avg_mileage(calculate_fuel_consumption):
    """
    Aggregate total fuel consumption by vehicle type, description, and fuel type, including additional metrics:
    Number of vehicles, average fuel consumption per vehicle, average mileage per year

    :param calculate_fuel_consumption: PySpark DataFrame with calculated fuel consumption and predictions.
    :return: Aggregated PySpark DataFrame with total fuel consumption, vehicle count,
             average fuel consumption, and average mileage.
    """

    # Add a numeric column derived from the vehicle type for sorting
    kuro_suvartojimas = kuro_suvartojimas.withColumn(
        'transporto_priemones_tipas_skaitmenine',
        F.regexp_extract(F.col('transporto_priemones_tipas'), '\d+', 0).cast(T.IntegerType())
    )

    # Aggregate fuel consumption, vehicle count, and averages
    agreguoti_rezultatai = kuro_suvartojimas.groupBy(
        'transporto_priemones_tipas', 'transporto_priemones_paaiskinimas', 'degalai'
    ).agg(
        F.sum('kuro_suvartojimas_lt').alias('total_kuro_suvartojimas_lt'),
        F.count('transporto_priemones_tipas').alias('transporto_priemoniu_kiekis'),
        F.avg('pred_kuro_sunaudojimas_l100km').alias('vid_kuro_sunaudojimas_100km'),
        F.avg('rida_per_metus').alias('vid_rida_per_metus')
    )

    # Round results for better readability
    agreguoti_rezultatai = agreguoti_rezultatai.withColumn(
        'total_kuro_suvartojimas_lt', F.round(F.col('total_kuro_suvartojimas_lt'), 0)
    ).withColumn(
        'vid_kuro_sunaudojimas_100km', F.round(F.col('vid_kuro_sunaudojimas_100km'), 2)
    ).withColumn(
        'vid_rida_per_metus', F.round(F.col('vid_rida_per_metus'), 0)
    )

    # Sort results by numeric vehicle type and fuel type
    agreguoti_rezultatai = agreguoti_rezultatai.join(
        kuro_suvartojimas.select('transporto_priemones_tipas', 'transporto_priemones_tipas_skaitmenine').distinct(),
        on='transporto_priemones_tipas',
        how='left'
    ).orderBy('transporto_priemones_tipas_skaitmenine', 'degalai')

    return agreguoti_rezultatai.drop('transporto_priemones_tipas_skaitmenine')

def proportions_on_fuel(regitra_transeksta_final):
    """
    Calculates proportions of vehicle records by fuel type and vehicle type for the entire dataset.

    :param regitra_transeksta_final: PySpark DataFrame containing vehicle data.
    :return: PySpark DataFrame with proportions calculated for each fuel type and vehicle type.
    """
    # Remove records where 'transporto_priemones_tipas' is NULL
    regitra_transeksta_final = regitra_transeksta_final.filter(F.col('transporto_priemones_tipas').isNotNull())

    # Calculate the total number of records
    total_count = regitra_transeksta_final.count()

    # Calculate the number of records where 'rida_per_metus' is not NULL
    count_with_rida = regitra_transeksta_final.filter(F.col('rida_per_metus').isNotNull()).count()

    # Group by vehicle type, vehicle description, and fuel type, and count records
    proportions_df = regitra_transeksta_final.groupBy('transporto_priemones_tipas', 'transporto_priemones_paaiskinimas', 'degalai') \
        .agg(
            F.count('*').alias('total_count'),
            F.count(F.when(F.col('rida_per_metus').isNotNull(), 1)).alias('count_with_rida')
        )

    # Calculate proportions as percentages with four decimal places
    proportions_df = proportions_df.withColumn(
        'proportion_total',
        F.round((F.col('total_count') / total_count) * 100, 4)
    ).withColumn(
        'proportion_with_rida',
        F.round((F.col('count_with_rida') / count_with_rida) * 100, 4)
    )

    # Add aggregate row ('K0') for each fuel type
    total_rows = regitra_transeksta_final.groupBy('degalai') \
        .agg(
            F.count('*').alias('total_count'),
            F.count(F.when(F.col('rida_per_metus').isNotNull(), 1)).alias('count_with_rida')
        ) \
        .withColumn('transporto_priemones_tipas', F.lit('K0')) \
        .withColumn('transporto_priemones_paaiskinimas', F.lit('Iš viso'))

    # Calculate proportions for total rows
    total_rows = total_rows.withColumn(
        'proportion_total',
        F.round((F.col('total_count') / total_count) * 100, 4)
    ).withColumn(
        'proportion_with_rida',
        F.round((F.col('count_with_rida') / count_with_rida) * 100, 4)
    )

    # Combine individual proportions with total rows
    proportions_df = proportions_df.unionByName(total_rows)

    # Sort the final result by vehicle type and fuel type
    proportions_df = proportions_df.orderBy('transporto_priemones_tipas', 'degalai')

    return proportions_df

# Final fuel consumption on Lithuania fleet
from pyspark.sql import functions as F

def calculate_final_fuel_consumption(results_with_unit_avg_mileage, proportions_on_fuel):
      """
    Adjusts fuel consumption sample proportions to match population proportions and calculates the total fuel consumption for the population.

    :param results_with_unit_avg_mileage: PySpark DataFrame with sample data including fuel consumption and vehicle counts.
    :param proportions_on_fuel: PySpark DataFrame with population proportions for vehicle and fuel types.
    :return: PySpark DataFrame with adjusted fuel consumption for the population.
    """

    # Select relevant columns and exclude 'K0' aggregate row
    Proporcijos_ant_degalu = Proporcijos_ant_degalu.select(
        'transporto_priemones_tipas', 'degalai', 'total_count', 'proportion_total'
    ).filter(F.col('transporto_priemones_tipas') != 'K0')

    # Convert population proportions to a range of 0-1
    Proporcijos_ant_degalu = Proporcijos_ant_degalu.withColumn(
        'proportion_total', F.col('proportion_total') / 100
    )

    # Calculate the total number of vehicles in the sample
    imtis_sum = Rezultatai_su_vnt_vid_rida.agg(F.sum('transporto_priemoniu_kiekis').alias('total_imties_kiekis')).collect()[0][0]

    # Calculate sample proportions
    imties_proporcijos = Rezultatai_su_vnt_vid_rida.withColumn(
        'imties_proporcija',
        F.col('transporto_priemoniu_kiekis') / imtis_sum
    )

    # Join sample proportions with population proportions
    sujungti_duomenys = imties_proporcijos.join(
        Proporcijos_ant_degalu,
        on=['transporto_priemones_tipas', 'degalai'],
        how='right'
    )

    # Fill NULL values with defaults
    sujungti_duomenys = sujungti_duomenys.fillna({
        'transporto_priemoniu_kiekis': 0,
        'total_kuro_suvartojimas_lt': 0,
        'imties_proporcija': 0
    })

    # Filter rows with vehicle counts greater than zero
    sujungti_duomenys = sujungti_duomenys.filter(F.col('transporto_priemoniu_kiekis') > 0)

    # Adjust vehicle counts to match population proportions
    sujungti_duomenys = sujungti_duomenys.withColumn(
        'koreguotas_transporto_priemoniu_kiekis',
        F.col('proportion_total') * imtis_sum
    )

    # Adjust fuel consumption to align with new proportions
    sujungti_duomenys = sujungti_duomenys.withColumn(
        'koreguotas_kuro_suvartojimas',
        F.col('total_kuro_suvartojimas_lt') * (F.col('proportion_total') / F.col('imties_proporcija'))
    )

    # Calculate scaling factor to match population totals
    sujungti_duomenys = sujungti_duomenys.withColumn(
        'padidinimo_koeficientas',
        F.col('total_count') / F.col('koreguotas_transporto_priemoniu_kiekis')
    )

    # Adjust fuel consumption by scaling factor
    sujungti_duomenys = sujungti_duomenys.withColumn(
        'galutinis_kuro_suvartojimas_populiacijai',
        F.col('koreguotas_kuro_suvartojimas') * F.col('padidinimo_koeficientas')
    )

    # Select final columns for the output
    galutiniai_rezultatai = sujungti_duomenys.select(
        'transporto_priemones_tipas',
        'degalai',
        'koreguotas_transporto_priemoniu_kiekis',
        'koreguotas_kuro_suvartojimas',
        'padidinimo_koeficientas',
        'galutinis_kuro_suvartojimas_populiacijai'
    )

    return galutiniai_rezultatai
