# Plots

## Heatmap of CAS for 3 different prompt/label configurations

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

datasets = ['CIFAR10', 'CIFAR100', 'Oxford-IIIT-Pet', 'TinyImageNet']

# Data structure: Each configuration has 4 datasets x 4 sample ratios
config_names = ['"n: d" - Hard Labels', '"n: c" - Hard Labels', '"n: c" - Soft Labels']
sample_ratios = ['1/50', '1/20', '1/10', '1']

# Create the data arrays for each configuration
config1_data = np.array([
    [71.19, 75.27, 80.43, 80.52],  # CIFAR10
    [32.91, 37.48, 45.43, 47.37],  # CIFAR100
    [2.98, 2.52, 3.79, 2.43],      # Oxford-IIIT-Pet
    [17.64, 20.15, 20.79, 32.04]   # TinyImageNet
])

config2_data = np.array([
    [78.92, 82.76, 82.52, 81.52],  # CIFAR10
    [39.14, 47.19, 47.13, 50.09],  # CIFAR100
    [3.85, 2.07, 2.01, 2.22],      # Oxford-IIIT-Pet
    [30.37, 30.42, 31.34, 35.57]   # TinyImageNet
])

config3_data = np.array([
    [86.97, 87.79, 88.09, 90.28],  # CIFAR10
    [63.29, 66.13, 66.74, 71.17],  # CIFAR100
    [30.72, 65.86, 69.22, 79.98],  # Oxford-IIIT-Pet
    [63.36, 62.97, 62.78, 66.65]   # TinyImageNet
])

# Create figure with subplots
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Custom colormap to highlight higher values
cmap = sns.color_palette("YlGnBu", as_cmap=True)

# Create a list of data arrays and axes for easier iteration
all_data = [config1_data, config2_data, config3_data]

# Plot heatmaps with manual annotations
for i, (data, name) in enumerate(zip(all_data, config_names)):
    # Create heatmap without annotations first
    hm = sns.heatmap(data, 
                 annot=False,  # No automatic annotations
                 cmap=cmap,
                 xticklabels=sample_ratios, 
                 yticklabels=datasets,
                 ax=axes[i], 
                 vmin=0, 
                 vmax=100)
    
    # Add manual annotations for each cell
    for y in range(data.shape[0]):
        for x in range(data.shape[1]):
            # Position text in center of each cell
            axes[i].text(x + 0.5, y + 0.5, f'{data[y, x]:.2f}', 
                     horizontalalignment='center',
                     verticalalignment='center',
                     fontsize=15)
    
    # Set titles and labels
    axes[i].set_title(name, fontsize=20, fontweight='bold')
    axes[i].set_xlabel('Fine-Tuning Samples Ratio', fontsize=16)
    if i == 0:
        axes[i].set_ylabel('Dataset', fontsize=16)
    axes[i].tick_params(axis='both', labelsize=11)

plt.tight_layout()
plt.show()

In [None]:
fig.savefig('../images/impact_blip2_gkd_heatmap.png', bbox_inches='tight', dpi=600)

### Bar chart of top CAS for each of those 3 configurations

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Data from Table 1 - using only the best performance for each configuration (ratio=1)
datasets = ['CIFAR10', 'CIFAR100', 'Oxford-IIIT-Pet', 'TinyImageNet']
config_names = ['"n: d" - Hard Labels', '"n: c" - Hard Labels', '"n: c" - Soft Labels']

# Best scores for each configuration (using ratio=1)
config1_best = [80.52, 47.37, 3.79, 32.04]  # "n: d" - Hard Labels
config2_best = [82.76, 50.09, 3.85, 35.57]  # "n: c" - Hard Labels  
config3_best = [90.28, 71.17, 79.98, 66.65]  # "n: c" - Soft Labels

# Set width of bars
barWidth = 0.25
 
# Set positions of the bars on X axis
r1 = np.arange(len(datasets))
r2 = [x + barWidth for x in r1]
r3 = [x + barWidth for x in r2]
 
# Create the figure with larger size for better visibility
plt.figure(figsize=(12, 6))
 
# Make the plot
bars1 = plt.bar(r1, config1_best, width=barWidth, edgecolor='grey', label=config_names[0], color='skyblue')
bars2 = plt.bar(r2, config2_best, width=barWidth, edgecolor='grey', label=config_names[1], color='lightgreen')
bars3 = plt.bar(r3, config3_best, width=barWidth, edgecolor='grey', label=config_names[2], color='salmon')
 
# Add values on top of bars
def add_labels(bars):
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.2f}', ha='center', va='bottom', fontsize=12)

add_labels(bars1)
add_labels(bars2)
add_labels(bars3)
 
# Add labels and title
plt.xlabel('Dataset', fontsize=17)
plt.ylabel('CAS', fontsize=17)
#plt.title('Comparison of Best CAS Across Different Configurations', fontweight='bold', fontsize=23, pad=20)
 
# Add xticks on the middle of the group bars
plt.xticks([r + barWidth for r in range(len(datasets))], datasets, fontsize=14)
 
# Create legend
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=3, fontsize=12)
 
# Adjust layout and save
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig('../images/impact_blip2_gkd_chart.png', bbox_inches='tight', dpi=600) # do not move after plt.show() or it will save a blank image
plt.show()

## CAS vs Synthetic Cardinality

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.gridspec as gridspec


datasets = [
    'CIFAR10', 'CIFAR100', 'Oxford-IIIT-Pet', 'TinyImageNet', 
    'StanfordCars', 'Food101', 'STL10', 'Imagewoof', 
    'Imagenette', 'Caltech101'
]

# Cardinality values
cardinalities = [0.1, 0.2, 1, 5, 10, 20]
cardinality_labels = ['0.1×', '0.2×', '1×', '5×', '10×', '20×']

# CAS values for each dataset at different cardinalities
cas_values = [
    [90.28, 92.47, 95.45, 96.89, 97.20, 97.33],  # CIFAR10
    [71.17, 76.56, 82.61, 84.93, 85.28, 85.74],  # CIFAR100
    [79.98, 84.40, 92.34, 93.87, 94.41, 94.68],  # Oxford-IIIT-Pet
    [66.65, 69.81, 74.42, 75.92, 76.00, 76.22],  # TinyImageNet
    [30.03, 54.20, 82.60, 87.44, 87.90, 88.00],  # StanfordCars
    [71.42, 77.01, 84.38, 86.52, 86.72, 86.93],  # Food101
    [90.59, 92.65, 94.76, 96.16, 96.12, 96.36],  # STL10
    [88.24, 90.58, 90.91, 92.21, 92.29, 92.52],  # Imagewoof
    [95.06, 95.97, 96.76, 97.50, 97.68, 97.83],  # Imagenette
    [55.24, 69.25, 91.65, 92.24, 92.62, 92.60],  # Caltech101
]

# Create a figure using GridSpec for the small multiples (2 rows x 5 columns)
fig = plt.figure(figsize=(18, 8))
gs_small = gridspec.GridSpec(2, 5, figure=fig, wspace=0.4, hspace=0.5)

# Use numeric x positions for consistency
x = np.arange(len(cardinality_labels))
colors = plt.cm.tab20(np.linspace(0, 1, len(datasets)))

for i, (dataset, values) in enumerate(zip(datasets, cas_values)):
    # Calculate row and column indices for the 2x5 grid.
    ax = fig.add_subplot(gs_small[i // 5, i % 5])
    
    # Calculate absolute improvement
    abs_improvement = values[-1] - values[0]
    
    # Plot the accuracy line using numeric x positions
    ax.plot(x, values, marker='o', linewidth=2, color=colors[i])
    
    # Annotate each vertex with its CAS value
    y_min, y_max = min(values), max(values)
    vertical_fixed_offset = 15  # fixed offset in points
    for xi, val in zip(x, values):
        if val > y_min + 0.75 * (y_max - y_min):
            vertical_offset = -vertical_fixed_offset
        else:
            vertical_offset = vertical_fixed_offset
        horizontal_fixed_offset = 6
        if xi == x[0]:
            horizontal_offset = horizontal_fixed_offset
        elif xi == x[-1]:
            horizontal_offset = -horizontal_fixed_offset
        else:
            horizontal_offset = 0

        ax.annotate(f"{val:.2f}", xy=(xi, val), 
                    xytext=(horizontal_offset, vertical_offset),
                    textcoords='offset points', ha='center', va='center', fontsize=7)
    
    # Determine best point index: for Caltech101 use second-to-last, otherwise the last
    best_idx = len(values) - 2 if dataset == 'Caltech101' else len(values) - 1
    ax.scatter(x[best_idx], values[best_idx], marker='*', color=colors[i],
               s=200, zorder=5)
    
    ax.set_title(r"$\mathbf{" + dataset + "}$" + f"\n(+{abs_improvement:.2f})", fontsize=12)
    
    min_val = max(0, min(values) - 5)
    max_val = min(100, max(values) + 5)
    ax.set_ylim(min_val, max_val + 3)
    
    ax.grid(True, linestyle='--', alpha=0.5)
    ax.set_xticks(x)
    ax.set_xticklabels(cardinality_labels, fontsize=10)
    
    # Only add x-axis label for bottom row subplots
    if i >= 5:
        ax.set_xlabel('Cardinality', fontsize=10)
    # Only add y-axis label for the leftmost column subplots
    if i % 5 == 0:
        ax.set_ylabel('CAS', fontsize=10)

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

# save the figure in high resolution
fig.savefig("../images/cas_vs_cardinality.png", bbox_inches='tight', dpi=600)

## CAS Improvement Rate for Increasing Synthetic Cardinality

In [None]:
import matplotlib.pyplot as plt
import numpy as np

datasets = [
    'CIFAR10', 'CIFAR100', 'Oxford-IIIT-Pet', 'TinyImageNet', 
    'StanfordCars', 'Food101', 'STL10', 'Imagewoof', 'Imagenette', 'Caltech101'
]

# Cardinality values
cardinalities = [0.1, 0.2, 1, 5, 10, 20]
cardinality_labels = ['0.1×', '0.2×', '1×', '5×', '10×', '20×']

# CAS values for each dataset at different cardinalities
cas_values = [
    [90.28, 92.47, 95.45, 96.89, 97.20, 97.33],  # CIFAR10
    [71.17, 76.56, 82.61, 84.93, 85.28, 85.74],  # CIFAR100
    [79.98, 84.40, 92.34, 93.87, 94.41, 94.68],  # Oxford-IIIT-Pet
    [66.65, 69.81, 74.42, 75.92, 76.00, 76.22],  # TinyImageNet
    [30.03, 54.20, 82.60, 87.44, 87.90, 88.00],  # StanfordCars
    [71.42, 77.01, 84.38, 86.52, 86.72, 86.93],  # Food101
    [90.59, 92.65, 94.76, 96.16, 96.12, 96.36],  # STL10
    [88.24, 90.58, 90.91, 92.21, 92.29, 92.52],  # Imagewoof
    [95.06, 95.97, 96.76, 97.50, 97.68, 97.83],  # Imagenette
    [55.24, 69.25, 91.65, 92.24, 92.62, 92.60],  # Caltech101
]


plt.figure(figsize=(12, 6))

# Calculate improvements between each cardinality level
improvements = []
for values in cas_values:
    dataset_improvements = []
    for i in range(1, len(values)):
        dataset_improvements.append(values[i] - values[i-1])
    improvements.append(dataset_improvements)

# Calculate average improvement at each cardinality transition
avg_improvements = np.mean(improvements, axis=0)
std_improvements = np.std(improvements, axis=0)

# Define transition labels
transition_labels = ["0.1× → 0.2×", "0.2× → 1×", "1× → 5×", "5× → 10×", "10× → 20×"]

# Plot average improvement between cardinality levels
plt.bar(transition_labels, avg_improvements, yerr=std_improvements, 
        alpha=0.7, capsize=5, color='royalblue')
plt.axhline(y=0, color='k', linestyle='-', alpha=0.3)
plt.grid(True, axis='y', linestyle='--', alpha=0.5)
plt.ylabel('Average CAS Improvement', fontsize=14)
plt.xlabel('Synthetic Dataset Cardinality Transition', fontsize=14)
#plt.title('Diminishing CAS Improvement (%) with Increasing Synthetic Dataset Cardinality', fontweight='bold', fontsize=17, pad=20)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)

plt.tight_layout()
plt.savefig('../images/cas_improvement_rate.png', bbox_inches='tight', dpi=600) # do not move after plt.show() or it will save a blank image
plt.show()

## AUC & AOP vs Synthetic Cardinality

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.gridspec as gridspec


datasets = [
    'CIFAR10', 'CIFAR100', 'Oxford-IIIT-Pet', 'TinyImageNet', 
    'StanfordCars', 'Food101', 'STL10', 'Imagewoof', 
    'Imagenette', 'Caltech101'
]

cardinalities = [0.1, 0.2, 1, 5, 10, 20]
cardinality_labels = ['0.1×', '0.2×', '1×', '5×', '10×', '20×']

auc_mia_values = [
    [52.03, 51.53, 52.53, 52.33, 53.03, 52.81],  # CIFAR10
    [51.00, 52.54, 54.26, 59.24, 61.15, 62.80],  # CIFAR100
    [50.97, 50.59, 56.13, 61.33, 63.09, 64.60],  # Oxford-IIIT-Pet
    [51.02, 51.45, 52.66, 61.90, 63.70, 64.55],  # TinyImageNet
    [55.00, 55.42, 65.61, 75.34, 78.03, 79.44],  # StanfordCars
    [51.06, 51.21, 51.61, 55.62, 57.08, 58.90],  # Food101
    [52.67, 52.57, 52.22, 55.99, 56.84, 58.38],  # STL10 (formerly index 8)
    [51.58, 52.76, 53.94, 55.86, 55.51, 55.82],  # Imagewoof (formerly index 9)
    [54.55, 54.26, 54.68, 55.20, 55.01, 55.32],  # Imagenette (formerly index 10)
    [57.84, 60.53, 56.44, 60.29, 60.60, 58.37],  # Caltech101 (formerly index 11)
]

aop_values = [
    [83.37, 87.06, 86.48, 88.45, 86.41, 87.25],  # CIFAR10
    [68.41, 69.34, 70.15, 60.50, 57.02, 54.35],  # CIFAR100
    [76.96, 82.44, 73.27, 62.39, 59.30, 56.72],  # Oxford-IIIT-Pet
    [64.01, 65.93, 67.09, 49.54, 46.82, 45.73],  # TinyImageNet
    [24.82, 44.12, 47.97, 38.51, 36.09, 34.86],  # StanfordCars
    [68.49, 73.41, 79.20, 69.92, 66.54, 62.54],  # Food101
    [81.64, 83.81, 86.87, 76.69, 74.38, 70.68],  # STL10 (formerly index 8)
    [82.92, 81.35, 78.11, 73.88, 74.88, 74.23],  # Imagewoof (formerly index 9)
    [79.86, 81.49, 80.91, 80.00, 80.70, 79.92],  # Imagenette (formerly index 10)
    [41.28, 47.25, 71.93, 63.44, 63.05, 67.95],  # Caltech101 (formerly index 11)
]

# Create a master figure for small multiples in a 2x5 grid
fig = plt.figure(figsize=(18, 8))
gs_small = gridspec.GridSpec(2, 5, figure=fig, wspace=0.3, hspace=0.5)
x = np.arange(len(cardinality_labels))  # Numeric positions for the x-axis
colors = plt.cm.tab20(np.linspace(0, 1, len(datasets)))

for i, dataset in enumerate(datasets):
    outer_spec = gs_small[i // 5, i % 5]
    gs_nested = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=outer_spec, hspace=0.1)
    
    # Plot AUC_MIA in the top subplot
    ax_auc = fig.add_subplot(gs_nested[0])
    auc_curve = auc_mia_values[i]
    best_idx_auc = np.argmin(auc_curve)
    ax_auc.plot(x, auc_curve, marker='o', linestyle='-', linewidth=2, color=colors[i], label=r'$AUC_{MIA}$')
    
    vertical_fixed_offset = 15  # Offset in points for annotations
    horizontal_fixed_offset = 5 # Horizontal offset for edge adjustments
    combined_min_auc = min(auc_curve)
    lower_bound_auc = combined_min_auc - (vertical_fixed_offset - 3)
    for idx, xi in enumerate(x):
        if xi == x[0]:
            horizontal_offset = horizontal_fixed_offset
        elif xi == x[-1]:
            horizontal_offset = -horizontal_fixed_offset
        else:
            horizontal_offset = 0

        offset = -vertical_fixed_offset
        if (auc_curve[idx] + offset) < lower_bound_auc:
            offset = vertical_fixed_offset

        ax_auc.annotate(f"{auc_curve[idx]:.2f}", xy=(xi, auc_curve[idx]),
                        xytext=(horizontal_offset, offset),
                        textcoords='offset points', ha='center', va='center',
                        fontsize=7, color='black')
    
    ax_auc.scatter(x[best_idx_auc], auc_curve[best_idx_auc], marker='*', color=colors[i],
                   s=200, zorder=5)
    
    ax_auc.set_title(r"$\mathbf{" + dataset + "}$", fontsize=12)
    
    combined_min = min(auc_curve)
    combined_max = max(auc_curve)
    min_val = max(0, combined_min - 5)
    max_val = min(100, combined_max + 5)
    ax_auc.set_ylim(min_val, max_val + 3)
    
    ax_auc.grid(True, linestyle='--', alpha=0.5)
    ax_auc.set_xticks(x)
    ax_auc.set_xticklabels([])  # Remove x-tick labels for the top plot
    ax_auc.legend(loc='upper left', fontsize=10)
    
    # Plot AOP in the bottom subplot
    ax_aop = fig.add_subplot(gs_nested[1])
    aop_curve = aop_values[i]
    best_idx_aop = np.argmax(aop_curve)
    ax_aop.plot(x, aop_curve, marker='s', linestyle='--', linewidth=2, color=colors[i], label='AOP')
    
    combined_min_aop = min(aop_curve)
    lower_bound_aop = combined_min_aop - (vertical_fixed_offset - 3)
    for idx, xi in enumerate(x):
        if xi == x[0]:
            horizontal_offset = horizontal_fixed_offset
        elif xi == x[-1]:
            horizontal_offset = -horizontal_fixed_offset
        else:
            horizontal_offset = 0

        offset = -vertical_fixed_offset
        if (aop_curve[idx] + offset) < lower_bound_aop:
            offset = vertical_fixed_offset

        ax_aop.annotate(f"{aop_curve[idx]:.2f}", xy=(xi, aop_curve[idx]),
                        xytext=(horizontal_offset, offset),
                        textcoords='offset points', ha='center', va='center',
                        fontsize=7, color='black')
    
    ax_aop.scatter(x[best_idx_aop], aop_curve[best_idx_aop], marker='*', color=colors[i],
                   s=200, zorder=5)
    
    combined_min = min(aop_curve)
    combined_max = max(aop_curve)
    min_val = max(0, combined_min - 5)
    max_val = min(100, combined_max + 5)
    ax_aop.set_ylim(min_val, max_val + 3)
    
    ax_aop.grid(True, linestyle='--', alpha=0.5)
    ax_aop.set_xticks(x)
    ax_aop.set_xticklabels(cardinality_labels, fontsize=12)
    ax_aop.legend(loc='upper left', fontsize=10)
    
    # Add x-label only on the bottom row of cells (i >= 5 in a 2x5 grid)
    if i >= 5:
        ax_aop.set_xlabel('Cardinality', fontsize=10)

plt.tight_layout(rect=[0, 0.07, 1, 0.95])
plt.show()

# save the figure in high resolution
fig.savefig("../images/privacy_vs_cardinality.png", bbox_inches='tight', dpi=600)

## Radar charts for metrics balance

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D


datasets = [
    'CIFAR10', 'CIFAR100', 'Oxford-IIIT-Pet', 'TinyImageNet', 
    'StanfordCars', 'Food101', 'STL10', 'Imagewoof', 
    'Imagenette', 'Caltech101'
]

# Cardinality values
cardinalities = [0.1, 0.2, 1, 5, 10, 20]
cardinality_labels = ['0.1×', '0.2×', '1×', '5×', '10×', '20×']

# Updated AUC_MIA values - lower is better for privacy
auc_mia_values = [
    [52.03, 51.53, 52.53, 52.33, 53.03, 52.81],  # CIFAR10
    [51.00, 52.54, 54.26, 59.24, 61.15, 62.80],  # CIFAR100
    [50.97, 50.59, 56.13, 61.33, 63.09, 64.60],  # Oxford-IIIT-Pet
    [51.02, 51.45, 52.66, 61.90, 63.70, 64.55],  # TinyImageNet
    [55.00, 55.42, 65.61, 75.34, 78.03, 79.44],  # StanfordCars
    [51.06, 51.21, 51.61, 55.62, 57.08, 58.90],  # Food101
    [52.67, 52.57, 52.22, 55.99, 56.84, 58.38],  # STL10
    [51.58, 52.76, 53.94, 55.86, 55.51, 55.82],  # Imagewoof
    [54.55, 54.26, 54.68, 55.20, 55.01, 55.32],  # Imagenette
    [57.84, 60.53, 56.44, 60.29, 60.60, 58.37],  # Caltech101
]

# Updated AOP values - higher is better for combined accuracy/privacy
aop_values = [
    [83.37, 87.06, 86.48, 88.45, 86.41, 87.25],  # CIFAR10
    [68.41, 69.34, 70.15, 60.50, 57.02, 54.35],  # CIFAR100
    [76.96, 82.44, 73.27, 62.39, 59.30, 56.72],  # Oxford-IIIT-Pet
    [64.01, 65.93, 67.09, 49.54, 46.82, 45.73],  # TinyImageNet
    [24.82, 44.12, 47.97, 38.51, 36.09, 34.86],  # StanfordCars
    [68.49, 73.41, 79.20, 69.92, 66.54, 62.54],  # Food101
    [81.64, 83.81, 86.87, 76.69, 74.38, 70.68],  # STL10
    [82.92, 81.35, 78.11, 73.88, 74.88, 74.23],  # Imagewoof
    [79.86, 81.49, 80.91, 80.00, 80.70, 79.92],  # Imagenette
    [41.28, 47.25, 71.93, 63.44, 63.05, 67.95],  # Caltech101
]

# Updated CAS values for comparison
cas_values = [
    [90.28, 92.47, 95.45, 96.89, 97.20, 97.33],  # CIFAR10
    [71.17, 76.56, 82.61, 84.93, 85.28, 85.74],  # CIFAR100
    [79.98, 84.40, 92.34, 93.87, 94.41, 94.68],  # Oxford-IIIT-Pet
    [66.65, 69.81, 74.42, 75.92, 76.00, 76.22],  # TinyImageNet
    [30.03, 54.20, 82.60, 87.44, 87.90, 88.00],  # StanfordCars
    [71.42, 77.01, 84.38, 86.52, 86.72, 86.93],  # Food101
    [90.59, 92.65, 94.76, 96.16, 96.12, 96.36],  # STL10
    [88.24, 90.58, 90.91, 92.21, 92.29, 92.52],  # Imagewoof
    [95.06, 95.97, 96.76, 97.50, 97.68, 97.83],  # Imagenette
    [55.24, 69.25, 91.65, 92.24, 92.62, 92.60],  # Caltech101
]

# We'll plot using radar charts in a 2x5 grid
fig = plt.figure(figsize=(20, 10))

# Create indices for 10 datasets
dataset_indices = list(range(len(datasets)))

# Create a grid of radar charts arranged in 2 rows and 5 columns
for i, idx in enumerate(dataset_indices):
    ax = fig.add_subplot(2, 5, i+1, projection='polar')
    
    # Prepare data for radar chart for current dataset
    # Use the corresponding values from CAS, AUC_MIA, and AOP arrays.
    accuracy_scores = cas_values[idx]
    privacy_scores = auc_mia_values[idx]
    aop_scores = aop_values[idx]
    
    # Set angles for the three metrics. (Order: CAS, AUC, AOP)
    angles = np.linspace(0, 2*np.pi, 3, endpoint=False).tolist()
    angles += angles[:1]  # Close the loop
    
    # Plot a radar chart line for each cardinality value
    for j, cardinality in enumerate(cardinality_labels):
        # For each cardinality pick the metrics at the same index
        values = [accuracy_scores[j], privacy_scores[j], aop_scores[j]]
        values += values[:1]  # Close the loop
        color = plt.cm.viridis(j/len(cardinality_labels))
        ax.plot(angles, values, linewidth=2, label=f"{cardinality}", color=color)
        ax.fill(angles, values, alpha=0.1, color=color)
    
    # Set labels for the three axes
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(['CAS', 'AUC', 'AOP'])
    
    # Set chart title with dataset name
    ax.set_title(datasets[i], fontweight='bold', size=14, pad=20)
    
    # Set limits and grid
    ax.set_ylim(0, 100)
    ax.grid(True)
    ax.set_rgrids(np.linspace(20, 100, 5), angle=45, fontsize=8)

# Add a common legend
handles = [Line2D([0], [0], color=plt.cm.viridis(j/len(cardinality_labels)), 
                  linewidth=2, label=cardinality_labels[j]) 
           for j in range(len(cardinality_labels))]
fig.legend(handles=handles, loc='upper center', bbox_to_anchor=(0.5, 0.07), 
           ncol=6, fontsize=12)

plt.tight_layout(rect=[0, 0.1, 1, 0.95], h_pad=3)
plt.savefig('../images/metrics_balance_radar.png', bbox_inches='tight', dpi=600)  # do not move after plt.show() or it will save a blank image
plt.show()

## Best Synthetic Cardinality for each Metric

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Add this line to extract colors from the 'tab10' palette
palette = plt.get_cmap('tab10').colors

# Data from tables
datasets = [
    'CIFAR10', 'CIFAR100', 'Oxford-IIIT-Pet', 'TinyImageNet', 
    'StanfordCars', 'Food101', 'STL10', 'Imagewoof', 'Imagenette', 'Caltech101'
]

# Cardinality values
cardinalities = [0.1, 0.2, 1, 5, 10, 20]
cardinality_labels = ['0.1×', '0.2×', '1×', '5×', '10×', '20×']

# AUC_MIA values (Table 5) - lower is better for privacy
auc_mia_values = [
    [52.03, 51.53, 52.53, 52.33, 53.03, 52.81],  # CIFAR10
    [51.00, 52.54, 54.26, 59.24, 61.15, 62.80],  # CIFAR100
    [50.97, 50.59, 56.13, 61.33, 63.09, 64.60],  # Oxford-IIIT-Pet
    [51.02, 51.45, 52.66, 61.90, 63.70, 64.55],  # TinyImageNet
    [55.00, 55.42, 65.61, 75.34, 78.03, 79.44],  # StanfordCars
    [51.06, 51.21, 51.61, 55.62, 57.08, 58.90],  # Food101
    [52.67, 52.57, 52.22, 55.99, 56.84, 58.38],  # STL10
    [51.58, 52.76, 53.94, 55.86, 55.51, 55.82],  # Imagewoof
    [54.55, 54.26, 54.68, 55.20, 55.01, 55.32],  # Imagenette
    [57.84, 60.53, 56.44, 60.29, 60.60, 58.37],  # Caltech101
]

# AOP values (Table 6) - higher is better for combined accuracy/privacy
aop_values = [
    [83.37, 87.06, 86.48, 88.45, 86.41, 87.25],  # CIFAR10
    [68.41, 69.34, 70.15, 60.50, 57.02, 54.35],  # CIFAR100
    [76.96, 82.44, 73.27, 62.39, 59.30, 56.72],  # Oxford-IIIT-Pet
    [64.01, 65.93, 67.09, 49.54, 46.82, 45.73],  # TinyImageNet
    [24.82, 44.12, 47.97, 38.51, 36.09, 34.86],  # StanfordCars
    [68.49, 73.41, 79.20, 69.92, 66.54, 62.54],  # Food101
    [81.64, 83.81, 86.87, 76.69, 74.38, 70.68],  # STL10
    [82.92, 81.35, 78.11, 73.88, 74.88, 74.23],  # Imagewoof
    [79.86, 81.49, 80.91, 80.00, 80.70, 79.92],  # Imagenette
    [41.28, 47.25, 71.93, 63.44, 63.05, 67.95],  # Caltech101
]

# CAS values from previous table (for comparison)
cas_values = [
    [90.28, 92.47, 95.45, 96.89, 97.20, 97.33],  # CIFAR10
    [71.17, 76.56, 82.61, 84.93, 85.28, 85.74],  # CIFAR100
    [79.98, 84.40, 92.34, 93.87, 94.41, 94.68],  # Oxford-IIIT-Pet
    [66.65, 69.81, 74.42, 75.92, 76.00, 76.22],  # TinyImageNet
    [30.03, 54.20, 82.60, 87.44, 87.90, 88.00],  # StanfordCars
    [71.42, 77.01, 84.38, 86.52, 86.72, 86.93],  # Food101
    [90.59, 92.65, 94.76, 96.16, 96.12, 96.36],  # STL10
    [88.24, 90.58, 90.91, 92.21, 92.29, 92.52],  # Imagewoof
    [95.06, 95.97, 96.76, 97.50, 97.68, 97.83],  # Imagenette
    [55.24, 69.25, 91.65, 92.24, 92.62, 92.60],  # Caltech101
]


plt.figure(figsize=(14, 8))

# Create a dataframe-like structure for easier manipulation
dataset_optimal = []
for i, dataset in enumerate(datasets):
    # Find best cardinality for each metric
    best_auc_idx = np.argmin(auc_mia_values[i])  # Lower is better for AUC_MIA
    best_aop_idx = np.argmax(aop_values[i])      # Higher is better for AOP
    best_cas_idx = np.argmax(cas_values[i])      # Higher is better for CAS
    
    dataset_optimal.append({
        'dataset': dataset,
        'best_auc_card': cardinality_labels[best_auc_idx],
        'best_aop_card': cardinality_labels[best_aop_idx],
        'best_cas_card': cardinality_labels[best_cas_idx],
        'best_auc_val': auc_mia_values[i][best_auc_idx],
        'best_aop_val': aop_values[i][best_aop_idx],
        'best_cas_val': cas_values[i][best_cas_idx]
    })

# Use the default dataset ordering
default_order = datasets

# Create arrays for the plot
x = np.arange(len(default_order))
width = 0.25

fig, ax = plt.subplots(figsize=(18, 8))

# Create arrays for the plot in default order
best_auc_indices = [cardinality_labels.index(dataset_optimal[i]['best_auc_card']) for i in range(len(datasets))]
best_aop_indices = [cardinality_labels.index(dataset_optimal[i]['best_aop_card']) for i in range(len(datasets))]
best_cas_indices = [cardinality_labels.index(dataset_optimal[i]['best_cas_card']) for i in range(len(datasets))]

# Introduce a vertical offset (e.g. 0.2) to all bar heights
offset = 0.2
best_auc_indices_offset = [val + offset for val in best_auc_indices]
best_aop_indices_offset = [val + offset for val in best_aop_indices]
best_cas_indices_offset = [val + offset for val in best_cas_indices]

# Create bars for optimal cardinality by metric
bar1 = ax.bar(x - width, best_auc_indices_offset, width, label='Best AUC$_{MIA}$', color=palette[0], alpha=0.7)
bar2 = ax.bar(x, best_aop_indices_offset, width, label='Best AOP', color=palette[1], alpha=0.7)
bar3 = ax.bar(x + width, best_cas_indices_offset, width, label='Best CAS', color=palette[2], alpha=0.7)

# Enable y-axis ticks and labels
ax.tick_params(axis='y', which='both', left=True, labelleft=True)
ax.set_yticks([i + offset for i in range(len(cardinality_labels))])
ax.set_yticklabels(cardinality_labels)
ax.set_ylabel('Synthetic Dataset Cardinality', fontsize=16)

# Add x-axis ticks with default dataset names
ax.set_xticks(x)
ax.set_xticklabels(default_order, fontsize=14, rotation=45)

# And add the legend for the three metrics:
ax.legend(fontsize=12, loc='best')

# Add grid
ax.grid(True, axis='y', linestyle='--', alpha=0.7)
#plt.title('Optimal Cardinality for Each Metric', fontweight='bold', fontsize=22, pad=20)
plt.tight_layout()
plt.savefig('../images/metrics_best_cardinality.png', bbox_inches='tight', dpi=600) # do not move after plt.show() or it will save a blank image
plt.show()

## Student-Teacher Gap

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# Extract data from the table
datasets = [
    'CIFAR10', 'CIFAR100', 'Oxford-IIIT-Pet', 'TinyImageNet', 'StanfordCars', 
    'Food101', 'STL10', 'Imagewoof', 'Imagenette', 'Caltech101'
]

# Accuracy (higher is better)
teacher_accuracy = [97.52, 85.49, 93.96, 75.67, 88.22, 86.79, 96.74, 93.05, 98.29, 92.26]
student_accuracy = [97.33, 85.74, 94.68, 76.22, 88.00, 86.93, 96.36, 92.52, 97.83, 92.62]

# AUC_MIA (lower is better)
teacher_auc_mia = [53.89, 70.32, 72.74, 70.29, 82.53, 65.63, 65.89, 58.67, 60.32, 67.81]
student_auc_mia = [52.81, 62.80, 64.60, 64.55, 79.44, 58.90, 58.38, 55.82, 55.32, 60.60]

# AOP (higher is better)
teacher_aop = [83.95, 43.22, 44.40, 38.29, 32.38, 50.37, 55.71, 67.58, 67.53, 50.16]
student_aop = [87.25, 54.35, 56.72, 45.73, 34.86, 62.64, 70.68, 74.23, 79.92, 63.05]

# Create a DataFrame for easier data manipulation
df = pd.DataFrame({
    'Dataset': datasets * 3,
    'Metric': ['Accuracy'] * len(datasets) + ['AUC_MIA'] * len(datasets) + ['AOP'] * len(datasets),
    'Teacher': teacher_accuracy + teacher_auc_mia + teacher_aop,
    'Student': student_accuracy + student_auc_mia + student_aop,
})

# Calculate the differences (Student - Teacher)
df['Difference'] = df['Student'] - df['Teacher']

# For AUC_MIA, we invert the difference because lower is better
df.loc[df['Metric'] == 'AUC_MIA', 'Difference'] = -df.loc[df['Metric'] == 'AUC_MIA', 'Difference']

# Desired order for plotting
df['Metric'] = pd.Categorical(df['Metric'], categories=['Accuracy', 'AUC_MIA', 'AOP'])

# Set the style for all plots
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('muted')
colors = sns.color_palette()

###########################################################################################################################################

plt.figure(figsize=(15, 12))

# Create a 3x1 grid for subplots (one for each metric)
fig, axes = plt.subplots(3, 1, figsize=(14, 18), sharex=True)
fig.subplots_adjust(hspace=0.3)

metrics = ['Accuracy', 'AUC_MIA', 'AOP']
titles = ['Classification Accuracy', 
          'AUC$_{MIA}$', 
          'AOP']
ylabels = ['Classification Accuracy', 'AUC$_{MIA}$', 'AOP']

# Use the default order (no sorting)
sorted_indices = {}
sorted_indices['Accuracy'] = np.arange(len(teacher_accuracy))
sorted_indices['AUC_MIA'] = np.arange(len(teacher_auc_mia))
sorted_indices['AOP'] = np.arange(len(teacher_aop))

for i, metric in enumerate(metrics):
    ax = axes[i]
    
    # Get the sorted indices for this metric
    idx = sorted_indices[metric]
    
    # Get the relevant data
    if metric == 'Accuracy':
        teacher_vals = [teacher_accuracy[j] for j in idx]
        student_vals = [student_accuracy[j] for j in idx]
        datasets_sorted = [datasets[j] for j in idx]
    elif metric == 'AUC_MIA':
        teacher_vals = [teacher_auc_mia[j] for j in idx]
        student_vals = [student_auc_mia[j] for j in idx]
        datasets_sorted = [datasets[j] for j in idx]
    else:  # AOP
        teacher_vals = [teacher_aop[j] for j in idx]
        student_vals = [student_aop[j] for j in idx]
        datasets_sorted = [datasets[j] for j in idx]
    
    # Set the positions and width for the bars
    pos = np.arange(len(datasets_sorted))
    width = 0.30
    
    # Create the bars
    teacher_bars = ax.bar(pos - width/2, teacher_vals, width, label='Teacher Classifier', 
                        color=colors[0], alpha=0.8, edgecolor='black', linewidth=0.5)
    student_bars = ax.bar(pos + width/2, student_vals, width, label='Student Classifier', 
                         color=colors[1], alpha=0.8, edgecolor='black', linewidth=0.5)
    
    # Calculate differences for annotations
    diffs = np.array(student_vals) - np.array(teacher_vals)
    
    # Add difference annotations
    for j, (p, diff) in enumerate(zip(pos, diffs)):
        if metric == 'AUC_MIA':
            # For AUC_MIA, negative diff is better (lower AUC_MIA is better)
            color = 'green' if diff < 0 else 'red'
            diff_text = f"{diff:-.2f}"  # Negate to show as improvement
        else:
            # For Accuracy and AOP, positive diff is better
            color = 'green' if diff > 0 else 'red'
            diff_text = f"{diff:+.2f}"
        
        ax.annotate(diff_text, 
                    xy=(p, max(teacher_vals[j], student_vals[j]) + 0.5), 
                    ha='center', va='bottom', 
                    color=color, fontweight='bold', fontsize=13)
    
    # Add some text for labels, title and custom x-axis tick labels
    ax.set_ylabel(ylabels[i], fontsize=16)
    ax.set_title(titles[i], fontsize=20, fontweight='bold')
    ax.set_xticks(pos)
    ax.set_xticklabels(datasets_sorted, rotation=45, ha='right', fontsize=13)
    
    # Add grid for easier reading
    ax.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Add a horizontal line for reference
    if metric == 'AUC_MIA':
        ax.axhline(y=50, color='gray', linestyle='--', alpha=0.5)  # 50% is random guess
    
    # Add legend
    ax.legend(fontsize=10)

plt.tight_layout()
plt.savefig('../images/student_vs_teacher.png', dpi=600, bbox_inches='tight')
plt.show()