This notebook is inspired by the work of CEM (Zarlenga et al. 2022) and CBM (Koh et al. 2020) papers. Please visit their GitHub repositories:
[CEM GitHub](https://github.com/mateoespinosa/cem) and [CBM GitHub](https://github.com/yewsiang/ConceptBottleneck).

# CBM: oracle model training - V2 (trained on CF train & NCF test splits)

There are four main steps:
1. Loading the dataset.
2. Initializing a CBM with InceptionV3 vision backbone for the dataset.
3. Training the CBM

## Step 1: Load Data

The first step is to load the data. The designed CBM class with the PyTorch Lightning Trainer takes in PyTorch DataLoader object.
Furthermore, it needs to contain three elements (in the following order):
1. the sample image, $\mathbf{x}$
2. the image label, $\mathbf{y}$
3. the concept labels, in binary format, $\mathbf{c}$

In [None]:
#from cub_data_module import *
import logging
import torch
import numpy as np
import yaml
import torchvision.models as models
import pytorch_lightning as pl

In [None]:
#import sys #for Kaggle
#sys.path.append("/kaggle/usr/lib/cub_data_module_confounded") #for Kaggle

import cub_data_module_confounded_oracle_v2 as cub_data_module

In [None]:
def _update_config_with_dataset(
    config,
    train_dl,
    n_concepts,
    n_tasks,
    concept_map,
):
    config["n_concepts"] = config.get(
        "n_concepts",
        n_concepts,
    )
    config["n_tasks"] = config.get(
        "n_tasks",
        n_tasks,
    )
    config["concept_map"] = config.get(
        "concept_map",
        concept_map,
    )

    task_class_weights = None

    if config.get('use_task_class_weights', False):
        logging.info(
            f"Computing task class weights in the training dataset with "
            f"size {len(train_dl)}..."
        )
        attribute_count = np.zeros((max(n_tasks, 2),))
        samples_seen = 0
        for i, data in enumerate(train_dl):
            if len(data) == 2:
                (_, (y, _)) = data
            else:
                (_, y, _) = data
            if n_tasks > 1:
                y = torch.nn.functional.one_hot(
                    y,
                    num_classes=n_tasks,
                ).cpu().detach().numpy()
            else:
                y = torch.cat(
                    [torch.unsqueeze(1 - y, dim=-1), torch.unsqueeze(y, dim=-1)],
                    dim=-1,
                ).cpu().detach().numpy()
            attribute_count += np.sum(y, axis=0)
            samples_seen += y.shape[0]
        print("Class distribution is:", attribute_count / samples_seen)
        if n_tasks > 1:
            task_class_weights = samples_seen / attribute_count - 1
        else:
            task_class_weights = np.array(
                [attribute_count[0]/attribute_count[1]]
            )
    return task_class_weights

In [None]:
def _generate_dataset_and_update_config(
    experiment_config
):
    if experiment_config.get("dataset_config", None) is None:
        raise ValueError(
            "A dataset_config must be provided for each experiment run!"
        )

    dataset_config = experiment_config['dataset_config']
    logging.debug(
        f"The dataset's root directory is {dataset_config.get('root_dir')}"
    )
    intervention_config = experiment_config.get('intervention_config', {})
    if dataset_config["dataset"] == "cub":
        data_module = cub_data_module
    else:
        raise ValueError(f"Unsupported dataset {dataset_config['dataset']}!")

    train_dl, val_dl, test_dl, imbalance, (n_concepts, n_tasks, concept_map) = \
        data_module.generate_data(
            config=dataset_config,
            seed=42,
            output_dataset_vars=True,
            root_dir=dataset_config.get('root_dir', None),
        )
    # For now, we assume that all concepts have the same
    # aquisition cost
    acquisition_costs = None
    if concept_map is not None:
        intervened_groups = list(
            range(
                0,
                len(concept_map) + 1,
                intervention_config.get('intervention_freq', 1),
            )
        )
    else:
        intervened_groups = list(
            range(
                0,
                n_concepts + 1,
                intervention_config.get('intervention_freq', 1),
            )
        )

    task_class_weights = _update_config_with_dataset(
        config=experiment_config,
        train_dl=train_dl,
        n_concepts=n_concepts,
        n_tasks=n_tasks,
        concept_map=concept_map,
    )
    return (
        train_dl,
        val_dl,
        test_dl,
        imbalance,
        concept_map,
        intervened_groups,
        task_class_weights,
        acquisition_costs,
    )

In [None]:
yaml_path = "data/cub.yaml" # for local development, might need to use whole path.

with open(yaml_path, "r") as file:
    yaml_config = yaml.safe_load(file)
yaml_config["shared_params"]["dataset_config"]["root_dir"] = "/kaggle/input/cem-cub2000-filtered/" #for Kaggle, replace this with locally downloaded folder.
yaml_config["shared_params"]["dataset_config"]["num_workers"] = 4 #change depending on resources available.
yaml_config["shared_params"]["dataset_config"]["batch_size"] = 64 #change depending on resources available.

In [None]:
train_dl, val_dl, test_dl, imbalance, concept_map, intervened_groups, task_class_weights, acquisition_costs = _generate_dataset_and_update_config(yaml_config["shared_params"])

## Step 2: Create the CBM
### Step 2.1 Define model for input to concepts
We first need to define a architecture that will extract the concepts from the input image.

For this, we used a pre-trained InceptionV3 model. We remove the last linear layer and make one that we can use for our task, so it is ready for fine-tuning.

In [None]:
def latent_code_generator_model(output_dim=112):
    # Load pre-trained InceptionV3
    inception = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)

    # Remove auxiliary classifier (set to None)
    inception.aux_logits = False  # Disable aux_logits
    inception.AuxLogits = None  # Delete aux classifier branch

    inception.fc = torch.nn.Linear(2048, output_dim)  # Replace classification layer with output_dim

    return inception

### Step 2.2: define CBM model.
We need to define the following:
1. `n_concepts`: the number of concepts in the dataset (112).
2. `n_tasks`: the number of output labels in the dataset (200).
3. `concept_loss_weight`: the weight to use for the concept prediction loss during training of the CBM. Picked to be the same as the CEM paper.
4. `learning_rate` and `optimizer`: to use during training. Optimizer is Adam by default, otherwise SGD.
5. `c_extractor_arch`: the model architecture to use for going from the input space to the concepts.
6. `c2y_model` and `c2y_layers`: the model architecture to use for going from the concepts to the labels. It can be directly the model, like c_extractor_arch or the layers as a list. We choose to do a simple linear layer.

In [None]:
from src.utils_cbm import *
from src.cbm import ConceptBottleneckModel

In [None]:
cbm_model = ConceptBottleneckModel(
  n_concepts=112,
  n_tasks=200,
  concept_loss_weight=yaml_config["shared_params"]["concept_loss_weight"],
  learning_rate=yaml_config["shared_params"]["learning_rate"],
  optimizer="sgd", 
  c_extractor_arch=latent_code_generator_model,
  c2y_model=None
)
#print(cbm_model)

## Step 3: Train the CBM

Let's construct a PyTorch Lightning Trainer object to take care of the training.

In [None]:
EPOCHS = 30

trainer = pl.Trainer(
    accelerator="gpu", #or "cpu"
    devices="auto",
    max_epochs=EPOCHS,
    check_val_every_n_epoch=5,
    logger=False,
)

trainer.fit(cbm_model, train_dl, val_dl) # Train the model

trainer.save_checkpoint(f"/kaggle/working/cbm_{EPOCHS}.ckpt") # Save the trainer with the model

## Step 4: possible continue training
If you want to continue the training, you can do so by calling the trainer again and giving the path to the previous model in the fit(...) call

In [None]:
EPOCHS = 40

trainer = pl.Trainer(
    accelerator="gpu", #or "cpu"
    devices="auto",
    max_epochs=EPOCHS,
    check_val_every_n_epoch=5,
    logger=False,
)

trainer.fit(cbm_model, train_dl, val_dl, ckpt_path="/kaggle/working/cbm_30.ckpt")  # Train the model starting from the previous checkpoint

trainer.save_checkpoint(f"/kaggle/working/cbm_{EPOCHS}.ckpt") # Save the trainer with the model

## Step 5: evaluation
This will be done in a seperate notebook