# Same Class Different Colors (SCDC) Comparison

This notebook compares CVCL and CLIP models on color prototype evaluation within the same class.
For example, testing if a model can distinguish between red and green apples using color prototypes.

In [1]:
!python -m spacy download en_core_web_sm

Collecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
     ---------------------------------------- 0.0/12.8 MB ? eta -:--:--
     ------------------------------- ------- 10.5/12.8 MB 72.5 MB/s eta 0:00:01
     ---------------------------------------- 12.8/12.8 MB 61.8 MB/s  0:00:00
[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')


In [None]:
import os
import sys
import random
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from collections import defaultdict

# ─── Path setup ───
REPO_ROOT = os.path.abspath(os.path.join(os.getcwd(), os.pardir, os.pardir, os.pardir, os.pardir))

# Add discover-hidden-visual-concepts to path
DISCOVER_ROOT = os.path.join(REPO_ROOT, 'discover-hidden-visual-concepts')
sys.path.insert(0, DISCOVER_ROOT)
sys.path.insert(0, REPO_ROOT)

# Import from discover-hidden-visual-concepts repo
sys.path.append(os.path.join(DISCOVER_ROOT, 'src'))
from utils.model_loader import load_model
from models.feature_extractor import FeatureExtractor
from models.multimodal.multimodal_lit import MultiModalLitModel

# ─── hard-coded paths ───
CSV_PATH = os.path.join(REPO_ROOT, 'data', 'KonkLab', 'testdata.csv')
IMG_DIR = os.path.join(REPO_ROOT, 'data', 'KonkLab', '17-objects')
MASTER_CSV = os.path.join(REPO_ROOT, 'PatrickProject', 'Chart_Generation', 'all_prototype_results.csv')

In [None]:
# Shared Dataset and Helper Functions
class ColorImageDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform):
        self.df = pd.read_csv(csv_path)
        assert all(col in self.df for col in ['Filename','Class','Color']), \
            "CSV must have Filename, Class, and Color columns"
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        cls, fn = row['Class'], row['Filename']
        path = os.path.join(self.img_dir, cls, fn)
        img = Image.open(path).convert('RGB')
        return self.transform(img), row['Class'], row['Color'], idx

def collate_fn(batch):
    imgs = torch.stack([b[0] for b in batch])
    classes = [b[1] for b in batch]
    colors = [b[2] for b in batch]
    idxs = [b[3] for b in batch]
    return imgs, classes, colors, idxs

def run_scdc_test(model_name, seed=0, device='cuda' if torch.cuda.is_available() else 'cpu',
                  batch_size=64, trials_per_class=10, max_images=None, max_trials=4000):
    """
    Run Same Class Different Colors (SCDC) test.
    
    This test evaluates whether models can distinguish between different colors 
    of the SAME class. For example: red apple vs green apple vs yellow apple.
    
    Distractors are only from the same class but different colors, making this
    a harder test focused purely on color discrimination within object categories.
    """
    random.seed(seed)
    torch.manual_seed(seed)

    # 1) load model & transform
    model, transform = load_model(model_name, seed=seed, device=device)
    extractor = FeatureExtractor(model_name, model, device)

    # 2) load data + extract embeddings
    full_ds = ColorImageDataset(CSV_PATH, IMG_DIR, transform)
    full_loader = DataLoader(
        full_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,  # Avoid worker issues on Windows
        collate_fn=collate_fn
    )
    print(f"[ℹ️] Full dataset size: {len(full_ds)} samples")

    # 3) extract embeddings for all images
    all_embs, all_classes, all_colors, all_idxs = [], [], [], []
    with torch.no_grad():
        for imgs, classes, colors, idxs in full_loader:
            feats = extractor.get_img_feature(imgs.to(device))
            feats = extractor.norm_features(feats)
            feats = feats.float()  # ensure float dtype
            all_embs.append(feats.cpu())
            all_classes.extend(classes)
            all_colors.extend(colors)
            all_idxs.extend(idxs)
    all_embs = torch.cat(all_embs, dim=0)  # [N, D]

    # 4) organize by class -> color
    class_color_idxs = defaultdict(lambda: defaultdict(list))
    for idx, cls, col in zip(all_idxs, all_classes, all_colors):
        class_color_idxs[cls][col].append(idx)

    # 5) Calculate all valid combinations
    all_combinations = []
    for cls, color_groups in class_color_idxs.items():
        for color, idx_list in color_groups.items():
            # Pick distractors only from *other colors* of *the same class*
            other_idxs = [
                i for col2, lst2 in color_groups.items() if col2 != color
                for i in lst2
            ]
            if len(idx_list) >= 1 and len(other_idxs) >= 3:
                all_combinations.append((cls, color, idx_list, other_idxs))
    
    # Adjust trials per combination to stay under max_trials
    total_combinations = len(all_combinations)
    if total_combinations * trials_per_class > max_trials:
        trials_per_combo = max(1, max_trials // total_combinations)
        print(f"[INFO] Limiting to {trials_per_combo} trials per combination to stay under {max_trials} total trials")
    else:
        trials_per_combo = trials_per_class

    # 6) run within-class 4-way trials
    total_correct = 0
    total_trials = 0
    class_color_results = {}

    print("[ℹ️] Running 4-way color-vs-other-color trials *within each class* ...")
    for cls, color, idx_list, other_idxs in all_combinations:
        if total_trials >= max_trials:
            print(f"[INFO] Reached maximum trials limit ({max_trials})")
            break

        correct = 0
        actual_trials = min(trials_per_combo, max_trials - total_trials)
        
        for _ in range(actual_trials):
            q = random.choice(idx_list)
            same_color = [i for i in idx_list if i != q]
            proto = all_embs[[all_idxs.index(i) for i in same_color]].mean(0)
            proto = proto / proto.norm()

            distractors = random.sample(other_idxs, 3)
            candidates = [q] + distractors
            feats_cand = all_embs[[all_idxs.index(i) for i in candidates]]
            sims = feats_cand @ proto
            guess = candidates[sims.argmax().item()]

            total_correct += int(guess == q)
            total_trials += 1
            correct += int(guess == q)

        acc = correct / actual_trials if actual_trials > 0 else 0
        key = f"{cls}-{color}"
        class_color_results[key] = {
            'correct': correct,
            'trials': actual_trials,
            'accuracy': acc
        }
        print(f"{cls:20s} / {color:12s}: {correct}/{actual_trials} ({acc:.1%})")

    overall_acc = total_correct / total_trials if total_trials else 0.0
    print(f"\nOverall accuracy: {total_correct}/{total_trials} ({overall_acc:.1%})")

    # 7) save results
    summary_df = pd.DataFrame([{
        'Model': model_name,
        'Test': 'Same-Class-Different-Colors',
        'Correct': total_correct,
        'Trials': total_trials,
        'Accuracy': overall_acc
    }])
    
    os.makedirs(os.path.dirname(MASTER_CSV), exist_ok=True)
    if os.path.exists(MASTER_CSV):
        summary_df.to_csv(MASTER_CSV, mode='a', header=False, index=False, float_format='%.4f')
    else:
        summary_df.to_csv(MASTER_CSV, index=False, float_format='%.4f')

    return class_color_results, overall_acc

## CVCL Color Test

In [4]:
# Run CVCL color evaluation
cvcl_results, cvcl_overall = run_scdc_test('cvcl-resnext')

print("\nCVCL Results by Class-Color:")
for key, res in cvcl_results.items():
    print(f"{key:24s}: {res['correct']}/{res['trials']} ({res['accuracy']:.1%})")
print(f"\nCVCL Overall Accuracy: {cvcl_overall:.1%}")

Loading checkpoint from C:\Users\jbats\.cache\huggingface\hub\models--wkvong--cvcl_s_dino_resnext50_embedding\snapshots\f50eaa0c50a6076a5190b1dd52aeeb6c3e747045\cvcl_s_dino_resnext50_embedding.ckpt


Lightning automatically upgraded your loaded checkpoint from v1.5.8 to v2.5.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint C:\Users\jbats\.cache\huggingface\hub\models--wkvong--cvcl_s_dino_resnext50_embedding\snapshots\f50eaa0c50a6076a5190b1dd52aeeb6c3e747045\cvcl_s_dino_resnext50_embedding.ckpt`



CVCL Results by Class-Color:
butterfly-Multicolored  : 7/10 (70.0%)
butterfly-Yellow        : 0/10 (0.0%)
butterfly-Red           : 10/10 (100.0%)
muffins-Multicolored    : 0/10 (0.0%)
muffins-Orange          : 1/10 (10.0%)
muffins-Yellow          : 5/10 (50.0%)
pitcher-Yellow          : 0/10 (0.0%)
pitcher-Multicolored    : 0/10 (0.0%)
pitcher-Green           : 2/10 (20.0%)
pitcher-Grey            : 3/10 (30.0%)
pitcher-Orange          : 8/10 (80.0%)
pitcher-Blue            : 0/10 (0.0%)
tennisracquet-Multicolored: 0/10 (0.0%)
tennisracquet-Grey      : 6/10 (60.0%)
tennisracquet-Pink      : 10/10 (100.0%)
tennisracquet-Green     : 10/10 (100.0%)
phone-Grey              : 0/10 (0.0%)
phone-Blue              : 0/10 (0.0%)
phone-Yellow            : 2/10 (20.0%)
phone-Green             : 10/10 (100.0%)
phone-Red               : 10/10 (100.0%)
phone-Multicolored      : 0/10 (0.0%)
headband-Multicolored   : 6/10 (60.0%)
headband-Purple         : 10/10 (100.0%)
headband-Grey           : 10/

## CLIP Color Test

In [5]:
# Run CLIP color evaluation
clip_results, clip_overall = run_scdc_test('clip-resnext')

print("\nCLIP Results by Class-Color:")
for key, res in clip_results.items():
    print(f"{key:24s}: {res['correct']}/{res['trials']} ({res['accuracy']:.1%})")
print(f"\nCLIP Overall Accuracy: {clip_overall:.1%}")

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)



CLIP Results by Class-Color:
butterfly-Multicolored  : 10/10 (100.0%)
butterfly-Yellow        : 5/10 (50.0%)
butterfly-Red           : 10/10 (100.0%)
muffins-Multicolored    : 1/10 (10.0%)
muffins-Orange          : 7/10 (70.0%)
muffins-Yellow          : 0/10 (0.0%)
pitcher-Yellow          : 0/10 (0.0%)
pitcher-Multicolored    : 0/10 (0.0%)
pitcher-Green           : 4/10 (40.0%)
pitcher-Grey            : 2/10 (20.0%)
pitcher-Orange          : 5/10 (50.0%)
pitcher-Blue            : 0/10 (0.0%)
tennisracquet-Multicolored: 2/10 (20.0%)
tennisracquet-Grey      : 10/10 (100.0%)
tennisracquet-Pink      : 10/10 (100.0%)
tennisracquet-Green     : 10/10 (100.0%)
phone-Grey              : 1/10 (10.0%)
phone-Blue              : 4/10 (40.0%)
phone-Yellow            : 2/10 (20.0%)
phone-Green             : 10/10 (100.0%)
phone-Red               : 10/10 (100.0%)
phone-Multicolored      : 1/10 (10.0%)
headband-Multicolored   : 1/10 (10.0%)
headband-Purple         : 10/10 (100.0%)
headband-Grey       