# Plots

## Heatmap of CAS for 3 different prompt/label configurations

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

datasets = ['CIFAR10', 'CIFAR100', 'Oxford-IIIT-Pet', 'TinyImageNet']

# Data structure: Each configuration has 4 datasets x 4 sample ratios
config_names = ['"n: d" - Hard Labels', '"n: c" - Hard Labels', '"n: c" - Soft Labels']
sample_ratios = ['1/50', '1/20', '1/10', '1']

# Create the data arrays for each configuration
config1_data = np.array([
    [71.19, 75.27, 80.43, 80.52],  # CIFAR10
    [32.91, 37.48, 45.43, 47.37],  # CIFAR100
    [2.98, 2.52, 3.79, 2.43],      # Oxford-IIIT-Pet
    [17.64, 20.15, 20.79, 32.04]   # TinyImageNet
])

config2_data = np.array([
    [78.92, 82.76, 82.52, 81.52],  # CIFAR10
    [39.14, 47.19, 47.13, 50.09],  # CIFAR100
    [3.85, 2.07, 2.01, 2.22],      # Oxford-IIIT-Pet
    [30.37, 30.42, 31.34, 35.57]   # TinyImageNet
])

config3_data = np.array([
    [86.97, 87.79, 88.09, 90.28],  # CIFAR10
    [63.29, 66.13, 66.74, 71.17],  # CIFAR100
    [30.72, 65.86, 69.22, 79.98],  # Oxford-IIIT-Pet
    [63.36, 62.97, 62.78, 66.65]   # TinyImageNet
])

# Create figure with subplots
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Custom colormap to highlight higher values
cmap = sns.color_palette("YlGnBu", as_cmap=True)

# Create a list of data arrays and axes for easier iteration
all_data = [config1_data, config2_data, config3_data]

# Plot heatmaps with manual annotations
for i, (data, name) in enumerate(zip(all_data, config_names)):
    # Create heatmap without annotations first
    hm = sns.heatmap(data, 
                 annot=False,  # No automatic annotations
                 cmap=cmap,
                 xticklabels=sample_ratios, 
                 yticklabels=datasets,
                 ax=axes[i], 
                 vmin=0, 
                 vmax=100)
    
    # Add manual annotations for each cell
    for y in range(data.shape[0]):
        for x in range(data.shape[1]):
            # Position text in center of each cell
            axes[i].text(x + 0.5, y + 0.5, f'{data[y, x]:.2f}', 
                     horizontalalignment='center',
                     verticalalignment='center',
                     fontsize=15)
    
    # Set titles and labels
    axes[i].set_title(name, fontsize=20, fontweight='bold')
    axes[i].set_xlabel('Fine-Tuning Samples Ratio', fontsize=16)
    if i == 0:
        axes[i].set_ylabel('Dataset', fontsize=16)
    axes[i].tick_params(axis='both', labelsize=11)

plt.tight_layout()
plt.show()

In [None]:
fig.savefig('../images/impact_blip2_gkd_heatmap.pdf', bbox_inches='tight')

### Bar chart of top CAS for each of those 3 configurations

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Data from Table 1 - using only the best performance for each configuration (ratio=1)
datasets = ['CIFAR10', 'CIFAR100', 'Oxford-IIIT-Pet', 'TinyImageNet']
config_names = ['"n: d" - Hard Labels', '"n: c" - Hard Labels', '"n: c" - Soft Labels']

# Best scores for each configuration (using ratio=1)
config1_best = [80.52, 47.37, 3.79, 32.04]  # "n: d" - Hard Labels
config2_best = [82.76, 50.09, 3.85, 35.57]  # "n: c" - Hard Labels  
config3_best = [90.28, 71.17, 79.98, 66.65]  # "n: c" - Soft Labels

# Set width of bars
barWidth = 0.25
 
# Set positions of the bars on X axis
r1 = np.arange(len(datasets))
r2 = [x + barWidth for x in r1]
r3 = [x + barWidth for x in r2]
 
# Create the figure with larger size for better visibility
plt.figure(figsize=(12, 6))
 
# Make the plot
bars1 = plt.bar(r1, config1_best, width=barWidth, edgecolor='grey', label=config_names[0], color='skyblue')
bars2 = plt.bar(r2, config2_best, width=barWidth, edgecolor='grey', label=config_names[1], color='lightgreen')
bars3 = plt.bar(r3, config3_best, width=barWidth, edgecolor='grey', label=config_names[2], color='salmon')
 
# Add values on top of bars
def add_labels(bars):
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.2f}', ha='center', va='bottom', fontsize=12)

add_labels(bars1)
add_labels(bars2)
add_labels(bars3)
 
# Add labels and title
plt.xlabel('Dataset', fontsize=17)
plt.ylabel('CAS', fontsize=17)
#plt.title('Comparison of Best CAS Across Different Configurations', fontweight='bold', fontsize=23, pad=20)
 
# Add xticks on the middle of the group bars
plt.xticks([r + barWidth for r in range(len(datasets))], datasets, fontsize=14)
 
# Create legend
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.2), ncol=3, fontsize=12)
 
# Adjust layout and save
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig('../images/impact_blip2_gkd_chart.pdf', bbox_inches='tight') # do not move after plt.show() or it will save a blank image
plt.show()

## CAS vs Synthetic Cardinality

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.gridspec as gridspec

# Data for both sub-plots
datasets = [
    'CIFAR10', 'CIFAR100', 'Oxford-IIIT-Pet', 'TinyImageNet', 
    'StanfordCars', 'Food101', 'STL10', 'Imagewoof', 
    'Imagenette', 'Caltech101'
]

# Cardinality values
cardinalities = [0.1, 0.2, 1, 5, 10, 20]
cardinality_labels = ['0.1', '0.2', '1', '5', '10', '20']

# CAS values for each dataset at different cardinalities
cas_values = [
    [90.28, 92.47, 95.45, 96.89, 97.20, 97.33],  # CIFAR10
    [71.17, 76.56, 82.61, 84.93, 85.28, 85.74],  # CIFAR100
    [79.98, 84.40, 92.34, 93.87, 94.41, 94.68],  # Oxford-IIIT-Pet
    [66.65, 69.81, 74.42, 75.92, 76.00, 76.22],  # TinyImageNet
    [30.03, 54.20, 82.60, 87.44, 87.90, 88.00],  # StanfordCars
    [71.42, 77.01, 84.38, 86.52, 86.72, 86.93],  # Food101
    [90.59, 92.65, 94.76, 96.16, 96.12, 96.36],  # STL10
    [88.24, 90.58, 90.91, 92.21, 92.29, 92.52],  # Imagewoof
    [95.06, 95.97, 96.76, 97.50, 97.68, 97.83],  # Imagenette
    [55.24, 69.25, 91.65, 92.24, 92.62, 92.60],  # Caltech101
]

# Create a figure with custom grid layout
fig = plt.figure(figsize=(22, 9))

# Create a GridSpec with 2 rows and 6 columns: 6th column for extra space between plots and 7th for the bar chart
gs = gridspec.GridSpec(2, 7, width_ratios=[1, 1, 1, 1, 1, 0.1, 1.5], wspace=0.3, hspace=0.3)

# Use numeric x positions for consistency
x = np.arange(len(cardinality_labels))
colors = plt.cm.tab20(np.linspace(0, 1, len(datasets)))

# Plot the small multiples (first plot)
for i, (dataset, values) in enumerate(zip(datasets, cas_values)):
    # Calculate row and column indices for the 2x5 grid
    row = i // 5
    col = i % 5
    
    # Create subplot in the appropriate position
    ax = fig.add_subplot(gs[row, col])
    
    # Plot the accuracy line using numeric x positions
    ax.plot(x, values, marker='o', linewidth=2, color=colors[i])
    
    # Annotate each vertex with its CAS value
    y_min, y_max = min(values), max(values)
    vertical_fixed_offset = 22  # fixed offset in points
    for xi, val in zip(x, values):
        if val > y_min + 0.75 * (y_max - y_min):
            vertical_offset = -vertical_fixed_offset
        else:
            vertical_offset = vertical_fixed_offset
        horizontal_fixed_offset = 5
        if xi == x[0]:
            horizontal_offset = horizontal_fixed_offset
        elif xi == x[-1]:
            horizontal_offset = -horizontal_fixed_offset
        else:
            horizontal_offset = 0

        ax.annotate(f"{val:.2f}", xy=(xi, val), 
                    xytext=(horizontal_offset, vertical_offset),
                    textcoords='offset points', ha='center', va='center', fontsize=7)
    
    # Determine best point index: for Caltech101 use second-to-last, otherwise the last
    best_idx = len(values) - 2 if dataset == 'Caltech101' else len(values) - 1
    ax.scatter(x[best_idx], values[best_idx], marker='*', color=colors[i],
               s=200, zorder=5)
    
    ax.set_title(r"$\mathbf{" + dataset + "}$", fontsize=15)
    
    range_val = max(values) - min(values)
    min_val = max(0, min(values) - 0.1 * range_val)
    max_val = min(100, max(values) + 0.1 * range_val)
    ax.set_ylim(min_val, max_val)

    # Set exactly 10 equally spaced horizontal grid lines
    num_ticks = 10
    yticks = np.linspace(min_val, max_val, num_ticks)
    ax.set_yticks(yticks)
    ax.set_yticklabels([f"{tick:.0f}" for tick in yticks], fontsize=10)
    
    ax.grid(True, linestyle='--', alpha=0.5)
    ax.set_xticks(x)
    ax.set_xticklabels(cardinality_labels, fontsize=10)
    
    # Only add x-axis label for bottom row subplots
    if row == 1:
        ax.set_xlabel('Cardinality', fontsize=14)
    # Only add y-axis label for the leftmost column subplots
    if col == 0:
        ax.set_ylabel('CAS', fontsize=14)

########################################################################################

# Now create the transposed bar chart (second plot)
ax_bar = fig.add_subplot(gs[:, 6])  # Span both rows in the last column

# Add title to the bar chart
#ax_bar.set_title("Classification Accuracy Score", fontsize=15, fontweight='bold') 

# Calculate improvements between each cardinality level
improvements = []
for values in cas_values:
    dataset_improvements = []
    for i in range(1, len(values)):
        dataset_improvements.append(values[i] - values[i-1])
    improvements.append(dataset_improvements)

# Calculate average improvement at each cardinality transition
avg_improvements = np.mean(improvements, axis=0)

# Define transition labels
transition_labels = ["0.1 → 0.2", "0.2 → 1", "1 → 5", "5 → 10", "10 → 20"]

# Plot average improvement between cardinality levels
y_pos = np.arange(len(transition_labels))

# Use different colors for positive and negative bars
colors = ['green' if x >= 0 else 'red' for x in avg_improvements]

# Plot bars without error bars
ax_bar.barh(y_pos, avg_improvements, alpha=0.7, color=colors)
ax_bar.axvline(x=0, color='k', linestyle='-', alpha=0.3)
ax_bar.grid(True, axis='x', linestyle='--', alpha=0.5)
ax_bar.set_xlabel("Average Change", fontsize=14)
ax_bar.set_yticks(y_pos)
ax_bar.set_yticklabels(transition_labels, fontsize=13)
ax_bar.invert_yaxis()  # To have 0.1x→0.2x at the top

# Set x-axis limits for left-to-right direction
max_value = max(avg_improvements) * 1.15  # Add 15% padding
ax_bar.set_xlim(0, 10)  # Changed to make bars go left to right

# Add value annotations at the end of each bar (adjusted position)
for i, v in enumerate(avg_improvements):
    ax_bar.text(v + max_value*0.01, i, f"{v:.2f}", va='center', fontsize=11)

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig("../images/cas_vs_cardinality.pdf", bbox_inches='tight')
plt.show()

## AUC vs Synthetic Cardinality

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.gridspec as gridspec

# Data for both sub-plots
datasets = [
    'CIFAR10', 'CIFAR100', 'Oxford-IIIT-Pet', 'TinyImageNet', 
    'StanfordCars', 'Food101', 'STL10', 'Imagewoof', 
    'Imagenette', 'Caltech101'
]

# Cardinality values
cardinalities = [0.1, 0.2, 1, 5, 10, 20]
cardinality_labels = ['0.1', '0.2', '1', '5', '10', '20']

# AUC_MIA values for each dataset at different cardinalities (lower is better)
auc_mia_values = [
    [52.03, 51.53, 52.53, 52.33, 53.03, 52.81],  # CIFAR10
    [51.00, 52.54, 54.26, 59.24, 61.15, 62.80],  # CIFAR100
    [50.97, 50.59, 56.13, 61.33, 63.09, 64.60],  # Oxford-IIIT-Pet
    [51.02, 51.45, 52.66, 61.90, 63.70, 64.55],  # TinyImageNet
    [55.00, 55.42, 65.61, 75.34, 78.03, 79.44],  # StanfordCars
    [51.06, 51.21, 51.61, 55.62, 57.08, 58.90],  # Food101
    [52.67, 52.57, 52.22, 55.99, 56.84, 58.38],  # STL10
    [51.58, 52.76, 53.94, 55.86, 55.51, 55.82],  # Imagewoof
    [54.55, 54.26, 54.68, 55.20, 55.01, 55.32],  # Imagenette
    [57.84, 60.53, 56.44, 60.29, 60.60, 58.37],  # Caltech101
]

# Create a figure with custom grid layout
fig = plt.figure(figsize=(22, 9))

# Create a GridSpec with 2 rows and 6 columns: 6th column for extra space between plots and 7th for the bar chart
gs = gridspec.GridSpec(2, 7, width_ratios=[1, 1, 1, 1, 1, 0.1, 1.5], wspace=0.3, hspace=0.3)

# Use numeric x positions for consistency
x = np.arange(len(cardinality_labels))
colors = plt.cm.tab20(np.linspace(0, 1, len(datasets)))

# Plot the small multiples (first plot)
for i, (dataset, values) in enumerate(zip(datasets, auc_mia_values)):
    # Calculate row and column indices for the 2x5 grid
    row = i // 5
    col = i % 5
    
    # Create subplot in the appropriate position
    ax = fig.add_subplot(gs[row, col])
    
    # Plot the AUC_MIA line using numeric x positions
    ax.plot(x, values, marker='o', linewidth=2, color=colors[i])
    
    # Annotate each vertex with its AUC_MIA value
    y_min, y_max = min(values), max(values)
    vertical_fixed_offset = 22  # fixed offset in points
    for xi, val in zip(x, values):
        if val > y_min + 0.75 * (y_max - y_min):
            vertical_offset = -vertical_fixed_offset
        else:
            vertical_offset = vertical_fixed_offset
        horizontal_fixed_offset = 5
        if xi == x[0]:
            horizontal_offset = horizontal_fixed_offset
        elif xi == x[-1]:
            horizontal_offset = -horizontal_fixed_offset
        else:
            horizontal_offset = 0

        ax.annotate(f"{val:.2f}", xy=(xi, val), 
                    xytext=(horizontal_offset, vertical_offset),
                    textcoords='offset points', ha='center', va='center', fontsize=7)
    
    # Determine best point index (lowest value for AUC_MIA)
    best_idx = np.argmin(values)
    ax.scatter(x[best_idx], values[best_idx], marker='*', color=colors[i],
               s=200, zorder=5)
    
    ax.set_title(r"$\mathbf{" + dataset + "}$", fontsize=15)
    
    range_val = max(values) - min(values)
    min_val = max(50, min(values) - 0.1 * range_val) # 50 is ideal (random guessing)
    max_val = min(100, max(values) + 0.1 * range_val)
    ax.set_ylim(min_val, max_val)

    # Set exactly 10 equally spaced horizontal grid lines
    num_ticks = 10
    yticks = np.linspace(min_val, max_val, num_ticks)
    ax.set_yticks(yticks)
    ax.set_yticklabels([f"{tick:.0f}" for tick in yticks], fontsize=10)
    
    ax.grid(True, linestyle='--', alpha=0.5)
    ax.set_xticks(x)
    ax.set_xticklabels(cardinality_labels, fontsize=10)
    
    # Only add x-axis label for bottom row subplots
    if row == 1:
        ax.set_xlabel('Cardinality', fontsize=14)
    # Only add y-axis label for the leftmost column subplots
    if col == 0:
        ax.set_ylabel('AUC$_{MIA}$', fontsize=14)

########################################################################################

# Now create the bar chart showing privacy degradation between cardinality levels
ax_bar = fig.add_subplot(gs[:, 6])  # Span both rows in the last column

# Add title to the bar chart
#ax_bar.set_title("Area Under the ROC Curve for MIAs", fontsize=15, fontweight='bold') 

# Calculate changes between each cardinality level
# For AUC_MIA, negative change is better (previous - current)
changes = []
for values in auc_mia_values:
    dataset_changes = []
    for i in range(1, len(values)):
        dataset_changes.append(values[i] - values[i-1])
    changes.append(dataset_changes)

# Calculate average change at each cardinality transition
avg_changes = np.mean(changes, axis=0)

# Define transition labels
transition_labels = ["0.1 → 0.2", "0.2 → 1", "1 → 5", "5 → 10", "10 → 20"]

# Plot average change between cardinality levels
y_pos = np.arange(len(transition_labels))

# Use different colors for positive and negative bars
colors = ['green' if x <= 0 else 'red' for x in avg_changes]

# Plot bars
bars = ax_bar.barh(y_pos, avg_changes, alpha=0.7, color=colors)
ax_bar.axvline(x=0, color='k', linestyle='-', alpha=0.3)
ax_bar.grid(True, axis='x', linestyle='--', alpha=0.5)
ax_bar.set_xlabel("Average Change", fontsize=14)
ax_bar.set_yticks(y_pos)
ax_bar.set_yticklabels(transition_labels, fontsize=13)
ax_bar.invert_yaxis()  # To have 0.1→0.2 at the top

# Set x-axis limits with some padding
max_abs_value = max(abs(min(avg_changes)), abs(max(avg_changes))) * 1.30
ax_bar.set_xlim(-6, 6)

# Add value annotations at the end of each bar
for i, v in enumerate(avg_changes):
    ax_bar.text(v + np.sign(v) * max_abs_value * 0.02, i, f"{v:.2f}", 
                va='center', ha='left' if v >= 0 else 'right', fontsize=11)

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig("../images/auc_vs_cardinality.pdf", bbox_inches='tight')
plt.show()

## AOP vs Synthetic Cardinality

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.gridspec as gridspec

# Data for both sub-plots
datasets = [
    'CIFAR10', 'CIFAR100', 'Oxford-IIIT-Pet', 'TinyImageNet', 
    'StanfordCars', 'Food101', 'STL10', 'Imagewoof', 
    'Imagenette', 'Caltech101'
]

# Cardinality values
cardinalities = [0.1, 0.2, 1, 5, 10, 20]
cardinality_labels = ['0.1', '0.2', '1', '5', '10', '20']

# AOP values for each dataset at different cardinalities (higher is better)
aop_values = [
    [83.37, 87.06, 86.48, 88.45, 86.41, 87.25],  # CIFAR10
    [68.41, 69.34, 70.15, 60.50, 57.02, 54.35],  # CIFAR100
    [76.96, 82.44, 73.27, 62.39, 59.30, 56.72],  # Oxford-IIIT-Pet
    [64.01, 65.93, 67.09, 49.54, 46.82, 45.73],  # TinyImageNet
    [24.82, 44.12, 47.97, 38.51, 36.09, 34.86],  # StanfordCars
    [68.49, 73.41, 79.20, 69.92, 66.54, 62.54],  # Food101
    [81.64, 83.81, 86.87, 76.69, 74.38, 70.68],  # STL10
    [82.92, 81.35, 78.11, 73.88, 74.88, 74.23],  # Imagewoof
    [79.86, 81.49, 80.91, 80.00, 80.70, 79.92],  # Imagenette
    [41.28, 47.25, 71.93, 63.44, 63.05, 67.95],  # Caltech101
]

# Create a figure with custom grid layout
fig = plt.figure(figsize=(22, 9))

# Create a GridSpec with 2 rows and 6 columns: 6th column for extra space between plots and 7th for the bar chart
gs = gridspec.GridSpec(2, 7, width_ratios=[1, 1, 1, 1, 1, 0.1, 1.5], wspace=0.3, hspace=0.3)

# Use numeric x positions for consistency
x = np.arange(len(cardinality_labels))
colors = plt.cm.tab20(np.linspace(0, 1, len(datasets)))

# Plot the small multiples (first plot)
for i, (dataset, values) in enumerate(zip(datasets, aop_values)):
    # Calculate row and column indices for the 2x5 grid
    row = i // 5
    col = i % 5
    
    # Create subplot in the appropriate position
    ax = fig.add_subplot(gs[row, col])
    
    # Plot the accuracy line using numeric x positions
    ax.plot(x, values, marker='o', linewidth=2, color=colors[i])
    
    # Annotate each vertex with its AOP value
    y_min, y_max = min(values), max(values)
    vertical_fixed_offset = 22  # fixed offset in points
    for xi, val in zip(x, values):
        if val > y_min + 0.75 * (y_max - y_min):
            vertical_offset = -vertical_fixed_offset
        else:
            vertical_offset = vertical_fixed_offset
        horizontal_fixed_offset = 5
        if xi == x[0]:
            horizontal_offset = horizontal_fixed_offset
        elif xi == x[-1]:
            horizontal_offset = -horizontal_fixed_offset
        else:
            horizontal_offset = 0

        ax.annotate(f"{val:.2f}", xy=(xi, val), 
                    xytext=(horizontal_offset, vertical_offset),
                    textcoords='offset points', ha='center', va='center', fontsize=7)
    
    # Determine best point index (highest value for AOP)
    best_idx = np.argmax(values)
    ax.scatter(x[best_idx], values[best_idx], marker='*', color=colors[i],
               s=200, zorder=5)
    
    ax.set_title(r"$\mathbf{" + dataset + "}$", fontsize=15)
    
    range_val = max(values) - min(values)
    min_val = max(0, min(values) - 0.1 * range_val)
    max_val = min(100, max(values) + 0.1 * range_val)
    ax.set_ylim(min_val, max_val)

    # Set exactly 10 equally spaced horizontal grid lines
    num_ticks = 10
    yticks = np.linspace(min_val, max_val, num_ticks)
    ax.set_yticks(yticks)
    ax.set_yticklabels([f"{tick:.0f}" for tick in yticks], fontsize=10)
    
    ax.grid(True, linestyle='--', alpha=0.5)
    ax.set_xticks(x)
    ax.set_xticklabels(cardinality_labels, fontsize=10)
    
    # Only add x-axis label for bottom row subplots
    if row == 1:
        ax.set_xlabel('Cardinality', fontsize=14)
    # Only add y-axis label for the leftmost column subplots
    if col == 0:
        ax.set_ylabel('AOP', fontsize=14)

########################################################################################

# Now create the bar chart showing AOP changes between cardinality levels
ax_bar = fig.add_subplot(gs[:, 6])  # Span both rows in the last column

# Add title to the bar chart
#ax_bar.set_title("Accuracy Over Privacy", fontsize=15, fontweight='bold') 

# Calculate improvements between each cardinality level
improvements = []
for values in aop_values:
    dataset_improvements = []
    for i in range(1, len(values)):
        dataset_improvements.append(values[i] - values[i-1])
    improvements.append(dataset_improvements)

# Calculate average improvement at each cardinality transition
avg_improvements = np.mean(improvements, axis=0)

# Define transition labels
transition_labels = ["0.1 → 0.2", "0.2 → 1", "1 → 5", "5 → 10", "10 → 20"]

# Plot average improvement between cardinality levels
y_pos = np.arange(len(transition_labels))

# Use different colors for positive and negative bars
colors = ['green' if x >= 0 else 'red' for x in avg_improvements]

# Plot bars
bars = ax_bar.barh(y_pos, avg_improvements, alpha=0.7, color=colors)
ax_bar.axvline(x=0, color='k', linestyle='-', alpha=0.3)
ax_bar.grid(True, axis='x', linestyle='--', alpha=0.5)
ax_bar.set_xlabel("Average Change", fontsize=14)
ax_bar.set_yticks(y_pos)
ax_bar.set_yticklabels(transition_labels, fontsize=13)
ax_bar.invert_yaxis()  # To have 0.1→0.2 at the top

# Set x-axis limits with some padding
max_abs_value = max(abs(min(avg_improvements)), abs(max(avg_improvements))) * 1.40
ax_bar.set_xlim(-max_abs_value, max_abs_value)

# Add value annotations at the end of each bar
for i, v in enumerate(avg_improvements):
    ax_bar.text(v + np.sign(v) * max_abs_value * 0.02, i, f"{v:.2f}", 
                va='center', ha='left' if v >= 0 else 'right', fontsize=11)

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig("../images/aop_vs_cardinality.pdf", bbox_inches='tight')
plt.show()

## Radar charts for metrics balance

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D


datasets = [
    'CIFAR10', 'CIFAR100', 'Oxford-IIIT-Pet', 'TinyImageNet', 
    'StanfordCars', 'Food101', 'STL10', 'Imagewoof', 
    'Imagenette', 'Caltech101'
]

# Cardinality values
cardinalities = [0.1, 0.2, 1, 5, 10, 20]
cardinality_labels = ['0.1×', '0.2×', '1×', '5×', '10×', '20×']

# Updated AUC_MIA values - lower is better for privacy
auc_mia_values = [
    [52.03, 51.53, 52.53, 52.33, 53.03, 52.81],  # CIFAR10
    [51.00, 52.54, 54.26, 59.24, 61.15, 62.80],  # CIFAR100
    [50.97, 50.59, 56.13, 61.33, 63.09, 64.60],  # Oxford-IIIT-Pet
    [51.02, 51.45, 52.66, 61.90, 63.70, 64.55],  # TinyImageNet
    [55.00, 55.42, 65.61, 75.34, 78.03, 79.44],  # StanfordCars
    [51.06, 51.21, 51.61, 55.62, 57.08, 58.90],  # Food101
    [52.67, 52.57, 52.22, 55.99, 56.84, 58.38],  # STL10
    [51.58, 52.76, 53.94, 55.86, 55.51, 55.82],  # Imagewoof
    [54.55, 54.26, 54.68, 55.20, 55.01, 55.32],  # Imagenette
    [57.84, 60.53, 56.44, 60.29, 60.60, 58.37],  # Caltech101
]

# Updated AOP values - higher is better for combined accuracy/privacy
aop_values = [
    [83.37, 87.06, 86.48, 88.45, 86.41, 87.25],  # CIFAR10
    [68.41, 69.34, 70.15, 60.50, 57.02, 54.35],  # CIFAR100
    [76.96, 82.44, 73.27, 62.39, 59.30, 56.72],  # Oxford-IIIT-Pet
    [64.01, 65.93, 67.09, 49.54, 46.82, 45.73],  # TinyImageNet
    [24.82, 44.12, 47.97, 38.51, 36.09, 34.86],  # StanfordCars
    [68.49, 73.41, 79.20, 69.92, 66.54, 62.54],  # Food101
    [81.64, 83.81, 86.87, 76.69, 74.38, 70.68],  # STL10
    [82.92, 81.35, 78.11, 73.88, 74.88, 74.23],  # Imagewoof
    [79.86, 81.49, 80.91, 80.00, 80.70, 79.92],  # Imagenette
    [41.28, 47.25, 71.93, 63.44, 63.05, 67.95],  # Caltech101
]

# Updated CAS values for comparison
cas_values = [
    [90.28, 92.47, 95.45, 96.89, 97.20, 97.33],  # CIFAR10
    [71.17, 76.56, 82.61, 84.93, 85.28, 85.74],  # CIFAR100
    [79.98, 84.40, 92.34, 93.87, 94.41, 94.68],  # Oxford-IIIT-Pet
    [66.65, 69.81, 74.42, 75.92, 76.00, 76.22],  # TinyImageNet
    [30.03, 54.20, 82.60, 87.44, 87.90, 88.00],  # StanfordCars
    [71.42, 77.01, 84.38, 86.52, 86.72, 86.93],  # Food101
    [90.59, 92.65, 94.76, 96.16, 96.12, 96.36],  # STL10
    [88.24, 90.58, 90.91, 92.21, 92.29, 92.52],  # Imagewoof
    [95.06, 95.97, 96.76, 97.50, 97.68, 97.83],  # Imagenette
    [55.24, 69.25, 91.65, 92.24, 92.62, 92.60],  # Caltech101
]

# We'll plot using radar charts in a 2x5 grid
fig = plt.figure(figsize=(20, 10))

# Create indices for 10 datasets
dataset_indices = list(range(len(datasets)))

# Create a grid of radar charts arranged in 2 rows and 5 columns
for i, idx in enumerate(dataset_indices):
    ax = fig.add_subplot(2, 5, i+1, projection='polar')
    
    # Prepare data for radar chart for current dataset
    # Use the corresponding values from CAS, AUC_MIA, and AOP arrays.
    accuracy_scores = cas_values[idx]
    privacy_scores = auc_mia_values[idx]
    aop_scores = aop_values[idx]
    
    # Set angles for the three metrics. (Order: CAS, AUC, AOP)
    angles = np.linspace(0, 2*np.pi, 3, endpoint=False).tolist()
    angles += angles[:1]  # Close the loop
    
    # Plot a radar chart line for each cardinality value
    for j, cardinality in enumerate(cardinality_labels):
        # For each cardinality pick the metrics at the same index
        values = [accuracy_scores[j], privacy_scores[j], aop_scores[j]]
        values += values[:1]  # Close the loop
        color = plt.cm.viridis(j/len(cardinality_labels))
        ax.plot(angles, values, linewidth=2, label=f"{cardinality}", color=color)
        ax.fill(angles, values, alpha=0.1, color=color)
    
    # Set labels for the three axes
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(['CAS', 'AUC', 'AOP'], fontsize=12)
    
    # Set chart title with dataset name
    ax.set_title(datasets[i], fontweight='bold', size=18, pad=20)
    
    # Set limits and grid
    ax.set_ylim(0, 100)
    ax.grid(True)
    ax.set_rgrids(np.linspace(20, 100, 5), angle=45, fontsize=12)

# Add a common legend
handles = [Line2D([0], [0], color=plt.cm.viridis(j/len(cardinality_labels)), 
                  linewidth=2, label=cardinality_labels[j]) 
           for j in range(len(cardinality_labels))]
fig.legend(handles=handles, loc='upper center', bbox_to_anchor=(0.5, 0.07), ncol=6, fontsize=16)

plt.tight_layout(rect=[0, 0.1, 1, 0.95], h_pad=6)
plt.subplots_adjust(wspace=0.2)
plt.savefig('../images/metrics_balance_radar.pdf', bbox_inches='tight')  # do not move after plt.show() or it will save a blank image
plt.show()

## Best Synthetic Cardinality for each Metric

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Add this line to extract colors from the 'tab10' palette
palette = plt.get_cmap('tab10').colors

# Data from tables
datasets = [
    'CIFAR10', 'CIFAR100', 'Oxford-IIIT-Pet', 'TinyImageNet', 
    'StanfordCars', 'Food101', 'STL10', 'Imagewoof', 'Imagenette', 'Caltech101'
]

# Cardinality values
cardinalities = [0.1, 0.2, 1, 5, 10, 20]
cardinality_labels = ['0.1×', '0.2×', '1×', '5×', '10×', '20×']

# AUC_MIA values (Table 5) - lower is better for privacy
auc_mia_values = [
    [52.03, 51.53, 52.53, 52.33, 53.03, 52.81],  # CIFAR10
    [51.00, 52.54, 54.26, 59.24, 61.15, 62.80],  # CIFAR100
    [50.97, 50.59, 56.13, 61.33, 63.09, 64.60],  # Oxford-IIIT-Pet
    [51.02, 51.45, 52.66, 61.90, 63.70, 64.55],  # TinyImageNet
    [55.00, 55.42, 65.61, 75.34, 78.03, 79.44],  # StanfordCars
    [51.06, 51.21, 51.61, 55.62, 57.08, 58.90],  # Food101
    [52.67, 52.57, 52.22, 55.99, 56.84, 58.38],  # STL10
    [51.58, 52.76, 53.94, 55.86, 55.51, 55.82],  # Imagewoof
    [54.55, 54.26, 54.68, 55.20, 55.01, 55.32],  # Imagenette
    [57.84, 60.53, 56.44, 60.29, 60.60, 58.37],  # Caltech101
]

# AOP values (Table 6) - higher is better for combined accuracy/privacy
aop_values = [
    [83.37, 87.06, 86.48, 88.45, 86.41, 87.25],  # CIFAR10
    [68.41, 69.34, 70.15, 60.50, 57.02, 54.35],  # CIFAR100
    [76.96, 82.44, 73.27, 62.39, 59.30, 56.72],  # Oxford-IIIT-Pet
    [64.01, 65.93, 67.09, 49.54, 46.82, 45.73],  # TinyImageNet
    [24.82, 44.12, 47.97, 38.51, 36.09, 34.86],  # StanfordCars
    [68.49, 73.41, 79.20, 69.92, 66.54, 62.54],  # Food101
    [81.64, 83.81, 86.87, 76.69, 74.38, 70.68],  # STL10
    [82.92, 81.35, 78.11, 73.88, 74.88, 74.23],  # Imagewoof
    [79.86, 81.49, 80.91, 80.00, 80.70, 79.92],  # Imagenette
    [41.28, 47.25, 71.93, 63.44, 63.05, 67.95],  # Caltech101
]

# CAS values from previous table (for comparison)
cas_values = [
    [90.28, 92.47, 95.45, 96.89, 97.20, 97.33],  # CIFAR10
    [71.17, 76.56, 82.61, 84.93, 85.28, 85.74],  # CIFAR100
    [79.98, 84.40, 92.34, 93.87, 94.41, 94.68],  # Oxford-IIIT-Pet
    [66.65, 69.81, 74.42, 75.92, 76.00, 76.22],  # TinyImageNet
    [30.03, 54.20, 82.60, 87.44, 87.90, 88.00],  # StanfordCars
    [71.42, 77.01, 84.38, 86.52, 86.72, 86.93],  # Food101
    [90.59, 92.65, 94.76, 96.16, 96.12, 96.36],  # STL10
    [88.24, 90.58, 90.91, 92.21, 92.29, 92.52],  # Imagewoof
    [95.06, 95.97, 96.76, 97.50, 97.68, 97.83],  # Imagenette
    [55.24, 69.25, 91.65, 92.24, 92.62, 92.60],  # Caltech101
]


plt.figure(figsize=(14, 8))

# Create a dataframe-like structure for easier manipulation
dataset_optimal = []
for i, dataset in enumerate(datasets):
    # Find best cardinality for each metric
    best_auc_idx = np.argmin(auc_mia_values[i])  # Lower is better for AUC_MIA
    best_aop_idx = np.argmax(aop_values[i])      # Higher is better for AOP
    best_cas_idx = np.argmax(cas_values[i])      # Higher is better for CAS
    
    dataset_optimal.append({
        'dataset': dataset,
        'best_auc_card': cardinality_labels[best_auc_idx],
        'best_aop_card': cardinality_labels[best_aop_idx],
        'best_cas_card': cardinality_labels[best_cas_idx],
        'best_auc_val': auc_mia_values[i][best_auc_idx],
        'best_aop_val': aop_values[i][best_aop_idx],
        'best_cas_val': cas_values[i][best_cas_idx]
    })

# Use the default dataset ordering
default_order = datasets

# Create arrays for the plot
x = np.arange(len(default_order))
width = 0.25

fig, ax = plt.subplots(figsize=(18, 8))

# Create arrays for the plot in default order
best_auc_indices = [cardinality_labels.index(dataset_optimal[i]['best_auc_card']) for i in range(len(datasets))]
best_aop_indices = [cardinality_labels.index(dataset_optimal[i]['best_aop_card']) for i in range(len(datasets))]
best_cas_indices = [cardinality_labels.index(dataset_optimal[i]['best_cas_card']) for i in range(len(datasets))]

# Introduce a vertical offset (e.g. 0.2) to all bar heights
offset = 0.2
best_auc_indices_offset = [val + offset for val in best_auc_indices]
best_aop_indices_offset = [val + offset for val in best_aop_indices]
best_cas_indices_offset = [val + offset for val in best_cas_indices]

# Create bars for optimal cardinality by metric
bar1 = ax.bar(x - width, best_auc_indices_offset, width, label='Best AUC$_{MIA}$', color=palette[0], alpha=0.7)
bar2 = ax.bar(x, best_aop_indices_offset, width, label='Best AOP', color=palette[1], alpha=0.7)
bar3 = ax.bar(x + width, best_cas_indices_offset, width, label='Best CAS', color=palette[2], alpha=0.7)

# Enable y-axis ticks and labels
ax.tick_params(axis='y', which='both', left=True, labelleft=True)
ax.set_yticks([i + offset for i in range(len(cardinality_labels))])
ax.set_yticklabels(cardinality_labels)
ax.set_ylabel('Synthetic Dataset Cardinality', fontsize=16)

# Add x-axis ticks with default dataset names
ax.set_xticks(x)
ax.set_xticklabels(default_order, fontsize=14, rotation=45)

# And add the legend for the three metrics:
ax.legend(fontsize=12, loc='best')

# Add grid
ax.grid(True, axis='y', linestyle='--', alpha=0.7)
#plt.title('Optimal Cardinality for Each Metric', fontweight='bold', fontsize=22, pad=20)
plt.tight_layout()
plt.savefig('../images/metrics_best_cardinality.pdf', bbox_inches='tight') # do not move after plt.show() or it will save a blank image
plt.show()

## Student-Teacher Gap

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# Extract data from the table
datasets = [
    'CIFAR10', 'CIFAR100', 'Oxford-IIIT-Pet', 'TinyImageNet', 'StanfordCars', 
    'Food101', 'STL10', 'Imagewoof', 'Imagenette', 'Caltech101'
]

# Accuracy (higher is better)
teacher_accuracy = [97.52, 85.49, 93.96, 75.67, 88.22, 86.79, 96.74, 93.05, 98.29, 92.26]
student_accuracy = [97.33, 85.74, 94.68, 76.22, 88.00, 86.93, 96.36, 92.52, 97.83, 92.62]

# AUC_MIA (lower is better)
teacher_auc_mia = [53.89, 70.32, 72.74, 70.29, 82.53, 65.63, 65.89, 58.67, 60.32, 67.81]
student_auc_mia = [52.81, 62.80, 64.60, 64.55, 79.44, 58.90, 58.38, 55.82, 55.32, 60.60]

# AOP (higher is better)
teacher_aop = [83.95, 43.22, 44.40, 38.29, 32.38, 50.37, 55.71, 67.58, 67.53, 50.16]
student_aop = [87.25, 54.35, 56.72, 45.73, 34.86, 62.64, 70.68, 74.23, 79.92, 63.05]

# Create a DataFrame for easier data manipulation
df = pd.DataFrame({
    'Dataset': datasets * 3,
    'Metric': ['Accuracy'] * len(datasets) + ['AUC_MIA'] * len(datasets) + ['AOP'] * len(datasets),
    'Teacher': teacher_accuracy + teacher_auc_mia + teacher_aop,
    'Student': student_accuracy + student_auc_mia + student_aop,
})

# Calculate the differences (Student - Teacher)
df['Difference'] = df['Student'] - df['Teacher']

# For AUC_MIA, we invert the difference because lower is better
df.loc[df['Metric'] == 'AUC_MIA', 'Difference'] = -df.loc[df['Metric'] == 'AUC_MIA', 'Difference']

# Desired order for plotting
df['Metric'] = pd.Categorical(df['Metric'], categories=['Accuracy', 'AUC_MIA', 'AOP'])

# Set the style for all plots
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('muted')
colors = sns.color_palette()

###########################################################################################################################################

plt.figure(figsize=(15, 12))

# Create a 3x1 grid for subplots (one for each metric)
fig, axes = plt.subplots(3, 1, figsize=(14, 18), sharex=True)
fig.subplots_adjust(hspace=0.3)

metrics = ['Accuracy', 'AUC_MIA', 'AOP']
titles = ['Classification Accuracy', 
          'AUC$_{MIA}$', 
          'AOP']
ylabels = ['Classification Accuracy', 'AUC$_{MIA}$', 'AOP']

# Use the default order (no sorting)
sorted_indices = {}
sorted_indices['Accuracy'] = np.arange(len(teacher_accuracy))
sorted_indices['AUC_MIA'] = np.arange(len(teacher_auc_mia))
sorted_indices['AOP'] = np.arange(len(teacher_aop))

for i, metric in enumerate(metrics):
    ax = axes[i]
    
    # Get the sorted indices for this metric
    idx = sorted_indices[metric]
    
    # Get the relevant data
    if metric == 'Accuracy':
        teacher_vals = [teacher_accuracy[j] for j in idx]
        student_vals = [student_accuracy[j] for j in idx]
        datasets_sorted = [datasets[j] for j in idx]
    elif metric == 'AUC_MIA':
        teacher_vals = [teacher_auc_mia[j] for j in idx]
        student_vals = [student_auc_mia[j] for j in idx]
        datasets_sorted = [datasets[j] for j in idx]
    else:  # AOP
        teacher_vals = [teacher_aop[j] for j in idx]
        student_vals = [student_aop[j] for j in idx]
        datasets_sorted = [datasets[j] for j in idx]
    
    # Set the positions and width for the bars
    pos = np.arange(len(datasets_sorted))
    width = 0.30
    
    # Create the bars
    teacher_bars = ax.bar(pos - width/2, teacher_vals, width, label='Teacher Classifier', 
                        color=colors[0], alpha=0.8, edgecolor='black', linewidth=0.5)
    student_bars = ax.bar(pos + width/2, student_vals, width, label='Student Classifier', 
                         color=colors[1], alpha=0.8, edgecolor='black', linewidth=0.5)
    
    # Calculate differences for annotations
    diffs = np.array(student_vals) - np.array(teacher_vals)
    
    # Add difference annotations
    for j, (p, diff) in enumerate(zip(pos, diffs)):
        if metric == 'AUC_MIA':
            # For AUC_MIA, negative diff is better (lower AUC_MIA is better)
            color = 'green' if diff < 0 else 'red'
            diff_text = f"{diff:-.2f}"  # Negate to show as improvement
        else:
            # For Accuracy and AOP, positive diff is better
            color = 'green' if diff > 0 else 'red'
            diff_text = f"{diff:+.2f}"
        
        ax.annotate(diff_text, 
                    xy=(p, max(teacher_vals[j], student_vals[j]) + 0.5), 
                    ha='center', va='bottom', 
                    color=color, fontweight='bold', fontsize=13)
    
    # Add some text for labels, title and custom x-axis tick labels
    ax.set_ylabel(ylabels[i], fontsize=16)
    ax.set_title(titles[i], fontsize=20, fontweight='bold')
    ax.set_xticks(pos)
    ax.set_xticklabels(datasets_sorted, rotation=45, ha='right', fontsize=13)
    
    # Add grid for easier reading
    ax.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Add a horizontal line for reference
    if metric == 'AUC_MIA':
        ax.axhline(y=50, color='gray', linestyle='--', alpha=0.5)  # 50% is random guess
    
    # Add legend
    ax.legend(fontsize=10)

plt.tight_layout()
plt.savefig('../images/student_vs_teacher.pdf', bbox_inches='tight')
plt.show()

## Prompts: Claude vs Blip-2

In [None]:
from datasets import load_dataset
import matplotlib.pyplot as plt
import textwrap

dataset_info = {
    "Cifar100": {
        "class_name": "castle",
        "class_description": "a large fortified building or set of buildings with thick walls",
        "hf_dataset": "uoft-cs/cifar100",
        "img_column": "img",
        "captions": [
            "a large pink castle with a tower in the middle of a field",
            "a castle on a hill with trees and bushes",
            "a large castle sitting on a green field next to a river"
        ],
        "image_indices": [31, 1802, 6130]
    },
    "TinyImageNet": {
        "class_name": "scorpion",
        "class_description": "a venomous arachnid known for its painful sting, equipped with powerful claws and a segmented tail",
        "hf_dataset": "zh-plus/tiny-imagenet", 
        "img_column": "image",  
        "captions": [
            "a scorpion with red and yellow legs on a concrete floor",
            "a scorpion on top of a book with a black background",
            "a person holding a small yellow scorpion on their finger"
        ],
        "image_indices": [3576, 3613, 3699] 
    },
    "Food101": {
        "class_name": "donuts",
        "class_description": "fried ring-shaped dough pastries often glazed or decorated",
        "hf_dataset": "Multimodal-Fatima/Food101_train",
        "img_column": "image", 
        "captions": [
            "a tray of donuts with pink frosting",
            "a person holding a half eaten donut in front of a bakery",
            "a doughnut that is shaped like a letter s"
        ],
        "image_indices": [57760, 57896, 57960] 
    }
}

# Create a grid figure with 3 rows and 4 columns
fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(10, 8))
fig.tight_layout(pad=3)

for row, ds in enumerate(dataset_info):
    info = dataset_info[ds]
    
    # Split the text into separate elements
    text1 = "$\\bf{Dataset:}$ " + ds + "\n"
    text2 = "$\\bf{Class:}$ " + info['class_name'] + "\n\n"
    text3 = info['class_description']

    # Position them separately
    axes[row, 0].axis('off')
    axes[row, 0].text(0.5, 0.7, text1, fontsize=13, ha='center', va='center', wrap=True)
    axes[row, 0].text(0.5, 0.5, text2, fontsize=13, ha='center', va='center', wrap=True)
    axes[row, 0].text(0.5, 0.3, text3, fontsize=11, ha='center', va='center', wrap=True, style='italic')
    
    # Load the Hugging Face dataset using the provided dataset name and image column.
    hf_dataset = load_dataset(info["hf_dataset"], split="train")
    image_column = info["img_column"]
    
    # Use the provided indices from dataset_info to pick images.
    indices = info["image_indices"]
    for col, idx in enumerate(indices, start=1):
        image = hf_dataset[idx][image_column]
        # If image is not square, crop it to a square (centered)
        if hasattr(image, "size"):
            width, height = image.size
            if width != height:
                min_side = min(width, height)
                left = (width - min_side) // 2
                top = (height - min_side) // 2
                right = left + min_side
                bottom = top + min_side
                image = image.crop((left, top, right, bottom))
        axes[row, col].imshow(image, interpolation="nearest", resample=False)
        axes[row, col].axis('off')
        caption = textwrap.fill(info['captions'][col-1], width=25)
        axes[row, col].set_title(caption, fontsize=9)

# Draw horizontal separation lines between rows
line_vertical_shift = 0.02
nrows = len(dataset_info)
for i in range(1, nrows):
    # Get the bounding box for the row above and the row below
    bbox_up = axes[i-1, 0].get_position()
    bbox_low = axes[i, 0].get_position()
    # Compute a y coordinate between the two rows
    y_line = (bbox_up.y0 + bbox_low.y1) / 2 + line_vertical_shift
    # Draw a horizontal line spanning the figure width (adjust x0 & x1 if needed)
    line = plt.Line2D([0.05, 0.95], [y_line, y_line], transform=fig.transFigure, color="black", lw=1)
    fig.add_artist(line)

plt.show()

In [None]:
# save the figure in high resolution
fig.savefig("../images/claude_vs_blip2_prompts.pdf", bbox_inches='tight')

## Images: Real vs Synthetic

In [None]:
import torch
import torch.nn as nn
from diffusers import AutoPipelineForText2Image
import os
from datasets import load_dataset
import matplotlib.pyplot as plt


def load_sd_model(pretrained_sd, finetuned_sd_path, finetuned_sd_weights, device, noise_sigma):
    pipeline = AutoPipelineForText2Image.from_pretrained(pretrained_sd, torch_dtype=torch.float16, use_safetensors=True, safety_checker = None, requires_safety_checker = False).to(device)
    
    if (finetuned_sd_path is not None) and (finetuned_sd_weights is not None):
        pipeline.load_lora_weights(finetuned_sd_path, weight_name=finetuned_sd_weights)
    else:
        print("NOT-finetuned SD loaded.\n")

    pipeline.scheduler.init_noise_sigma = noise_sigma
    if noise_sigma != 1:
        print(f"Initial noise sigma (std) is set to {noise_sigma}, instead of default 1.\n")

    return pipeline

def generate_batch_images_with_stable_diffusion(pipeline, text_prompts, sd_output_resolution, num_inference_steps, guidance_scale):
    width, height = sd_output_resolution
    image = pipeline(text_prompts, width=width, height=height, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).images
    return image

In [None]:
# Pick the first 5 images (of different classes) from the dataset, and their corresponding BLIP-2 caption
dataset_info = {
    "cifar10": {
        "display_name": "CIFAR10",
        "hf_dataset": "uoft-cs/cifar10",
        "img_column": "img",
        "captions": [
            "airplane: a fedex plane is parked on the tarmac at the san diego international airport",
            "frog: a close up of a toad on a rock",
            "bird: a kiwi bird on the grass with its beak open",
            "horse: a black horse standing in the grass",
            "automobile: a black suv driving down a road"
        ],
        "image_indices": [0, 1, 3, 4, 6]
    },
    "cifar100": {
        "display_name": "CIFAR100",
        "hf_dataset": "uoft-cs/cifar100",
        "img_column": "img",
        "captions": [
            "cattle: a cow with a red nose is standing in a field",
            "dinosaur: a dinosaur with a long tail and a long neck",
            "apple: apple cinnamon spice bath bomb",
            "boy: a young boy in a red and white shirt",
            "aquarium_fish: a fish with a black and white stripe on its back"
        ],
        "image_indices": [0, 1, 2, 3, 4]
    },
    "pets": {
        "display_name": "Oxford-IIIT-Pet",
        "hf_dataset": "Isamu136/oxford_pets_with_l14_emb",
        "img_column": "image",
        "captions": [
            "Siamese: a siamese cat sitting in the grass near plants",
            "Birman: a white and black cat with long hair",
            "shiba inu: a dog standing on a log in the middle of a forest",
            "staffordshire bull terrier: a black dog laying on a couch with its head on the armrest",
            "basset hound: a basset hound dog standing on the grass"
        ],
        "image_indices": [0, 1, 2, 3, 4]
    },
    "tiny": {
        "display_name": "TinyImageNet",
        "hf_dataset": "zh-plus/tiny-imagenet",
        "img_column": "image",
        "captions": [
            "n01443537: a fish in a tank with plants and plants",
            "n01629819: a small black and white lizard on the ground",
            "n01641577: a frog sitting on the ground with its head facing the camera",
            "n01644900: a toad is sitting in the grass",
            "n01698640: a large alligator laying on the ground near a body of water"
        ],
        "image_indices": [0, 500, 1000, 1500, 2000]
    },
    "cars": {
        "display_name": "StanfordCars",
        "hf_dataset": "Multimodal-Fatima/StanfordCars_train", 
        "img_column": "image",  
        "captions": [
            "audi tts coupe 2012: a white audi tt parked in front of a dealers showroom",
            "acura tl sedan 2012: the black acura is parked in front of a dealership",
            "dodge dakota club cab 2007: a red dodge truck parked in a gravel lot",
            "hyundai sonata hybrid sedan 2012: the rear end of a red 2019 bmw i3",
            "ford f-450 super duty crew cab 2012: two white trucks parked in a parking lot"
        ],
        "image_indices": [0, 1, 2, 3, 4] 
    },
    "food": {
        "display_name": "Food101",
        "hf_dataset": "Multimodal-Fatima/Food101_train",
        "img_column": "image",
        "captions": [
            "beignets: a plate with a piece of cake and fork on it",
            "prime rib: a piece of meat in a sauce on a plate",
            "ramen: a bowl of ramen with meat and vegetables",
            "hamburger: a plate of food",
            "bruschetta: a sandwich with tomatoes and cheese on a white plate"
        ],
        "image_indices": [0, 750, 1500, 2250, 3000]
    },
    "stl": {
        "display_name": "STL10",
        "hf_dataset": "tanganke/stl10",
        "img_column": "image",
        "captions": [
            "airplane: a small plane sitting on a runway",
            "bird: a yellow bird with a yellow beak sitting on a branch",
            "car: a small yellow car with a driver's seat",
            "cat: a bobcat is sitting on the ground in the sun",
            "deer: a deer is walking down the street at night"
        ],
        "image_indices": [0, 500, 1000, 1500, 2000]
    },
    "imagewoof": {
        "display_name": "Imagewoof",
        "hf_dataset": "frgfm/imagewoof",
        "img_column": "image",
        "captions": [
            "Samoyed: a brown and white dog",
            "English foxhound: a dog sitting in the grass",
            "Golden retriever: a white dog laying on the floor in a cage",
            "Shih-Tzu: a dog is running in the grass",
            "Old English sheepdog: a dingo is sitting in the grass"
        ],
        "image_indices": [0, 932, 1875, 2796, 3745]
    },
    "imagenette": {
        "display_name": "Imagenette",
        "hf_dataset": "Multimodal-Fatima/Imagenette_train",
        "img_column": "image",
        "captions": [
            "cassette player: a radio sitting on top of a counter",
            "tench: a man in a hat holding a fish in front of a pond",
            "chain saw: a red chainsaw with a black handle",
            "church: a large church with a tower",
            "parachute: a large sculpture of a man hanging from a rope"
        ],
        "image_indices": [0, 993, 1956, 2814, 3755]
    },
    "caltech101": {
        "display_name": "Caltech101",
        "hf_dataset": "dpdl-benchmark/caltech101",
        "img_column": "image",
        "captions": [
            "brain: a drawing of a brain with a red line through it",
            "inline_skate: a purple and black inline skate with wheels",
            "beaver: a beaver sitting on a log with a bird on the ground",
            "airplanes: a boeing 737-800 plane is parked at the airport",
            "dragonfly: a dragonfly necklace with a black and white bead"
        ],
        "image_indices": [0, 1, 2, 3, 4]
    },
}


# --- Accumulate both real and synthetic images for comparison ---

comparisons = {}

for dataset_name, info in dataset_info.items():
    # Load the Hugging Face dataset using the provided dataset name and image column.
    if dataset_name=="imagewoof":
        hf_dataset = load_dataset(info["hf_dataset"], "full_size", split="train")
    else:
        hf_dataset = load_dataset(info["hf_dataset"], split="train")
    image_column = info["img_column"]
    
    real_images = []
    # Retrieve the real images
    for idx in info["image_indices"]:
        image = hf_dataset[idx][image_column]
        # If image is not square, crop it to a square (centered)
        if hasattr(image, "size"):
            width, height = image.size
            if width != height:
                min_side = min(width, height)
                left = (width - min_side) // 2
                top = (height - min_side) // 2
                right = left + min_side
                bottom = top + min_side
                image = image.crop((left, top, right, bottom))
        real_images.append(image)

    # Select the path to the finetuned SD model
    if dataset_name not in ["cifar10", "cifar100"]:
        finetuned_sd_path = f"../storage/finetuned_SD/{dataset_name}/sd_2/only_unet/className_and_blip2/1.00"
    else:
        finetuned_sd_path = f"../storage/finetuned_SD/{dataset_name}/sd_2/only_unet/className_and_blip2/50k"

    # Generate images using the Stable Diffusion model
    sd_model = load_sd_model(
        pretrained_sd="stabilityai/stable-diffusion-2", 
        finetuned_sd_path=finetuned_sd_path, 
        finetuned_sd_weights="pytorch_lora_weights.safetensors", 
        device="cuda", 
        noise_sigma=1,
    )
    prompts = info["captions"]
    synthetic_images = generate_batch_images_with_stable_diffusion(
        pipeline=sd_model,
        text_prompts=prompts,
        sd_output_resolution=(224, 224),
        num_inference_steps=20,
        guidance_scale=2
    )

    comparisons[info["display_name"]] = {
        "real": real_images,
        "synthetic": synthetic_images
    }

In [None]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

dataset_names = list(comparisons.keys())
n_samples = len(next(iter(comparisons.values()))["real"])

# Create the outer grid
fig = plt.figure(figsize=(n_samples * 2, len(dataset_names)))
outer = gridspec.GridSpec(len(dataset_names), 2, wspace=0.1, hspace=0.1)

# Set overall column headings
fig.text(0.31, 0.90, "Real", ha='center', fontsize=10, fontweight='bold', color='blue')
fig.text(0.715, 0.90, "Synthetic", ha='center', fontsize=10, fontweight='bold', color='green')

for row, ds_name in enumerate(dataset_names):
    # For Real images (left column):
    cell_real = outer[row, 0]
    inner_real = gridspec.GridSpecFromSubplotSpec(
        1, n_samples,
        subplot_spec=cell_real,
        wspace=0.05, hspace=0.05
    )
    real_imgs = comparisons[ds_name]["real"]
    for col in range(n_samples):
        ax = fig.add_subplot(inner_real[0, col])
        ax.imshow(real_imgs[col])
        ax.axis("off")
        # Add dataset name on the left only once per row in the left column's first image.
        if col == 0:
            ax.text(-0.3, 0.5, ds_name,
                    transform=ax.transAxes,
                    va="center", ha="right", fontsize=8, color='black', fontweight='bold')

    # For Synthetic images (right column):
    cell_synth = outer[row, 1]
    inner_synth = gridspec.GridSpecFromSubplotSpec(
        1, n_samples,
        subplot_spec=cell_synth,
        wspace=0.05, hspace=0.05
    )
    synth_imgs = comparisons[ds_name]["synthetic"]
    for col in range(n_samples):
        ax = fig.add_subplot(inner_synth[0, col])
        ax.imshow(synth_imgs[col])
        ax.axis("off")

plt.tight_layout(rect=[0, 0.05, 1, 0.95])
plt.show()

In [None]:
# save the figure in high resolution
fig.savefig("../images/real_vs_synthetic_images.pdf", bbox_inches='tight')