In [1]:
import ipywidgets as widgets
from IPython.display import display, clear_output
import numpy as np
import pandas as pd
import pickle
from imblearn.under_sampling import ClusterCentroids, NearMiss, TomekLinks, OneSidedSelection
from imblearn.over_sampling import SMOTE, ADASYN
import warnings
from sklearn.utils import resample

warnings.filterwarnings('ignore')

# Fixed random state for reproducibility
RANDOM_STATE = 42

# Global variable to store the sampled DataFrame
new_df = None

# Define custom samplers outside of the function
class StratifiedSampler:
    def __init__(self, sampling_strategy, strata_features, random_state=RANDOM_STATE):
        self.sampling_strategy = sampling_strategy
        self.strata_features = list(strata_features)  # Convert tuple to list
        self.random_state = random_state

    def fit_resample(self, X, y):
        # Ensure we are selecting the correct columns from X as a DataFrame
        strata = X[self.strata_features]  # Use list here to avoid KeyError
        return resample(
            X, y,
            stratify=strata.values.flatten(),
            n_samples=int(self.sampling_strategy * len(y)),
            random_state=self.random_state
        )

    def __repr__(self):
        return f"StratifiedSampler(sampling_strategy={self.sampling_strategy}, strata_features={self.strata_features}, random_state={self.random_state})"


class RandomSampler:
    def __init__(self, sampling_strategy, random_state=RANDOM_STATE):
        self.sampling_strategy = sampling_strategy
        self.random_state = random_state

    def fit_resample(self, X, y):
        return resample(
            X, y,
            n_samples=int(self.sampling_strategy * len(y)),
            random_state=self.random_state
        )

    def __repr__(self):
        return f"RandomSampler(sampling_strategy={self.sampling_strategy}, random_state={self.random_state})"

# Getter for the global `new_df`
def get_sampled_data():
    """Return the globally saved sampled DataFrame."""
    global new_df
    if new_df is None:
        raise ValueError("No sampled dataset is available. Please run sampling and save the results first.")
    return new_df
    

# Function to run sampling UI
def run_sampling_ui(X, y):
    global new_df, result_log, current_sampling, selected_features
    new_df = None
    selected_features = []
    result_log = pd.DataFrame(columns=["Method", "Sampling %", "Original Size", "Resampled Size", 
                                       "Original Dist (Abs)", "Original Dist (%)", 
                                       "Resampled Dist (Abs)", "Resampled Dist (%)"])

    # Initialize current sampling method
    current_sampling = None

    # Show total features and dataset size
    dataset_size_label = widgets.Label(f"Dataset Size: {X.shape[0]} samples, Number of Features: {X.shape[1]}")

    # Create main buttons
    under_sampling_button = widgets.Button(description="Under Sampling", tooltip="Reduce the majority class samples.")
    over_sampling_button = widgets.Button(description="Over Sampling", tooltip="Increase the minority class samples.")
    stratified_sampling_button = widgets.Button(description="Stratified Sampling", tooltip="Sample based on strata features.")
    random_sampling_button = widgets.Button(description="Random Sampling", tooltip="Randomly sample the dataset.")

    # Sampling percentage slider
    sampling_percentage_slider = widgets.FloatSlider(description='Sampling %', min=0.1, max=1.0, step=0.1, value=1.0)

    # Display area for sub buttons, output, and table
    sub_buttons_box = widgets.VBox([])  # This will hold the sub-method buttons
    output_box = widgets.Output()  # This will display outputs like error messages
    table_output_box = widgets.Output()  # This will display the sampling result log table

    # Function to save the sampler as a .pkl file and create new_df from selected method
    def save_sampler(method_name, sampling_percentage):
        global new_df
        if method_name == "Stratified Sampling":
            sampler = get_sampler(method_name, sampling_percentage, selected_features=selected_features)
        elif method_name == "Random Sampling":
            sampler = get_sampler(method_name, sampling_percentage)
        else:
            sampler = get_sampler(method_name, sampling_percentage)

        if sampler is None:
            with output_box:
                clear_output(wait=True)
                print(f"Error: Sampler for {method_name} not found.")
            return

        # Perform resampling using fit_resample
        X_res, y_res = sampler.fit_resample(X, y)
        new_df = pd.concat([pd.DataFrame(X_res), pd.Series(y_res, name='target')], axis=1)


        # Save the sampler to a pickle file
        save_path = f"sampler_{method_name}_{int(sampling_percentage * 100)}.pkl"
        with open(save_path, "wb") as f:
            pickle.dump(sampler, f)

        with output_box:
            clear_output(wait=True)
            print(f"Sampler for {method_name} at {sampling_percentage * 100:.0f}% saved as '{save_path}'.")
            print("New sampled dataset 'new_df' created for further use.")
            

    # Function to clear sub buttons but keep output intact
    def clear_sub_buttons():
        sub_buttons_box.children = []

    # Function to calculate and display class distribution with percentages and absolute numbers
    def display_distribution(y_before, y_after, method_name, sampling_percentage):
        original_distribution = dict(zip(*np.unique(y_before, return_counts=True)))
        resampled_distribution = dict(zip(*np.unique(y_after, return_counts=True)))

        total_before = sum(original_distribution.values())
        total_after = sum(resampled_distribution.values())

        original_percentages = {k: (v / total_before) * 100 for k, v in original_distribution.items()}
        resampled_percentages = {k: (v / total_after) * 100 for k, v in resampled_distribution.items()}

        # Append the new sampling result to the result_log DataFrame
        result_log.loc[len(result_log)] = [
            method_name,
            f"{sampling_percentage * 100:.0f}%",
            total_before,
            total_after,
            str(original_distribution),
            str({cls: f"{pct:.2f}%" for cls, pct in original_percentages.items()}),
            str(resampled_distribution),
            str({cls: f"{pct:.2f}%" for cls, pct in resampled_percentages.items()})
        ]

        # Display the updated result_log DataFrame in the table_output_box
        with table_output_box:
            clear_output(wait=True)
            display(result_log)

            # Display save buttons for each row in the result_log
            for idx in range(len(result_log)):
                method_name = result_log.iloc[idx]["Method"]
                sampling_percentage = float(result_log.iloc[idx]["Sampling %"].strip('%')) / 100
                save_button = widgets.Button(description=f"Save {method_name} Sampler ({int(sampling_percentage * 100)}%)")
                save_button.on_click(lambda b, m=method_name, sp=sampling_percentage: save_sampler(m, sp))
                display(save_button)


    # Function to retrieve the sampler based on the method name and sampling percentage
    def get_sampler(method_name, sampling_percentage, selected_features=None):
        if method_name == 'Cluster Centroids':
            return ClusterCentroids(sampling_strategy=sampling_percentage, random_state=RANDOM_STATE)
        elif method_name == 'Near Miss':
            return NearMiss(sampling_strategy=sampling_percentage)
        elif method_name == 'Tomek Links':
            return TomekLinks()
        elif method_name == 'One-Sided Selection':
            return OneSidedSelection(random_state=RANDOM_STATE)
        elif method_name == 'SMOTE':
            return SMOTE(sampling_strategy=sampling_percentage, random_state=RANDOM_STATE)
        elif method_name == 'ADASYN':
            return ADASYN(sampling_strategy=sampling_percentage, random_state=RANDOM_STATE)
        elif method_name == 'Stratified Sampling' and selected_features:
            return StratifiedSampler(sampling_strategy=sampling_percentage, strata_features=selected_features)
        elif method_name == 'Random Sampling':
            return RandomSampler(sampling_strategy=sampling_percentage)

    # Under-sampling methods with error handling
    def under_sampling_method(method_name, sampling_percentage):
        try:
            sampler = get_sampler(method_name, sampling_percentage)
            X_res, y_res = sampler.fit_resample(X, y)
            display_distribution(y, y_res, method_name, sampling_percentage)
        except ValueError as e:
            with output_box:
                clear_output(wait=True)
                print(f"Error in under-sampling with {method_name}: {str(e)}")
                print("Please adjust the sampling percentage and try again.")

    # Over-sampling methods with error handling
    def over_sampling_method(method_name, sampling_percentage):
        try:
            sampler = get_sampler(method_name, sampling_percentage)
            X_res, y_res = sampler.fit_resample(X, y)
            display_distribution(y, y_res, method_name, sampling_percentage)
        except ValueError as e:
            with output_box:
                clear_output(wait=True)
                print(f"Error in over-sampling with {method_name}: {str(e)}")
                print("Please adjust the sampling percentage and try again.")

    # Stratified sampling method
    def stratified_sampling_method(sampling_percentage, selected_features):
        try:
            # Ensure selected_features is a list of strings
            selected_features = list(selected_features) if isinstance(selected_features, tuple) else selected_features

            strata = X[selected_features]  # This will check if the features exist in X
            print("Strata Features Confirmed:", strata.columns.tolist())  # Debug: Confirm feature names

            X_res, y_res = resample(
                X, y, stratify=strata.values if len(selected_features) > 1 else strata.values.flatten(),
                n_samples=int(sampling_percentage * len(y)), random_state=RANDOM_STATE
            )
            display_distribution(y, y_res, "Stratified Sampling", sampling_percentage)
        except KeyError as e:
            with output_box:
                clear_output(wait=True)
                print(f"Error in stratified sampling: {str(e)}")
                print("Please adjust the feature selection and try again.")


    # Random sampling method
    def random_sampling_method(sampling_percentage):
        try:
            sampler = get_sampler("Random Sampling", sampling_percentage)
            X_res, y_res = sampler.fit_resample(X, y)
            display_distribution(y, y_res, "Random Sampling", sampling_percentage)
        except ValueError as e:
            with output_box:
                clear_output(wait=True)
                print(f"Error in random sampling: {str(e)}")
                print("Please adjust the sampling percentage and try again.")

    # Function to handle main method selection
    def on_main_button_clicked(button):
        global current_sampling
        clear_sub_buttons()

        # Show the sampling percentage slider when a method is selected
        sampling_percentage_slider.layout.display = 'inline-block'

        if button.description == "Under Sampling":
            current_sampling = "Under Sampling"
            sub_buttons = [
                widgets.Button(description="Cluster Centroids", tooltip="This method reduces the majority class by selecting the most representative samples (centroids)."),
                widgets.Button(description="Near Miss", tooltip="This method reduces the majority class by selecting samples closest to the minority class."),
                widgets.Button(description="Tomek Links", tooltip="This method removes overlapping majority and minority class pairs to clean the data."),
                widgets.Button(description="One-Sided Selection", tooltip="This method removes majority class samples that are easily classified, keeping only boundary cases.")
            ]
            for sub_button in sub_buttons:
                sub_button.on_click(lambda b: under_sampling_method(b.description, sampling_percentage_slider.value))
            sub_buttons_box.children = sub_buttons
        elif button.description == "Over Sampling":
            current_sampling = "Over Sampling"
            sub_buttons = [
                widgets.Button(description="SMOTE", tooltip="This method creates new synthetic samples for the minority class using k-nearest neighbors."),
                widgets.Button(description="ADASYN", tooltip="This method generates synthetic samples for the minority class, focusing more on difficult cases.")
            ]
            for sub_button in sub_buttons:
                sub_button.on_click(lambda b: over_sampling_method(b.description, sampling_percentage_slider.value))
            sub_buttons_box.children = sub_buttons
        elif button.description == "Stratified Sampling":
            current_sampling = "Stratified Sampling"
            strata_feature_selection = widgets.SelectMultiple(
                options=X.columns.tolist(),
                description="Select Strata Features"
            )
            finalize_button = widgets.Button(description="Finalize Strata Features")
            finalize_button.on_click(lambda b: finalize_stratified_sampling(strata_feature_selection.value))
            sub_buttons_box.children = [strata_feature_selection, finalize_button]
        elif button.description == "Random Sampling":
            current_sampling = "Random Sampling"
            random_sampling_button = widgets.Button(description="Run Random Sampling")
            random_sampling_button.on_click(lambda b: random_sampling_method(sampling_percentage_slider.value))
            sub_buttons_box.children = [random_sampling_button]

            
    # Display features for stratified sampling
    def show_stratified_features():
        global current_sampling
        current_sampling = 'stratified'

        # Multi-selection widget for choosing features
        feature_selector = widgets.SelectMultiple(
            options=X.columns.tolist(),
            description='Select Features',
            disabled=False
        )

        # Button to finalize the selection of features
        finalize_button = widgets.Button(description="Finalize Selection")

        # Corrected line: ensure selected features are passed as a list
        finalize_button.on_click(lambda b: stratified_sampling_method(sampling_percentage_slider.value, list(feature_selector.value)))

        # Display feature selector and finalize button
        sub_buttons_box.children = (feature_selector, finalize_button, sampling_percentage_slider)


    # Function to finalize stratified sampling selection
    def finalize_stratified_sampling(features):
        global selected_features
        selected_features = features
        if selected_features:
            sampling_percentage_slider.layout.display = 'inline-block'
            stratified_sampling_run_button = widgets.Button(description="Run Stratified Sampling")
            stratified_sampling_run_button.on_click(lambda b: stratified_sampling_method(sampling_percentage_slider.value, selected_features))
            sub_buttons_box.children = [stratified_sampling_run_button]

    # Attach main button click events
    under_sampling_button.on_click(on_main_button_clicked)
    over_sampling_button.on_click(on_main_button_clicked)
    stratified_sampling_button.on_click(on_main_button_clicked)
    random_sampling_button.on_click(on_main_button_clicked)

    # Arrange layout and display widgets
    display(dataset_size_label, widgets.HBox([
        under_sampling_button, over_sampling_button, stratified_sampling_button, random_sampling_button
    ]), sampling_percentage_slider, sub_buttons_box, output_box, table_output_box)
    
    
    return new_df

