In [2]:
import os
import argparse
import wandb
import re
import clip

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# add the parent directory to the path
import sys
current_dir = os.getcwd()
parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.insert(0, parent_dir)
                
from utils.dcbm import *



# Load data & Apply Concept extraction

In [3]:
# ----------------- Fix -----------------
embed_path = "../data/Embeddings/"
dataset = "cub"
class_labels_path = "../data/classes/cub_classes.txt"
segment_path = "../data/Segments/"
selected_image_concepts = "../data/Embeddings/subsets"

# ----------------- Hyperparameters -----------------

model_name = "CLIP-ViT-L14"  # "CLIP-ViT-L14", "CLIP-RN50", CLIP-ViT-B16

segmentation_technique = "SAM2"  # GDINO, SAM, SAM2, DETR, MaskRCNN
concept_name = None # Define for GDINE [awa, sun, sun-lowthresh, cub...]

device = "cpu"

clusters = 2048
cluster_method = "kmeans"  # "hierarchical", "kmeans"
centroid_method = "median"  # "mean", "median"

concept_per_class = 50  # How many images for each class 5,10,20,50, None

one_hot = False
epochs = 200
lambda_1 = 1e-4
lr = 1e-4
batch_size = 32

crop = False  # True without background

use_wandb = False
project = "YOUR_PROJECT_NAME"  # Define your own project name within wandb

In [5]:
cbm = CBM(
    embed_path, dataset, model_name, class_labels_path, device=device
)  # Initialize CBM

cbm.load_concepts(
    segment_path,
    segmentation_technique,
    concept_name,
    selected_image_concepts,
    concept_per_class,
    crop=crop,
)  # Load concepts with predefined segmentation technique and hyperparameters

if clusters is not None:  # if clustering is needed
    cbm.cluster_image_concepts(cluster_method, clusters, pca = True)
else:
    cbm.clustered_concepts = cbm.image_segments

cbm.centroid_concepts(
    centroid_method
)  # Calculate centroids of the concepts with given method

cbm.preprocess_data(
    type_="standard", label_type=one_hot
)  # preprocess data for training
cbm.train(  # train the model
    num_epochs=epochs,
    lambda_1=lambda_1,
    lr=lr,
    device=device,
    batch_size=batch_size,
    project=project,
    to_print=False,
    early_stopping_patience=None,
    one_hot=one_hot,
    use_wandb=use_wandb,
)

Train data loaded from ../Embeddings/images_CUB_train_CLIP-ViT-L14.torch
Validation data loaded from ../Embeddings/images_CUB_val_CLIP-ViT-L14.torch
Test data loaded from ../Embeddings/images_CUB_test_CLIP-ViT-L14.torch


  concepts_dict = torch.load(concept_path)


Concepts loaded from  ../Segments/Seg_embs/segcrop_CUB_200_2011_SAM2_CLIP-ViT-L14.torch
PCA:  True
Number of image embeddings:  38416
Number of clusters:  2048
Clustering method:  kmeans


Processing Batches: 100%|██████████| 6/6 [00:00<00:00, 367.28it/s]
Processing Batches: 100%|██████████| 1/1 [00:00<00:00, 686.47it/s]
Processing Batches: 100%|██████████| 6/6 [00:00<00:00, 676.85it/s]


(5394, 2048) (5794, 2048) (600, 2048) (5394, 200) (5794, 200) (600, 200)
learning rate:  0.0001
lambda_1:  0.0001


VBox(children=(Label(value='0.012 MB of 0.012 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))



0,1
background,▁
batch_size,▁
ce_loss,█▄▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
clusters,▁
l1_loss,▅▇████▇▇▆▆▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁
lambda_1,▁
lr,▁
one_hot,▁
test_accuracy,▁▅▇▇▇▇██████████████████████████████████
test_loss,█▄▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Max_class_num,All
background,False
batch_size,32
ce_loss,0.1797
centroid_method,median
cluster_technique,kmeans
clusters,2048
l1_loss,0.36913
lambda_1,0.0001
lr,0.0001


LinearProbe(
  (linear): Linear(in_features=2048, out_features=200, bias=True)
)