In [None]:
from preprocessing.preprocess import Preprocessor, apply_preprocessing_step, apply_correlation_pruning
from raw_data.dataset import Dataset, exclude_indicators_from_dataset, split_dataset
import numpy as np

In [None]:
raw_dataset_path = 'data/dataset' + 'PLACEHOLDER_raw' + 'h5' # replaceme
output_name = 'data/dataset' + 'PLACEHOLDER' # replaceme
ensure_test_set_complete = False
test_set_ratio = 0.2
"""If True, ensures that the test set is complete and does not contain any missing values."""
# if ensuing test set completeness, prints out the countries meant to be used in the test set

# parameters for preprocessing, set to None to skip the step
per_country_missingness_threshold = 0.5
"""Removes indicators if any country has more than this percentage of missing values."""
global_variance_threshold = 0.05
"""Removes indicators if the global variance is below this threshold."""
# joined parameters
country_variance_threshold = 0.05
"""Threshold for variance of indicators per country."""
country_variance_percentage = 0.5
"""Percentage of countries that must have variance above the threshold."""
pearson_correlation_threshold = 0.8
"""Threshold for Pearson correlation between indicators."""

In [None]:
# Load the selected raw dataset
raw_dataset = Dataset.load(raw_dataset_path)
original_indicator_count = raw_dataset.n_indicators()
print(f"Raw dataset contains {original_indicator_count} indicators.")
print(f"Raw dataset contains {raw_dataset.n_countries()} countries and {raw_dataset.n_years()} years.")

# forward declaring a deep copy such that individual steps could be omitted
preprocessed_dataset: Dataset = raw_dataset.copy()
prev_indicator_count = raw_dataset.n_indicators()

if per_country_missingness_threshold:
    # Remove indicators with missing values above a threshold
    preprocessed_dataset, prev_indicator_count = apply_preprocessing_step(
        dataset=preprocessed_dataset,
        func=Preprocessor.remove_missing_per_country,
        func_kwargs={"threshold": per_country_missingness_threshold},
        description_template="Removed indicators for missing above a per country threshold of {threshold}.",
        prev_count=prev_indicator_count
    )

if global_variance_threshold:
    # Remove indicators whose global variance is below a threshold
    preprocessed_dataset, prev_indicator_count = apply_preprocessing_step(
        dataset=preprocessed_dataset,
        func=Preprocessor.remove_variance,
        func_kwargs={"threshold": global_variance_threshold},
        description_template="Removed indicators for global variance below {threshold}.",
        prev_count=prev_indicator_count
    )

if country_variance_threshold and country_variance_percentage:
    # Remove indicators with too many constant countries
    preprocessed_dataset, prev_indicator_count = apply_preprocessing_step(
        dataset=preprocessed_dataset,
        func=Preprocessor.remove_constant,
        func_kwargs={"variance_threshold": country_variance_threshold, "country_threshold": country_variance_percentage},
        description_template="Removed indicators with too many constant countries (variance < {variance_threshold} and min active countries < {country_threshold}).",
        prev_count=prev_indicator_count
    )

if pearson_correlation_threshold:
    # Removal based on correlation
    # Making a copy of the dataset to avoid modifying the original to avoid leakage
    print("Removing indicators based on correlation...")
    pruned_dataset = apply_correlation_pruning(preprocessed_dataset, threshold=pearson_correlation_threshold)

    removed_indicators = set(preprocessed_dataset.indicators) - set(pruned_dataset.indicators)

    print(f"Removed {len(removed_indicators)} indicators based on correlation pruning. Remaining indicators: {pruned_dataset.n_indicators()}")

    preprocessed_dataset = exclude_indicators_from_dataset(preprocessed_dataset, list(removed_indicators))

if ensure_test_set_complete:
    # iterate over all countries, and find the country with least amount of missing values, to use as test set
    country_missing_values = {} # country -> (total missing values ratio, number of indicators with missing values)
    for country in preprocessed_dataset.countries:
        country_data = preprocessed_dataset.extract_country(country)
        total_missing_values = np.sum(np.isnan(country_data))
        indicators_with_missing_values = np.sum(np.isnan(country_data).any(axis=1))
        country_missing_values[country] = (total_missing_values / country_data.size, indicators_with_missing_values)
    
    n_countries_to_take = int(preprocessed_dataset.n_countries() * test_set_ratio)
    sorted_countries = sorted(country_missing_values.items(), key=lambda x: x[1][0])
    least_missing_countries = [country for country, _ in sorted_countries[:int(n_countries_to_take)]]
    print(f"Countries with least missing values: {least_missing_countries}")
    # indicators with missing values in these countries
    indicators_with_missing_values = set()
    for country in least_missing_countries:
        country_data = preprocessed_dataset.extract_country(country)
        missing_indicators = np.isnan(country_data).any(axis=1)
        # print(missing_indicators.shape)
        missing_indicators = np.where(missing_indicators)[0]
        # print(missing_indicators)
        missing_indicators = preprocessed_dataset.indicators[missing_indicators]
        # print(missing_indicators)
        indicators_with_missing_values.update(missing_indicators)
        # print(np.sum(missing_indicators))
        # print(missing_indicators)
    print(f'Excluding {len(indicators_with_missing_values)} indicators with missing values in the least missing countries: {indicators_with_missing_values}')
    preprocessed_dataset = exclude_indicators_from_dataset(preprocessed_dataset, list(indicators_with_missing_values))
    print(f"Remaining indicators after excluding those with missing values in the least missing countries: {preprocessed_dataset.n_indicators()}")

preprocessed_dataset.save(output_name, format='hdf5')