In [None]:
from concepts.data import kitchens
from concepts.training import train_cbm_independent, train_cbm_joint, train_cbm_sequential, train_cem, train_black_box
import torch
import lightning

In [None]:
def get_intervention_accuracies(model, test_dl):
    trainer = lightning.Trainer()

    intervention_accuracies = []
    model.intervention_mask = torch.tensor([0] * model.n_concepts)
    [test_results] = trainer.test(model, test_dl)
    intervention_accuracies.append(round(test_results["test_y_accuracy"], 4))
    for c in range(model.n_concepts):
        model.intervention_mask[c] = 1
        [test_results] = trainer.test(model, test_dl)
        intervention_accuracies.append(round(test_results["test_y_accuracy"], 4))
    return intervention_accuracies

In [None]:
datasets = kitchens.KitchensDatasets(foundation_model="dinov2")

# Joint CBM

In [None]:
joint_cbm, joint_cbm_results, joint_cbm_recipeless_results = train_cbm_joint(datasets)

In [None]:
intervention_accuracies = get_intervention_accuracies(joint_cbm, datasets.test_dl())

In [None]:
print(", ".join(map(str, intervention_accuracies)))

# Sequential CBM

In [None]:
sequential_cbm, sequential_cbm_results, sequential_cbm_recipeless_results = train_cbm_sequential(datasets)

In [None]:
intervention_accuracies = get_intervention_accuracies(sequential_cbm, datasets.test_dl())

In [None]:
print(", ".join(map(str, intervention_accuracies)))

# Independent CBM

In [None]:
independent_cbm, independent_cbm_results, independent_cbm_recipeless_results = train_cbm_independent(datasets)

In [None]:
intervention_accuracies = get_intervention_accuracies(independent_cbm, datasets.test_dl())

In [None]:
print(", ".join(map(str, intervention_accuracies)))

# Sequential CBM with Independent Concepts

In [None]:
recipeless_sequential_cbm, recipeless_sequential_cbm_results, recipeless_sequential_cbm_recipeless_results = train_cbm_sequential(datasets, use_recipeless=True)

In [None]:
intervention_accuracies = get_intervention_accuracies(recipeless_sequential_cbm, datasets.test_dl())

In [None]:
print(", ".join(map(str, intervention_accuracies)))

# Independent CBM with Independent Concepts

In [None]:
recipeless_independent_cbm, recipeless_independent_cbm_results, recipeless_independent_cbm_recipeless_results = train_cbm_independent(datasets, use_recipeless=True)

In [None]:
intervention_accuracies = get_intervention_accuracies(recipeless_independent_cbm, datasets.test_dl())

In [None]:
print(", ".join(map(str, intervention_accuracies)))

# CEM

In [None]:
cem, cem_results, cem_recipeless_results = train_cem(datasets)

In [None]:
intervention_accuracies = get_intervention_accuracies(cem, datasets.test_dl())

In [None]:
print(", ".join(map(str, intervention_accuracies)))

# Black Box

In [None]:
black_box, black_box_results = train_black_box(datasets)