In [6]:
import os
import pickle
import numpy as np

In [9]:
methods = [
    'WMSSDA-A',
    'WMSSDA-B',
    'WMSSDA-C',
    'WMSSDA-D',
    'WMSSDA-E',
    'WMSSDA',
]

clean_metrics = {
    'bal_acc': 'bACC',
    'auc': 'AUC',
    'f1': 'F1',
}

dataset = 'ABLATION'

runs = 5

clean_domains = {
    'mnist': 'MNIST',
    'mnistm': 'MNIST-M',
    'svhn': 'SVHN',
    'syn': 'SYN',
    'usps': 'USPS'
}
batch_size = 128
epochs = 150
py = ''

In [11]:
end_line = ' \\\\'
print('\hline')
line = 'Method & Metric &'
for i, (domain, clean_domain) in enumerate(clean_domains.items()):
    line += f' \multicolumn{{2}}{{l{"||" if i == len(clean_domains) - 1 else "|"}}}{{{clean_domain}}} &'
line += f' \multicolumn{{2}}{{l|}}{{Avg}}'
print(line + end_line)
print('\hline')

bests = {}
for domain in list(clean_domains.keys()) + ['avg']:
    bests[domain] = {}
    for metric in clean_metrics.keys():
        bests[domain][metric] = 0

for i, method in enumerate(methods):
    for metric, clean_metric in clean_metrics.items():
        filename = f'results/{dataset}_{py}{method.replace("-", "_")}_r{runs}_e{epochs}_b{batch_size}.pickle'
        if os.path.exists(filename):
            metrics = pickle.load(open(filename, 'rb'))
            avg = []
            for domain, clean_domain in clean_domains.items():
                metric_values = metrics[domain][metric]
                avg.extend(metric_values)
                if np.mean(metric_values) >= bests[domain][metric]:
                    bests[domain][metric] = np.mean(metric_values)
            if np.mean(avg) >= bests['avg'][metric]:
                bests['avg'][metric] = np.mean(avg)

for i, method in enumerate(methods):
    line = f'\\multirow{{{len(clean_metrics)}}}{{*}}{{{method}}} &'
    for metric, clean_metric in clean_metrics.items():
        line += f' {clean_metric} &'
        filename = f'results/{dataset}_{py}{method.replace("-", "_")}_r{runs}_e{epochs}_b{batch_size}.pickle'
        if os.path.exists(filename):
            metrics = pickle.load(open(filename, 'rb'))
            avg = []
            for domain, clean_domain in clean_domains.items():
                metric_values = metrics[domain][metric]
                avg.extend(metric_values)
                best = np.mean(metric_values) == bests[domain][metric]
                line += ' $' + ('\\mathbf{' if best else '')
                if metric == 'auc':
                    mean = f'{np.mean(metric_values)/100.:.04f}'.lstrip('0')
                    std = f'{np.std(metric_values)/100.:.04f}'.lstrip('0')
                    line += f'{mean} \pm {std}'
                else:
                    line += f'{np.mean(metric_values):.02f} \pm {np.std(metric_values):.02f}'
                line += ('}' if best else '') + '$ &'
                line += ' &'
            best = np.mean(avg) == bests['avg'][metric]
            line += ' $' + ('\\mathbf{' if best else '')
            if metric == 'auc':
                mean = f'{np.mean(avg)/100.:.04f}'.lstrip('0')
                line += f'{mean}'
            else:
                line += f'{np.mean(avg):.02f}'
            line += ('}' if best else '') + '$ &'
        else:
            line += ' &' * len(clean_domains)
        print(line + end_line)
        line = '&'
    if i < len(methods) - 1:
        print('\hline')
    line = ''
print('\hline')

\hline
Method & Metric & \multicolumn{2}{l|}{MNIST} & \multicolumn{2}{l|}{MNIST-M} & \multicolumn{2}{l|}{SVHN} & \multicolumn{2}{l|}{SYN} & \multicolumn{2}{l||}{USPS} & \multicolumn{2}{l|}{Avg} \\
\hline
\multirow{3}{*}{WMSSDA-A} & bACC & $94.95 \pm 0.58$ & & $72.02 \pm 1.24$ & & $67.69 \pm 2.28$ & & $80.94 \pm 0.49$ & & $96.57 \pm 0.22$ & & $82.43$ & \\
& AUC & $.9985 \pm .0001$ & & $.9620 \pm .0024$ & & $.9429 \pm .0054$ & & $.9783 \pm .0012$ & & $.9989 \pm .0002$ & & $.9761$ & \\
& F1 & $94.89 \pm 0.61$ & & $71.20 \pm 1.35$ & & $65.64 \pm 2.70$ & & $80.63 \pm 0.53$ & & $96.83 \pm 0.19$ & & $81.84$ & \\
\hline
\multirow{3}{*}{WMSSDA-B} & bACC & $93.77 \pm 0.90$ & & $74.62 \pm 1.50$ & & $66.34 \pm 2.24$ & & $79.72 \pm 1.20$ & & $95.89 \pm 0.35$ & & $82.07$ & \\
& AUC & $.9980 \pm .0003$ & & $.9721 \pm .0021$ & & $.9414 \pm .0048$ & & $.9805 \pm .0013$ & & $.9982 \pm .0005$ & & $.9781$ & \\
& F1 & $93.71 \pm 0.92$ & & $73.79 \pm 1.72$ & & $62.83 \pm 2.40$ & & $79.48 \pm 1.18$ & & $96.0