This notebook is inspired by the work of CEM (Zarlenga et al. 2022) and CBM (Koh et al. 2020) papers. Please visit their GitHub repositories:
[CEM GitHub](https://github.com/mateoespinosa/cem) and [CBM GitHub](https://github.com/yewsiang/ConceptBottleneck).

# Ranker: ranker training

There are four main steps:
1. Loading the dataset.
2. Initializing a CBM with InceptionV3 vision backbone for the dataset.
3. Load CBMs
4. Training the Ranker
5. Evaluating

## Step 1: Load Data

The first step is to load the data. The designed CBM class with the PyTorch Lightning Trainer takes in PyTorch DataLoader object.
Furthermore, it needs to contain three elements (in the following order):
1. the sample image, $\mathbf{x}$
2. the image label, $\mathbf{y}$
3. the concept labels, in binary format, $\mathbf{c}$

In [None]:
#from cub_data_module import *
import logging
import torch
import numpy as np
import yaml
import torchvision.models as models
import pytorch_lightning as pl

In [None]:
#import sys #for Kaggle
#sys.path.append("/kaggle/usr/lib/cub_data_module_confounded") #for Kaggle

import src.cub_data_module as cub_data_module

In [None]:
def _update_config_with_dataset(
    config,
    train_dl,
    n_concepts,
    n_tasks,
    concept_map,
):
    config["n_concepts"] = config.get(
        "n_concepts",
        n_concepts,
    )
    config["n_tasks"] = config.get(
        "n_tasks",
        n_tasks,
    )
    config["concept_map"] = config.get(
        "concept_map",
        concept_map,
    )

    task_class_weights = None

    if config.get('use_task_class_weights', False):
        logging.info(
            f"Computing task class weights in the training dataset with "
            f"size {len(train_dl)}..."
        )
        attribute_count = np.zeros((max(n_tasks, 2),))
        samples_seen = 0
        for i, data in enumerate(train_dl):
            if len(data) == 2:
                (_, (y, _)) = data
            else:
                (_, y, _) = data
            if n_tasks > 1:
                y = torch.nn.functional.one_hot(
                    y,
                    num_classes=n_tasks,
                ).cpu().detach().numpy()
            else:
                y = torch.cat(
                    [torch.unsqueeze(1 - y, dim=-1), torch.unsqueeze(y, dim=-1)],
                    dim=-1,
                ).cpu().detach().numpy()
            attribute_count += np.sum(y, axis=0)
            samples_seen += y.shape[0]
        print("Class distribution is:", attribute_count / samples_seen)
        if n_tasks > 1:
            task_class_weights = samples_seen / attribute_count - 1
        else:
            task_class_weights = np.array(
                [attribute_count[0]/attribute_count[1]]
            )
    return task_class_weights

In [None]:
def _generate_dataset_and_update_config(
    experiment_config
):
    if experiment_config.get("dataset_config", None) is None:
        raise ValueError(
            "A dataset_config must be provided for each experiment run!"
        )

    dataset_config = experiment_config['dataset_config']
    logging.debug(
        f"The dataset's root directory is {dataset_config.get('root_dir')}"
    )
    intervention_config = experiment_config.get('intervention_config', {})
    if dataset_config["dataset"] == "cub":
        data_module = cub_data_module
    else:
        raise ValueError(f"Unsupported dataset {dataset_config['dataset']}!")

    train_dl, val_dl, test_dl, imbalance, (n_concepts, n_tasks, concept_map) = \
        data_module.generate_data(
            config=dataset_config,
            seed=42,
            output_dataset_vars=True,
            root_dir=dataset_config.get('root_dir', None),
        )
    # For now, we assume that all concepts have the same
    # aquisition cost
    acquisition_costs = None
    if concept_map is not None:
        intervened_groups = list(
            range(
                0,
                len(concept_map) + 1,
                intervention_config.get('intervention_freq', 1),
            )
        )
    else:
        intervened_groups = list(
            range(
                0,
                n_concepts + 1,
                intervention_config.get('intervention_freq', 1),
            )
        )

    task_class_weights = _update_config_with_dataset(
        config=experiment_config,
        train_dl=train_dl,
        n_concepts=n_concepts,
        n_tasks=n_tasks,
        concept_map=concept_map,
    )
    return (
        train_dl,
        val_dl,
        test_dl,
        imbalance,
        concept_map,
        intervened_groups,
        task_class_weights,
        acquisition_costs,
    )

In [None]:
yaml_path = "data/cub.yaml" # for local development, might need to use whole path.

with open(yaml_path, "r") as file:
    yaml_config = yaml.safe_load(file)
yaml_config["shared_params"]["dataset_config"]["root_dir"] = "/kaggle/input/cem-cub2000-filtered/" #for Kaggle, replace this with locally downloaded folder.
yaml_config["shared_params"]["dataset_config"]["num_workers"] = 4 #change depending on resources available.
yaml_config["shared_params"]["dataset_config"]["batch_size"] = 64 #change depending on resources available.

In [None]:
train_dl, val_dl, test_dl, imbalance, concept_map, intervened_groups, task_class_weights, acquisition_costs = _generate_dataset_and_update_config(yaml_config["shared_params"])

## Step 2: Create the CBM
### Step 2.1 Define model for input to concepts
We first need to define a architecture that will extract the concepts from the input image.

For this, we used a pre-trained InceptionV3 model. We remove the last linear layer and make one that we can use for our task, so it is ready for fine-tuning.

In [None]:
def latent_code_generator_model(output_dim=112):
    # Load pre-trained InceptionV3
    inception = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)

    # Remove auxiliary classifier (set to None)
    inception.aux_logits = False  # Disable aux_logits
    inception.AuxLogits = None  # Delete aux classifier branch

    inception.fc = torch.nn.Linear(2048, output_dim)  # Replace classification layer with output_dim

    return inception

### Step 2.2: define CBM model.
We need to define the following:
1. `n_concepts`: the number of concepts in the dataset (112).
2. `n_tasks`: the number of output labels in the dataset (200).
3. `concept_loss_weight`: the weight to use for the concept prediction loss during training of the CBM. Picked to be the same as the CEM paper.
4. `learning_rate` and `optimizer`: to use during training. Optimizer is Adam by default, otherwise SGD.
5. `c_extractor_arch`: the model architecture to use for going from the input space to the concepts.
6. `c2y_model` and `c2y_layers`: the model architecture to use for going from the concepts to the labels. It can be directly the model, like c_extractor_arch or the layers as a list. We choose to do a simple linear layer.

In [None]:
from src.utils_cbm import *
from src.cbm import ConceptBottleneckModel

In [None]:
def custom_saving(model, epoch):
    #Get optimizer and linear_rate_scheduler
    opt = model.configure_optimizers()['optimizer']
    lr_sched = model.configure_optimizers()['lr_scheduler']

    #opt = model.configure_optimizers()
    # Separate saving
    torch.save(model.state_dict(), f"/kaggle/working/model_{epoch}.pt")
    torch.save(opt.state_dict(), f"/kaggle/working/optimizer_{epoch}.pt")
    torch.save(lr_sched.state_dict(), f"/kaggle/working/lr_scheduler_{epoch}.pt")

    # All things together
    checkpoint = { 
        'epoch': epoch,
        'model': model.state_dict(),
        'optimizer': opt.state_dict(),
        'lr_scheduler': lr_sched.state_dict()
    }
    torch.save(checkpoint, f"checkpoint_{epoch}.pt")

## Step 3: Load the CBMs

Now that we have both the dataset and the model defined, we can train our CEM
using Pytorch Lightning's wrappers for ease. This should be very simple via
Pytorch Lightning's `Trainer` once the data has been generated:

In [None]:
def load_model(typ):
    cbm_model_new = ConceptBottleneckModel.load_from_checkpoint(
        checkpoint_path=f"models/{typ}.ckpt", #might need to be changed to local
        n_concepts=112,
        n_tasks=200,
        concept_loss_weight=yaml_config["shared_params"]["concept_loss_weight"],
        learning_rate=yaml_config["shared_params"]["learning_rate"],  # The learning rate to use during training.
        optimizer="sgd",
        c_extractor_arch=latent_code_generator_model, # Here we provide our generating function for the latent code generator model.
        c2y_model=None,
    )
    
    return cbm_model_new

In [None]:
cbm_model_f = load_model(typ="cbm_buggy")
cbm_model_o = load_model(typ="cbm_oracle")

## Step 4: Ranker model

The ranker model has different configurations.
- Config 1 (Ranker1 class), all four: $(x, c, w, y)$. Ranker $r(x, c_m(x), w_m, m(x))$.
- Config 2 (Ranker2 class), without input image: $(c, w, y)$. Ranker $r(c_m(x), w_m, m(x))$.
- Config 3 (Ranker3 class), without predicted label: $(x, c, w)$. Ranker $r(x, c_m(x), w_m)$.
- Config 4 (Ranker4 class), only concepts and weights: $(c, w)$. Ranker $r(c_m(x), w_m)$.

### Config 1
all four: $(x, c, w, y)$. Ranker $r(x, c_m(x), w_m, m(x))$.

In [None]:
from src.rankers import Ranker1

In [None]:
EPOCHS_NEW = 2

ranker_config1 = Ranker1(cbm_model_f, cbm_model_o, lr=1e-5, margin=5)

trainer = pl.Trainer(
    accelerator="gpu",  # Change to "cpu" if you are not running on a GPU!
    devices="auto", 
    max_epochs=EPOCHS_NEW,  # The number of epochs we will train our model for #ORIGINAL 500
    check_val_every_n_epoch=1,  # And how often we will check for validation metrics
    logger=False,  # No logs to be dumped for this trainer
)

trainer.fit(ranker_config1, train_dl, val_dl)

In [None]:
# Trainer saving
trainer.save_checkpoint(f"/kaggle/working/ranker_config1_{EPOCHS_NEW}.ckpt")
custom_saving(ranker_config1, EPOCHS_NEW)

In [None]:
# Define input structure
image_size = 3 * 299 * 299  # 268203
concepts_size = 112
weights_size = 22600
labels_size = 200

# Get weight vector
W = ranker_config1.lay[0].weight.detach().cpu().numpy().flatten()
total_features = image_size+concepts_size+weights_size+labels_size
W = W[:total_features]

# Compute absolute importance
importance = np.abs(W)

# Get indices of top 50 features
top_k = 2500
top_indices = np.argsort(-importance)[:top_k]  # Sort in descending order

# Function to determine feature type
def get_feature_category(index):
    if index < image_size:
        return "Image"
    elif index < image_size + concepts_size:
        return "Concept"
    elif index < image_size + concepts_size + weights_size:
        return "Weight"
    else:
        return "Label"

counter = {
    "Image":(0, 0),
    "Concept":(0, 0),
    "Weight":(0, 0),
    "Label":(0, 0),
}

# Print categorized results
print(f"Top {top_k} Most Important Features by Category:")
for rank, idx in enumerate(top_indices):
    category = get_feature_category(idx)
    counter[category] = (counter[category][0]+1, counter[category][1]+importance[idx])
    #print(f"Rank {rank+1}: Feature {idx} ({category}) - Weight {W[idx]:.6f} (Importance: {importance[idx]:.6f})")
counter

### Config 4
only concepts and weights: $(c, w)$. Ranker $r(c_m(x), w_m)$.

In [None]:
from src.rankers import Ranker4

In [None]:
EPOCHS_NEW = 2 #same number of epochs

ranker_config4 = Ranker4(cbm_model_f, cbm_model_o, lr=1e-5, margin=5)

trainer = pl.Trainer(
    accelerator="gpu",  # Change to "cpu" if you are not running on a GPU!
    devices="auto", 
    max_epochs=EPOCHS_NEW,  # The number of epochs we will train our model for #ORIGINAL 500
    check_val_every_n_epoch=1,  # And how often we will check for validation metrics
    logger=False,  # No logs to be dumped for this trainer
)

trainer.fit(ranker_config4, train_dl, val_dl)

In [None]:
# Define input structure
concepts_size = 112
weights_size = 22600

# Get weight vector
W = ranker_config4.lay[0].weight.detach().cpu().numpy().flatten()
total_features = concepts_size+weights_size
W = W[:total_features]

# Compute absolute importance
importance = np.abs(W)

# Get indices of top 50 features
top_k = 2500
top_indices = np.argsort(-importance)[:top_k]  # Sort in descending order

# Function to determine feature type
def get_feature_category(index):
    if index < concepts_size:
        return "Concept"
    else:
        return "Weight"

counter = {
    "Concept":(0, 0),
    "Weight":(0, 0),
}

# Print categorized results
print(f"Top {top_k} Most Important Features by Category:")
for rank, idx in enumerate(top_indices):
    category = get_feature_category(idx)
    counter[category] = (counter[category][0]+1, counter[category][1]+importance[idx])
    #print(f"Rank {rank+1}: Feature {idx} ({category}) - Weight {W[idx]:.6f} (Importance: {importance[idx]:.6f})")
counter

In [None]:
# Trainer saving
trainer.save_checkpoint(f"/kaggle/working/ranker_config4_{EPOCHS_NEW}.ckpt")
custom_saving(ranker_config4, EPOCHS_NEW)

### Config 3
without predicted label: $(x, c, w)$. Ranker $r(x, c_m(x), w_m)$.

In [None]:
from src.rankers import Ranker3

In [None]:
EPOCHS_NEW = 2 #same number of epochs

ranker_config3 = Ranker3(cbm_model_f, cbm_model_o, lr=1e-5, margin=5)

trainer = pl.Trainer(
    accelerator="gpu",  # Change to "cpu" if you are not running on a GPU!
    devices="auto", 
    max_epochs=EPOCHS_NEW,  # The number of epochs we will train our model for #ORIGINAL 500
    check_val_every_n_epoch=1,  # And how often we will check for validation metrics
    logger=False,  # No logs to be dumped for this trainer
)

trainer.fit(ranker_config3, train_dl, val_dl)

In [None]:
# Define input structure
image_size = 3 * 299 * 299  # 268203
concepts_size = 112
weights_size = 22600

# Get weight vector
W = ranker_config3.lay[0].weight.detach().cpu().numpy().flatten()
total_features = image_size+concepts_size+weights_size
W = W[:total_features]

# Compute absolute importance
importance = np.abs(W)

# Get indices of top 50 features
top_k = 2500
top_indices = np.argsort(-importance)[:top_k]  # Sort in descending order

# Function to determine feature type
def get_feature_category(index):
    if index < image_size:
        return "Image"
    elif index < image_size + concepts_size:
        return "Concept"
    else:
        return "Weight"

counter = {
    "Image":(0, 0),
    "Concept":(0, 0),
    "Weight":(0, 0),
}

# Print categorized results
print(f"Top {top_k} Most Important Features by Category:")
for rank, idx in enumerate(top_indices):
    category = get_feature_category(idx)
    counter[category] = (counter[category][0]+1, counter[category][1]+importance[idx])
    #print(f"Rank {rank+1}: Feature {idx} ({category}) - Weight {W[idx]:.6f} (Importance: {importance[idx]:.6f})")
counter

In [None]:
# Trainer saving
trainer.save_checkpoint(f"/kaggle/working/ranker_config3_{EPOCHS_NEW}.ckpt")
custom_saving(ranker_config3, EPOCHS_NEW)

### Config 2
without input image: $(c, w, y)$. Ranker $r(c_m(x), w_m, m(x))$.

In [None]:
from src.rankers import Ranker2

In [None]:
EPOCHS_NEW = 2 #same number of epochs

ranker_config2 = Ranker2(cbm_model_f, cbm_model_o, lr=1e-5, margin=5)

trainer = pl.Trainer(
    accelerator="gpu",  # Change to "cpu" if you are not running on a GPU!
    devices="auto", 
    max_epochs=EPOCHS_NEW,  # The number of epochs we will train our model for #ORIGINAL 500
    check_val_every_n_epoch=1,  # And how often we will check for validation metrics
    logger=False,  # No logs to be dumped for this trainer
)

trainer.fit(ranker_config2, train_dl, val_dl)

In [None]:
concepts_size = 112
weights_size = 22600
labels_size = 200
total_features = concepts_size + weights_size + labels_size  # Excluding bias

# Get weight vector
W = ranker_config2.lay[0].weight.detach().cpu().numpy().flatten()

# Ignore bias term (last weight in W)
W = W[:total_features]

# Compute absolute importance
importance = np.abs(W)

# Get indices of top K features
top_k = 2500
top_indices = np.argsort(-importance)[:top_k]  # Sort in descending order

# Function to determine feature type
def get_feature_category(index):
    if index < concepts_size:
        return "Concept"
    elif index < concepts_size + weights_size:
        return "Weight"
    else:
        return "Label"

counter = {
    "Concept": (0, 0),
    "Weight": (0, 0),
    "Label": (0, 0),
}

# Print categorized results
print(f"Top {top_k} Most Important Features by Category:")
for rank, idx in enumerate(top_indices):
    category = get_feature_category(idx)
    counter[category] = (counter[category][0] + 1, counter[category][1] + importance[idx])

counter

In [None]:
# Trainer saving
trainer.save_checkpoint(f"/kaggle/working/ranker_config2_{EPOCHS_NEW}.ckpt")
custom_saving(ranker_config2, EPOCHS_NEW)

## Step 5: Evaluation

In [None]:
def evaluate_model(dataloader, ranker):
    #Now we are ready to generate the concept, label, and embedding predictions for
    #the test set using our trained CEM:

    # We will use a Trainer object to run inference in batches over our test
    # dataset
    trainer_inference = pl.Trainer(
        accelerator="gpu",
        devices="auto",
        logger=False, # No logs to be dumped for this trainer
    )
    batch_results = trainer_inference.predict(ranker, dataloader)
    average_loss = torch.mean(torch.stack(batch_results))
    print(f"The average test loss is: {average_loss.item()}")

In [None]:
evaluate_model(test_dl, ranker_config1)

In [None]:
evaluate_model(test_dl, ranker_config2)

In [None]:
evaluate_model(test_dl, ranker_config3)

In [None]:
evaluate_model(test_dl, ranker_config4)