In [1]:
import numpy as np
import os
import copy

In [2]:



def get_statistics(exp_data):
    data_keys = ['in-domain', 'textile', 'indoor', 'outdoor', 'distractors']
    exp_stats = copy.deepcopy(exp_data)
    for exp in exp_stats:
        all_data = []
        for key in data_keys:
            if key != 'in-domain':
                all_data.append(exp[key])
            exp[key] = {'mean': f'{np.mean(exp[key]):.2f}', 'std': f'{np.std(exp[key]):.2f}'}

        exp['all'] = {'mean': f'{np.mean(all_data):.2f}', 'std': f'{np.std(all_data):.2f}'}
    return exp_stats
            
def print_table(exp_data, format='markdown'):
    exp_stats = get_statistics(exp_data)
    max_means = {key: max(float(exp[key]['mean']) for exp in exp_stats) for key in ['in-domain', 'textile', 'indoor', 'outdoor', 'distractors', 'all']}
    if format == 'markdown':
        print('| Model | Task | In-domain | Textile | Indoor | Outdoor | Distractors | All |')
        print('| --- | --- | --- | --- | --- | --- | --- | --- |')
        for exp in exp_stats:
            row = [exp['Net'], exp['task']]
            for key in ['in-domain', 'textile', 'indoor', 'outdoor', 'distractors', 'all']:
                mean_value = exp[key]['mean']
                std_value = exp[key]['std']
                if float(mean_value) == max_means[key]:
                    row.append(f"**{mean_value} +/- {std_value}**")  # Make the maximum mean bold
                else:
                    row.append(f"{mean_value} +/- {std_value}")
            print("| " + " | ".join(row) + " |")
    elif format == 'latex':
        print('\\begin{table}[h]')
        print('\\centering')
        print('\\begin{tabular}{|c|c|c|c|c|c|c|c|}')
        print('\\hline')
        print('Model & Task & In-domain & Textile & Indoor & Outdoor & Distractors & All \\\\')
        print('\\hline')
        for exp in exp_stats:
            row = [exp['Net'], exp['task']]
            for key in ['in-domain', 'textile', 'indoor', 'outdoor', 'distractors', 'all']:
                mean_value = exp[key]['mean']
                std_value = exp[key]['std']
                if float(mean_value) == max_means[key]:
                    row.append(f"\\textbf{{{mean_value} \pm {std_value}}}")  # Make the maximum mean bold
                else:
                    row.append(f"{mean_value} \pm {std_value}")
            print(" & ".join(row) + " \\\\")
        print('\\hline')
        print('\\end{tabular}')
        print('\\end{table}')
    
        

In [11]:
saga_bc_rnn_square = {'Net': 'SaGA', 'task': 'square', 'in-domain': [0.82, 0.84, 0.84], 'textile': [0.7, 0.64, 0.68], 'indoor': [0.42, 0.42, 0.38], 'outdoor': [0.5, 0.54, 0.46], 'distractors':[0.9, 0.78, 0.76]}
soda_bc_rnn_square = {'Net': 'SODA', 'task': 'square', 'in-domain':[0.76, 0.8, 0.8] , 'textile': [0.6, 0.44, 0.48], 'indoor': [0.28, 0.18, 0.3], 'outdoor': [0.22, 0.18, 0.22], 'distractors':[0.54, 0.58, 0.66]}
overlay_bc_rnn_square = {'Net': 'Overlay', 'task': 'square', 'in-domain':[0.82, 0.82, 0.74], 'textile': [0.6, 0.56, 0.36], 'indoor': [0.22, 0.22, 0.12], 'outdoor': [0.24, 0.24, 0.2], 'distractors':[0.5, 0.44, 0.56]}


saga_diff_square = {'Net': 'SaGA', 'task': 'square', 'in-domain': [0.88, 0.86, 0.86], 'textile': [0.76, 0.72, 0.8], 'indoor': [0.72, 0.54, 0.7], 'outdoor': [0.64, 0.58, 0.56] , 'distractors':[0.78, 0.82, 0.82]}
soda_diff_square = {'Net': 'SODA', 'task': 'square', 'in-domain':[0.9, 0.88, 0.88] , 'textile': [0.72, 0.68, 0.68], 'indoor': [0.46, 0.52, 0.44], 'outdoor': [0.38, 0.52, 0.52], 'distractors':[0.8, 0.8, 0.78]}

overlay_diff_square = {'Net': 'Overlay', 'task': 'square', 'in-domain':[0.92, 0.92, 0.90], 'textile': [0.62, 0.74, 0.62], 'indoor': [0.4, 0.4, 0.42], 'outdoor': [0.44, 0.44, 0.52], 'distractors':[0.76, 0.8, 0.78]}

baseline_diff_square = {'Net': 'Baseline', 'task': 'square', 'in-domain':[0.92, 0.92, 0.90], 'textile': [0.08, 0.12, 0.06], 'indoor': [0.0, 0.0, 0.0], 'outdoor': [0.0, 0.0, 0.0], 'distractors':[0.64, 0.66, 0.68]}


lift_bc_rnn_saga = {'Net': 'SaGA', 'task': 'lift', 'in-domain': [1, 1, 1], 'textile': [0.9, 0.96, 0.98], 'indoor':[0.8, 0.82, 0.98], 'outdoor': [0.86, 0.98, 0.94], 'distractors':[1.0, 1.0, 0.92]}

lift_bc_rnn_soda = {'Net': 'SaGA', 'task': 'lift', 'in-domain': [1, 1, 1], 'textile': [0.9, 0.96, 0.98], 'indoor':[0.8, 0.82, 0.98], 'outdoor': [0.86, 0.98, 0.94], 'distractors':[1.0, 0.9, 0.92]}

bc_rnn_square_exps = [saga_bc_rnn_square, soda_bc_rnn_square, overlay_bc_rnn_square]

diff_square_exps = [saga_diff_square, soda_diff_square, overlay_diff_square, baseline_diff_square]

lift_exps = [lift_bc_rnn_saga]


In [12]:
print_table(lift_exps)

| Model | Task | In-domain | Textile | Indoor | Outdoor | Distractors | All |
| --- | --- | --- | --- | --- | --- | --- | --- |
| SaGA | lift | **1.00 +/- 0.00** | **0.95 +/- 0.03** | **0.87 +/- 0.08** | **0.93 +/- 0.05** | **0.97 +/- 0.01** | **0.93 +/- 0.06** |



## Square
### BC-RNN
| Model | Task | In-domain | Textile | Indoor | Outdoor | Distractors | All |
| --- | --- | --- | --- | --- | --- | --- | --- |
| SaGA | square | **0.83 +/- 0.01** | **0.67 +/- 0.02** | **0.41 +/- 0.02** | **0.50 +/- 0.03** | **0.81 +/- 0.06** | **0.60 +/- 0.16** |
| SODA | square | 0.79 +/- 0.02 | 0.51 +/- 0.07 | 0.25 +/- 0.05 | 0.21 +/- 0.02 | 0.59 +/- 0.05 | 0.39 +/- 0.17 |
| Overlay | square | 0.79 +/- 0.04 | 0.51 +/- 0.10 | 0.19 +/- 0.05 | 0.23 +/- 0.02 | 0.50 +/- 0.05 | 0.35 +/- 0.16 |


### Diffusion Policy
| Model | Task | In-domain | Textile | Indoor | Outdoor | Distractors | All |
| --- | --- | --- | --- | --- | --- | --- | --- |
| SaGA | square | 0.87 +/- 0.01 | **0.76 +/- 0.03** | **0.65 +/- 0.08** | **0.59 +/- 0.03** | **0.81 +/- 0.02** | **0.70 +/- 0.10** |
| SODA | square | 0.89 +/- 0.01 | 0.69 +/- 0.02 | 0.47 +/- 0.03 | 0.47 +/- 0.07 | 0.79 +/- 0.01 | 0.61 +/- 0.14 |
| Overlay | square | **0.91 +/- 0.01** | 0.66 +/- 0.06 | 0.41 +/- 0.01 | 0.47 +/- 0.04 | 0.78 +/- 0.02 | 0.58 +/- 0.15 |
| Baseline | square | **0.91 +/- 0.01** | 0.09 +/- 0.02 | 0.00 +/- 0.00 | 0.00 +/- 0.00 | 0.66 +/- 0.02 | 0.19 +/- 0.28 |



## Lift
| Model | Task | In-domain | Textile | Indoor | Outdoor | Distractors | All |
| --- | --- | --- | --- | --- | --- | --- | --- |
| SaGA | lift | **1.00 +/- 0.00** | **0.95 +/- 0.03** | **0.87 +/- 0.08** | **0.93 +/- 0.05** | **0.97 +/- 0.01** | **0.93 +/- 0.06** |






