## Environment Setup and Dataset Preparation


In [1]:
pip -q install torch torchvision transformers openai requests Pillow numpy scikit-learn matplotlib seaborn

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
import requests
import tarfile
import os

dataset_url = "http://vision.stanford.edu/aditya86/ImageNetDogs/images.tar"
dataset_path = "/content/drive/MyDrive/CSCI2470/final/stanford_dogs_dataset"

if not os.path.exists(dataset_path):
    os.makedirs(dataset_path)
    print(f"Downloading Stanford Dogs dataset from {dataset_url}...")
    response = requests.get(dataset_url, stream=True)
    response.raise_for_status() # Raise an exception for bad status codes

    tar_file_path = os.path.join(dataset_path, "images.tar")
    with open(tar_file_path, 'wb') as f:
        for chunk in response.iter_content(chunk_size=8192):
            f.write(chunk)
    print("Download complete. Extracting...")

    with tarfile.open(tar_file_path, 'r') as tar:
        tar.extractall(path=dataset_path)
    print("Extraction complete.")
    os.remove(tar_file_path) # Clean up the tar file after extraction
else:
    print("Stanford Dogs dataset already downloaded and extracted.")

images_root = os.path.join(dataset_path, 'Images')
if not os.path.exists(images_root):
    print(f"Warning: 'Images' subdirectory not found in {dataset_path}. Please check dataset structure.")
else:
    print(f"Stanford Dogs images root directory: {images_root}")

Stanford Dogs dataset already downloaded and extracted.
Stanford Dogs images root directory: /content/drive/MyDrive/CSCI2470/final/stanford_dogs_dataset/Images


In [4]:
import torch
from torchvision import transforms
from PIL import Image
import os

class StanfordDogsDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.class_to_idx = {}
        self.idx_to_class = {}

        # Collect image paths and assign labels
        for i, class_name in enumerate(sorted(os.listdir(root_dir))):
            if not class_name.startswith('n'): # Skip non-breed directories
                continue
            self.class_to_idx[class_name] = i
            self.idx_to_class[i] = class_name
            class_path = os.path.join(root_dir, class_name)
            for img_name in os.listdir(class_path):
                self.image_paths.append(os.path.join(class_path, img_name))
                self.labels.append(i)

        print(f"Found {len(self.image_paths)} images belonging to {len(self.class_to_idx)} classes.")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

In [5]:
import torch
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import DataLoader
import logging

logging.getLogger("transformers").setLevel(logging.ERROR)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

clip_model_name = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(clip_model_name)
clip_model = CLIPModel.from_pretrained(clip_model_name).to(device)
clip_model.eval()
print("CLIP ready")

Using device: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


CLIP ready


In [6]:
import os
import torch
from torch.utils.data import DataLoader

dataset = StanfordDogsDataset(root_dir=images_root, transform=None)

emb_dir = os.path.join(dataset_path, 'embeddings')
os.makedirs(emb_dir, exist_ok=True)
emb_path = os.path.join(emb_dir, 'clip_embeddings.pt')

if os.path.exists(emb_path):
    print(f"Embedding file {emb_path} already exists. Skipping generation.")
    data = torch.load(emb_path)
    embeddings = data["embeddings"]
    labels = data["labels"]
else:
    loader = DataLoader(
        dataset,
        batch_size=64,
        shuffle=False,
        num_workers=2,
        collate_fn=lambda batch: batch  # list of (image, label)
    )

    all_embeddings = []
    all_labels = []

    with torch.no_grad():
        for batch in loader:
            images, labels_batch = zip(*batch)
            inputs = processor(images=list(images), return_tensors="pt")
            inputs = {k: v.to(device) for k, v in inputs.items()}

            emb = clip_model.get_image_features(**inputs)
            emb = emb / emb.norm(dim=-1, keepdim=True)

            all_embeddings.append(emb.cpu())
            all_labels.append(torch.tensor(labels_batch))

    embeddings = torch.cat(all_embeddings)
    labels = torch.cat(all_labels)

    torch.save(
        {
            "embeddings": embeddings,
            "labels": labels
        },
        emb_path
    )
    print(f"Saved embeddings to {emb_path}")


Found 20580 images belonging to 120 classes.
Embedding file /content/drive/MyDrive/CSCI2470/final/stanford_dogs_dataset/embeddings/clip_embeddings.pt already exists. Skipping generation.


In [7]:
import torch
from torch.utils.data import Dataset

class StanfordDogsEmbeddingDataset(Dataset):
    def __init__(self, emb_path):
        data = torch.load(emb_path)
        self.embeddings = data["embeddings"]
        self.labels = data["labels"]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]


In [8]:
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split
import numpy as np

# embedding Dataset
full_dataset = StanfordDogsEmbeddingDataset(emb_path)

all_labels = [full_dataset[i][1].item() for i in range(len(full_dataset))]
all_indices = np.arange(len(full_dataset))

use_full_data = True

if use_full_data:
    train_idx, temp_idx = train_test_split(
        all_indices,
        train_size=0.7,
        stratify=all_labels,
        random_state=42
    )

    temp_labels = [all_labels[i] for i in temp_idx]
    val_idx, test_idx = train_test_split(
        temp_idx,
        test_size=0.5,
        stratify=temp_labels,
        random_state=42
    )
else:
    # 10% samples
    sample_indices, _ = train_test_split(
        all_indices,
        train_size=0.1,
        stratify=all_labels,
        random_state=42
    )

    sample_labels = [all_labels[i] for i in sample_indices]

    train_idx, temp_idx = train_test_split(
        sample_indices,
        train_size=0.7,
        stratify=sample_labels,
        random_state=42
    )

    temp_labels = [all_labels[i] for i in temp_idx]
    val_idx, test_idx = train_test_split(
        temp_idx,
        test_size=0.5,
        stratify=temp_labels,
        random_state=42
    )

train_dataset = Subset(full_dataset, train_idx)
val_dataset = Subset(full_dataset, val_idx)
test_dataset = Subset(full_dataset, test_idx)

print(f"Using full data: {use_full_data}")
print(
    "train =", len(train_dataset),
    "val =", len(val_dataset),
    "test =", len(test_dataset)
)


Using full data: True
train = 14405 val = 3087 test = 3088


In [9]:
from torch.utils.data import DataLoader

batch_size = 64

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print("DataLoaders created successfully.")
print(f"Train DataLoader has {len(train_dataloader)} batches.")
print(f"Validation DataLoader has {len(val_dataloader)} batches.")
print(f"Test DataLoader has {len(test_dataloader)} batches.")


DataLoaders created successfully.
Train DataLoader has 226 batches.
Validation DataLoader has 49 batches.
Test DataLoader has 49 batches.


## Load LLM generated concepts & concept matrix

In [10]:
import json

# Extract unique dog breed names
dog_breed_names = [name.split('-')[-1].replace('_', ' ') for name in dataset.class_to_idx.keys()]
dog_breed_class_ids = list(dataset.class_to_idx.values())

# Create a mapping from cleaned breed name to class ID
breed_name_to_class_id = {name.split('-')[-1].replace('_', ' '): class_id for name, class_id in dataset.class_to_idx.items()}
class_id_to_breed_name = {v: k.split('-')[-1].replace('_', ' ') for k, v in dataset.class_to_idx.items()}

print(f"Extracted {len(dog_breed_names)} unique dog breed names.")
print(f"First 5 breed names: {dog_breed_names[:5]}")

dog_breed_names_path = os.path.join(dataset_path, 'dog_breed_names.txt')

if not os.path.exists(dog_breed_names_path):
    with open(dog_breed_names_path, "w") as f:
        json.dump(dog_breed_names, f)
    print(f"Saved {dog_breed_names_path}")
else:
    print(f"{dog_breed_names_path} already exists, skipping save.")

Extracted 120 unique dog breed names.
First 5 breed names: ['Chihuahua', 'Japanese spaniel', 'Maltese dog', 'Pekinese', 'Tzu']
/content/drive/MyDrive/CSCI2470/final/stanford_dogs_dataset/dog_breed_names.txt already exists, skipping save.


In [11]:
llm_concepts = ['long coat', 'short coat', 'wir y coat', 'curly coat', 'double coat', 'single coat', 'beard', 'mustache', 'facial wrinkles', 'bushy eyebrows', 'ears erect', 'ears droopy', 'tail curled', 'tail docked', 'tail carried high', 'tail carried low', 'muzzle long', 'muzzle short', 'facial mask', 'neck ruff', 'feathering on legs', 'dewclaws present', 'dewclaws absent', 'webbed feet', 'brindle pattern', 'merle pattern', 'spotted coat', 'solid coat', 'jowls present']

In [12]:
concept_matrix_path = '/content/drive/MyDrive/CSCI2470/final/outputs/concept_matrix.json'

with open(concept_matrix_path, "r", encoding="utf-8") as f:
    concept_matrix = json.load(f)

## Baseline CLIP Linear Probing

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

class CLIPLinearProbe(nn.Module):
    def __init__(self, clip_embed_dim=512, num_classes=120):
        super().__init__()
        self.fc = nn.Linear(clip_embed_dim, num_classes)
    def forward(self, x):
        return self.fc(x)

def train_clip_baseline(model, train_loader, val_loader, num_epochs=10, lr=1e-3, weight_decay=1e-4):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0
        all_preds, all_labels = [], []

        for feats, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            feats = feats.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            logits = model(feats)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item() * len(feats)
            preds = torch.argmax(logits, dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

        all_preds = torch.cat(all_preds)
        all_labels = torch.cat(all_labels)
        acc = (all_preds == all_labels).float().mean().item()
        print(f"[Train] Epoch {epoch+1} Loss: {epoch_loss/len(train_loader.dataset):.4f}, Top-1 Acc: {acc:.4f}")

        if epoch + 1 % 5 == 0:
            val_top1, val_top5, val_loss = evaluate_clip_baseline(model, val_loader, device)
            print(f"[Val] Epoch {epoch+1} Loss: {val_loss:.4f}, Top-1: {val_top1:.4f}, Top-5: {val_top5:.4f}")

def evaluate_clip_baseline(model, dataloader, device='cuda'):
    model.eval()
    criterion = nn.CrossEntropyLoss()
    total_loss = 0
    total_top1_correct = 0
    total_top5_correct = 0
    total_samples = 0

    with torch.no_grad():
        for feats, labels in dataloader:
            feats = feats.to(device)
            labels = labels.to(device)

            logits = model(feats)
            loss = criterion(logits, labels)
            total_loss += loss.item() * len(feats)
            total_samples += len(feats)

            # Top-1
            preds = torch.argmax(logits, dim=1)
            total_top1_correct += (preds == labels).sum().item()

            # Top-5
            top5_preds = torch.topk(logits, k=5, dim=1).indices
            for i in range(len(labels)):
                if labels[i] in top5_preds[i]:
                    total_top5_correct += 1

    top1_acc = total_top1_correct / total_samples
    top5_acc = total_top5_correct / total_samples
    avg_loss = total_loss / total_samples

    return top1_acc, top5_acc, avg_loss


In [14]:
baseline_model = CLIPLinearProbe(clip_embed_dim=512, num_classes=120)

train_clip_baseline(baseline_model, train_dataloader, val_loader=val_dataloader, num_epochs=25, lr=1e-3)

test_top1, test_top5, test_loss = evaluate_clip_baseline(baseline_model, test_dataloader)
print(f"[Test] Top-1: {test_top1:.4f}, Top-5: {test_top5:.4f}, Loss: {test_loss:.4f}")


Epoch 1/25: 100%|██████████| 226/226 [00:01<00:00, 181.01it/s]


[Train] Epoch 1 Loss: 4.6341, Top-1 Acc: 0.1377


Epoch 2/25: 100%|██████████| 226/226 [00:00<00:00, 374.77it/s]


[Train] Epoch 2 Loss: 4.3419, Top-1 Acc: 0.3243


Epoch 3/25: 100%|██████████| 226/226 [00:00<00:00, 314.65it/s]


[Train] Epoch 3 Loss: 4.0920, Top-1 Acc: 0.4208


Epoch 4/25: 100%|██████████| 226/226 [00:00<00:00, 355.87it/s]


[Train] Epoch 4 Loss: 3.8753, Top-1 Acc: 0.4875


Epoch 5/25: 100%|██████████| 226/226 [00:00<00:00, 383.57it/s]


[Train] Epoch 5 Loss: 3.6874, Top-1 Acc: 0.5111


Epoch 6/25: 100%|██████████| 226/226 [00:00<00:00, 380.10it/s]


[Train] Epoch 6 Loss: 3.5256, Top-1 Acc: 0.5588


Epoch 7/25: 100%|██████████| 226/226 [00:00<00:00, 361.23it/s]


[Train] Epoch 7 Loss: 3.3854, Top-1 Acc: 0.5707


Epoch 8/25: 100%|██████████| 226/226 [00:00<00:00, 370.33it/s]


[Train] Epoch 8 Loss: 3.2641, Top-1 Acc: 0.5858


Epoch 9/25: 100%|██████████| 226/226 [00:00<00:00, 325.21it/s]


[Train] Epoch 9 Loss: 3.1597, Top-1 Acc: 0.6021


Epoch 10/25: 100%|██████████| 226/226 [00:00<00:00, 283.47it/s]


[Train] Epoch 10 Loss: 3.0691, Top-1 Acc: 0.6126


Epoch 11/25: 100%|██████████| 226/226 [00:00<00:00, 319.71it/s]


[Train] Epoch 11 Loss: 2.9902, Top-1 Acc: 0.6230


Epoch 12/25: 100%|██████████| 226/226 [00:00<00:00, 562.27it/s]


[Train] Epoch 12 Loss: 2.9216, Top-1 Acc: 0.6287


Epoch 13/25: 100%|██████████| 226/226 [00:00<00:00, 530.44it/s]


[Train] Epoch 13 Loss: 2.8617, Top-1 Acc: 0.6377


Epoch 14/25: 100%|██████████| 226/226 [00:00<00:00, 545.29it/s]


[Train] Epoch 14 Loss: 2.8097, Top-1 Acc: 0.6372


Epoch 15/25: 100%|██████████| 226/226 [00:00<00:00, 528.22it/s]


[Train] Epoch 15 Loss: 2.7642, Top-1 Acc: 0.6482


Epoch 16/25: 100%|██████████| 226/226 [00:00<00:00, 546.86it/s]


[Train] Epoch 16 Loss: 2.7244, Top-1 Acc: 0.6494


Epoch 17/25: 100%|██████████| 226/226 [00:00<00:00, 533.92it/s]


[Train] Epoch 17 Loss: 2.6889, Top-1 Acc: 0.6562


Epoch 18/25: 100%|██████████| 226/226 [00:00<00:00, 523.11it/s]


[Train] Epoch 18 Loss: 2.6582, Top-1 Acc: 0.6580


Epoch 19/25: 100%|██████████| 226/226 [00:00<00:00, 494.42it/s]


[Train] Epoch 19 Loss: 2.6308, Top-1 Acc: 0.6644


Epoch 20/25: 100%|██████████| 226/226 [00:00<00:00, 452.27it/s]


[Train] Epoch 20 Loss: 2.6065, Top-1 Acc: 0.6669


Epoch 21/25: 100%|██████████| 226/226 [00:00<00:00, 444.10it/s]


[Train] Epoch 21 Loss: 2.5854, Top-1 Acc: 0.6677


Epoch 22/25: 100%|██████████| 226/226 [00:00<00:00, 452.39it/s]


[Train] Epoch 22 Loss: 2.5662, Top-1 Acc: 0.6666


Epoch 23/25: 100%|██████████| 226/226 [00:00<00:00, 475.60it/s]


[Train] Epoch 23 Loss: 2.5492, Top-1 Acc: 0.6748


Epoch 24/25: 100%|██████████| 226/226 [00:00<00:00, 435.51it/s]


[Train] Epoch 24 Loss: 2.5341, Top-1 Acc: 0.6753


Epoch 25/25: 100%|██████████| 226/226 [00:00<00:00, 459.73it/s]


[Train] Epoch 25 Loss: 2.5207, Top-1 Acc: 0.6771
[Test] Top-1: 0.6218, Top-5: 0.9258, Loss: 2.5877


## CBM Train & Test

In [15]:
import torch.nn as nn

class ConceptPredictor(nn.Module):
    def __init__(self, in_dim=512, num_concepts=20):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, 256), nn.ReLU(),
            nn.Linear(256, num_concepts)
        )
    def forward(self, feats):
        return self.mlp(feats)

class LabelPredictor(nn.Module):
    def __init__(self, num_concepts=20, num_classes=120):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(num_concepts, 128), nn.ReLU(),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, c):
        return self.mlp(c)

class CBM(nn.Module):
    def __init__(self, num_concepts, num_classes):
        super().__init__()
        self.concept_pred = ConceptPredictor(512, num_concepts)
        self.classifier = LabelPredictor(num_concepts, num_classes)

    def forward(self, x):
        concepts = self.concept_pred(x)
        logits   = self.classifier(concepts)
        return concepts, logits

cbm_model = CBM(num_concepts=len(llm_concepts), num_classes=len(dog_breed_names)).to(device)


In [16]:
def train_cbm(cbm_model, train_loader, val_loader, concept_matrix, num_epochs_concept=15, num_epochs_label=10, lr=1e-3):
    cbm_model.to(device)

    # Concept Predictor
    optimizer_c = torch.optim.Adam(cbm_model.concept_pred.parameters(), lr=lr, weight_decay=1e-4)
    criterion_c = nn.BCEWithLogitsLoss()

    for epoch in range(num_epochs_concept):
        cbm_model.train()
        epoch_loss = 0
        for feats, class_labels in train_loader:
            feats = feats.to(device)
            class_labels = class_labels.to(device)

            # construct concept labels from LLM concept matrix
            concept_labels = torch.tensor(
                [
                    [concept_matrix[class_id_to_breed_name[cid.item()]][c] for c in llm_concepts]
                    for cid in class_labels
                ],
                dtype=torch.float32,
                device=device
            )

            optimizer_c.zero_grad()
            concepts, _ = cbm_model(feats)
            loss = criterion_c(concepts, concept_labels)
            loss.backward()
            optimizer_c.step()
            epoch_loss += loss.item() * len(feats)
        if (epoch + 1) % 5 == 0:
          print(f"[Concept] Epoch {epoch+1}/{num_epochs_concept} Loss: {epoch_loss/len(train_loader.dataset):.4f}")

    # Freeze Concept Predictor
    for param in cbm_model.concept_pred.parameters():
        param.requires_grad = False

    # Label Predictor
    optimizer_l = torch.optim.Adam(cbm_model.classifier.parameters(), lr=lr)
    criterion_l = nn.CrossEntropyLoss()

    for epoch in range(num_epochs_label):
        cbm_model.train()
        epoch_loss = 0
        for feats, class_labels in train_loader:
            feats = feats.to(device)
            class_labels = class_labels.to(device)

            optimizer_l.zero_grad()
            concepts, logits = cbm_model(feats)
            loss = criterion_l(logits, class_labels)
            loss.backward()
            optimizer_l.step()
            epoch_loss += loss.item() * len(feats)
        if (epoch + 1) % 5 == 0:
          print(f"[Label] Epoch {epoch+1}/{num_epochs_label} Loss: {epoch_loss/len(train_loader.dataset):.4f}")


In [17]:
def evaluate_cbm(cbm_model, dataloader, concept_matrix=None):
    cbm_model.eval()
    criterion_l = nn.CrossEntropyLoss()
    criterion_c = nn.BCEWithLogitsLoss()

    total_loss_label = 0.0
    total_loss_concept = 0.0
    total_correct_top1 = 0
    total_correct_top5 = 0
    total_samples = 0

    with torch.no_grad():
        for feats, class_labels in dataloader:
            feats = feats.to(device)
            class_labels = class_labels.to(device)
            batch_size = class_labels.size(0)
            total_samples += batch_size

            # ---------- Concept Predictor ----------
            concepts_pred, logits = cbm_model(feats)

            # Compute concept loss if concept_matrix given
            if concept_matrix is not None:
                concept_labels = torch.tensor(
                    [
                        [concept_matrix[class_id_to_breed_name[cid.item()]][c] for c in llm_concepts]
                        for cid in class_labels
                    ],
                    dtype=torch.float32,
                    device=device
                )
                loss_c = criterion_c(concepts_pred, concept_labels)
                total_loss_concept += loss_c.item() * batch_size

            # ---------- Label Predictor ----------
            loss_l = criterion_l(logits, class_labels)
            total_loss_label += loss_l.item() * batch_size

            # ---------- Predictions ----------
            # Top-1
            preds_top1 = torch.argmax(logits, dim=1)
            total_correct_top1 += (preds_top1 == class_labels).sum().item()
            # Top-5
            preds_top5 = torch.topk(logits, k=5, dim=1)[1]  # [B, 5]
            total_correct_top5 += (preds_top5 == class_labels.view(-1, 1)).any(dim=1).sum().item()

    accuracy_top1 = total_correct_top1 / total_samples
    accuracy_top5 = total_correct_top5 / total_samples
    avg_loss_label = total_loss_label / total_samples
    avg_loss_concept = total_loss_concept / total_samples if concept_matrix is not None else None

    return accuracy_top1, accuracy_top5, avg_loss_label, avg_loss_concept


In [18]:
train_cbm(cbm_model, train_dataloader, val_dataloader, concept_matrix, num_epochs_concept=20, num_epochs_label=35)

[Concept] Epoch 5/20 Loss: 0.3282
[Concept] Epoch 10/20 Loss: 0.3136
[Concept] Epoch 15/20 Loss: 0.3069
[Concept] Epoch 20/20 Loss: 0.3024
[Label] Epoch 5/35 Loss: 1.4371
[Label] Epoch 10/35 Loss: 1.2695
[Label] Epoch 15/35 Loss: 1.1836
[Label] Epoch 20/35 Loss: 1.1182
[Label] Epoch 25/35 Loss: 1.0723
[Label] Epoch 30/35 Loss: 1.0468
[Label] Epoch 35/35 Loss: 1.0015


In [20]:
# ---------- Training set ----------
train_top1, train_top5, train_loss_label, train_loss_concept = evaluate_cbm(
    cbm_model, train_dataloader, concept_matrix
)
print(f"[Train] Top-1: {train_top1:.4f}, Top-5: {train_top5:.4f}, Label Loss: {train_loss_label:.4f}, Concept Loss: {train_loss_concept:.4f}")

# ---------- Test set ----------
test_top1, test_top5, test_loss_label, test_loss_concept = evaluate_cbm(
    cbm_model, test_dataloader, concept_matrix
)
print(f"[Test] Top-1: {test_top1:.4f}, Top-5: {test_top5:.4f}, Label Loss: {test_loss_label:.4f}, Concept Loss: {test_loss_concept:.4f}")

# ---------- Validation ----------
val_top1, val_top5, val_loss_label, val_loss_concept = evaluate_cbm(
    cbm_model, val_dataloader, concept_matrix
)
print(f"[Test] Top-1: {val_top1:.4f}, Top-5: {val_top5:.4f}, Label Loss: {val_loss_label:.4f}, Concept Loss: {val_loss_concept:.4f}")


[Train] Top-1: 0.6816, Top-5: 0.9515, Label Loss: 0.9890, Concept Loss: 0.3001
[Test] Top-1: 0.6205, Top-5: 0.9317, Label Loss: 1.2203, Concept Loss: 0.3039
[Test] Top-1: 0.6168, Top-5: 0.9252, Label Loss: 1.2576, Concept Loss: 0.3047


## Label-free CBM

Reference:

https://github.com/Trustworthy-ML-Lab/Label-free-CBM


https://arxiv.org/pdf/2304.06129

In [21]:
concepts_path = '/content/drive/MyDrive/CSCI2470/final/outputs/label_free_concepts.json'

with open(concepts_path, "r", encoding="utf-8") as f:
    breed_to_concepts = json.load(f)

all_concepts = []
for concepts in breed_to_concepts.values():
    all_concepts.extend(concepts)

concepts_inputs = processor(text=all_concepts, padding=True, truncation=True, return_tensors="pt").to(device)

with torch.no_grad():
    text_features = clip_model.get_text_features(**concepts_inputs)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)


In [22]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

class ConceptProj(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, 512),
            nn.ReLU(),
            nn.Linear(512, out_dim)
        )
    def forward(self, x):
        return self.mlp(x)

class LinearClassifier(nn.Module):
    def __init__(self, in_dim, num_classes):
        super().__init__()
        self.linear = nn.Linear(in_dim, num_classes)
    def forward(self, x):
        return self.linear(x)

def train_cbm_from_embeddings(train_loader, val_loader, text_features,
                              device='cuda', proj_steps=500, interpret_cutoff=0.2,
                              lr=1e-3, num_epochs_cls=20, batch_size_cls=64):

    all_feats, all_labels = [], []
    with torch.no_grad():
        for feats, labels in train_loader:
            all_feats.append(feats.to(device))
            all_labels.append(labels.to(device))

    target_features = torch.cat(all_feats)  # [N, image_emb_dim]
    labels = torch.cat(all_labels)           # [N]

    # projection layer
    text_features = text_features.to(device)
    text_features = text_features / text_features.norm(dim=1, keepdim=True)  # normalize
    proj_layer = ConceptProj(target_features.size(1), text_features.size(1)).to(device)
    optimizer = torch.optim.Adam(proj_layer.parameters(), lr=lr)

    for step in range(proj_steps):
        optimizer.zero_grad()
        out = proj_layer(target_features)              # [N, text_emb_dim]
        # print("Projection output mean/std:", out.mean().item(), out.std().item())
        out = out / out.norm(dim=1, keepdim=True)      # normalize
        sim_matrix = out @ text_features.T            # [N, num_concepts]
        # print("Initial sim_matrix mean/std:", sim_matrix.mean().item(), sim_matrix.std().item())
        loss = -sim_matrix.mean()
        loss.backward()
        optimizer.step()
        if step % 50 == 0:
            print(f"Proj Step {step}, Loss {loss.item():.4f}")

    # concept features
    with torch.no_grad():
        proj_out = proj_layer(target_features)
        proj_out = proj_out / proj_out.norm(dim=1, keepdim=True)
        concept_feats = proj_out @ text_features.T   # [N, num_concepts]

        # interpretability cutoff
        sim_per_concept = concept_feats.mean(dim=0)  # [num_concepts]
        active_idx = sim_per_concept > interpret_cutoff
        print(f"Keeping {active_idx.sum().item()}/{text_features.size(0)} concepts")

        concept_feats = concept_feats[:, active_idx]

    mean = concept_feats.mean(dim=0, keepdim=True)
    std  = concept_feats.std(dim=0, keepdim=True)
    concept_feats = (concept_feats - mean) / (std + 1e-8)

    # label predictor
    classifier = LinearClassifier(concept_feats.size(1), labels.max().item()+1).to(device)
    optimizer_cls = torch.optim.Adam(classifier.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    dataset = TensorDataset(concept_feats, labels)
    loader = DataLoader(dataset, batch_size=batch_size_cls, shuffle=True)

    for epoch in range(num_epochs_cls):
        classifier.train()
        total_loss = 0
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            optimizer_cls.zero_grad()
            logits = classifier(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer_cls.step()
            total_loss += loss.item() * x.size(0)
        if (epoch + 1) % 5 == 0:
          print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataset):.4f}")

    return proj_layer, classifier, mean, std, active_idx


In [23]:
def test_cbm_topk(test_loader, proj_layer, classifier, text_features, mean, std, active_idx, device='cuda'):
    proj_layer.eval()
    classifier.eval()
    text_features = text_features.to(device)
    text_features = text_features / text_features.norm(dim=1, keepdim=True)

    all_labels = []
    top1_correct = 0
    top5_correct = 0
    total = 0

    with torch.no_grad():
        for feats, labels in test_loader:
            feats = feats.to(device)
            labels = labels.to(device)
            batch_size = feats.size(0)

            proj_out = proj_layer(feats)
            proj_out = proj_out / proj_out.norm(dim=1, keepdim=True)

            # generate concept features
            concept_feats = proj_out @ text_features.T
            concept_feats = concept_feats[:, active_idx]
            concept_feats = (concept_feats - mean) / (std + 1e-8)

            logits = classifier(concept_feats)

            # top k predictions
            top5_preds = torch.topk(logits, k=5, dim=1).indices
            top1_preds = torch.argmax(logits, dim=1)

            top1_correct += (top1_preds == labels).sum().item()
            top5_correct += sum([1 if labels[i] in top5_preds[i] else 0 for i in range(batch_size)])
            total += batch_size

    top1_acc = top1_correct / total
    top5_acc = top5_correct / total
    print(f"Top-1 Accuracy: {top1_acc*100:.2f}%")
    print(f"Top-5 Accuracy: {top5_acc*100:.2f}%")
    return top1_acc, top5_acc


In [24]:
proj_layer, classifier, mean, std, active_idx = train_cbm_from_embeddings(
    train_loader=train_dataloader,
    val_loader=val_dataloader,
    text_features=text_features,
    device=device,
    proj_steps=200,
    interpret_cutoff=0.85,
    lr=3e-3,
    num_epochs_cls=100,
    batch_size_cls=64
)


Proj Step 0, Loss 0.0348
Proj Step 50, Loss -0.8707
Proj Step 100, Loss -0.8712
Proj Step 150, Loss -0.8712
Keeping 425/585 concepts
Epoch 5, Loss: 2.2841
Epoch 10, Loss: 1.9162
Epoch 15, Loss: 1.7122
Epoch 20, Loss: 1.5785
Epoch 25, Loss: 1.4767
Epoch 30, Loss: 1.4006
Epoch 35, Loss: 1.3526
Epoch 40, Loss: 1.2985
Epoch 45, Loss: 1.2509
Epoch 50, Loss: 1.2609
Epoch 55, Loss: 1.2315
Epoch 60, Loss: 1.1609
Epoch 65, Loss: 1.1310
Epoch 70, Loss: 1.1709
Epoch 75, Loss: 1.0545
Epoch 80, Loss: 1.0626
Epoch 85, Loss: 1.0443
Epoch 90, Loss: 1.0317
Epoch 95, Loss: 1.0144
Epoch 100, Loss: 1.0093


In [25]:
# ---------- Training set ----------
train_top1, train_top5 = test_cbm_topk(
    test_loader=train_dataloader,
    proj_layer=proj_layer,
    classifier=classifier,
    text_features=text_features,
    mean=mean,
    std=std,
    active_idx=active_idx,
    device=device
)
print(f"[Train] Top-1: {train_top1:.4f}, Top-5: {train_top5:.4f}")

# ---------- Validation set ----------
val_top1, val_top5 = test_cbm_topk(
    test_loader=val_dataloader,
    proj_layer=proj_layer,
    classifier=classifier,
    text_features=text_features,
    mean=mean,
    std=std,
    active_idx=active_idx,
    device=device
)
print(f"[Val] Top-1: {val_top1:.4f}, Top-5: {val_top5:.4f}")

# ---------- Test set ----------
test_top1, test_top5 = test_cbm_topk(
    test_loader=test_dataloader,
    proj_layer=proj_layer,
    classifier=classifier,
    text_features=text_features,
    mean=mean,
    std=std,
    active_idx=active_idx,
    device=device
)
print(f"[Test] Top-1: {test_top1:.4f}, Top-5: {test_top5:.4f}")


Top-1 Accuracy: 74.99%
Top-5 Accuracy: 94.95%
[Train] Top-1: 0.7499, Top-5: 0.9495
Top-1 Accuracy: 50.63%
Top-5 Accuracy: 79.46%
[Val] Top-1: 0.5063, Top-5: 0.7946
Top-1 Accuracy: 51.10%
Top-5 Accuracy: 80.89%
[Test] Top-1: 0.5110, Top-5: 0.8089
