In [1]:
!python -m spacy download en_core_web_sm

Collecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
     ---------------------------------------- 0.0/12.8 MB ? eta -:--:--
     --------------------------------- ----- 11.0/12.8 MB 68.5 MB/s eta 0:00:01
     ---------------------------------------- 12.8/12.8 MB 66.7 MB/s  0:00:00
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')


# CVCL vs CLIP Classification Comparison

This notebook compares the classification performance of CVCL and CLIP models on the KonkLab dataset using prototype-based evaluation.

In [6]:
import os
import sys
import random
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image

# ─── Path setup ───
REPO_ROOT = os.path.abspath(os.path.join(os.getcwd(), os.pardir, os.pardir, os.pardir))

# Add discover-hidden-visual-concepts to path
DISCOVER_ROOT = os.path.join(REPO_ROOT, 'discover-hidden-visual-concepts')
sys.path.insert(0, DISCOVER_ROOT)
sys.path.insert(0, REPO_ROOT)

# Import from discover-hidden-visual-concepts repo
sys.path.append(os.path.join(DISCOVER_ROOT, 'src'))
from utils.model_loader import load_model
from models.feature_extractor import FeatureExtractor
from models.multimodal.multimodal_lit import MultiModalLitModel

# ─── hard-coded paths ───
CSV_PATH = os.path.join(REPO_ROOT, 'data', 'KonkLab', 'testdata.csv')
IMG_DIR = os.path.join(REPO_ROOT, 'data', 'KonkLab', '17-objects')
MASTER_CSV = os.path.join(REPO_ROOT, 'PatrickProject', 'Chart_Generation', 'all_prototype_results.csv')

In [7]:
# Shared Dataset and Helper Functions
class ClassImageDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform):
        self.df = pd.read_csv(csv_path)
        assert 'Filename' in self.df and 'Class' in self.df, \
            "CSV needs Filename and Class columns"
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        cls, fn = row['Class'], row['Filename']
        path = os.path.join(self.img_dir, cls, fn)
        img = Image.open(path).convert('RGB')
        return self.transform(img), cls, idx

def collate_fn(batch):
    imgs = torch.stack([b[0] for b in batch])
    classes = [b[1] for b in batch]
    idxs = [b[2] for b in batch]
    return imgs, classes, idxs

def run_classification_test(model_name, seed=0, device='cuda' if torch.cuda.is_available() else 'cpu',
                          batch_size=64, trials_per_class=10, max_images=None):
    """
    Run 4-way classification test using prototype evaluation.
    Returns accuracy per class and overall accuracy.
    """
    random.seed(seed)
    torch.manual_seed(seed)

    # 1) load model & transform
    model, transform = load_model(model_name, seed=seed, device=device)
    extractor = FeatureExtractor(model_name, model, device)

    # 2) optionally subsample CSV
    df = pd.read_csv(CSV_PATH)
    if max_images and len(df) > max_images:
        df = df.sample(n=max_images, random_state=seed).reset_index(drop=True)

    # 3) load data + extract embeddings
    ds = ClassImageDataset(CSV_PATH, IMG_DIR, transform)
    # Using single-process data loading to avoid worker issues
    dl = DataLoader(ds, batch_size=batch_size, shuffle=False,
                   num_workers=0, collate_fn=collate_fn)
    
    all_embs, all_classes, all_idxs = [], [], []
    with torch.no_grad():
        for imgs, classes, idxs in dl:
            feats = extractor.get_img_feature(imgs.to(device))
            feats = extractor.norm_features(feats).cpu().float()
            all_embs.append(feats)
            all_classes.extend(classes)
            all_idxs.extend(idxs)
    all_embs = torch.cat(all_embs, dim=0)

    # 4) build maps for prototype eval
    idx2class = {i:c for i,c in zip(all_idxs, all_classes)}
    idx2row = {i:r for r,i in enumerate(all_idxs)}
    class2idxs = {}
    for i,c in idx2class.items():
        class2idxs.setdefault(c, []).append(i)

    # 5) run 4-way trials
    class_results = {}
    total_correct = 0
    total_trials = 0
    
    for cls, idxs in class2idxs.items():
        if len(idxs) < 2:
            continue
        correct = 0
        for _ in range(trials_per_class):
            # query
            q = random.choice(idxs)
            # prototype over other same-class images
            proto_idxs = [i for i in idxs if i != q]
            proto = all_embs[[idx2row[i] for i in proto_idxs]].mean(0)
            proto = proto / proto.norm()
            # distractors
            others = [i for i in all_idxs if idx2class[i] != cls]
            distractors = random.sample(others, 3)
            cands = [q] + distractors
            sims = (all_embs[[idx2row[i] for i in cands]] @ proto)
            guess = cands[sims.argmax().item()]
            if guess == q:
                correct += 1
            total_correct += int(guess == q)
            total_trials += 1

        acc = correct / trials_per_class
        class_results[cls] = {'correct': correct, 'trials': trials_per_class, 'accuracy': acc}

    overall_acc = total_correct / total_trials if total_trials else 0.0
    
    # 6) save results
    summary_df = pd.DataFrame([{
        'Model': model_name,
        'Test': 'Class-Prototype',
        'Correct': total_correct,
        'Trials': total_trials,
        'Accuracy': overall_acc
    }])
    
    os.makedirs(os.path.dirname(MASTER_CSV), exist_ok=True)
    if os.path.exists(MASTER_CSV):
        summary_df.to_csv(MASTER_CSV, mode='a', header=False, index=False, float_format='%.4f')
    else:
        summary_df.to_csv(MASTER_CSV, index=False, float_format='%.4f')

    return class_results, overall_acc

## CVCL Classification Test

In [4]:
# Run CVCL classification
cvcl_results, cvcl_overall = run_classification_test('cvcl-resnext')

print("\nCVCL Results per Class:")
for cls, res in cvcl_results.items():
    print(f"{cls:20s}: {res['correct']}/{res['trials']} ({res['accuracy']:.1%})")
print(f"\nCVCL Overall Accuracy: {cvcl_overall:.1%}")

Loading checkpoint from C:\Users\jbats\.cache\huggingface\hub\models--wkvong--cvcl_s_dino_resnext50_embedding\snapshots\f50eaa0c50a6076a5190b1dd52aeeb6c3e747045\cvcl_s_dino_resnext50_embedding.ckpt


Lightning automatically upgraded your loaded checkpoint from v1.5.8 to v2.5.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint C:\Users\jbats\.cache\huggingface\hub\models--wkvong--cvcl_s_dino_resnext50_embedding\snapshots\f50eaa0c50a6076a5190b1dd52aeeb6c3e747045\cvcl_s_dino_resnext50_embedding.ckpt`



CVCL Results per Class:
butterfly           : 10/10 (100.0%)
muffins             : 10/10 (100.0%)
pitcher             : 6/10 (60.0%)
tennisracquet       : 10/10 (100.0%)
phone               : 8/10 (80.0%)
headband            : 7/10 (70.0%)
bagel               : 10/10 (100.0%)
grill               : 9/10 (90.0%)
basket              : 10/10 (100.0%)
bell                : 9/10 (90.0%)
sodacan             : 10/10 (100.0%)
microwave           : 10/10 (100.0%)
trophy              : 4/10 (40.0%)
fan                 : 9/10 (90.0%)
lei                 : 8/10 (80.0%)
stapler             : 8/10 (80.0%)
exercise_equipment  : 7/10 (70.0%)
handgun             : 10/10 (100.0%)
seashell            : 10/10 (100.0%)
powerstrip          : 9/10 (90.0%)
lipstick            : 10/10 (100.0%)
lantern             : 10/10 (100.0%)
doorknob            : 10/10 (100.0%)
abacus              : 10/10 (100.0%)
jack-o-lantern      : 10/10 (100.0%)
camcorder           : 10/10 (100.0%)
bird                : 7/10 (70.0%)


## CLIP Classification Test

In [5]:
# Run CLIP classification
clip_results, clip_overall = run_classification_test('clip-resnext')

print("\nCLIP Results per Class:")
for cls, res in clip_results.items():
    print(f"{cls:20s}: {res['correct']}/{res['trials']} ({res['accuracy']:.1%})")
print(f"\nCLIP Overall Accuracy: {clip_overall:.1%}")

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)



CLIP Results per Class:
butterfly           : 10/10 (100.0%)
muffins             : 10/10 (100.0%)
pitcher             : 10/10 (100.0%)
tennisracquet       : 10/10 (100.0%)
phone               : 10/10 (100.0%)
headband            : 10/10 (100.0%)
bagel               : 10/10 (100.0%)
grill               : 10/10 (100.0%)
basket              : 10/10 (100.0%)
bell                : 7/10 (70.0%)
sodacan             : 10/10 (100.0%)
microwave           : 10/10 (100.0%)
trophy              : 10/10 (100.0%)
fan                 : 10/10 (100.0%)
lei                 : 10/10 (100.0%)
stapler             : 9/10 (90.0%)
exercise_equipment  : 10/10 (100.0%)
handgun             : 10/10 (100.0%)
seashell            : 10/10 (100.0%)
powerstrip          : 10/10 (100.0%)
lipstick            : 10/10 (100.0%)
lantern             : 10/10 (100.0%)
doorknob            : 10/10 (100.0%)
abacus              : 10/10 (100.0%)
jack-o-lantern      : 10/10 (100.0%)
camcorder           : 10/10 (100.0%)
bird             