In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import spearmanr
import os

def extract_json_data(directory):
    """
    Extracts data from JSON files within a given directory and its subdirectories.

    Args:
        directory: The path to the directory to search.

    Returns:
        A list of dictionaries, where each dictionary contains data extracted from a JSON file.
    """

    extracted_data = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(".json"):
                filepath = os.path.join(root, file)
                try:
                    with open(filepath, "r") as f:
                        data = json.load(f)
                        extracted_data.append(data)
                        #print(f"Extracted data from {filepath}")  # Optional: Print filepaths for confirmation.
                except json.JSONDecodeError as e:
                    print(f"Error decoding JSON in {filepath}: {e}")
                except Exception as e:
                    print(f"An error occurred while processing {filepath}: {e}")
    return extracted_data

In [None]:
results_adult = extract_json_data("/content/drive/MyDrive/Seminar_Project/Adult_Files")
results_ca = extract_json_data("/content/drive/MyDrive/Seminar_Project/CA_Files")
results_higgs = extract_json_data("/content/drive/MyDrive/Seminar_Project/HIGGS_Files")
results_jannis = extract_json_data("/content/drive/MyDrive/Seminar_Project/Jannis_Files")
results_helena = extract_json_data("/content/drive/MyDrive/Seminar_Project/Helena_Files")

In [None]:
# Define the dataset names and file paths
datasets = {
    'California': results_ca,
    'Adult': results_adult,
    'Helena': results_helena,
    'Jannis': results_jannis,
    'Higgs': results_higgs
}

In [None]:
# Set style for plots - using updated style syntax
plt.style.use('default')  # Reset to default style
sns.set_theme(style="whitegrid")  # Use seaborn's whitegrid style
sns.set_context("paper", font_scale=1.2)
plt.rcParams['figure.figsize'] = (10, 8)

# Define the dataset names and file paths
datasets = {
    'California': results_ca,
    'Adult': results_adult,
    'Helena': results_helena,
    'Jannis': results_jannis,
    'Higgs': results_higgs
}

# Define model names and their display labels for the plots
model_names = {
    'ft_linear_tuned': 'FT-Linear',
    'ft_piecewise_tuned': 'FT-PLE',
    'sparse_ft_linear_tuned': 'FT-Sparse-Linear',
    'sparse_ft_piecewise_tuned': 'FT-Sparse-PLE'
}

# Function to load the JSON data
def load_data(file_path):
    with open(file_path, 'r') as f:
        return json.load(f)

# Function to extract PFI scores for each model
def extract_pfi_scores(data):
    pfi_scores = {}
    for model_key, model_data in data[0].items():
        if model_key in model_names:
            pfi_scores[model_key] = model_data['correlation_analysis']['pfi_scores']
    return pfi_scores

# Function to calculate Spearman rank correlations between models' PFI scores
def calculate_pfi_correlations(pfi_scores):
    models = list(pfi_scores.keys())
    num_models = len(models)
    correlation_matrix = np.zeros((num_models, num_models))

    for i, model1 in enumerate(models):
        for j, model2 in enumerate(models):
            # Calculate Spearman rank correlation
            corr, _ = spearmanr(pfi_scores[model1], pfi_scores[model2])
            correlation_matrix[i, j] = corr

    # Create a DataFrame for better visualization
    corr_df = pd.DataFrame(correlation_matrix,
                          index=[model_names[m] for m in models],
                          columns=[model_names[m] for m in models])
    return corr_df

# Function to create a heatmap for visualizing the correlations
def create_correlation_heatmap(correlation_df, dataset_name, output_dir='heatmaps'):
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    plt.figure(figsize=(10, 8))
    mask = np.triu(np.ones_like(correlation_df, dtype=bool))

    # Create the heatmap with a blue-red diverging colormap
    ax = sns.heatmap(
        correlation_df,
        annot=True,             # Show correlation values
        mask=mask,              # Show only lower triangle
        cmap='coolwarm',        # Blue-red color map
        vmin=-1, vmax=1,        # Correlation ranges from -1 to 1
        square=True,            # Make cells square
        linewidths=0.5,         # Add lines between cells
        cbar_kws={"shrink": 0.8},
        fmt='.2f'               # Format correlation values to 2 decimal places
    )

    # Add title and labels
    plt.title(f'Spearman Rank Correlation of PFI Scores\nDataset: {dataset_name}', fontsize=14)

    # Save the figure
    plt.tight_layout()
    filename = os.path.join(output_dir, f'{dataset_name.lower()}_pfi_correlation.png')
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()

    return filename

# Main function to process all datasets
def analyze_all_datasets():
    all_results = {}

    for dataset_name, data in datasets.items():
        print(f"Processing {dataset_name}...")


        # Extract PFI scores
        pfi_scores = extract_pfi_scores(data)

        # Calculate correlations
        correlation_df = calculate_pfi_correlations(pfi_scores)

        # Create and save heatmap
        heatmap_file = create_correlation_heatmap(correlation_df, dataset_name)

        # Store results
        all_results[dataset_name] = {
            'correlation_df': correlation_df,
            'heatmap_file': heatmap_file
        }

        print(f"  Heatmap saved to {heatmap_file}")

        # Display the correlation matrix
        print(f"\nCorrelation Matrix for {dataset_name}:")
        print(correlation_df)
        print("\n" + "-"*50 + "\n")

    return all_results

# Function to generate a combined figure with all heatmaps
def create_combined_heatmap(all_results, output_dir='heatmaps'):
    datasets = list(all_results.keys())
    num_datasets = len(datasets)

    # Determine grid size (trying to make it approximately square)
    grid_size = int(np.ceil(np.sqrt(num_datasets)))
    fig, axes = plt.subplots(grid_size, grid_size, figsize=(15, 15))

    # Flatten axes array for easier indexing
    axes = axes.flatten() if hasattr(axes, 'flatten') else [axes]

    for i, dataset_name in enumerate(datasets):
        if i < len(axes):
            correlation_df = all_results[dataset_name]['correlation_df']

            # Create heatmap on the specific subplot
            sns.heatmap(
                correlation_df,
                annot=True,
                cmap='coolwarm',
                vmin=-1, vmax=1,
                square=True,
                linewidths=0.5,
                ax=axes[i],
                cbar=False,
                fmt='.2f'
            )

            axes[i].set_title(dataset_name)

    # Hide any unused subplots
    for j in range(i + 1, len(axes)):
        axes[j].axis('off')

    # Add a common colorbar
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    sm = plt.cm.ScalarMappable(cmap='coolwarm', norm=plt.Normalize(-1, 1))
    sm.set_array([])
    fig.colorbar(sm, cax=cbar_ax)

    plt.suptitle('Spearman Rank Correlation of PFI Scores Across Datasets', fontsize=16)
    plt.tight_layout(rect=[0, 0, 0.9, 0.95])

    # Save the combined figure
    combined_file = os.path.join(output_dir, 'combined_pfi_correlation.png')
    plt.savefig(combined_file, dpi=300, bbox_inches='tight')
    plt.close()

    return combined_file

# Function to create a LaTeX table for the paper
def create_latex_table(all_results):
    # Extract average correlation for each model pair across datasets
    model_pairs = []
    correlations = []

    # Get unique model pairs
    first_dataset = list(all_results.keys())[0]
    df = all_results[first_dataset]['correlation_df']
    models = df.index.tolist()

    # Lower triangular part only (unique pairs)
    for i in range(len(models)):
        for j in range(i):
            model_pairs.append((models[j], models[i]))

    # Collect correlations for each dataset
    dataset_correlations = {}
    for dataset, result in all_results.items():
        df = result['correlation_df']
        dataset_correlations[dataset] = []

        for model1, model2 in model_pairs:
            dataset_correlations[dataset].append(df.loc[model1, model2])

    # Create a DataFrame for the LaTeX table
    table_df = pd.DataFrame(dataset_correlations, index=[f"{m1} vs {m2}" for m1, m2 in model_pairs])

    # Add an Average column
    table_df['Average'] = table_df.mean(axis=1)

    # Generate LaTeX table
    latex_table = table_df.round(2).to_latex(escape=False)

    # Save the LaTeX table to a file
    with open('pfi_correlation_table.tex', 'w') as f:
        f.write(latex_table)

    return latex_table

# Run the analysis
if __name__ == "__main__":
    all_results = analyze_all_datasets()

    # Create combined heatmap
    combined_file = create_combined_heatmap(all_results)
    print(f"Combined heatmap saved to {combined_file}")

    # Create LaTeX table
    latex_table = create_latex_table(all_results)
    print("\nLaTeX Table for the paper:")
    print(latex_table)



Processing California...
  Heatmap saved to heatmaps/california_pfi_correlation.png

Correlation Matrix for California:
                  FT-Linear    FT-PLE  FT-Sparse-Linear  FT-Sparse-PLE
FT-Linear          1.000000  0.952381          1.000000       0.952381
FT-PLE             0.952381  1.000000          0.952381       1.000000
FT-Sparse-Linear   1.000000  0.952381          1.000000       0.952381
FT-Sparse-PLE      0.952381  1.000000          0.952381       1.000000

--------------------------------------------------

Processing Adult...
  Heatmap saved to heatmaps/adult_pfi_correlation.png

Correlation Matrix for Adult:
                  FT-Linear    FT-PLE  FT-Sparse-Linear  FT-Sparse-PLE
FT-Linear          1.000000  0.912088          0.916484       0.846154
FT-PLE             0.912088  1.000000          0.898901       0.885714
FT-Sparse-Linear   0.916484  0.898901          1.000000       0.907692
FT-Sparse-PLE      0.846154  0.885714          0.907692       1.000000

-----------

  plt.tight_layout(rect=[0, 0, 0.9, 0.95])


Combined heatmap saved to heatmaps/combined_pfi_correlation.png

LaTeX Table for the paper:
\begin{tabular}{lrrrrrr}
\toprule
 & California & Adult & Helena & Jannis & Higgs & Average \\
\midrule
FT-Linear vs FT-PLE & 0.950000 & 0.910000 & 0.870000 & 0.830000 & 0.800000 & 0.870000 \\
FT-Linear vs FT-Sparse-Linear & 1.000000 & 0.920000 & 0.680000 & 0.860000 & 0.880000 & 0.870000 \\
FT-PLE vs FT-Sparse-Linear & 0.950000 & 0.900000 & 0.690000 & 0.850000 & 0.840000 & 0.850000 \\
FT-Linear vs FT-Sparse-PLE & 0.950000 & 0.850000 & 0.860000 & 0.800000 & 0.890000 & 0.870000 \\
FT-PLE vs FT-Sparse-PLE & 1.000000 & 0.890000 & 0.900000 & 0.770000 & 0.880000 & 0.890000 \\
FT-Sparse-Linear vs FT-Sparse-PLE & 0.950000 & 0.910000 & 0.830000 & 0.790000 & 0.900000 & 0.880000 \\
\bottomrule
\end{tabular}



In [None]:
import shutil

# Assuming 'heatmaps' is the directory containing the heatmaps
source_dir = '/content/heatmaps'
destination_dir = '/content/drive/MyDrive/heatmaps'  # Replace with your desired destination

# Create the destination directory if it doesn't exist
!mkdir -p "$destination_dir"

# Copy the 'heatmaps' directory and its contents to Google Drive
shutil.copytree(source_dir, destination_dir, dirs_exist_ok=True)


print(f"Heatmaps folder copied to: {destination_dir}")


Heatmaps folder copied to: /content/drive/MyDrive/heatmaps
