In [None]:
import pandas as pd
import numpy as np
from sklearn.manifold import MDS
from geopy.distance import geodesic
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression, Lasso, Ridge, ElasticNet
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.tree import DecisionTreeRegressor
from sklearn.svm import SVR
from xgboost import XGBRegressor
from sklearn.neighbors import KNeighborsRegressor
from sklearn.metrics import mean_squared_error, r2_score
import warnings

# Suppress convergence warnings
warnings.filterwarnings("ignore", category=UserWarning, module="sklearn.linear_model._coordinate_descent")

# Define configurations for models and tasks
model_config = {
    'bloom': {
        'sizes': ['560M', '1b1', '1b7', '3b', '7b1'],
        'n_parameters' : [560, 1100, 1700, 3000, 7100],
        'jaccard_matrix': 'distances/jaccard_similarity/jaccard_matrix_bloom.csv',
        'train_data_col': 'Bloom Train Data Percentage'
    },
    'bloomz': {
        'sizes': ['560M', '1b1', '1b7', '3b', '7b1'],
        'n_parameters' : [560, 1100, 1700, 3000, 7100],
        'jaccard_matrix': 'distances/jaccard_similarity/jaccard_matrix_bloom.csv',
        'train_data_col': 'Bloom Train Data Percentage',
        'finetune_col': 'BLOOMZ Finetune Data'
    },
    'xglm': {
        'sizes': ['564M', '1.7B', '2.9B', '7.5B'],
        'n_parameters' : [564, 1700, 2900, 7500],
        'jaccard_matrix': 'distances/jaccard_similarity/jaccard_matrix_xglm.csv',
        'train_data_col': 'XGLM Train Percentage'
    }
}

task_config = {
    'zero_shot_classification': 'F1 {}-{}',
    'two_shot_classification': 'F1 {}-{} 2s',
    'zero_shot_generation': '{}-{} scbleu',
    'two_shot_generation': '{}-{} scbleu 2s'
}

# Set current model and task
for current_model in model_config.keys():
    for current_task in task_config.keys():
        # Load the language similarity matrix for the current model
        jaccard_matrix_path = model_config[current_model]['jaccard_matrix']
        jaccard_matrix_full = pd.read_csv(jaccard_matrix_path, header=None)
        jaccard_matrix = jaccard_matrix_full.iloc[1:, 1:].astype(float).values

        # Load the data
        file_path = 'SIB-200 languages.xlsx'
        data = pd.read_excel(file_path, sheet_name='Sheet1')
        
        # Standardize folder names (remove extra spaces and convert to lowercase)
        data['Folder Name'] = data['Folder Name'].str.strip().str.lower()
        jaccard_folder_names = jaccard_matrix_full.iloc[0, 1:].str.strip().str.lower().tolist()
        
        # Ensure the matching process works correctly
        matching_indices = [jaccard_folder_names.index(folder) for folder in data['Folder Name'] if folder in jaccard_folder_names]
        if not matching_indices:
            raise ValueError("No matching folders found between the Jaccard matrix and the main dataset.")
        
        # Filter the Jaccard matrix to include only matching folders
        filtered_jaccard_matrix = jaccard_matrix[np.ix_(matching_indices, matching_indices)]
        
        # Handle NaN values in 'countries'
        data['countries'] = data['countries'].apply(lambda x: eval(x) if pd.notna(x) else [])
        
        # Create country similarity matrix
        def create_country_similarity(data):
            country_list = data['countries'].tolist()
            country_similarity = np.zeros((len(country_list), len(country_list)))
            for i, countries_i in enumerate(country_list):
                for j, countries_j in enumerate(country_list):
                    if i != j:
                        common_countries = set(countries_i).intersection(set(countries_j))
                        total_countries = set(countries_i).union(set(countries_j))
                        if len(total_countries) > 0:
                            country_similarity[i, j] = len(common_countries) / len(total_countries)
                        else:
                            country_similarity[i, j] = 0
            return country_similarity
        
        country_similarity_matrix = create_country_similarity(data)
        
        # MDS on country similarity matrix
        mds_country = MDS(n_components=10, dissimilarity='precomputed', normalized_stress=False)
        country_mds = mds_country.fit_transform(1 - country_similarity_matrix)
        country_mds_df = pd.DataFrame(country_mds, columns=[f'Country_MDS{i+1}' for i in range(country_mds.shape[1])])
        
        # Handle geographical coordinates
        geo_df = data[['latitude', 'longitude']].dropna()
        coords = geo_df[['latitude', 'longitude']].values
        
        # Calculate geographical distance matrix
        geo_distance_matrix = np.zeros((len(coords), len(coords)))
        
        for i, coord1 in enumerate(coords):
            for j, coord2 in enumerate(coords):
                geo_distance_matrix[i, j] = geodesic(coord1, coord2).kilometers
        
        # MDS on geographical distance matrix
        mds_geo = MDS(n_components=10, dissimilarity='precomputed', normalized_stress=False)
        geo_mds = mds_geo.fit_transform(geo_distance_matrix)
        geo_mds_df = pd.DataFrame(geo_mds, columns=[f'Geo_MDS{i+1}' for i in range(geo_mds.shape[1])])
        
        # MDS on filtered Jaccard similarity matrix
        mds_jaccard = MDS(n_components=10, dissimilarity='precomputed', normalized_stress=False)
        jaccard_mds = mds_jaccard.fit_transform(1 - filtered_jaccard_matrix)
        jaccard_mds_df = pd.DataFrame(jaccard_mds, columns=[f'Jaccard_MDS{i+1}' for i in range(jaccard_mds.shape[1])])
        
        # Merge the dataframes
        data = data.merge(geo_mds_df, left_index=True, right_index=True, how='left')
        data = data.merge(country_mds_df, left_index=True, right_index=True, how='left')
        data = data.merge(jaccard_mds_df, left_index=True, right_index=True, how='left', suffixes=('', '_jaccard'))
        
        # Handle ordinal and categorical data
        data['Population'] = data['Population'].replace({
            '10K to 1 million': 0,
            '1 million to 1 billion': 1,
            '1 billion plus': 2,
            'None': -1
        }).fillna(-1)
        
        data['Language Vitality'] = data['Language Vitality'].replace({
            'Extinct': 0,
            'Endangered': 1,
            'Stable': 2,
            'Institutional': 3,
            'None': -1
        }).fillna(-1)
        
        data['Digital Language Support'] = data['Digital Language Support'].replace({
            'Still': 0,
            'Emerging': 1,
            'Ascending': 2,
            'Vital': 3,
            'Thriving': 4,
            'None': -1
        }).fillna(-1)
        
        data['Resource Level'] = data['Resource Level'].replace('None', 0).fillna(0)
        
        # Convert categorical data
        categorical_features = ['Script (ISO 15924)', 'Language Family']
        data = pd.get_dummies(data, columns=categorical_features, dummy_na=True)
        
        # Handle missing Bloom Train Data Percentage or XGLM Train Percentage
        train_data_col = model_config[current_model]['train_data_col']
        data[train_data_col] = data[train_data_col].fillna(0)
        
        # Add BLOOMZ Finetune Data if applicable
        if current_model == 'bloom':
            data['BLOOMZ Finetune Data'] = data['BLOOMZ Finetune Data'].fillna(0)
        
        # List of model sizes for the current model
        model_sizes = model_config[current_model]['sizes']
        numeric_model_sizes = model_config[current_model]['n_parameters']
        
        # Prepare a combined DataFrame
        combined_data = pd.DataFrame()
        
        for model_name, size in zip(model_sizes, numeric_model_sizes):
            temp_data = data.copy()
            temp_data['Model_Size'] = size
            task_col = task_config[current_task].format(current_model, model_name)
            temp_data['Performance'] = temp_data[task_col]
            combined_data = pd.concat([combined_data, temp_data], ignore_index=True)
        
        # Define features and ensure there are no NaN values
        features = (
                [f'Geo_MDS{i+1}' for i in range(10)] +
                [f'Country_MDS{i+1}' for i in range(10)] +
                [f'Jaccard_MDS{i+1}' for i in range(10)] +
                ['Population', 'Language Vitality', 'Digital Language Support', 'Resource Level', train_data_col, 'Model_Size'] +
                list(combined_data.columns[combined_data.columns.str.startswith('Script (ISO 15924)_')]) +
                list(combined_data.columns[combined_data.columns.str.startswith('Language Family_')])
        )
        
        if current_model == 'bloomz':
            features.append('BLOOMZ Finetune Data')
        
        # Ensure there are no NaN values in the feature matrix
        combined_data = combined_data.dropna(subset=features + ['Performance'])
        combined_data.reset_index(drop=True, inplace=True)
        
        X = combined_data[features]
        y = combined_data['Performance']
        
        # Split the data
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
        
        # Define models to test
        models = {
            'Linear Regression': LinearRegression(),
            'Random Forest': RandomForestRegressor(n_estimators=100, random_state=42),
            'Decision Tree': DecisionTreeRegressor(random_state=42),
            'SVR': SVR(),
            'Gradient Boosting': GradientBoostingRegressor(random_state=42),
            'XGBoost': XGBRegressor(random_state=42),
            'K-Nearest Neighbors': KNeighborsRegressor(),
            'Lasso Regression': Lasso(random_state=42),
            'Ridge Regression': Ridge(random_state=42),
            'Elastic Net': ElasticNet(random_state=42)
        }
        
        # Initialize results list
        results_list = []
        
        # Train and evaluate each model
        for model_name, model in models.items():
            model.fit(X_train, y_train)
            y_pred = model.predict(X_test)
            r2 = r2_score(y_test, y_pred)
            mse = mean_squared_error(y_test, y_pred)
            results_list.append({'Model': model_name, 'R2 Score': r2, 'MSE': mse})
        
        # Convert results list to DataFrame
        results = pd.DataFrame(results_list)
        results.to_csv(f'results/regression/{current_model}_{current_task}_results.csv', index=False)
        # Display the results
        results

In [None]:
        import shap
        import pandas as pd
        import numpy as np
        import matplotlib.pyplot as plt
        from sklearn.linear_model import LinearRegression, Lasso, Ridge, ElasticNet
        from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
        from sklearn.tree import DecisionTreeRegressor
        from sklearn.svm import SVR
        from xgboost import XGBRegressor
        from sklearn.neighbors import KNeighborsRegressor
        from sklearn.inspection import permutation_importance
        import os
        
        # Create results directory if it doesn't exist
        results_dir = 'results/features'
        os.makedirs(results_dir, exist_ok=True)
        
        # Function to train the selected regression model and compute feature importances
        def train_and_compute_importances(model_name, model, X_train, y_train, X_test, y_test):
            model.fit(X_train, y_train)
        
            if model_name in ['XGBoost', 'Random Forest', 'Gradient Boosting', 'Decision Tree']:
                feature_importances = model.feature_importances_
                feature_names = X_train.columns
                importance_values = feature_importances
            else:
                result = permutation_importance(model, X_test, y_test, n_repeats=10, random_state=42)
                feature_names = X_train.columns
                importance_values = result.importances_mean
        
            # Create a DataFrame for better visualization
            importance_df = pd.DataFrame({'Feature': feature_names, 'Importance': importance_values})
            importance_df = importance_df.sort_values(by='Importance', ascending=False)
        
            return importance_df
        
        # Function to aggregate importances by abstract feature
        def aggregate_importances(importance_df, abstract_features):
            abstract_importances = {}
        
            for abstract_feature, sub_features in abstract_features.items():
                total_importance = importance_df[importance_df['Feature'].isin(sub_features)]['Importance'].sum()
                abstract_importances[abstract_feature] = total_importance
        
            abstract_importance_df = pd.DataFrame.from_dict(abstract_importances, orient='index', columns=['Importance'])
            abstract_importance_df = abstract_importance_df.sort_values(by='Importance', ascending=False)
        
            return abstract_importance_df
        
        # Select the model
        models = {
            'Linear Regression': LinearRegression(),
            'Random Forest': RandomForestRegressor(n_estimators=100, random_state=42),
            'Decision Tree': DecisionTreeRegressor(random_state=42),
            'SVR': SVR(),
            'Gradient Boosting': GradientBoostingRegressor(random_state=42),
            'XGBoost': XGBRegressor(random_state=42),
            'K-Nearest Neighbors': KNeighborsRegressor(),
            'Lasso Regression': Lasso(random_state=42),
            'Ridge Regression': Ridge(random_state=42),
            'Elastic Net': ElasticNet(random_state=42)
        }
        
        for selected_model_name in ['XGBoost', 'Random Forest', 'Gradient Boosting']:  # Change this to select a different model
            selected_model = models[selected_model_name]
        
            # Train the model and compute feature importances
            importance_df = train_and_compute_importances(selected_model_name, selected_model, X_train, y_train, X_test, y_test)
        
            # Define abstract features
            abstract_features = {
                'Geographical Features': [f'Geo_MDS{i+1}' for i in range(10)],
                'Country Similarity': [f'Country_MDS{i+1}' for i in range(10)],
                'Token Similarity': [f'Jaccard_MDS{i+1}' for i in range(10)],
                'Script': [col for col in X_train.columns if col.startswith('Script (ISO 15924)_')],
                'Language Family': [col for col in X_train.columns if col.startswith('Language Family_')],
                'Population': ['Population'],
                'Language Vitality': ['Language Vitality'],
                'Digital Language Support': ['Digital Language Support'],
                'Resource Level': ['Resource Level'],
                'Pre-train Data Percentage': [model_config[current_model]['train_data_col']],
                'Model Size': ['Model_Size']
            }
        
            if current_model == 'bloomz':
                abstract_features['Instruction Tune Data'] = ['BLOOMZ Finetune Data']
        
            # Aggregate importances by abstract feature
            abstract_importance_df = aggregate_importances(importance_df, abstract_features)
            abstract_importance_df= abstract_importance_df.reset_index().sort_values(by='Importance', ascending=False).rename(columns={'index': 'Abstract Feature'}).reset_index(drop=True)
            abstract_importance_df.to_csv(os.path.join(results_dir, f'importance/table/{current_model}_{current_task }_{selected_model_name}_abstract_feature_importances.csv'), index=False)

            # Plot the abstract feature importances and save the plot
            plt.figure(figsize=(10, 8))
            plt.barh(abstract_importance_df['Abstract Feature'], abstract_importance_df['Importance'])
            plt.xlabel('Importance')
            plt.ylabel('Abstract Feature')
            plt.title(f'Abstract Feature Importances based on {selected_model_name}')
            plt.gca().invert_yaxis()
            plt.savefig(os.path.join(results_dir, f'importance/plot/{current_model}_{current_task }_{selected_model_name}_abstract_feature_importances.jpg'))
            plt.close()
        
            # SHAP value calculation for the selected model
            def compute_shap_values(model_name, model, X):
                if model_name in ['XGBoost', 'Gradient Boosting', 'Random Forest', 'Decision Tree']:
                    explainer = shap.Explainer(model, X)
                else:
                    explainer = shap.KernelExplainer(model.predict, X)
        
                shap_values = explainer(X, check_additivity=False)
                return shap_values
        
            # Aggregate SHAP values by abstract feature
            def aggregate_shap_and_feature_values(shap_values, feature_values, feature_names, abstract_features):
                aggregated_shap_values = np.zeros((shap_values.shape[0], len(abstract_features)))
                aggregated_feature_values = np.zeros((feature_values.shape[0], len(abstract_features)))
                abstract_feature_names = []
        
                for i, (abstract_feature, sub_features) in enumerate(abstract_features.items()):
                    feature_indices = [feature_names.index(sub_feature) for sub_feature in sub_features if sub_feature in feature_names]
                    aggregated_shap_values[:, i] = shap_values[:, feature_indices].sum(axis=1)
                    aggregated_feature_values[:, i] = feature_values[:, feature_indices].sum(axis=1)
                    abstract_feature_names.append(abstract_feature)
        
                return aggregated_shap_values, aggregated_feature_values, abstract_feature_names
        
            # Compute SHAP values for the selected model
            shap_values = compute_shap_values(selected_model_name, selected_model, X)
        
            # Aggregate SHAP values and feature values
            aggregated_shap_values, aggregated_feature_values, abstract_feature_names = aggregate_shap_and_feature_values(
                shap_values.values, X.values, X.columns.tolist(), abstract_features
            )
        
            # Convert to SHAP values object for visualization
            aggregated_shap_values_explanation = shap.Explanation(
                values=aggregated_shap_values,
                base_values=shap_values.base_values,
                data=aggregated_feature_values,
                feature_names=abstract_feature_names
            )

            shap_df = pd.DataFrame({'Abstract Feature': abstract_feature_names, 'SHAP value': np.abs(aggregated_shap_values).mean(0)}).sort_values(by='SHAP value', ascending=False).reset_index(drop=True)
            shap_df.to_csv(os.path.join(results_dir, f'shap/table/{current_model}_{current_task }_{selected_model_name}_shap_values.csv'), index=False)

            # Save the summary plot for abstract features
            shap.summary_plot(aggregated_shap_values, feature_names=abstract_feature_names, plot_type='bar', show=False)
            plt.savefig(os.path.join(results_dir, f'shap/plot/{current_model}_{current_task }_{selected_model_name}_shap_summary_plot_bar.jpg'))
            plt.close()
        
            shap.summary_plot(aggregated_shap_values, features=aggregated_shap_values, feature_names=abstract_feature_names, show=False)
            plt.savefig(os.path.join(results_dir, f'shap/plot/{current_model}_{current_task }_{selected_model_name}_shap_summary_plot.jpg'))
            plt.close()
        print(f'{current_model}_{current_task} done')

In [None]:
df = pd.read_excel('results/SHAP_results.xlsx', sheet_name='Zero-Classification')

# Sorting the DataFrame by feature names in descending order
df_sorted = df.sort_values(by='feature', ascending=False)

# Plotting the bar charts in horizontal subplots
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(18, 6), sharey=True)

# Bloom
axes[0].barh(df_sorted['feature'], df_sorted['Bloom'], color='b')
axes[0].set_title('Bloom')

# Bloomz
axes[1].barh(df_sorted['feature'], df_sorted['Bloomz'], color='g')
axes[1].set_title('Bloomz')

# XGLM
axes[2].barh(df_sorted['feature'], df_sorted['XGLM'], color='r')
axes[2].set_title('XGLM')

# Set common ylabel
fig.text(0.04, 0.5, 'Features', va='center', rotation='vertical')
# Set common xlabel
fig.text(0.5, 0.04, 'SHAP Values', ha='center')

plt.tight_layout(rect=[0.05, 0.05, 1, 1])
plt.show()

In [None]:
df = pd.read_excel('results/SHAP_results.xlsx', sheet_name='Two-Classification')

# Sorting the DataFrame by feature names in descending order
df_sorted = df.sort_values(by='feature', ascending=False)

# Plotting the bar charts in horizontal subplots
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(15, 5), sharey=True)

# Bloom
axes[0].barh(df_sorted['feature'], df_sorted['Bloom'], color='b')
axes[0].set_title('Bloom')
axes[0].tick_params(axis='y', labelsize=12)  # Adjust y-tick label size

# Bloomz
axes[1].barh(df_sorted['feature'], df_sorted['Bloomz'], color='g')
axes[1].set_title('Bloomz')
axes[1].tick_params(axis='y', labelsize=12)  # Adjust y-tick label size

# XGLM
axes[2].barh(df_sorted['feature'], df_sorted['XGLM'], color='r')
axes[2].set_title('XGLM')
axes[2].tick_params(axis='y', labelsize=12)  # Adjust y-tick label size

# Set common ylabel
fig.text(0.04, 0.5, 'Features', va='center', rotation='vertical', fontsize=12)
# Set common xlabel
fig.text(0.5, 0.04, 'SHAP Values', ha='center', fontsize=12)

plt.tight_layout(rect=[0.05, 0.05, 1, 1])
plt.show()

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Function to create a plot for a given sheet
def plot_shap_values(sheet_name, row_idx, axes):
    df = pd.read_excel('results/SHAP_results.xlsx', sheet_name=sheet_name)

    # Renaming the features
    rename_dict = {
        "Instruction Tune Data": "Instruction Tuning Data",
        "Geographical Features": "Geographical Proximity"
    }
    df['feature'] = df['feature'].replace(rename_dict)

    # Sorting the DataFrame by feature names in descending order
    df_sorted = df.sort_values(by='feature', ascending=False)

    # Define the title suffix based on the sheet name
    if 'Zero' in sheet_name:
        title_suffix = 'Zero-shot'
    elif 'Two' in sheet_name:
        title_suffix = 'Two-shot'

    if 'Classification' in sheet_name:
        title_suffix += ' Classification'
    elif 'Generation' in sheet_name:
        title_suffix += ' Generation'

    # Plotting each model in the corresponding column
    axes[row_idx, 0].barh(df_sorted['feature'], df_sorted['Bloom'], color='b')
    axes[row_idx, 0].set_title(f'Bloom - {title_suffix}')
    axes[row_idx, 0].tick_params(axis='y', labelsize=12)

    axes[row_idx, 1].barh(df_sorted['feature'], df_sorted['Bloomz'], color='g')
    axes[row_idx, 1].set_title(f'Bloomz - {title_suffix}')
    axes[row_idx, 1].tick_params(axis='y', labelsize=12)

    axes[row_idx, 2].barh(df_sorted['feature'], df_sorted['XGLM'], color='r')
    axes[row_idx, 2].set_title(f'XGLM - {title_suffix}')
    axes[row_idx, 2].tick_params(axis='y', labelsize=12)

# Classification sheets
classification_sheets = ['Zero-Classification', 'Two-Classification']

# Generation sheets
generation_sheets = ['Zero-Generation', 'Two-Generation']

# Plotting the bar charts for classification in a grid of subplots
fig1, axes1 = plt.subplots(nrows=len(classification_sheets), ncols=3, figsize=(18, 10), sharey=True)

# Plot each classification sheet in the corresponding row
for idx, sheet in enumerate(classification_sheets):
    plot_shap_values(sheet, idx, axes1)

# Set common ylabel for classification
fig1.text(0.04, 0.5, 'Features', va='center', rotation='vertical', fontsize=16)
# Set common xlabel for classification
fig1.text(0.5, 0.04, 'SHAP Values', ha='center', fontsize=16)

plt.tight_layout(rect=[0.05, 0.05, 1, 1])
plt.show()

# Plotting the bar charts for generation in a grid of subplots
fig2, axes2 = plt.subplots(nrows=len(generation_sheets), ncols=3, figsize=(18, 10), sharey=True)

# Plot each generation sheet in the corresponding row
for idx, sheet in enumerate(generation_sheets):
    plot_shap_values(sheet, idx, axes2)

# Set common ylabel for generation
fig2.text(0.04, 0.5, 'Features', va='center', rotation='vertical', fontsize=16)
# Set common xlabel for generation
fig2.text(0.5, 0.04, 'SHAP Values', ha='center', fontsize=16)

plt.tight_layout(rect=[0.05, 0.05, 1, 1])
plt.show()

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Function to create an enhanced plot for zero-shot and two-shot on the same axes with a split for model and language features
def plot_combined_shap_values(sheet_zero, sheet_two, axes, col_idx):
    # Read both zero-shot and two-shot data
    df_zero = pd.read_excel('results/SHAP_results.xlsx', sheet_name=sheet_zero)
    df_two = pd.read_excel('results/SHAP_results.xlsx', sheet_name=sheet_two)

    # Renaming the features and the model name in the DataFrame
    rename_dict = {
        "Instruction Tune Data": "Instruction Tuning Data",
        "Geographical Features": "Geographical Proximity",
        "Script": "Script Type",
        "Bloomz": "BloomZ"  # Rename the model column
    }
    df_zero['feature'] = df_zero['feature'].replace(rename_dict)
    df_two['feature'] = df_two['feature'].replace(rename_dict)

    # Rename the columns if necessary
    df_zero.rename(columns=rename_dict, inplace=True)
    df_two.rename(columns=rename_dict, inplace=True)

    # Define model features and language features
    model_features = ["Model Size", "Pre-train Data Percentage", "Instruction Tuning Data"]
    language_features = ["Country Similarity", "Digital Language Support", "Geographical Proximity",
                         "Language Family", "Script Type", "Language Vitality", "Population", "Resource Level", "Token Similarity"]

    # Separate the data into model and language features
    df_zero_model = df_zero[df_zero['feature'].isin(model_features)].sort_values(by='feature', ascending=False)
    df_zero_language = df_zero[df_zero['feature'].isin(language_features)].sort_values(by='feature', ascending=False)
    df_two_model = df_two[df_two['feature'].isin(model_features)].sort_values(by='feature', ascending=False)
    df_two_language = df_two[df_two['feature'].isin(language_features)].sort_values(by='feature', ascending=False)

    # Concatenate the two parts to create a split effect
    df_zero_sorted = pd.concat([df_zero_model, df_zero_language])
    df_two_sorted = pd.concat([df_two_model, df_two_language])

    # Define custom vibrant colors for zero-shot and two-shot with added transparency
    colors_zero_shot = ['#FF6F61', '#6B5B95', '#88B04B']
    colors_two_shot = ['#FFB3A7', '#B39EB5', '#C3D7A4']

    # Plotting both zero-shot and two-shot results
    bar_width = 0.35
    indices = np.arange(len(df_zero_sorted['feature']))

    # Plot the bars with added shadow and edge color
    for i, ax in enumerate(axes):
        ax.barh(indices + bar_width, df_zero_sorted.iloc[:, i + 1], height=bar_width, color=colors_zero_shot[i], edgecolor='black', linewidth=0.5, label='Zero-shot', alpha=0.9)
        ax.barh(indices, df_two_sorted.iloc[:, i + 1], height=bar_width, color=colors_two_shot[i], edgecolor='black', linewidth=0.5, label='Two-shot', alpha=0.9)
        ax.set_yticks(indices + bar_width / 2)
        ax.set_yticklabels(df_zero_sorted['feature'], fontsize=12, fontweight='bold')
        ax.set_title(['Bloom', 'BloomZ', 'XGLM'][i], fontsize=14, fontweight='bold', color='#333333')
        ax.legend(loc='upper right', fontsize=10, shadow=True)
        ax.grid(True, which='major', axis='x', linestyle='--', alpha=0.6)
        ax.set_facecolor('#f9f9f9')

    # Add a dashed horizontal line to split the model and language features
    split_index = len(df_zero_model)
    for ax in axes:
        ax.axhline(y=split_index - bar_width, color='#333333', linewidth=1.5, linestyle='--', alpha=0.7)

    # Enhance the plot aesthetics
    for ax in axes:
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.spines['bottom'].set_color('#aaaaaa')
        ax.tick_params(axis='x', colors='#333333')
        ax.tick_params(axis='y', colors='#333333')

# Classification sheets
classification_sheets = ['Zero-Classification', 'Two-Classification']

# Generation sheets
generation_sheets = ['Zero-Generation', 'Two-Generation']

# Plotting the bar charts for classification in a separate figure
fig1, axes1 = plt.subplots(nrows=1, ncols=3, figsize=(15, 8), sharey=True)

# Plot each classification model
plot_combined_shap_values('Zero-Classification', 'Two-Classification', axes1, 0)

# Set common y-label and x-label for classification
fig1.text(0.04, 0.5, 'Features', va='center', rotation='vertical', fontsize=16, fontweight='bold', color='#333333')
fig1.text(0.5, 0.04, 'SHAP Values', ha='center', fontsize=16, fontweight='bold', color='#333333')

# Save the classification figure as a high-quality image
plt.tight_layout(rect=[0.05, 0.05, 1, 1])
fig1.savefig('figures/classification.png', dpi=600, bbox_inches='tight')

# Plotting the bar charts for generation in a separate figure
fig2, axes2 = plt.subplots(nrows=1, ncols=3, figsize=(15, 8), sharey=True)

# Plot each generation model
plot_combined_shap_values('Zero-Generation', 'Two-Generation', axes2, 0)

# Set common y-label and x-label for generation
fig2.text(0.04, 0.5, 'Features', va='center', rotation='vertical', fontsize=16, fontweight='bold', color='#333333')
fig2.text(0.5, 0.04, 'SHAP Values', ha='center', fontsize=16, fontweight='bold', color='#333333')

# Save the generation figure as a high-quality image
plt.tight_layout(rect=[0.05, 0.05, 1, 1])
fig2.savefig('figures/generation.png', dpi=600, bbox_inches='tight')

In [None]:
# Visualize the first prediction's explanation with abstract features
shap.plots.waterfall(shap_values[712])

In [None]:
import matplotlib.pylab as pl
import numpy as np

# Compute SHAP interaction values for the selected model
shap_interaction_values = shap.TreeExplainer(selected_model).shap_interaction_values(X)

# Aggregate SHAP interaction values and feature values by abstract features
def aggregate_shap_interaction_values(shap_interaction_values, feature_values, feature_names, abstract_features):
    n_samples = shap_interaction_values.shape[0]
    n_abstract_features = len(abstract_features)

    aggregated_shap_interaction_values = np.zeros((n_samples, n_abstract_features, n_abstract_features))
    aggregated_feature_values = np.zeros((n_samples, n_abstract_features))

    abstract_feature_names = list(abstract_features.keys())

    for i, (abstract_feature_i, sub_features_i) in enumerate(abstract_features.items()):
        for j, (abstract_feature_j, sub_features_j) in enumerate(abstract_features.items()):
            feature_indices_i = [feature_names.index(sub_feature) for sub_feature in sub_features_i if sub_feature in feature_names]
            feature_indices_j = [feature_names.index(sub_feature) for sub_feature in sub_features_j if sub_feature in feature_names]

            aggregated_shap_interaction_values[:, i, j] = shap_interaction_values[:, feature_indices_i][:, :, feature_indices_j].sum(axis=(1, 2))

        aggregated_feature_values[:, i] = feature_values[:, feature_indices_i].mean(axis=1)

    return aggregated_shap_interaction_values, aggregated_feature_values, abstract_feature_names

# Aggregate SHAP interaction values and feature values
aggregated_shap_interaction_values, aggregated_feature_values, abstract_feature_names = aggregate_shap_interaction_values(
    shap_interaction_values, X.values, X.columns.tolist(), abstract_features
)

# Summarize the interaction values
tmp = np.abs(aggregated_shap_interaction_values).sum(0)
for i in range(tmp.shape[0]):
    tmp[i, i] = 0
inds = np.argsort(-tmp.sum(0))[:50]
tmp2 = tmp[inds, :][:, inds]

# Plot the heatmap
pl.figure(figsize=(12, 12))
pl.imshow(tmp2, cmap='viridis')
pl.yticks(range(tmp2.shape[0]), [abstract_feature_names[i] for i in inds], rotation=50.4, horizontalalignment="right")
pl.xticks(range(tmp2.shape[0]), [abstract_feature_names[i] for i in inds], rotation=50.4, horizontalalignment="left")
pl.gca().xaxis.tick_top()
pl.colorbar()
pl.title("SHAP Interaction Values for Abstract Features")
pl.show()