# Plots

## Best Synthetic Cardinality for each Metric

In [None]:
import matplotlib.pyplot as plt
import numpy as np


# Data from tables
datasets = [
    'CIFAR10', 'CIFAR100', 'Oxford-IIIT-Pet', 'TinyImageNet', 
    'StanfordCars', 'Food101', 'STL10', 'Imagewoof', 'Imagenette', 'Caltech101'
]

# Cardinality values
cardinalities = [0.1, 0.2, 1, 5, 10, 20]
cardinality_labels = ['0.1×', '0.2×', '1×', '5×', '10×', '20×']

# AUC_MIA values (Table 5) - lower is better for privacy
auc_mia_values = [
    [52.03, 51.53, 52.53, 52.33, 53.03, 52.81],  # CIFAR10
    [51.00, 52.54, 54.26, 59.24, 61.15, 62.80],  # CIFAR100
    [50.97, 50.59, 56.13, 61.33, 63.09, 64.60],  # Oxford-IIIT-Pet
    [51.02, 51.45, 52.66, 61.90, 63.70, 64.55],  # TinyImageNet
    [55.00, 55.42, 65.61, 75.34, 78.03, 79.44],  # StanfordCars
    [51.06, 51.21, 51.61, 55.62, 57.08, 58.90],  # Food101
    [52.67, 52.57, 52.22, 55.99, 56.84, 58.38],  # STL10
    [51.58, 52.76, 53.94, 55.86, 55.51, 55.82],  # Imagewoof
    [54.55, 54.26, 54.68, 55.20, 55.01, 55.32],  # Imagenette
    [57.84, 60.53, 56.44, 60.29, 60.60, 58.37],  # Caltech101
]

# AOP values (Table 6) - higher is better for combined accuracy/privacy
aop_values = [
    [83.37, 87.06, 86.48, 88.45, 86.41, 87.25],  # CIFAR10
    [68.41, 69.34, 70.15, 60.50, 57.02, 54.35],  # CIFAR100
    [76.96, 82.44, 73.27, 62.39, 59.30, 56.72],  # Oxford-IIIT-Pet
    [64.01, 65.93, 67.09, 49.54, 46.82, 45.73],  # TinyImageNet
    [24.82, 44.12, 47.97, 38.51, 36.09, 34.86],  # StanfordCars
    [68.49, 73.41, 79.20, 69.92, 66.54, 62.54],  # Food101
    [81.64, 83.81, 86.87, 76.69, 74.38, 70.68],  # STL10
    [82.92, 81.35, 78.11, 73.88, 74.88, 74.23],  # Imagewoof
    [79.86, 81.49, 80.91, 80.00, 80.70, 79.92],  # Imagenette
    [41.28, 47.25, 71.93, 63.44, 63.05, 67.95],  # Caltech101
]

# CAS values from previous table (for comparison)
cas_values = [
    [90.28, 92.47, 95.45, 96.89, 97.20, 97.33],  # CIFAR10
    [71.17, 76.56, 82.61, 84.93, 85.28, 85.74],  # CIFAR100
    [79.98, 84.40, 92.34, 93.87, 94.41, 94.68],  # Oxford-IIIT-Pet
    [66.65, 69.81, 74.42, 75.92, 76.00, 76.22],  # TinyImageNet
    [30.03, 54.20, 82.60, 87.44, 87.90, 88.00],  # StanfordCars
    [71.42, 77.01, 84.38, 86.52, 86.72, 86.93],  # Food101
    [90.59, 92.65, 94.76, 96.16, 96.12, 96.36],  # STL10
    [88.24, 90.58, 90.91, 92.21, 92.29, 92.52],  # Imagewoof
    [95.06, 95.97, 96.76, 97.50, 97.68, 97.83],  # Imagenette
    [55.24, 69.25, 91.65, 92.24, 92.62, 92.60],  # Caltech101
]


plt.figure(figsize=(14, 8))

# Create a dataframe-like structure for easier manipulation
dataset_optimal = []
for i, dataset in enumerate(datasets):
    # Find best cardinality for each metric
    best_auc_idx = np.argmin(auc_mia_values[i])  # Lower is better for AUC_MIA
    best_aop_idx = np.argmax(aop_values[i])      # Higher is better for AOP
    best_cas_idx = np.argmax(cas_values[i])      # Higher is better for CAS
    
    dataset_optimal.append({
        'dataset': dataset,
        'best_auc_card': cardinality_labels[best_auc_idx],
        'best_aop_card': cardinality_labels[best_aop_idx],
        'best_cas_card': cardinality_labels[best_cas_idx],
        'best_auc_val': auc_mia_values[i][best_auc_idx],
        'best_aop_val': aop_values[i][best_aop_idx],
        'best_cas_val': cas_values[i][best_cas_idx]
    })

# Use the default dataset ordering
default_order = datasets

# Create arrays for the plot
x = np.arange(len(default_order))
width = 0.25

fig, ax = plt.subplots(figsize=(18, 8))

# Create arrays for the plot in default order
best_auc_indices = [cardinality_labels.index(dataset_optimal[i]['best_auc_card']) for i in range(len(datasets))]
best_aop_indices = [cardinality_labels.index(dataset_optimal[i]['best_aop_card']) for i in range(len(datasets))]
best_cas_indices = [cardinality_labels.index(dataset_optimal[i]['best_cas_card']) for i in range(len(datasets))]

# Introduce a vertical offset (e.g. 0.2) to all bar heights
offset = 0.2
best_auc_indices_offset = [val + offset for val in best_auc_indices]
best_aop_indices_offset = [val + offset for val in best_aop_indices]
best_cas_indices_offset = [val + offset for val in best_cas_indices]

# Create bars for optimal cardinality by metric
bar1 = ax.bar(x - width, best_auc_indices_offset, width, label='Best AUC$_{MIA}$', color='green', alpha=0.7)
bar2 = ax.bar(x, best_aop_indices_offset, width, label='Best AOP', color='blue', alpha=0.7)
bar3 = ax.bar(x + width, best_cas_indices_offset, width, label='Best CAS', color='red', alpha=0.7)

# Enable y-axis ticks and labels
ax.tick_params(axis='y', which='both', left=True, labelleft=True)
ax.set_yticks([i + offset for i in range(len(cardinality_labels))])
ax.set_yticklabels(cardinality_labels)
ax.set_ylabel('Synthetic Dataset Cardinality', fontsize=16)

# Add x-axis ticks with default dataset names
ax.set_xticks(x)
ax.set_xticklabels(default_order, fontsize=12, rotation=45)

# And add the legend for the three metrics:
ax.legend(fontsize=12, loc='best')

# Add grid
ax.grid(True, axis='y', linestyle='--', alpha=0.7)
#plt.title('Optimal Cardinality for Each Metric', fontweight='bold', fontsize=22, pad=20)
plt.tight_layout()
plt.savefig('../images/metrics_best_cardinality.png', bbox_inches='tight', dpi=600) # do not move after plt.show() or it will save a blank image
plt.show()

## Student-Teacher Gap

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# Extract data from the table
datasets = [
    'CIFAR10', 'CIFAR100', 'Oxford-IIIT-Pet', 'TinyImageNet', 'StanfordCars', 
    'Food101', 'STL10', 'Imagewoof', 'Imagenette', 'Caltech101'
]

# Accuracy (higher is better)
teacher_accuracy = [97.52, 85.49, 93.96, 75.67, 88.22, 86.79, 96.74, 93.05, 98.29, 92.26]
student_accuracy = [97.33, 85.74, 94.68, 76.22, 88.00, 86.93, 96.36, 92.52, 97.83, 92.62]

# AUC_MIA (lower is better)
teacher_auc_mia = [53.89, 70.32, 72.74, 70.29, 82.53, 65.63, 65.89, 58.67, 60.32, 67.81]
student_auc_mia = [52.81, 62.80, 64.60, 64.55, 79.44, 58.90, 58.38, 55.82, 55.32, 60.60]

# AOP (higher is better)
teacher_aop = [83.95, 43.22, 44.40, 38.29, 32.38, 50.37, 55.71, 67.58, 67.53, 50.16]
student_aop = [87.25, 54.35, 56.72, 45.73, 34.86, 62.64, 70.68, 74.23, 79.92, 63.05]

# Create a DataFrame for easier data manipulation
df = pd.DataFrame({
    'Dataset': datasets * 3,
    'Metric': ['Accuracy'] * len(datasets) + ['AUC_MIA'] * len(datasets) + ['AOP'] * len(datasets),
    'Teacher': teacher_accuracy + teacher_auc_mia + teacher_aop,
    'Student': student_accuracy + student_auc_mia + student_aop,
})

# Calculate the differences (Student - Teacher)
df['Difference'] = df['Student'] - df['Teacher']

# For AUC_MIA, we invert the difference because lower is better
df.loc[df['Metric'] == 'AUC_MIA', 'Difference'] = -df.loc[df['Metric'] == 'AUC_MIA', 'Difference']

# Desired order for plotting
df['Metric'] = pd.Categorical(df['Metric'], categories=['Accuracy', 'AUC_MIA', 'AOP'])

# Set the style for all plots
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('muted')
colors = sns.color_palette()

###########################################################################################################################################

plt.figure(figsize=(15, 12))

# Create a 3x1 grid for subplots (one for each metric)
fig, axes = plt.subplots(3, 1, figsize=(14, 18), sharex=True)
fig.subplots_adjust(hspace=0.3)

metrics = ['Accuracy', 'AUC_MIA', 'AOP']
titles = ['Classification Accuracy', 
          'AUC$_{MIA}$', 
          'AOP']
ylabels = ['Classification Accuracy', 'AUC$_{MIA}$', 'AOP']

# Use the default order (no sorting)
sorted_indices = {}
sorted_indices['Accuracy'] = np.arange(len(teacher_accuracy))
sorted_indices['AUC_MIA'] = np.arange(len(teacher_auc_mia))
sorted_indices['AOP'] = np.arange(len(teacher_aop))

for i, metric in enumerate(metrics):
    ax = axes[i]
    
    # Get the sorted indices for this metric
    idx = sorted_indices[metric]
    
    # Get the relevant data
    if metric == 'Accuracy':
        teacher_vals = [teacher_accuracy[j] for j in idx]
        student_vals = [student_accuracy[j] for j in idx]
        datasets_sorted = [datasets[j] for j in idx]
    elif metric == 'AUC_MIA':
        teacher_vals = [teacher_auc_mia[j] for j in idx]
        student_vals = [student_auc_mia[j] for j in idx]
        datasets_sorted = [datasets[j] for j in idx]
    else:  # AOP
        teacher_vals = [teacher_aop[j] for j in idx]
        student_vals = [student_aop[j] for j in idx]
        datasets_sorted = [datasets[j] for j in idx]
    
    # Set the positions and width for the bars
    pos = np.arange(len(datasets_sorted))
    width = 0.30
    
    # Create the bars
    teacher_bars = ax.bar(pos - width/2, teacher_vals, width, label='Teacher Classifier', 
                        color=colors[0], alpha=0.8, edgecolor='black', linewidth=0.5)
    student_bars = ax.bar(pos + width/2, student_vals, width, label='Student Classifier', 
                         color=colors[1], alpha=0.8, edgecolor='black', linewidth=0.5)
    
    # Calculate differences for annotations
    diffs = np.array(student_vals) - np.array(teacher_vals)
    
    # Add difference annotations
    for j, (p, diff) in enumerate(zip(pos, diffs)):
        if metric == 'AUC_MIA':
            # For AUC_MIA, negative diff is better (lower AUC_MIA is better)
            color = 'green' if diff < 0 else 'red'
            diff_text = f"{diff:-.2f}"  # Negate to show as improvement
        else:
            # For Accuracy and AOP, positive diff is better
            color = 'green' if diff > 0 else 'red'
            diff_text = f"{diff:+.2f}"
        
        ax.annotate(diff_text, 
                    xy=(p, max(teacher_vals[j], student_vals[j]) + 0.5), 
                    ha='center', va='bottom', 
                    color=color, fontweight='bold', fontsize=13)
    
    # Add some text for labels, title and custom x-axis tick labels
    ax.set_ylabel(ylabels[i], fontsize=16)
    ax.set_title(titles[i], fontsize=20, fontweight='bold')
    ax.set_xticks(pos)
    ax.set_xticklabels(datasets_sorted, rotation=45, ha='right', fontsize=13)
    
    # Add grid for easier reading
    ax.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Add a horizontal line for reference
    if metric == 'AUC_MIA':
        ax.axhline(y=50, color='gray', linestyle='--', alpha=0.5)  # 50% is random guess
    
    # Add legend
    ax.legend(fontsize=10)

plt.tight_layout()
plt.savefig('student_vs_teacher.png', dpi=600, bbox_inches='tight')
plt.show()