In [2]:
!python -m spacy download en_core_web_sm

Collecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
     ---------------------------------------- 0.0/12.8 MB ? eta -:--:--
     ---------------------------------- ---- 11.3/12.8 MB 78.6 MB/s eta 0:00:01
     ---------------------------------------- 12.8/12.8 MB 72.9 MB/s  0:00:00
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')


# Different Class Different Colors (DCDC) Comparison

This notebook compares CVCL and CLIP models on color prototype evaluation across different classes.

In [None]:
import os
import sys
import random
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from collections import defaultdict

# ─── Path setup ───
REPO_ROOT = os.path.abspath(os.path.join(os.getcwd(), os.pardir, os.pardir, os.pardir, os.pardir))

# Add discover-hidden-visual-concepts to path
DISCOVER_ROOT = os.path.join(REPO_ROOT, 'discover-hidden-visual-concepts')
sys.path.insert(0, DISCOVER_ROOT)
sys.path.insert(0, REPO_ROOT)

# Import from discover-hidden-visual-concepts repo
sys.path.append(os.path.join(DISCOVER_ROOT, 'src'))
from utils.model_loader import load_model
from models.feature_extractor import FeatureExtractor
from models.multimodal.multimodal_lit import MultiModalLitModel

# ─── hard-coded paths ───
CSV_PATH = os.path.join(REPO_ROOT, 'data', 'KonkLab', 'testdata.csv')
IMG_DIR = os.path.join(REPO_ROOT, 'data', 'KonkLab', '17-objects')
MASTER_CSV = os.path.join(REPO_ROOT, 'PatrickProject', 'Chart_Generation', 'all_prototype_results.csv')

In [None]:
# Shared Dataset and Helper Functions
class ColorImageDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform):
        self.df = pd.read_csv(csv_path)
        assert 'Filename' in self.df and 'Class' in self.df and 'Color' in self.df, \
            "CSV needs Filename, Class, and Color columns"
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        cls, color, fn = row['Class'], row['Color'], row['Filename']
        path = os.path.join(self.img_dir, cls, fn)
        img = Image.open(path).convert('RGB')
        return self.transform(img), cls, color, idx

def collate_fn(batch):
    imgs = torch.stack([b[0] for b in batch])
    classes = [b[1] for b in batch]
    colors = [b[2] for b in batch]
    idxs = [b[3] for b in batch]
    return imgs, classes, colors, idxs

def run_dcdc_test(model_name, seed=0, device='cuda' if torch.cuda.is_available() else 'cpu',
                  batch_size=64, trials_per_class=10, max_images=None):
    """
    Run Different Class Different Color (DCDC) test using color-class prototype evaluation.
    
    For each (class, color) combination:
    - Create a prototype from images with that specific class AND color
    - Test if model can identify a query image from that group vs 3 random distractors
    - Distractors are ANY images not in the same (class, color) group
    
    This tests the model's ability to encode both semantic (class) and visual (color) information.
    """
    random.seed(seed)
    torch.manual_seed(seed)

    # 1) load model & transform
    model, transform = load_model(model_name, seed=seed, device=device)
    extractor = FeatureExtractor(model_name, model, device)

    # 2) optionally subsample CSV
    df = pd.read_csv(CSV_PATH)
    if max_images and len(df) > max_images:
        df = df.sample(n=max_images, random_state=seed).reset_index(drop=True)

    # 3) load data + extract embeddings
    ds = ColorImageDataset(CSV_PATH, IMG_DIR, transform)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=False,
                   num_workers=0, collate_fn=collate_fn)
    
    all_embs, all_classes, all_colors, all_idxs = [], [], [], []
    with torch.no_grad():
        for imgs, classes, colors, idxs in dl:
            feats = extractor.get_img_feature(imgs.to(device))
            feats = extractor.norm_features(feats).cpu().float()
            all_embs.append(feats)
            all_classes.extend(classes)
            all_colors.extend(colors)
            all_idxs.extend(idxs)
    all_embs = torch.cat(all_embs, dim=0)

    # 4) organize by (class, color)
    class_color_idxs = defaultdict(lambda: defaultdict(list))
    for idx, cls, col in zip(all_idxs, all_classes, all_colors):
        class_color_idxs[cls][col].append(idx)

    # 5) run baseline trials
    total_correct = 0
    total_trials = 0
    class_color_results = {}
    
    print("[ℹ️] Running 4-way DCDC trials...")
    for cls, color_groups in class_color_idxs.items():
        for color, idx_list in color_groups.items():
            # Pool of distractors = all images NOT in this (class, color) group
            pool = [i for i in all_idxs if i not in idx_list]
            
            if len(idx_list) < 1 or len(pool) < 3:
                continue
                
            correct = 0
            for _ in range(trials_per_class):
                # Select query from this (class, color) group
                q = random.choice(idx_list)
                
                # Create prototype from other images in same (class, color) group
                same_color = [i for i in idx_list if i != q]
                # Use all_idxs.index() to get position, matching the script exactly
                proto = all_embs[[all_idxs.index(i) for i in same_color]].mean(0)
                proto = proto / proto.norm()
                
                # Sample 3 distractors from pool
                distractors = random.sample(pool, 3)
                
                # 4-way classification
                cands = [q] + distractors
                sims = (all_embs[[all_idxs.index(i) for i in cands]] @ proto)
                guess = cands[sims.argmax().item()]
                
                correct += int(guess == q)
                total_correct += int(guess == q)
                total_trials += 1
            
            acc = correct / trials_per_class
            key = f"{cls}-{color}"
            class_color_results[key] = {
                'correct': correct,
                'trials': trials_per_class,
                'accuracy': acc
            }
            print(f"{cls:20s} / {color:12s}: {correct}/{trials_per_class} ({acc:.1%})")

    overall_acc = total_correct / total_trials if total_trials else 0.0
    print(f"\nOverall accuracy: {total_correct}/{total_trials} ({overall_acc:.1%})")
    
    # 6) save results
    summary_df = pd.DataFrame([{
        'Model': model_name,
        'Test': 'Different-Class-Different-Colors',
        'Correct': total_correct,
        'Trials': total_trials,
        'Accuracy': overall_acc
    }])
    
    os.makedirs(os.path.dirname(MASTER_CSV), exist_ok=True)
    if os.path.exists(MASTER_CSV):
        summary_df.to_csv(MASTER_CSV, mode='a', header=False, index=False, float_format='%.4f')
    else:
        summary_df.to_csv(MASTER_CSV, index=False, float_format='%.4f')

    return class_color_results, overall_acc

## CVCL Color Classification Test

In [14]:
# Run CVCL color evaluation
cvcl_results, cvcl_overall = run_dcdc_test('cvcl-resnext')

print("\nCVCL Results by Class-Color:")
for key, res in cvcl_results.items():
    print(f"{key:24s}: {res['correct']}/{res['trials']} ({res['accuracy']:.1%})")
print(f"\nCVCL Overall Accuracy: {cvcl_overall:.1%}")

Lightning automatically upgraded your loaded checkpoint from v1.5.8 to v2.5.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint C:\Users\jbats\.cache\huggingface\hub\models--wkvong--cvcl_s_dino_resnext50_embedding\snapshots\f50eaa0c50a6076a5190b1dd52aeeb6c3e747045\cvcl_s_dino_resnext50_embedding.ckpt`


Loading checkpoint from C:\Users\jbats\.cache\huggingface\hub\models--wkvong--cvcl_s_dino_resnext50_embedding\snapshots\f50eaa0c50a6076a5190b1dd52aeeb6c3e747045\cvcl_s_dino_resnext50_embedding.ckpt

CVCL Results by Class-Color:
butterfly-Multicolored  : 9/10 (90.0%)
butterfly-Yellow        : 10/10 (100.0%)
butterfly-Red           : 10/10 (100.0%)
muffins-Multicolored    : 9/10 (90.0%)
muffins-Orange          : 9/10 (90.0%)
muffins-Yellow          : 9/10 (90.0%)
pitcher-Yellow          : 8/10 (80.0%)
pitcher-Multicolored    : 5/10 (50.0%)
pitcher-Green           : 8/10 (80.0%)
pitcher-Grey            : 6/10 (60.0%)
pitcher-Orange          : 10/10 (100.0%)
pitcher-Blue            : 5/10 (50.0%)
tennisracquet-Multicolored: 10/10 (100.0%)
tennisracquet-Grey      : 10/10 (100.0%)
tennisracquet-Pink      : 10/10 (100.0%)
tennisracquet-Green     : 10/10 (100.0%)
phone-Grey              : 6/10 (60.0%)
phone-Blue              : 9/10 (90.0%)
phone-Yellow            : 8/10 (80.0%)
phone-Green    

## CLIP Color Test

In [15]:
# Run CLIP color evaluation
clip_results, clip_overall = run_dcdc_test('clip-resnext')

print("\nCLIP Results by Class-Color:")
for key, res in clip_results.items():
    print(f"{key:24s}: {res['correct']}/{res['trials']} ({res['accuracy']:.1%})")
print(f"\nCLIP Overall Accuracy: {clip_overall:.1%}")

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)



CLIP Results by Class-Color:
butterfly-Multicolored  : 10/10 (100.0%)
butterfly-Yellow        : 10/10 (100.0%)
butterfly-Red           : 10/10 (100.0%)
muffins-Multicolored    : 10/10 (100.0%)
muffins-Orange          : 10/10 (100.0%)
muffins-Yellow          : 8/10 (80.0%)
pitcher-Yellow          : 10/10 (100.0%)
pitcher-Multicolored    : 10/10 (100.0%)
pitcher-Green           : 8/10 (80.0%)
pitcher-Grey            : 9/10 (90.0%)
pitcher-Orange          : 10/10 (100.0%)
pitcher-Blue            : 10/10 (100.0%)
tennisracquet-Multicolored: 10/10 (100.0%)
tennisracquet-Grey      : 10/10 (100.0%)
tennisracquet-Pink      : 10/10 (100.0%)
tennisracquet-Green     : 10/10 (100.0%)
phone-Grey              : 10/10 (100.0%)
phone-Blue              : 8/10 (80.0%)
phone-Yellow            : 10/10 (100.0%)
phone-Green             : 10/10 (100.0%)
phone-Red               : 10/10 (100.0%)
phone-Multicolored      : 10/10 (100.0%)
headband-Multicolored   : 10/10 (100.0%)
headband-Purple         : 10/10 (