In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import yaml
import copy

In [None]:
results_file = ''

data = pd.read_csv('')
filtered_data = data
datasets = filtered_data['dataset.name'].unique()

with open('experiment_config.yml') as f:
    experiment_config = yaml.full_load(f)
datasets

In [None]:
paper_replication_experiments = experiment_config['paper_replication_experiments']

dataset_to_dataset_choice = {
    'FieldGuide28': 'fg28', 
    'FieldGuide2': 'fg2',
    'CIFAR20': 'cifar20',
    'CIFAR10': 'cifar10', 
    'ImageNet50': 'imagenet'
}

approach_to_metric_map = {
    'SCAN' : {
        'test_acc': 'scan_alone_best_acc',
        'test_err': 'scan_alone_reconstruction_error_L1'
    },
    'DDFA (SPI)': {
        'test_acc': 'test_post_cluster_acc',
        'test_err': 'test_post_cluster_p_y_given_d_l1_norm'
    },
    'DDFA (SI)': {
        'test_acc': 'test_post_cluster_acc',
        'test_err': 'test_post_cluster_p_y_given_d_l1_norm'
    },
    'DDFA (RI)': {
        'test_acc': 'test_post_cluster_acc',
        'test_err': 'test_post_cluster_p_y_given_d_l1_norm'
    },
    'Naïve': {
        'test_acc': 'test_post_cluster_acc',
        'test_err': 'test_post_cluster_p_y_given_d_l1_norm'
    },
    'Naïve (ICA)': {
        'test_acc': 'test_post_cluster_acc',
        'test_err': 'test_post_cluster_p_y_given_d_l1_norm'
    },
    'Naïve (PCA)': {
        'test_acc': 'test_post_cluster_acc',
        'test_err': 'test_post_cluster_p_y_given_d_l1_norm'
    },
}

approaches = {
    'CIFAR10': ['SCAN','DDFA (RI)', 'DDFA (SI)'], 
    'CIFAR20': ['SCAN','DDFA (RI)', 'DDFA (SI)', 'Naïve', 'Naïve (ICA)', 'Naïve (PCA)'], 
    'ImageNet50': ['SCAN','DDFA (SI)'], 
    'FieldGuide2': ['SCAN','DDFA (SPI)'], 
    'FieldGuide28': ['SCAN','DDFA (SPI)'], 
}

approach_to_dd_name = {
    'CIFAR10': {
        'SCAN': 'scan_scan',
        'DDFA (RI)': 'CIFAR10PytorchCifar',
        'DDFA (SI)': 'scan_scan'
    }, 
    'CIFAR20': {
        'SCAN': 'scan_scan',
        'DDFA (RI)': 'CIFAR10PytorchCifar',
        'DDFA (SI)': 'scan_scan',
        'Naïve': 'scan_scan_naive',
        'Naïve (ICA)': 'scan_scan_naive_ica',
        'Naïve (PCA)': 'scan_scan_naive_pca',
    }, 
    'ImageNet50': {
        'SCAN': 'scan_scan_imagenet',
        'DDFA (SI)': 'scan_scan_imagenet'
    }, 
    'FieldGuide2': {
        'SCAN': 'scan_pretext',
        'DDFA (SPI)': 'scan_pretext',
    }, 
    'FieldGuide28': {
        'SCAN': 'scan_pretext',
        'DDFA (SPI)': 'scan_pretext',
    }, 
}

approach_to_class_prior_estimator_name = copy.deepcopy(approach_to_dd_name)
for dataset in approach_to_class_prior_estimator_name:
    for approach in approach_to_class_prior_estimator_name[dataset]:
        approach_to_class_prior_estimator_name[dataset][approach] = 'ClusterNMFClassPriorEstimation'

dataset_to_kappa_list = {
    dataset : paper_replication_experiments['datasets'][dataset_to_dataset_choice[dataset]]['max_condition_numbers']
    for dataset in datasets
}

dataset_to_domains = {
    dataset : sorted(paper_replication_experiments['datasets'][dataset_to_dataset_choice[dataset]]['domains'])
    for dataset in datasets
}

def compute_metric_agg(d, metric, alpha, approach, dataset):

    dd_name = approach_to_dd_name[dataset][approach]
    class_prior_estimator_name = approach_to_class_prior_estimator_name[dataset][approach]

    # necessary for clustering ablation
    main_size_cifar_20 = 20
    if dataset == 'CIFAR20' and (approach == 'DDFA (SI)' or approach == 'SCAN'):
        rows = filtered_data[(filtered_data['class_prior.n_domains'] == d) & \
                                (filtered_data['class_prior.alpha'] == alpha) & \
                                    (filtered_data['dataset.name'] == dataset) & \
                                        (filtered_data['discriminator.name'] == dd_name) & \
                                            (filtered_data['class_prior_estimator.name'] == class_prior_estimator_name) & 
                                             (filtered_data['class_prior_estimator.n_discretization'] == main_size_cifar_20)]
    else:
        rows = filtered_data[(filtered_data['class_prior.n_domains'] == d) & \
                        (filtered_data['class_prior.alpha'] == alpha) & \
                            (filtered_data['dataset.name'] == dataset) & \
                                (filtered_data['discriminator.name'] == dd_name) & \
                                    (filtered_data['class_prior_estimator.name'] == class_prior_estimator_name)]

    if len(rows) == 3:
        mean_pre_round = float(rows[approach_to_metric_map[approach][metric]].mean())
        if metric == 'test_err':
            mean_pre_round /= (d * filtered_data[filtered_data['dataset.name'] == dataset]['class_prior.n_classes'].mean())
        mean = np.round(mean_pre_round, decimals=3)

        std_pre_round = float(rows[approach_to_metric_map[approach][metric]].std())
        if metric == 'test_err':
            std_pre_round /= (d * filtered_data[filtered_data['dataset.name'] == dataset]['class_prior.n_classes'].mean())
        std = np.round(std_pre_round, decimals=3)
        return mean, std
    else:
        return '---', '---'

plot_results = {}
for dataset in datasets:
    metrics_and_metric_names = [('test_acc','Test accuracy'), ('test_err',r'Avg $Q_{Y|D}$ L1 reconstruction error per domain')]
    alpha_kappa_list = list(zip([0.5,3,10],dataset_to_kappa_list[dataset]))
    domains = dataset_to_domains[dataset]
    plot_results[dataset] = {
        d : {
            metric : {
                (alpha, kappa) : {
                    approach : compute_metric_agg(d, metric, alpha, approach, dataset)
                    for approach in approaches[dataset]
                } for alpha, kappa in alpha_kappa_list
            } for metric, metric_name in metrics_and_metric_names
        } for d in domains
    }
plot_results

In [None]:
# print outputs for main results and naive ablation
print_std = True
print_naive = True
for dataset in ['CIFAR10','CIFAR20','ImageNet50','FieldGuide2','FieldGuide28']:
    print(dataset)
    print('\n')
    for d in dataset_to_domains[dataset]:
        for approach in approaches[dataset]:
            if not print_naive and 'Naïve' in approach:
                continue
            format_string = '&'
            if approach == 'SCAN':
                print('\midrule')
                format_string = f'\multirow{{2}}{{*}}{{{d}}} &'
            format_string += f'{approach}'
            for alpha, kappa in list(zip([0.5,3,10],dataset_to_kappa_list[dataset])):
                for metric, metric_name in [('test_acc','Test accuracy'), ('test_err',r'Avg $Q_{Y|D}$ L1 reconstruction error per domain')]:
                    mean, std = plot_results[dataset][d][metric][(alpha, kappa)][approach]
                    if mean == '---' or std == '---':
                        format_string += f' & ---'
                    else:
                        if (metric == 'test_acc' and mean >= max(filter(lambda e : not isinstance(e,str), [plot_results[dataset][d][metric][(alpha, kappa)][approach][0] for approach in approaches[dataset]])  )) or (metric == 'test_err' and mean <= min(filter(lambda e : not isinstance(e,str), [plot_results[dataset][d][metric][(alpha, kappa)][approach][0] for approach in approaches[dataset]]))):
                            if print_std:
                                format_string += f' & \\textbf{{{mean:.3f}}} $\pm$ {std:.3f} '
                            else:
                                format_string += f' & \\textbf{{{mean:.3f}}}'
                        else:
                            if print_std:
                                format_string += f' & {mean:.3f} $\pm$ {std:.3f} '
                            else:
                                format_string += f' & {mean:.3f} '
            format_string += ' \\\\'
            print(format_string)
    print('\n\n\n\n')

In [None]:
# Print poster barplots
fig, axes = plt.subplots(4, 2, figsize=(9,10),dpi=1000)
domains = dataset_to_domains['FieldGuide28']
alpha_kappa_list = list(zip([0.5,3,10],dataset_to_kappa_list['FieldGuide28']))
for row, d in enumerate(domains):
    for col, (metric, metric_name) in enumerate(metrics_and_metric_names):
        ax = axes[row,col]

        # https://stackoverflow.com/questions/25812255/row-and-column-headers-in-matplotlibs-subplots
        if row == 0:
            ax.set_title(metric_name)

        if col == 0:
            ax.set_ylabel(f'{d} domains', rotation=90, size='large')

        # https://matplotlib.org/stable/gallery/lines_bars_and_markers/barchart.html

        labels = [f'$\\alpha$: {alpha}, $\\kappa$: {kappa}' for alpha, kappa in alpha_kappa_list]
        scan     = [plot_results['FieldGuide28'][d][metric][(alpha, kappa)]['SCAN'][0] for alpha, kappa in alpha_kappa_list]
        ddfa_spi = [plot_results['FieldGuide28'][d][metric][(alpha, kappa)]['DDFA (SPI)'][0] for alpha, kappa in alpha_kappa_list]
        x = np.arange(len(alpha_kappa_list))  # the label locations
        width = 0.35  # the width of the bars

        rects1 = ax.bar(x - width/2, scan, width, label='SCAN')
        rects2 = ax.bar(x + width/2, ddfa_spi, width, label='DDFA (SPI)')

        # Add some text for labels, title and custom x-axis tick labels, etc.
        # ax.set_ylabel('Scores')
        # ax.set_title('Semi-synthetic generatio')
        ax.set_xticks(x, labels)
        if metric == 'test_acc':
            ax.set_ylim([0,0.81])
        else:
            ax.set_ylim([0,2.5])
        ax.legend()

        ax.bar_label(rects1, padding=3)
        ax.bar_label(rects2, padding=3)

fig.suptitle('FieldGuide-28 Performance at different domain counts, avg. over 5 trials')
fig.tight_layout()

In [None]:
prior = pd.read_csv('prior_lookup.csv')
data = pd.read_csv('wandb_export_2023-01-15T10_07_03.702-05_00.csv')

def get_prior_cond_number(n_classes, n_domains, alpha, seed, dataframe_lookup):

    filtered_data = dataframe_lookup[
        (dataframe_lookup['class_prior.n_classes'] == n_classes) & 
        (dataframe_lookup['class_prior.n_domains'] == n_domains) & 
        (dataframe_lookup['class_prior.alpha'] == alpha) &
        (dataframe_lookup['class_prior.random_seed'] == seed)
    ]

    lookup_entry = filtered_data.iloc[0]

    string_to_parse = lookup_entry['class_prior.matrices.class_prior']
    num_classes     = lookup_entry['class_prior.n_classes']
    num_domains     = lookup_entry['class_prior.n_domains']

    condition_number = lookup_entry['class_prior.condition_number']

    return condition_number

for i, row in data.iterrows():
    cond  = row['class_prior.condition_number']
    n_classes     = row['class_prior.n_classes']
    n_domains     = row['class_prior.n_domains']
    alpha         = row['class_prior.alpha']
    seed          = row['class_prior.random_seed']

    prior_cond = get_prior_cond_number(n_classes, n_domains, alpha, seed, prior)

    assert(cond == prior_cond)

In [None]:
# PLOT cluster ablation - accuracy

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

df = pd.read_csv('wandb_export_2023-01-15T10_07_03.702-05_00.csv')

d = 20
dataset = 'CIFAR20'
dd_name = 'scan_scan'
class_prior_estimator_name = 'ClusterNMFClassPriorEstimation'

ms = [10, 20, 35, 50, 100, 150]
alphas = [0.5, 3, 10]

plt.figure(figsize=(9,6))

colors_1 = ['red','blue','orange']
colors_2 = ['green','black','brown']

for alpha, color_1, color_2 in zip(alphas, colors_1, colors_2):


    attr = 'test_post_cluster_acc'
    perfs = [df.loc[(df['class_prior.n_domains'] == d) & \
            (df['class_prior.alpha'] == alpha) & \
                (df['dataset.name'] == dataset) & \
                    (df['discriminator.name'] == dd_name) & \
                        (df['class_prior_estimator.name'] == class_prior_estimator_name) & 
                            (df['class_prior_estimator.n_discretization'] == m)][attr].mean() for m in ms]
    std = [df.loc[(df['class_prior.n_domains'] == d) & \
            (df['class_prior.alpha'] == alpha) & \
                (df['dataset.name'] == dataset) & \
                    (df['discriminator.name'] == dd_name) & \
                        (df['class_prior_estimator.name'] == class_prior_estimator_name) & 
                            (df['class_prior_estimator.n_discretization'] == m)][attr].std() for m in ms]
    plt.plot(ms, perfs, label=f'alpha {alpha}: DDFA (SI)', color=color_1)
    plt.fill_between(ms, [p + s for p,s in zip(perfs, std)], [p - s for p,s in zip(perfs, std)], alpha=0.1, color=color_1)
    plt.plot(ms, perfs, 'ko')

    for m in ms:
        assert(len(df.loc[(df['class_prior.n_domains'] == d) & \
            (df['class_prior.alpha'] == alpha) & \
                (df['dataset.name'] == dataset) & \
                    (df['discriminator.name'] == dd_name) & \
                        (df['class_prior_estimator.name'] == class_prior_estimator_name) & 
                            (df['class_prior_estimator.n_discretization'] == m)][attr]) == 3)

    attr = 'scan_alone_best_acc'
    perfs = [df.loc[(df['class_prior.n_domains'] == d) & \
            (df['class_prior.alpha'] == alpha) & \
                (df['dataset.name'] == dataset) & \
                    (df['discriminator.name'] == dd_name) & \
                        (df['class_prior_estimator.name'] == class_prior_estimator_name) & 
                            (df['class_prior_estimator.n_discretization'] == m)][attr].mean() for m in ms]
    std = [df.loc[(df['class_prior.n_domains'] == d) & \
            (df['class_prior.alpha'] == alpha) & \
                (df['dataset.name'] == dataset) & \
                    (df['discriminator.name'] == dd_name) & \
                        (df['class_prior_estimator.name'] == class_prior_estimator_name) & 
                            (df['class_prior_estimator.n_discretization'] == m)][attr].std() for m in ms]
    plt.fill_between(ms, [p + s for p,s in zip(perfs, std)], [p - s for p,s in zip(perfs, std)], alpha=0.1, color=color_2)
    plt.plot(ms, perfs, '--', label=f'alpha {alpha}: SCAN', color=color_2)


    # plt.plot(ms, perfs, 'ko')


plt.title('Test accuracy vs. number of clusters m, on CIFAR20 with 20 domains, across problem settings (alpha)')
plt.xticks(np.arange(min(ms), max(ms)+1, 10))
plt.legend()
plt.xlabel('Number of clusters m')
plt.ylabel('Test accuracy')
plt.savefig('m_ablation_test_acc.pdf')

In [None]:
# PLOT cluster ablation - error

plt.clf()
plt.figure(figsize=(9,6))
colors_1 = ['red','blue','orange']
colors_2 = ['green','black','brown']

for alpha, color_1, color_2 in zip(alphas, colors_1, colors_2):

    for m in ms:
        assert(len(df.loc[(df['class_prior.n_domains'] == d) & \
            (df['class_prior.alpha'] == alpha) & \
                (df['dataset.name'] == dataset) & \
                    (df['discriminator.name'] == dd_name) & \
                        (df['class_prior_estimator.name'] == class_prior_estimator_name) & 
                            (df['class_prior_estimator.n_discretization'] == m)][attr]) == 3)

    attr = 'test_post_cluster_p_y_given_d_l1_norm'
    perfs = [df.loc[(df['class_prior.n_domains'] == d) & \
            (df['class_prior.alpha'] == alpha) & \
                (df['dataset.name'] == dataset) & \
                    (df['discriminator.name'] == dd_name) & \
                        (df['class_prior_estimator.name'] == class_prior_estimator_name) & 
                            (df['class_prior_estimator.n_discretization'] == m)][attr].mean() / (20 * 20) for m in ms]
    std = [df.loc[(df['class_prior.n_domains'] == d) & \
            (df['class_prior.alpha'] == alpha) & \
                (df['dataset.name'] == dataset) & \
                    (df['discriminator.name'] == dd_name) & \
                        (df['class_prior_estimator.name'] == class_prior_estimator_name) & 
                            (df['class_prior_estimator.n_discretization'] == m)][attr].std() / (20 * 20) for m in ms]
    plt.fill_between(ms, [p + s for p,s in zip(perfs, std)], [p - s for p,s in zip(perfs, std)], alpha=0.1, color=color_1)
    plt.plot(ms, perfs, label=f'alpha {alpha}: DDFA (SI)', color=color_1)
    plt.plot(ms, perfs, 'ko')

    attr = 'scan_alone_reconstruction_error_L1'
    perfs = [df.loc[(df['class_prior.n_domains'] == d) & \
            (df['class_prior.alpha'] == alpha) & \
                (df['dataset.name'] == dataset) & \
                    (df['discriminator.name'] == dd_name) & \
                        (df['class_prior_estimator.name'] == class_prior_estimator_name) & 
                            (df['class_prior_estimator.n_discretization'] == m)][attr].mean() / (20 * 20) for m in ms]
    std = [df.loc[(df['class_prior.n_domains'] == d) & \
            (df['class_prior.alpha'] == alpha) & \
                (df['dataset.name'] == dataset) & \
                    (df['discriminator.name'] == dd_name) & \
                        (df['class_prior_estimator.name'] == class_prior_estimator_name) & 
                            (df['class_prior_estimator.n_discretization'] == m)][attr].std() / (20 * 20) for m in ms]
    plt.plot(ms, perfs, '--', label=f'alpha {alpha}: SCAN', color=color_2)
    plt.fill_between(ms, [p + s for p,s in zip(perfs, std)], [p - s for p,s in zip(perfs, std)], alpha=0.1, color=color_2)

    # plt.plot(ms, perfs, 'ko')

    if alpha == 10:
        print(perfs)



plt.title(r'$Q_{Y|D}$ err vs. number of clusters m, on CIFAR20 with 20 domains, across problem settings (alpha)')
plt.xticks(np.arange(min(ms), max(ms)+1, 10))
plt.legend()
plt.xlabel('Number of clusters m')
plt.ylabel(r'$Q_{Y|D}$ err')
plt.savefig('m_ablation_q_y_d_err.pdf')