In [2]:
from tabulate import tabulate

In [3]:
import wandb
import pandas as pd
import numpy as np

api = wandb.Api()

# Specify your project
project = "oualidzari/Link Inference In FL"

# Define your filters
filters = {
    "config.dataset": "Cora",
    "config.experiment_comment": "Label defense"
}

# Fetch filtered runs
runs = api.runs(project, filters=filters)



In [11]:
data = []
for run in runs:
    name = run.name
    budget = run.config.get("label_defense_budget")

    # Fetch history
    history = run.history()
    accuracy_gradient = (
        history["Accuracy-gradients"].iloc[0]
        if "Accuracy-gradients" in history.columns
        else None
    )
    accuracy_labels = (
        history["Accuracy-labels"].iloc[0]
        if "Accuracy-labels" in history.columns
        else None
    )
    test_accuracy = (
        history["accuracy_test"].iloc[299]
        if "accuracy_test" in history.columns
        else None
    )

    data.append(
        {
            "name": name,
            "budget": budget,
            "acc_grad": accuracy_gradient,
            "acc_lab": accuracy_labels,
            "test_acc": test_accuracy,
        }
    )

In [12]:
df = pd.DataFrame(data)

In [13]:
runs[0].history()

Unnamed: 0,accuracy_train,accuracy_test,AUC-output_server,loss,Accuracy-labels,_step,_runtime,Accuracy-features,AUC-features,epoch,AUC-forward_values,Accuracy-output_server,AUC-gradients,Accuracy-gradients,_timestamp,Accuracy-forward_values
0,0.132939,0.123846,0.669087,1.957257,0.001538,0,7.699032,0.710482,0.725107,0,0.557274,0.626069,0.570866,0.603226,1.722929e+09,0.568500
1,0.320532,0.279231,0.533283,1.818003,,1,8.434042,,,1,0.566263,0.543901,0.515573,0.498032,1.722929e+09,0.562628
2,0.320532,0.279231,0.542492,1.574450,,2,9.180580,,,2,0.572809,0.517524,0.526647,0.555532,1.722929e+09,0.575880
3,0.320532,0.279231,0.544418,1.112610,,3,9.885854,,,3,0.573756,0.512208,0.484646,0.519537,1.722929e+09,0.584864
4,0.320532,0.279231,0.540090,0.523473,,4,10.710140,,,4,0.575533,0.504233,0.473640,0.581472,1.722929e+09,0.551202
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
295,0.320532,0.279231,0.558502,0.000299,,295,193.622947,,,295,0.561659,0.675260,0.473897,0.482551,1.722929e+09,0.502703
296,0.320532,0.279231,0.561454,0.000299,,296,194.309980,,,296,0.561575,0.585928,0.473971,0.482553,1.722929e+09,0.504871
297,0.320532,0.279231,0.555526,0.000299,,297,195.026480,,,297,0.561386,0.679267,0.473698,0.482341,1.722929e+09,0.501860
298,0.320532,0.279231,0.562741,0.000299,,298,195.770182,,,298,0.561212,0.582897,0.473749,0.482118,1.722929e+09,0.500827


In [16]:
# Group by budget and calculate statistics
grouped_stats = df.groupby('budget').agg({
    'acc_grad': ['mean', 'std'],
    'acc_lab': ['mean', 'std'],
    'test_acc': ['mean', 'std']
})

# Function to format mean and std as percentages
def format_mean_std(row):
    return f"{row['mean']*100:.2f} ± {row['std']*100:.2f}"

# Apply formatting
formatted_stats = pd.DataFrame()
for metric in ['acc_grad', 'acc_lab', 'test_acc']:
    formatted_stats[metric] = grouped_stats[metric].apply(format_mean_std, axis=1)

# Reset index to make budget a column and convert to percentage
formatted_stats = formatted_stats.reset_index()
formatted_stats['budget'] = (formatted_stats['budget'] * 100).astype(int)

# Create a formatted table without row dividers
table = tabulate(formatted_stats, headers=['budget(%)', 'acc_grad', 'acc_lab', 'test_acc'], 
                 tablefmt='simple', showindex=False)

print(table)

  budget(%)  acc_grad      acc_lab       test_acc
-----------  ------------  ------------  ------------
          5  78.36 ± 0.54  78.74 ± 0.56  74.00 ± 2.42
         10  75.64 ± 1.49  75.32 ± 0.80  72.03 ± 1.00
         20  67.85 ± 1.99  67.49 ± 1.03  68.23 ± 1.59
         30  63.50 ± 3.36  57.33 ± 0.83  58.28 ± 0.75
         40  60.65 ± 5.94  45.44 ± 0.89  53.03 ± 1.27
         50  61.76 ± 4.92  32.51 ± 1.13  45.13 ± 0.51
         60  58.97 ± 1.69  17.65 ± 0.89  34.08 ± 1.80
         70  59.00 ± 3.00  0.14 ± 0.01   28.74 ± 1.36
         80  59.00 ± 3.00  0.14 ± 0.01   28.74 ± 1.36
         90  59.00 ± 3.00  0.14 ± 0.01   28.74 ± 1.36
