In [None]:
import torch
import math
import numpy as np
import matplotlib.pyplot as plt
import plot_settings as plot_settings
from PIL import Image
from IPython.display import clear_output
clear_output(wait=True)

In [None]:
plot_settings.set_latex_settings()

## Efficiency

In [None]:
import torch
import math
import numpy as np
import matplotlib.pyplot as plt
import plot_settings as plot_settings
from PIL import Image

plot_settings.set_latex_settings()

shapes = Image.open(r"images/Picture1.png")
mpi = Image.open(r"images/Picture2.png")
cars = Image.open(r"images/Picture3.png")
dsprites = Image.open(r"images/Picture4.png")
raven = Image.open(r"images/Picture5.png")
clevr = Image.open(r"images/Picture6.png")
cub = Image.open(r"images/Picture7.jpg")

# Define x values (number of attributes)
x = np.arange(2, 31)

# Compute y values (number of combinations: 2^n - 1)
y_comb = [math.comb(n, 2) for n in x]
y_const = [1 for _ in x]

# Create the plot
figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))
ax1.plot(x, y_comb, label=r'Pair-wise evaluation', color="#a00000",)
ax1.plot(x, y_const, label=r'Orthotopic evaluation', color="#298c8c", )

plt.axvline(x = 3, ymin = 0, ymax = 400, linestyle=":", color=(0, 0, 0, 0.4))
plt.axvline(x = 4, ymin = 0, ymax = 400, linestyle=":", color=(0, 0, 0, 0.4))
plt.axvline(x = 5, ymin = 0, ymax = 400, linestyle=":", color=(0, 0, 0, 0.4))
plt.axvline(x = 6, ymin = 0, ymax = 400, linestyle=":", color=(0, 0, 0, 0.4))
plt.axvline(x = 24, ymin = 0, ymax = 400, linestyle=":", color=(0, 0, 0, 0.4))

# Log scale for y-axis
# plt.yscale('log')
plt.xscale('log')
ax1.set_xticks([2,3,4,5,6,7,8,9,10,20,30])
ax1.set_xticklabels([2,3,4,5,6,7,8,9,10,20,30])
ax1.set_yticklabels([0, 1, 100, 200, 300, 400])
ax1.tick_params(axis="x", direction="in")
ax1.tick_params(axis="y", direction="in")

# Labels and title
plt.xlabel("$P$", fontsize=16)
plt.ylabel("Complexity", fontsize=16)
# plt.title("Efficiency of the proposed evaluation scheme", fontsize=14)
plt.legend(bbox_to_anchor=(0.5, -0.37), loc="lower center", fontsize=14)


ax_image = figure1.add_axes([0.30,0.76,0.08,0.08])
ax_image.imshow(clevr)
ax_image.set_title("CLEVR", y=-.8, fontsize=14)
ax_image.axis('off')

ax_image = figure1.add_axes([0.36,0.6,0.08,0.08])
ax_image.imshow(shapes)
ax_image.set_title("Shapes3D", y=-.8, fontsize=14)
ax_image.axis('off')

ax_image = figure1.add_axes([0.23,0.52,0.08,0.08])
ax_image.imshow(cars)
ax_image.set_title("Cars3D", y=-.8, fontsize=14)
ax_image.axis('off')



ax_image = figure1.add_axes([0.36,0.37,0.08,0.08])
ax_image.imshow(mpi)
ax_image.set_title("MPI3D", y=-.8, fontsize=14)
ax_image.axis('off')


ax_image = figure1.add_axes([0.23,0.36,0.08,0.08])
ax_image.imshow(raven)
ax_image.set_title("Raven", y=-.8, fontsize=14)
ax_image.axis('off')

ax_image = figure1.add_axes([0.23,0.22,0.08,0.08])
ax_image.imshow(dsprites)
ax_image.set_title("dSprites", y=-.8, fontsize=14)
ax_image.axis('off')

ax_image = figure1.add_axes([0.77,0.3,0.08,0.08])
ax_image.imshow(cub)
ax_image.set_title("CUB", y=-.8, fontsize=14)
ax_image.axis('off')

# Grid for readability
# plt.grid(True, which="both", linestyle="--", linewidth=0.5)

# Show the plot
figure1.savefig("results/efficiency.pgf", bbox_inches="tight")
!./pgf_compiler.sh efficiency

## Similarity

In [None]:
import torch 
import math
from numpy import dot
from numpy.linalg import norm

attributes = 6  # Dimension of the tensors
num_samples = 100  # Number of random trials per c
c_values = np.arange(0, attributes + 1)  # Values of c
values = 8
figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))

colors = {
    0: "#a00000",              # Deep Purple
    1: "#5B2C6F",              # Deep Purple
    2: "#2874A6",  # Steel Blue
    3: "#148F77",   # Dark Cyan
    4: "#D4AC0D",            # Amber Gold
}
codebook_comp = torch.randn(values, 2048)
codebook = torch.randn(int(math.pow(values, attributes)), 1024)
for r in range(0,5):
    if r == 1:
        continue
    mean_similarities = []
    variance_similarities = []
    
    for c in c_values:
        similarities = []
        for _ in range(num_samples):
            a = np.random.randint(0, values, attributes)
            b = np.random.randint(0, values, attributes)
            # shared_indices = np.random.choice(attributes, c, replace=False)
            for i in range(len(a)):
                if i < c:
                    b[i] = a[i]
                else:
                    b[i] = (a[i]+1) % values
            def encode(l):
                return int(sum([l[i]*math.pow(values, i) for i in range(len(l)-r)]))
            if r > 0:
                a_hol = codebook[encode(a)]
                b_hol = codebook[encode(b)]
                a_hol = torch.cat([a_hol] + [codebook_comp[a[-tmp]] for tmp in range(1,r+1)])
                b_hol = torch.cat([b_hol] + [codebook_comp[b[-tmp]] for tmp in range(1,r+1)])
            else:
                a_hol = codebook[encode(a)]
                b_hol = codebook[encode(b)]
            sim = dot(a_hol, b_hol)/(norm(a_hol)*norm(b_hol))
            similarities.append(sim)
        mean_similarities.append(np.clip(np.mean(similarities), a_min=0, a_max=1))
        variance_similarities.append(np.clip(np.std(similarities), a_min=0, a_max=1))
    mean_similarities = np.array(mean_similarities)
    variance_similarities = np.array(variance_similarities)
    # ax1.plot(c_values, mean_similarities, color="#a00000", label="Holistic ")
    # lo = np.clip(mean_similarities - variance_similarities, a_min=0, a_max=1) 
    # up = np.clip(mean_similarities + variance_similarities, a_min=0, a_max=1) 
    # ax1.fill_between(c_values, lo, up, color="#a00000")
    ax1.plot(c_values,mean_similarities, "d", ls="-", color=colors[r])

ax1.text(4.6, 0.15, "$n=P-1$", rotation=78, fontsize=12, color=colors[0], va='center', ha='left')
# ax1.text(3.7, 0.15, "$n=P-2$", rotation=63, fontsize=12, color=colors[1], va='center', ha='left')
ax1.text(2.7, 0.10, "$n=4$", rotation=50, fontsize=12, color=colors[2], va='center', ha='left')
ax1.text(1.8, 0.08, "$n=3$", rotation=40, fontsize=12, color=colors[3], va='center', ha='left')
ax1.text(.9, 0.08, "$n=2$", rotation=30, fontsize=12, color=colors[4], va='center', ha='left')
hol = plt.Line2D([], [], color='k', marker='d', linestyle='-', label='Holistic representation')
comp = plt.Line2D([], [], color='k', marker='o', linestyle='-', label='Compositional representation')




# concat compositional representations
mean_similarities = []
variance_similarities = []
for c in c_values:
    similarities = []

    for _ in range(num_samples):
        a = np.random.randint(0, values, attributes)
        b = np.random.randint(0, values, attributes)
        shared_indices = np.random.choice(attributes, c, replace=False)
        for i in range(len(a)):
            if i in shared_indices:
                b[i] = a[i]
            else:
                b[i] = (a[i]+2) % values
        codebook = torch.randn(values, 64)
        a_conc_rep = torch.cat([codebook[i,:] for i in a])
        b_conc_rep = torch.cat([codebook[i,:] for i in b])
        sim = dot(a_conc_rep, b_conc_rep)/(norm(a_conc_rep)*norm(b_conc_rep))
        similarities.append(sim)
    mean_similarities.append(np.mean(similarities))
    variance_similarities.append(np.std(similarities))

mean_similarities = np.array(mean_similarities)
variance_similarities = np.array(variance_similarities)
# comp = ax1.plot(c_values, mean_similarities, "c", color="#298c8c", label="Concatenative compositional")
# lo = np.clip(mean_similarities - variance_similarities, a_min=0, a_max=1) 
# up = np.clip(mean_similarities + variance_similarities, a_min=0, a_max=1) 
# ax1.fill_between(c_values, lo, up, color="#298c8c")
ax1.plot(c_values,mean_similarities, linestyle='-', marker="o", color="k", label="Compositional representation")




ax1.set_xticks([0,1,2,3,4,5,6])
ax1.set_xticklabels([0,1,2,3,"$\dots","$P-1$","$P$"])

ax1.axvline(x=0.5, linestyle='--', color='#298c8c', linewidth=0.8)
ax1.axvline(x=1.5, linestyle='--', color='#298c8c', linewidth=0.8)
ax1.axvline(x=5.5, linestyle='--', color='#298c8c', linewidth=0.8)

ax1.text(0.07, 0.5, "extrapolation", rotation=90, ha='center', va='center', fontsize=16,                 color='#298c8c', transform=ax1.transAxes)
ax1.text(0.21, 0.65, "comp. generalization", rotation=90, ha='center', va='center', fontsize=16, color='#298c8c', transform=ax1.transAxes)
ax1.text(0.49, 0.7, "weak comp. \ngeneralization", rotation=45, ha='center', va='center', fontsize=16,  color='#298c8c', transform=ax1.transAxes)
ax1.text(0.93, 0.5, "in-distribution", rotation=90, ha='center', va='center', fontsize=16,               color='#298c8c', transform=ax1.transAxes)

# style
plt.legend(handles=[hol,comp], bbox_to_anchor=(0.5, -0.35), loc="lower center", fontsize=14)
plt.xlim(-0.5,6.5)
ax1.tick_params(axis="x", direction="in")
ax1.tick_params(axis="y", direction="in")
plt.xlabel("$c$", fontsize=16)
plt.ylabel("Cosine similarity", fontsize=16)
figure1.savefig("results/similarity.pgf", bbox_inches="tight")
plt.close()
!./pgf_compiler.sh similarity

## Grokking

In [None]:
import pandas as pd

splits = {
    "or_el": "Orientation-Elevation",
    "el_ty": "Elevation-Type",
    "or_ty": "Orientation-Type"
}


for code, name in splits.items():
    file_path = f'grokking_data/cars_{code}.csv'
    train_acc = []
    wio_acc = []
    val_acc = []
    test_acc = []
    df = pd.read_csv(file_path)
    df = df[["Step","Grouped runs - train_acc__MAX", "Grouped runs - val_acc__MAX", "Grouped runs - test_acc__MAX", "Grouped runs - wio_acc__MAX"]]
    df.columns = df.columns.str.replace(r'^Grouped runs - ', '', regex=True)
    df = df.iloc[::10].reset_index(drop=True)
    df1 = df.rolling(window=100, min_periods=1).mean()
    df2 = df.rolling(window=1000, min_periods=1).mean()



    figure1, ax1 = plt.subplots(1, 1, figsize=(16,5))#, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))
    colors = {
        'train': '#5F8B4C',
        'val':   '#FFDDAB',
        'wio':   '#FF9A9A',
        'test':  '#945034' 
    }
    ax1.scatter(df1["Step"], df1["train_acc__MAX"], label='Train Accuracy', color=colors['train'], marker='o', s=1.2)
    ax1.scatter(df1["Step"], df1["val_acc__MAX"], label='Validation Accuracy', color=colors['val'], marker='o', s=1.2)
    ax1.scatter(df1["Step"], df1["wio_acc__MAX"], label='WIO Accuracy', color=colors['wio'], marker='o', s=1.2)
    ax1.scatter(df1["Step"], df1["test_acc__MAX"], label='Test Accuracy', color=colors['test'], marker='o', s=1.2)
    ax1.tick_params(axis='both', which='major', labelsize=16)
    ax1.tick_params(axis='both', which='minor', labelsize=16)

    plt.xlabel('Epoch', fontsize=16)
    plt.ylabel('Accuracy', fontsize=16)
    plt.legend(markerscale=5, fontsize=16)
    plt.savefig(f"results/grokking_{code}.pdf", bbox_inches="tight")
# figure1.savefig("results/grokking.pgf", bbox_inches="tight")
# !./pgf_compiler.sh grokking
 

In [None]:
import pandas as pd



file_path = f'grokking_data/dsprites.csv'
train_acc = []
wio_acc = []
val_acc = []
test_acc = []
df = pd.read_csv(file_path)
df = df[["Step","Grouped runs - train_acc__MAX", "Grouped runs - val_acc__MAX", "Grouped runs - test_acc__MAX", "Grouped runs - wio_acc__MAX"]]
df.columns = df.columns.str.replace(r'^Grouped runs - ', '', regex=True)
df1 = df.rolling(window=100, min_periods=1).mean()
df2 = df.rolling(window=1000, min_periods=1).mean()



figure1, ax1 = plt.subplots(1, 1, figsize=(16,5))#, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))
colors = {
    'train': '#5F8B4C',
    'val':   '#FFDDAB',
    'wio':   '#FF9A9A',
    'test':  '#945034' 
}
ax1.scatter(df1["Step"], df1["train_acc__MAX"], label='Train Accuracy', color=colors['train'], marker='o', s=1.2)
ax1.scatter(df1["Step"], df1["val_acc__MAX"], label='Validation Accuracy', color=colors['val'], marker='o', s=1.2)
ax1.scatter(df1["Step"], df1["wio_acc__MAX"], label='WIO Accuracy', color=colors['wio'], marker='o', s=1.2)
ax1.scatter(df1["Step"], df1["test_acc__MAX"], label='Test Accuracy', color=colors['test'], marker='o', s=1.2)
ax1.tick_params(axis='both', which='major', labelsize=16)
ax1.tick_params(axis='both', which='minor', labelsize=16)

plt.xlabel('Epoch', fontsize=16)
plt.ylabel('Accuracy', fontsize=16)
plt.legend(markerscale=5, fontsize=16)
plt.savefig(f"results/grokking_dsprites.pdf", bbox_inches="tight")
 

In [None]:
import pandas as pd

file_path = f'grokking_data/iraven.csv'
train_acc = []
wio_acc = []
val_acc = []
test_acc = []
df = pd.read_csv(file_path)
df = df[["Step","Grouped runs - train_acc__MAX", "Grouped runs - val_acc__MAX", "Grouped runs - test_acc__MAX", "Grouped runs - wio_acc__MAX"]]
df.columns = df.columns.str.replace(r'^Grouped runs - ', '', regex=True)
df1 = df.rolling(window=100, min_periods=1).mean()
df2 = df.rolling(window=1000, min_periods=1).mean()



figure1, ax1 = plt.subplots(1, 1, figsize=(16,5))#, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))
colors = {
    'train': '#5F8B4C',
    'val':   '#FFDDAB',
    'wio':   '#FF9A9A',
    'test':  '#945034' 
}
ax1.scatter(df1["Step"], df1["train_acc__MAX"], label='Train Accuracy', color=colors['train'], marker='o', s=1.2)
ax1.scatter(df1["Step"], df1["val_acc__MAX"], label='Validation Accuracy', color=colors['val'], marker='o', s=1.2)
ax1.scatter(df1["Step"], df1["wio_acc__MAX"], label='WIO Accuracy', color=colors['wio'], marker='o', s=1.2)
ax1.scatter(df1["Step"], df1["test_acc__MAX"], label='Test Accuracy', color=colors['test'], marker='o', s=1.2)
ax1.tick_params(axis='both', which='major', labelsize=16)
ax1.tick_params(axis='both', which='minor', labelsize=16)

plt.xlabel('Epoch', fontsize=16)
plt.ylabel('Accuracy', fontsize=16)
plt.legend(markerscale=5, fontsize=16)
plt.savefig(f"results/grokking_iraven.pdf", bbox_inches="tight")
 

## Selection metric

In [None]:
id_sel=np.load("selection/id_sel.npy" )
wio_sel=np.load("selection/wio_sel.npy")
ood_sel=np.load("selection/ood_sel.npy")
oracle=np.load("selection/oracle.npy" )

In [None]:
mean = np.mean(oracle-id_sel)
figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))
ax1.hist(oracle-id_sel, bins=80, alpha=0.7, color="#FF9A9A")
plt.axvline(x = mean, ymin = 0, ymax = 400, linestyle=":", color="r")
plt.yscale("log")
plt.xlabel("Accuracy $\Delta$", fontsize=16)
plt.ylabel("Number of experiments", fontsize=16)
figure1.savefig("results/selection_idvswio.pgf", bbox_inches="tight")
ax1.tick_params(axis='both', which='major', labelsize=14)
ax1.tick_params(axis='both', which='minor', labelsize=14)
ax1.text(0.35, 0.75, f"$\mu={mean:.2f}$\%", ha='center', va='center', fontsize=16, transform=ax1.transAxes)
figure1.savefig("results/selection_idvswio.pgf", bbox_inches="tight")
!./pgf_compiler.sh selection_idvswio

In [None]:
mean = np.mean(oracle-ood_sel)
figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))
ax1.hist(oracle-ood_sel, bins=80, alpha=0.7, color="#FFDDAB")
plt.axvline(x = mean, ymin = 0, ymax = 400, linestyle=":", color="r")
plt.yscale("log")
plt.xlabel("Accuracy $\Delta$", fontsize=16)
plt.ylabel("Number of experiments", fontsize=16)
figure1.savefig("results/selection_idvswio.pgf", bbox_inches="tight")
ax1.tick_params(axis='both', which='major', labelsize=14)
ax1.tick_params(axis='both', which='minor', labelsize=14)
ax1.text(0.4, 0.75, f"$\mu={mean:.2f}$\%", ha='center', va='center', fontsize=16, transform=ax1.transAxes)
figure1.savefig("results/selection_oodvswio.pgf", bbox_inches="tight")
!./pgf_compiler.sh selection_oodvswio

In [None]:
mean = np.mean(oracle-wio_sel)
figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))
ax1.hist(oracle-wio_sel, bins=80, alpha=0.7, color="#945034")
plt.axvline(x = mean, ymin = 0, ymax = 400, linestyle=":", color="r")
plt.yscale("log")
plt.xlabel("Accuracy $\Delta$", fontsize=16)
plt.ylabel("Number of experiments", fontsize=16)
figure1.savefig("results/selection_wiovsoracle.pgf", bbox_inches="tight")
ax1.text(0.35, 0.75, f"$\mu={mean:.2f}$\%", ha='center', va='center', fontsize=16, transform=ax1.transAxes)
# ax1.set_xlim(-110, 110)
ax1.tick_params(axis='both', which='major', labelsize=14)
ax1.tick_params(axis='both', which='minor', labelsize=14)
figure1.savefig("results/selection_oraclevswio.pgf", bbox_inches="tight")
!./pgf_compiler.sh selection_oraclevswio

## Pairwise

In [None]:
import pandas as pd
METRICS = [
    "train_acc",
    "val_acc",
    "ood_val_0_acc",
    "test_acc",
]
rdf = pd.read_pickle("pairwise/mpi3d.pkl")

group_columns = ["arch"]
res = rdf.groupby(group_columns)[METRICS].agg(['mean', 'sem']).reset_index()

res[[(col, 'mean') for col in METRICS] + [(col, 'sem') for col in METRICS]] = (
    res[[(col, 'mean') for col in METRICS] + [(col, 'sem') for col in METRICS]].round(2)
)
print(res.to_latex(index=False,
    formatters={"name": str.upper},
    float_format="{:.2f}".format,
))  

In [None]:
len(pd.read_pickle("pairwise/mpi3d.pkl"))+len(pd.read_pickle("pairwise/shapes3d.pkl"))+len(pd.read_pickle("pairwise/cars3d.pkl"))+len(pd.read_pickle("pairwise/dsprites.pkl"))+len(pd.read_pickle("pairwise/.pkl"))

In [None]:
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
METRICS = [
    "train_acc",
    "val_acc",
    "ood_val_0_acc",
    "test_acc",
]
for dataset in ["mpi3d", "shapes3d", "cars3d", "dsprites", "iraven"]:
    rdf = pd.read_pickle(f"pairwise/{dataset}.pkl")
    rdf = rdf[~rdf['arch'].str.contains('prelu', case=False, na=False)]
    stems = ['convnext', 'resnet', 'vit', "swin", "densenet", "mlp", "ed"]
    def assign_stem(model_name):
        for stem in stems:
            if stem in model_name:
                return stem
        return 'other'
    rdf['stem'] = rdf['arch'].apply(assign_stem)
    group_columns = ["combination", "stem"]
    res = rdf.groupby(group_columns)[METRICS].agg(['mean']).reset_index()
    res.columns = res.columns.droplevel(1)
    res['combination'] = res['combination'].apply(lambda x: f"({x.replace('_', ', ')})")


    plt.figure(figsize=(9, 6))
    sns.stripplot(
        x="combination", 
        y="test_acc", 
        data=res, 
        palette="muted",
        hue="stem",
        size=5,
        marker="o",
        edgecolor="black",alpha=.75, s=9,linewidth=1.0
    )
    ax = plt.gca()
    ax.tick_params(axis='both', which='major', labelsize=14)
    ax.tick_params(axis='both', which='minor', labelsize=14)
    # Customizing the plot
    plt.xlabel("Generative Factors Combination", fontsize=16)
    plt.xticks(rotation=30)
    plt.ylabel("Test Accuracy (%)", fontsize=16)
    plt.ylim([-10, 110])
    plt.legend(loc='best', fontsize=16)
    plt.savefig(f"results/attributewise_{dataset}.pdf", bbox_inches="tight", )
    plt.close()

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import plot_settings as plot_settings
plot_settings.set_latex_settings()

METRICS = [
    "train_acc",
    "val_acc",
    "ood_val_0_acc",
    "test_acc",
]
models = [
 'resnet152',
 'resnet101',
 'resnet34',
 'resnet18',
 'densenet161',
 'convnext',
 'densenet201',
 'densenet121',
 'convnext',
 'resnet50'
]

for dataset in ["mpi3d", "shapes3d", "cars3d", "iraven"]:
    figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))
    rdf = pd.read_pickle(f"pairwise/{dataset}.pkl")
    rdf = rdf[~rdf['arch'].str.contains("pretrained", case=False, na=False)]
    rdf = rdf[rdf['arch'].str.contains('|'.join(models), case=False, na=False)]
    stems = ['resnet', "densenet", 'convnext']
    def assign_stem(model_name):
        for stem in stems:
            if stem in model_name:
                return stems.index(stem)
        return 'other'
    rdf['stem'] = rdf['arch'].apply(assign_stem)
    rdf["prelu"] = rdf['arch'].str.contains("prelu", case=False, na=False)
    res = rdf.groupby(['prelu', 'stem'])[METRICS].agg(['mean', 'sem']).reset_index()
    x = list(range(3))
    y = [res[np.logical_and(res["stem"] == idx, res["prelu"] == False)]["test_acc"]["mean"].item() for idx in x]
    y_err = [res[np.logical_and(res["stem"] == idx, res["prelu"] == False)]["test_acc"]["sem"].item() for idx in x]
    ax1.errorbar(
        x,
        y,
        yerr=y_err,
        fmt='.',
        label="Standard",
        markersize=8,
        capsize=5,
        color="#FF9A9A"
    )
    y = [res[np.logical_and(res["stem"] == idx, res["prelu"] == True)]["test_acc"]["mean"].item() for idx in x]
    y_err = [res[np.logical_and(res["stem"] == idx, res["prelu"] == True)]["test_acc"]["sem"].item() for idx in x]

    ax1.errorbar(
        x,
        y,
        yerr=y_err,
        fmt='.',
        label="PReLU",
        markersize=8,
        capsize=5,
        color="#945034"
    )

    ax1.tick_params(axis='both', which='major', labelsize=14)
    ax1.tick_params(axis='both', which='minor', labelsize=14)
    plt.xticks([0, 1, 2], ["ResNets", "DenseNets", "ConvNeXts"])
    plt.ylabel("Test Accuracy (\%)", fontsize=16)
    plt.ylim([-10, 110])
    plt.xlim([-.5, 2.5])
    plt.legend(loc='best', fontsize=16)

    figure1.savefig(f"results/prelu_{dataset}.pgf", bbox_inches="tight")

!./pgf_compiler.sh prelu

In [None]:
# rdf = pd.concat([pd.read_pickle(f"pairwise/shapes3d.pkl"), pd.read_pickle(f"pairwise/mpi3d.pkl"), pd.read_pickle(f"pairwise/dsprites.pkl"), pd.read_pickle(f"pairwise/iraven.pkl")])

result = pd.DataFrame()
for dataset in ["iraven", "cars3d", "shapes3d", "mpi3d",]:
    rdf = pd.read_pickle(f"pairwise/{dataset}.pkl")
    rdf = rdf[~rdf['arch'].str.contains("pretrained", case=False, na=False)]
    rdf = rdf[rdf['arch'].str.contains('|'.join(models), case=False, na=False)]
    df = rdf.groupby(['arch'])[METRICS].agg(['mean']).reset_index()
    df.columns = df.columns.droplevel(1)
    df['base_arch'] = df['arch'].str.replace('_prelu', '', regex=False)
    pivot = df.pivot(index='base_arch', columns='arch', values='test_acc')
    pivot['delta_test_acc'] = pivot.apply(
        lambda row: row.get(f"{row.name}_prelu", float('nan')) - row.get(row.name, float('nan')),
        axis=1
    )
    result[dataset] = pivot['delta_test_acc']
result['AVG'] = result[['iraven', 'cars3d', 'shapes3d', 'mpi3d']].mean(axis=1)
print(result.to_latex(index=True, float_format="{:.2f}".format,))


In [None]:
result

## Main results

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import plot_settings
plot_settings.set_latex_settings()

METRICS = [
    "train_acc",
    "val_acc",
    "ood_val_0_acc",
    "test_acc",
]
pw = pd.concat([pd.read_pickle(f"pairwise/dsprites.pkl"), pd.read_pickle(f"pairwise/iraven.pkl"), pd.read_pickle(f"pairwise/shapes3d.pkl"), pd.read_pickle(f"pairwise/mpi3d.pkl"), pd.read_pickle(f"pairwise/cars3d.pkl")])
ot = pd.concat([pd.read_pickle(f"main/dsprites.pkl"), pd.read_pickle(f"main/iraven.pkl"), pd.read_pickle(f"main/shapes3d.pkl"), pd.read_pickle(f"main/cars3d.pkl")])

for name, pairwise in zip(["pairwise", "orthotopic"], [pw, ot]):

    if name == "orthotopic":
        pairwise = pairwise[pairwise["c"]=="1"]
    pairwise = pairwise[~pairwise['arch'].str.contains('prelu|convnext_tiny', case=False, na=False)]
    res = pairwise.groupby(["arch"])[METRICS].agg(['mean', 'sem']).reset_index()

    def modtoname(arch):
        mp = {
            'convnext_base': "CN-base", 
            'convnext_small': "CN-small",
            'densenet121': "DN-121",
            'densenet121_pretrained': "DN-121-PT",
            'densenet161': "DN-161",
            'densenet201': "DN-201",
            'ed': "ED",
            'mlp': "MLP",
            'resnet101': "RN-101",
            'resnet101_pretrained': "RN-101-PT",
            'resnet152': "RN-152",
            'resnet152_pretrained': "RN-151-PT",
            'resnet18': "RN-18",
            'resnet34': "RN-34",
            'resnet50': "RN-50",
            'swin_base': "ST-base",
            'swin_tiny': "ST-tiny",
            'vit': "ViT",
            'wideresnet': "WRN",
        }
        return mp[arch]

    color_dict = {
        'convnext': '#FF9A9A',
        'resnet':   '#FFDDAB',
        'vit':      '#945034',
        'densenet': '#7CA982',
        'mlp':      '#769ECB',
        'ed':       '#C287E8',
    }
    model_sizes = {
        "mlp":                406850,
        'densenet':       6965131,
        'densenet121':       6965131,
        'densenet121_pretrained': 6965131,
        'densenet161':      26486891,
        'densenet201':      18107787,
        "resnet18":         11175883,
        "resnet34":         21284043,
        "resnet50":         24556491,
        "resnet101_pretrained":        43548619,
        "resnet152_pretrained":        59192267,
        "resnet101":        43548619,
        "resnet152":        59192267,
        "wideresnet":       67882699,
        "ed":               22347136,
        "vit":              86576115,
        "convnext_base":    87573632,
        "convnext_small":   49460064,
        "swin_tiny":        27532469,
        "swin_base":        86771459,
    }
    families = ['convnext', 'resnet', 'vit', 'densenet', 'mlp', 'ed']
    families_caps = ['ConvNeXt', 'ResNet', 'ViT', 'DenseNet', 'MLP', 'ED']
    def get_family(arch):
        for fam in families:
            if arch.startswith(fam):
                return fam
            elif arch.startswith("wideresnet"):
                return "resnet"
            elif arch.startswith("swin"):
                return "vit"
        return "other"

    res["family"] = res["arch"].apply(get_family)
    res["model_size"] = res["arch"].map(model_sizes)

    figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*3/2, plot_settings.column_width*1/2))
    res["arch"] = res["arch"].apply(modtoname)
    res.sort_values(by=('test_acc', "mean"), inplace=True)
    res.reset_index(inplace=True)
    res['index1'] = res.index
    texts = []
    stds = []
    for _, row in res.iterrows():
        ax1.errorbar(
            row["index1"],
            row["test_acc"]["mean"],
            yerr=row["test_acc"]["sem"],
            fmt='.',
            color=color_dict[row["family"].item()],
            capsize=3,
            markersize=8,
            alpha=0.8
        )
        if name == "orthotopic":
            stds.append((row["arch"], row["test_acc"]["sem"]))
        texts.append(ax1.text(
            row["index1"].item(),
            row["test_acc"]["mean"]+10,
            f"{row['test_acc']['mean']:.1f}\%",
            fontsize=8,
            ha='center',
            va='center'
        ))

    handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color_dict[f], label=fc, markersize=6)
            for f, fc in zip(families, families_caps)]
    ax1.legend(handles=handles, loc='upper left', ncol=6, fontsize=8)

    # Axes
    # plt.xscale("log")
    # plt.xlabel("Number of Parameters", fontsize=16)
    plt.ylabel("Test Accuracy (\%)", fontsize=16)
    ax1.tick_params(axis='both', which='major', labelsize=10, rotation=25)
    ax1.tick_params(axis='both', which='minor', labelsize=14)
    ax1.set_xticks(list(range(len(res['arch'].tolist()))), res['arch'].tolist(), ha='right')
    plt.ylim([-5, 105])
    figure1.savefig(f"results/overall_{name}_alt.pgf", bbox_inches="tight")
    if name == "pairwise": res1 = res
!./pgf_compiler.sh overall_

In [None]:
np.mean([b for a, b in stds if a.item()!="ED"]) -   5.57544752412038

In [None]:
single_level_cols = [col for col, col2 in res.columns if col2 == '']
multi_level_means = res.xs('mean', axis=1, level=1)
res_sanitized = pd.concat([res[single_level_cols], multi_level_means], axis=1)
for _, row in res_sanitized.iterrows():
    if "PT" in row[("arch","")]:
        row[("family","")] = "pretrained"
        res_sanitized[res_sanitized[("arch",'')] == row[("arch","")]] = row
res_ortho = res_sanitized.groupby([("family", '')])[METRICS].agg(['mean', 'sem']).reset_index()
res_ortho.sort_values(by=('test_acc', "mean"), inplace=True)

single_level_cols = [col for col, col2 in res1.columns if col2 == '']
multi_level_means = res1.xs('mean', axis=1, level=1)
res_sanitized = pd.concat([res1[single_level_cols], multi_level_means], axis=1)
for _, row in res_sanitized.iterrows():
    if "PT" in row[("arch","")]:
        row[("family","")] = "pretrained"
        res_sanitized[res_sanitized[("arch",'')] == row[("arch","")]] = row
res_pair = res_sanitized.groupby([("family", '')])[METRICS].agg(['mean', 'sem']).reset_index()
res_pair.sort_values(by=('test_acc', "mean"), inplace=True)

families_caps = {
    "convnext":'ConvNeXt', 
    "resnet":'ResNet',
    "vit":'ViT',
    "densenet":'DenseNet',
    "mlp":'MLP',
    "ed": 'ED',
    "pretrained": "Pre-trained"
}


figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*0.6, plot_settings.column_width*1/2))
ax1.bar([families_caps[f] for f in res_ortho["family"]], res_ortho[('test_acc', 'mean')], color='#945034')
plt.ylabel("Test Accuracy (\%)", fontsize=16)
ax1.tick_params(axis='both', which='major', labelsize=10, rotation=25)
ax1.tick_params(axis='both', which='minor', labelsize=14)
plt.ylim([-5, 105])
ax1.set_xticks(list(range(len(res_ortho["family"]))), [families_caps[f] for f in res_ortho["family"]], ha='right')
figure1.savefig(f"results/comp_barplot_ortho.pgf", bbox_inches="tight")

figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*0.6, plot_settings.column_width*1/2))
ax1.bar([families_caps[f] for f in res_pair["family"]], res_pair[('test_acc', 'mean')], color='#FFDDAB')
plt.ylabel("Test Accuracy (\%)", fontsize=16)
ax1.tick_params(axis='both', which='major', labelsize=10, rotation=25)
ax1.tick_params(axis='both', which='minor', labelsize=14)
ax1.set_xticks(list(range(len(res_pair["family"]))), [families_caps[f] for f in res_pair["family"]], ha='right')
plt.ylim([-5, 105])
figure1.savefig(f"results/comp_barplot_pair.pgf", bbox_inches="tight")
!./pgf_compiler.sh comp_barplot

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import plot_settings
plot_settings.set_latex_settings()

METRICS = [
    "train_acc",
    "val_acc",
    "ood_val_0_acc",
    "test_acc",
]
families = ['convnext', 'resnet', 'vit', 'densenet', 'mlp', 'ed']
color_dict = {
    'convnext': '#FF9A9A',
    'resnet':   '#FFDDAB',
    'vit':      '#945034',
    'densenet': '#7CA982',
    'mlp':      '#769ECB',
    'ed':       '#C287E8',
}
families_caps = ['ConvNeXt', 'ResNet', 'ViT', 'DenseNet', 'MLP', 'ED']
def get_family(arch):
    for fam in families:
        if arch.startswith(fam):
            return fam
        elif arch.startswith("wideresnet"):
            return "resnet"
        elif arch.startswith("swin"):
            return "vit"
    return "other"

# for dataset in ["dsprites", "iraven", "cars3d", "shapes3d"]:
for dataset in ["clevr"]:
    data = pd.read_pickle(f"main/{dataset}.pkl")
    data["family"] = data["arch"].apply(get_family)
    data = data.groupby(["family", "c"])[METRICS].agg(['mean', 'sem']).reset_index()
    figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))
    for fam in families:
        res = data[data["family"] == fam]
        if res.empty: continue
        # --- add in-dist result by taking val accuracy from last c
        a = res.sort_values(by='c').iloc[-1]
        a["c"] = int(a["c"].item()) + 1
        a["test_acc"] = res.iloc[0]["val_acc"]
        res = res.append(a).reset_index()
        # ---
        ax1.plot([int(cc) for cc in res['c']], list(res['test_acc']["mean"]), label=fam, color=color_dict[fam], marker="o")
        ax1.fill_between(
            [int(cc) for cc in res['c']],
            res['test_acc']["mean"] - res['test_acc']["sem"],
            res['test_acc']["mean"] + res['test_acc']["sem"],
            alpha=0.3,
            color=color_dict[fam]
        )
    handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color_dict[f], label=fc, markersize=12)
            for f, fc in zip(families, families_caps)]
    if dataset == "dsprites":
        ax1.legend(handles=handles, loc='best',fontsize=16, bbox_to_anchor=(1.5, -0.2), ncol=6)
    ax1.set_xticks([int(x) for x in data["c"].unique()]+[int(max(data["c"].unique()))+1])
    plt.xlabel("$c$", fontsize=16)
    plt.ylabel("Test Accuracy (\%)", fontsize=16)
    ax1.tick_params(axis='both', which='major', labelsize=14)
    ax1.tick_params(axis='both', which='minor', labelsize=14)
    plt.ylim([-5, 105])
    figure1.savefig(f"results/neurips_main_{dataset}.pgf", bbox_inches="tight")

!./pgf_compiler.sh neurips_main

## Extended main results

In [None]:
import pandas as pd
METRICS = [
    "train_acc",
    "val_acc",
    "ood_val_0_acc",
    "test_acc",
]

rdf = pd.read_pickle("main/clevr.pkl")
print(rdf["c"].unique())
rdf = rdf[rdf["c"]=="3"]
group_columns = ["arch"]
res = rdf.groupby(group_columns)[METRICS].agg(['mean', 'sem']).reset_index()

res[[(col, 'mean') for col in METRICS] + [(col, 'sem') for col in METRICS]] = (
    res[[(col, 'mean') for col in METRICS] + [(col, 'sem') for col in METRICS]].round(2)
)
# print(res.to_latex(index=False,
#     formatters={"name": str.upper},
#     float_format="{:.2f}".format,
# ))  

def format_model_name(arch):
    name = arch.replace("_pretrained", "")
    pretrained = r"\cmark" if "pretrained" in arch else r"\xmark"
    pretty_map = {
        "resnet18": "ResNet-18",
        "resnet50": "ResNet-50",
        "resnet101": "ResNet-101",
        "resnet152": "ResNet-152",
        "densenet121": "DenseNet-121",
        "densenet161": "DenseNet-161",
        "densenet201": "DenseNet-201",
        "convnext_tiny": "ConvNeXt-Tiny",
        "convnext_small": "ConvNeXt-Small",
        "convnext_base": "ConvNeXt-Base",
        "swin_tiny": "Swin-Tiny",
        "swin_base": "Swin-Base",
        "wideresnet": "WideResNet",
        "ed": "ED",
        "mlp": "MLP"
    }
    pretty_name = pretty_map.get(name, name)
    return pretty_name, pretrained

# Generate LaTeX rows
latex_rows = []
for _, row in res.iterrows():
    model, pretrained = format_model_name(row["arch"].item())
    values = " & ".join([f"{row[col]:.2f}" for col in res.columns[1:]])
    latex_rows.append(f"{model} & {pretrained} & {values} \\\\")

# Join all rows
latex_table_body = "\n".join(latex_rows)
print(latex_table_body)

## FPE vs Linear

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import plot_settings
plot_settings.set_latex_settings()

for dataset in ["dsprites", "iraven", "shapes3d", "cars3d", "mpi3d"]:
    # FPE
    df = pd.read_csv(f"fpe_linear/{dataset}_fpe.csv")
    df.drop(list(df.filter(regex='MAX|MIN|Step')), axis=1, inplace=True)
    figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*6/4, plot_settings.column_width*3/4))
    data = np.array([df[name] for name in list(df)])
    mean = np.mean(data, axis=0)
    std = np.std(data, axis=0)
    ax1.plot(list(range(data.shape[1])), mean, color='#FF9A9A', label="FPE")
    ax1.fill_between(
        list(range(data.shape[1])),
        np.clip(mean - std, 0, 100),
        np.clip(mean + std, 0, 100),
        alpha=0.3,
        color='#FF9A9A'
    )
    # linear
    df = pd.read_csv(f"fpe_linear/{dataset}_linear.csv")
    df.drop(list(df.filter(regex='MAX|MIN|Step')), axis=1, inplace=True)
    clean = []
    for c in df.columns:
        clean.append(df[c].dropna())
    data = np.array(clean)
    mean = np.mean(data, axis=0)
    std = np.std(data, axis=0)
    ax1.plot(list(range(data.shape[1])), mean, color='#945034', label="Linear")
    ax1.fill_between(
        list(range(data.shape[1])),
        np.clip(mean - std, 0, 100),
        np.clip(mean + std, 0, 100),
        alpha=0.3,
        color='#945034'
    )
    plt.xlabel("Epoch", fontsize=16)
    plt.ylabel("Test Accuracy (\%)", fontsize=16)
    ax1.tick_params(axis='both', which='major', labelsize=14)
    ax1.tick_params(axis='both', which='minor', labelsize=14)
    plt.ylim([-5, 105])
    ax1.legend(loc='best', fontsize=16)
    figure1.savefig(f"results/fpe_linear_{dataset}.pgf", bbox_inches="tight")
!./pgf_compiler.sh fpe_linear_

## AIN

In [None]:
import pandas as pd
import numpy as np

METRICS = [
    "train_acc",
    "val_acc",
    "ood_val_0_acc",
    "test_acc",
]
for data in ["shapes3d", "mpi3d", "dsprites", "iraven", "cars3d", "clevr"]:

    ain = pd.read_pickle(f"ain/{data}.pkl")
    oth = pd.read_pickle(f"main/{data}.pkl")

    oth = oth[oth["c"] == "1"]

    group_columns = ["arch"]

    res_ain = ain.groupby(group_columns)[METRICS].agg(['mean', 'sem']).reset_index()
    res_ain[[(col, 'mean') for col in METRICS] + [(col, 'sem') for col in METRICS]] = (
        res_ain[[(col, 'mean') for col in METRICS] + [(col, 'sem') for col in METRICS]].round(2)
    )
    res_oth = oth.groupby(group_columns)[METRICS].agg(['mean', 'sem']).reset_index()
    res_oth[[(col, 'mean') for col in METRICS] + [(col, 'sem') for col in METRICS]] = (
        res_oth[[(col, 'mean') for col in METRICS] + [(col, 'sem') for col in METRICS]].round(2)
    )
    res_oth = res_oth[np.logical_or(res_oth["arch"] == "resnet18", res_oth["arch"] == "ed")]

    print(res_ain)
    print(res_oth)


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors
import numpy as np
from adjustText import adjust_text
import plot_settings
plot_settings.set_latex_settings()

METRICS = [
    "train_acc",
    "val_acc",
    "ood_val_0_acc",
    "test_acc",
]


ain = pd.concat([pd.read_pickle(f"ain/{dataset}.pkl") for dataset in ["shapes3d", "mpi3d", "dsprites", "iraven", "cars3d"]])
oth = pd.concat([pd.read_pickle(f"main/{dataset}.pkl") for dataset in ["shapes3d", "mpi3d", "dsprites", "iraven", "cars3d"]])
oth = oth[oth["c"] == "1"]
clevr = pd.read_pickle(f"main/clevr.pkl")
clevr = clevr[clevr["c"] == "1"]
data = pd.concat([ain, oth, clevr])
res = data.groupby(["arch"])[METRICS].agg(['mean', 'sem']).reset_index()


def modtoname(arch):
    mp = {
        'convnext_base': "CN-base", 
        'convnext_small': "CN-small",
        'densenet121': "DN-121",
        'convnext_tiny': "CN-tiny",
        'densenet121_pretrained': "DN-121-PT",
        'densenet161': "DN-161",
        'densenet201': "DN-201",
        'ed': "ED",
        'mlp': "MLP",
        'resnet101': "RN-101",
        'resnet101_pretrained': "RN-101-PT",
        'resnet152': "RN-152",
        'resnet152_pretrained': "RN-151-PT",
        'resnet18': "RN-18",
        'resnet34': "RN-34",
        'resnet50': "RN-50",
        'swin_base': "ST-base",
        'swin_tiny': "ST-tiny",
        'vit': "ViT",
        'wideresnet': "WRN",
        "split": "AIN"
    }
    return mp[arch]

color_dict = {
    'convnext': '#FF9A9A',
    'resnet':   '#FFDDAB',
    'vit':      '#945034',
    'densenet': '#7CA982',
    'mlp':      '#769ECB',
    'split': "#E63946",
    'ed':       '#C287E8',
}
model_sizes = {
    "mlp":                406850,
    'densenet':       6965131,
    'densenet121':       6965131,
    'densenet121_pretrained': 6965131,
    'densenet161':      26486891,
    'densenet201':      18107787,
    "resnet18":         11175883,
    "resnet34":         21284043,
    "resnet50":         24556491,
    "resnet101_pretrained":        43548619,
    "resnet152_pretrained":        59192267,
    "resnet101":        43548619,
    "resnet152":        59192267,
    "split":            11175883 +  11175883*0.032*4.16,
    "wideresnet":       67882699,
    "ed":               11175883 +  11175883*4.16,
    "vit":              86576115,
    "convnext_base":    87573632,
    "convnext_small":   49460064,
    "convnext_tiny":    28600064,
    "swin_tiny":        27532469,
    "swin_base":        86771459,
}
families = ['convnext', 'resnet', 'vit', 'densenet', 'mlp', 'ed', "split"]
families_caps = ['ConvNeXt', 'ResNet', 'ViT', 'DenseNet', 'MLP', 'ED', "AIN"]
def get_family(arch):
    for fam in families:
        if arch.startswith(fam):
            return fam
        elif arch.startswith("wideresnet"):
            return "resnet"
        elif arch.startswith("swin"):
            return "vit"
    return "other"

res["family"] = res["arch"].apply(get_family)
res["model_size"] = res["arch"].map(model_sizes)

figure1, ax1 = plt.subplots(1, 1, figsize=(plot_settings.column_width*3/4, plot_settings.column_width*3/4))
res["arch"] = res["arch"].apply(modtoname)
texts = []
for _, row in res.iterrows():
    ax1.errorbar(
        row["model_size"],
        row["test_acc"]["mean"],
        yerr=row["test_acc"]["sem"],
        fmt='.',
        color=color_dict[row["family"].item()],
        capsize=3,
        markersize=8,
        alpha=0.8
    )
#     texts.append(ax1.text(
#         row["model_size"].item() * 1.01,
#         row["test_acc"]["mean"],
#         row["arch"].item(),
#         fontsize=9,
#         verticalalignment='center'
#     ))

# adjust_text(texts, expand=(1,1), arrowprops=dict(arrowstyle='-', color='k', lw=1))
# Legend
handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=color_dict[f], label=fc, markersize=6)
           for f, fc in zip(families, families_caps)]
ax1.legend(handles=handles, loc='best', ncol=3)

# Axes
# plt.xscale("log")
plt.xlabel("Number of Parameters", fontsize=16)
plt.ylabel("Test Accuracy (\%)", fontsize=16)
ax1.tick_params(axis='both', which='major', labelsize=14)
ax1.tick_params(axis='both', which='minor', labelsize=14)
plt.ylim([-5, 105])
figure1.savefig("results/pareto.pgf", bbox_inches="tight")
!./pgf_compiler.sh pareto