# SCDC (Same Class Different Color) Per-Class Analysis with Error Bars

This notebook tests color discrimination within the same object class.
For example: Can models distinguish red apple vs blue apple vs green apple?

Key difference from Class test:
- Class test: Different objects, same properties (apple vs ball, both red-small-smooth)
- SCDC test: Same object, different colors (red apple vs blue apple, both small-smooth)

In [None]:
import os
import sys
import random
import torch
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Path setup
REPO_ROOT = os.path.abspath(os.path.join(os.getcwd(), os.pardir, os.pardir, os.pardir))
DISCOVER_ROOT = os.path.join(REPO_ROOT, 'discover-hidden-visual-concepts')
sys.path.insert(0, DISCOVER_ROOT)
sys.path.insert(0, REPO_ROOT)

# Import from discover-hidden-visual-concepts repo
sys.path.append(os.path.join(DISCOVER_ROOT, 'src'))
from utils.model_loader import load_model
from models.feature_extractor import FeatureExtractor

# SyntheticKonkle paths - Using 224x224 resized images for faster processing
DATA_DIR = os.path.join(REPO_ROOT, 'data', 'SyntheticKonkle_224')
RESULTS_DIR = os.path.join(REPO_ROOT, 'PatrickProject', 'Chart_Generation')
os.makedirs(RESULTS_DIR, exist_ok=True)

In [None]:
# Dataset setup
def build_synthetic_dataset():
    """Combine all labels.csv files from class_color folders."""
    all_data = []
    # Note: In SyntheticKonkle_224, folders are nested under SyntheticKonkle/
    base_dir = os.path.join(DATA_DIR, 'SyntheticKonkle')
    
    class_folders = [d for d in os.listdir(base_dir) 
                    if os.path.isdir(os.path.join(base_dir, d)) 
                    and d.endswith('_color')]
    
    for folder in class_folders:
        labels_path = os.path.join(base_dir, folder, 'labels.csv')
        if os.path.exists(labels_path):
            df = pd.read_csv(labels_path)
            df['folder'] = folder
            all_data.append(df)
    
    combined_df = pd.concat(all_data, ignore_index=True)
    combined_df = combined_df.dropna(subset=['class'])
    print(f"Loaded {len(combined_df)} images from {len(class_folders)} classes")
    unique_classes = combined_df['class'].unique()
    unique_colors = combined_df['color'].unique()
    print(f"Classes found: {len(unique_classes)}")
    print(f"Colors found: {len(unique_colors)}: {unique_colors}")
    return combined_df

class SyntheticImageDataset(Dataset):
    def __init__(self, df, data_dir, transform):
        self.df = df
        # For SyntheticKonkle_224, images are in nested structure
        self.data_dir = os.path.join(data_dir, 'SyntheticKonkle')
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.data_dir, row['folder'], row['filename'])
        try:
            img = Image.open(img_path).convert('RGB')
            return self.transform(img), row['class'], row['color'], row['size'], row['texture'], idx
        except:
            img = Image.new('RGB', (224, 224), color='black')
            return self.transform(img), row['class'], row['color'], row['size'], row['texture'], idx

def collate_fn(batch):
    imgs = torch.stack([b[0] for b in batch])
    classes = [b[1] for b in batch]
    colors = [b[2] for b in batch]
    sizes = [b[3] for b in batch]
    textures = [b[4] for b in batch]
    idxs = [b[5] for b in batch]
    return imgs, classes, colors, sizes, textures, idxs

In [None]:
def run_scdc_test_per_class(model_name, seed=0, device='cuda' if torch.cuda.is_available() else 'cpu', 
                            batch_size=64, trials_per_class=500):
    """Run SCDC (Same Class Different Color) test and return per-class results.
    
    Tests if model can distinguish different colors of the same object.
    Example: red apple vs blue apple vs green apple (all small-smooth).
    """
    
    random.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    # Load model & transform
    model, transform = load_model(model_name, seed=seed, device=device)
    extractor = FeatureExtractor(model_name, model, device)
    
    # Build dataset and extract embeddings
    df = build_synthetic_dataset()
    ds = SyntheticImageDataset(df, DATA_DIR, transform)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=collate_fn)

    all_embs, all_classes, all_colors, all_sizes, all_textures, all_idxs = [], [], [], [], [], []
    with torch.no_grad():
        for imgs, classes, colors, sizes, textures, idxs in loader:
            feats = extractor.get_img_feature(imgs.to(device))
            feats = extractor.norm_features(feats).cpu().float()
            all_embs.append(feats)
            all_classes.extend(classes)
            all_colors.extend(colors)
            all_sizes.extend(sizes)
            all_textures.extend(textures)
            all_idxs.extend(idxs)
    all_embs = torch.cat(all_embs, dim=0)

    # Group by class, size, texture (keeping these constant) and vary color
    # Structure: class_size_texture_groups[cls][(size, texture)] = {color: [idx_list]}
    class_size_texture_groups = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
    for idx, cls, col, size, texture in zip(all_idxs, all_classes, all_colors, all_sizes, all_textures):
        class_size_texture_groups[cls][(size, texture)][col].append(idx)

    # Track per-class performance for color discrimination
    class_correct = defaultdict(int)
    class_total = defaultdict(int)
    
    # Get unique classes and colors
    unique_classes = list(class_size_texture_groups.keys())
    all_colors_set = set(all_colors)
    
    # Run trials for each class
    for target_class in tqdm(unique_classes, desc=f"Testing {model_name} SCDC"):
        trials_done = 0
        
        # For each size-texture combination in this class
        for (size, texture), color_groups in class_size_texture_groups[target_class].items():
            if trials_done >= trials_per_class:
                break
            
            # Need at least 4 different colors for this class-size-texture combo
            available_colors = list(color_groups.keys())
            if len(available_colors) < 4:
                continue
            
            # Run multiple trials for this combination
            n_trials = min(50, trials_per_class - trials_done)  # More trials per combo
            
            for _ in range(n_trials):
                # Pick target color and 3 distractor colors
                selected_colors = random.sample(available_colors, 4)
                target_color = selected_colors[0]
                distractor_colors = selected_colors[1:4]
                
                # Pick query image from target color
                q = random.choice(color_groups[target_color])
                
                # Build prototype from other images of same color (if available)
                same_color_group = [i for i in color_groups[target_color] if i != q]
                if same_color_group:
                    proto = all_embs[[all_idxs.index(i) for i in same_color_group]].mean(0)
                else:
                    proto = all_embs[all_idxs.index(q)]
                proto = proto / proto.norm()

                # Pick one distractor from each distractor color
                distractors = []
                for dist_color in distractor_colors:
                    if color_groups[dist_color]:  # Make sure color group exists
                        distractors.append(random.choice(color_groups[dist_color]))
                
                if len(distractors) < 3:
                    continue  # Skip if we couldn't get enough distractors
                
                candidates = [q] + distractors
                
                # Compute similarities
                feats_cand = all_embs[[all_idxs.index(i) for i in candidates]]
                sims = feats_cand @ proto
                guess = candidates[sims.argmax().item()]

                # Update counts
                class_correct[target_class] += int(guess == q)
                class_total[target_class] += 1
                trials_done += 1
    
    # Calculate per-class accuracy for color discrimination
    class_accuracies = {}
    for cls in unique_classes:
        if class_total[cls] > 0:
            class_accuracies[cls] = class_correct[cls] / class_total[cls]
        else:
            class_accuracies[cls] = 0.0
    
    return class_accuracies

In [None]:
# Run multiple seeds for both models - PUBLICATION SETTINGS
n_seeds = 3  # Limited seeds due to CVCL rate limiting
trials_per_class = 500  # Increased trials for better statistical power
models_to_test = ['cvcl-resnext', 'clip-res']

# First, check dataset
test_df = build_synthetic_dataset()
n_classes = len(test_df['class'].unique())
print(f"Found {n_classes} unique classes in the dataset")

print(f"\nStarting SCDC (Same Class Different Color) evaluation:")
print(f"Configuration: {n_seeds} seeds × {trials_per_class} trials/class × {n_classes} classes")
print(f"Total trials per class: {n_seeds * trials_per_class}")
print(f"Expected margin of error: ~3.5% at 95% confidence level\n")

all_results = {model: defaultdict(list) for model in models_to_test}

# Run evaluation
for model_name in models_to_test:
    print(f"\n{'='*50}")
    print(f"Testing {model_name} with {n_seeds} seeds")
    print('='*50)
    
    for seed in range(n_seeds):
        print(f"\nSeed {seed+1}/{n_seeds} for {model_name}")
        
        try:
            class_acc = run_scdc_test_per_class(model_name, seed=seed, trials_per_class=trials_per_class)
            
            # Store results
            for cls, acc in class_acc.items():
                all_results[model_name][cls].append(acc)
            
            # Print progress
            if len(class_acc) > 0:
                mean_acc = np.mean(list(class_acc.values()))
                print(f"  Mean color discrimination accuracy: {mean_acc:.3f}")
                print(f"  Classes tested: {len(class_acc)}")
        except Exception as e:
            print(f"  Error: {e}")
            if "404" in str(e) or "rate" in str(e).lower():
                print(f"  Rate limit hit - waiting 60 seconds before retry...")
                import time
                time.sleep(60)
                try:
                    class_acc = run_scdc_test_per_class(model_name, seed=seed, trials_per_class=trials_per_class)
                    for cls, acc in class_acc.items():
                        all_results[model_name][cls].append(acc)
                    print(f"  Retry successful!")
                except:
                    print(f"  Retry failed - skipping this seed")
                    continue
        
        # Add delay between seeds for CVCL
        if 'cvcl' in model_name and seed < n_seeds - 1:
            import time
            print("  Waiting 30 seconds before next seed to avoid rate limiting...")
            time.sleep(30)

# Calculate statistics
stats_results = {}
for model_name in models_to_test:
    stats_results[model_name] = {}
    for cls, accs in all_results[model_name].items():
        if len(accs) > 0:
            n_samples = len(accs)
            stats_results[model_name][cls] = {
                'mean': np.mean(accs),
                'std': np.std(accs, ddof=1) if n_samples > 1 else 0,
                'se': np.std(accs, ddof=1) / np.sqrt(n_samples) if n_samples > 1 else 0,
                'ci95': 1.96 * np.std(accs, ddof=1) / np.sqrt(n_samples) if n_samples > 1 else 0,
                'n_samples': n_samples,
                'total_trials': n_samples * trials_per_class,
                'raw': accs
            }

print("\n" + "="*50)
print("SCDC EVALUATION COMPLETE")
print("="*50)

# Report statistics
for model_name in models_to_test:
    if len(stats_results[model_name]) > 0:
        all_means = [stats['mean'] for stats in stats_results[model_name].values()]
        overall_mean = np.mean(all_means)
        print(f"{model_name}:")
        print(f"  - {len(stats_results[model_name])} classes tested")
        print(f"  - Overall color discrimination: {overall_mean:.3f}")
        print(f"  - Expected: CVCL should struggle, CLIP should excel at color")

# Save results
detailed_df = []
for model_name in models_to_test:
    for cls, stats in stats_results[model_name].items():
        for seed_idx, acc in enumerate(stats['raw']):
            detailed_df.append({
                'model': model_name,
                'class': cls,
                'seed': seed_idx,
                'accuracy': acc,
                'n_trials': trials_per_class,
                'test_type': 'SCDC'
            })

if len(detailed_df) > 0:
    detailed_df = pd.DataFrame(detailed_df)
    detailed_df.to_csv(os.path.join(RESULTS_DIR, 'scdc_perclass_results.csv'), index=False)
    print(f"\nSaved detailed results to {os.path.join(RESULTS_DIR, 'scdc_perclass_results.csv')}")
    
    # Save summary
    summary_stats = []
    for model_name in models_to_test:
        for cls, stats in stats_results[model_name].items():
            summary_stats.append({
                'model': model_name,
                'class': cls,
                'mean_accuracy': stats['mean'],
                'std': stats['std'],
                'se': stats['se'],
                'ci95': stats['ci95'],
                'n_seeds': stats['n_samples'],
                'total_trials': stats['total_trials'],
                'test_type': 'SCDC'
            })
    summary_df = pd.DataFrame(summary_stats)
    summary_df.to_csv(os.path.join(RESULTS_DIR, 'scdc_perclass_summary.csv'), index=False)
    print(f"Saved summary to {os.path.join(RESULTS_DIR, 'scdc_perclass_summary.csv')}")

In [1]:
# Create publication-quality visualization
fig = plt.figure(figsize=(14, 11))

# Create subplots with space for legend in between
ax1 = plt.subplot2grid((20, 1), (0, 0), rowspan=8)
ax2 = plt.subplot2grid((20, 1), (12, 0), rowspan=8)

# Prepare data
classes = sorted(list(stats_results[models_to_test[0]].keys()))
mid_point = len(classes) // 2
classes_first_half = classes[:mid_point]
classes_second_half = classes[mid_point:]

# Colors optimized for color discrimination visualization
colors = {
    'cvcl-resnext': '#8b4513',  # Brown - CVCL expected to struggle with color
    'clip-res': '#4169e1'  # Royal blue - CLIP expected to excel at color
}
markers = {
    'cvcl-resnext': 'o',
    'clip-res': 's'
}
avg_line_styles = {
    'cvcl-resnext': '--',
    'clip-res': '-.'
}

legend_elements = []

def plot_on_axis(ax, class_subset, is_first=False):
    x_pos = np.arange(len(class_subset))
    
    for model_name in models_to_test:
        means = [stats_results[model_name][cls]['mean'] * 100 for cls in class_subset]
        errors = [stats_results[model_name][cls]['ci95'] * 100 for cls in class_subset]
        
        ax.errorbar(x_pos, means, yerr=errors, 
                    label=model_name.upper().replace('-', ' '),
                    color=colors[model_name],
                    marker=markers[model_name],
                    markersize=7,
                    linewidth=0,
                    capsize=4,
                    capthick=1.5,
                    alpha=0.9,
                    markeredgecolor='black',
                    markeredgewidth=0.5)
    
    # Chance line
    ax.axhline(y=25, color='#ffa500', linestyle=':', alpha=0.8, linewidth=1.5)
    
    # Calculate overall averages
    all_classes_means = {}
    for model_name in models_to_test:
        all_means = [stats_results[model_name][cls]['mean'] * 100 for cls in classes]
        all_classes_means[model_name] = np.mean(all_means)
    
    # Add average lines
    for model_name in models_to_test:
        avg_performance = all_classes_means[model_name]
        ax.axhline(y=avg_performance, 
                  color=colors[model_name], 
                  linestyle=avg_line_styles[model_name], 
                  alpha=0.7, 
                  linewidth=2)
        
        if is_first:
            ax.text(len(class_subset) + 0.8, avg_performance, 
                   f'{avg_performance:.1f}%', 
                   fontsize=9, 
                   color=colors[model_name], 
                   va='center',
                   fontweight='bold')
    
    # Formatting
    ax.set_ylabel('Color Discrimination Accuracy (%)', fontsize=11, fontweight='bold')
    ax.set_xticks(x_pos)
    ax.set_xticklabels(class_subset, rotation=45, ha='right', fontsize=10)
    ax.set_ylim(0, 105)
    ax.set_yticks([0, 25, 50, 75, 100])
    ax.grid(axis='y', alpha=0.3, linestyle='-', linewidth=0.5)
    ax.set_axisbelow(True)
    ax.set_facecolor('#fafafa')
    
    # Create legend elements (only once)
    global legend_elements
    if is_first:
        from matplotlib.lines import Line2D
        legend_elements = []
        
        for model_name in models_to_test:
            legend_elements.append(
                Line2D([0], [0], marker=markers[model_name], color='w', 
                      markerfacecolor=colors[model_name], markeredgecolor='black',
                      markersize=8, label=model_name.upper().replace('-', ' '))
            )
        
        for model_name in models_to_test:
            avg_val = all_classes_means[model_name]
            legend_elements.append(
                Line2D([0], [0], color=colors[model_name], 
                      linestyle=avg_line_styles[model_name], linewidth=2,
                      label=f'{model_name.upper().split("-")[0]} Average ({avg_val:.1f}%)')
            )
        
        legend_elements.append(
            Line2D([0], [0], color='#ffa500', linestyle=':', linewidth=1.5,
                  label='Chance Level (25%)')
        )

# Plot both halves
plot_on_axis(ax1, classes_first_half, is_first=True)
ax1.set_title('SCDC: Color Discrimination Performance - Part 1\nSame Class, Different Colors (Size & Texture Controlled)', 
              fontsize=13, fontweight='bold', pad=10)

plot_on_axis(ax2, classes_second_half, is_first=False)
ax2.set_title('SCDC: Color Discrimination Performance - Part 2', 
              fontsize=13, fontweight='bold', pad=10)
ax2.set_xlabel('Target Category', fontsize=11, fontweight='bold')

# Add legend in the middle
legend_ax = fig.add_axes([0.125, 0.44, 0.775, 0.08])
legend_ax.axis('off')

legend = legend_ax.legend(handles=legend_elements, 
                          loc='center', 
                          ncol=3,
                          fontsize=10,
                          frameon=True,
                          fancybox=True,
                          shadow=True,
                          framealpha=0.95,
                          columnspacing=2.5,
                          handlelength=3)

legend.get_frame().set_facecolor('white')
legend.get_frame().set_edgecolor('gray')
legend.get_frame().set_linewidth(1.5)

plt.tight_layout()
plt.subplots_adjust(hspace=0.35)

plt.savefig(os.path.join(RESULTS_DIR, 'scdc_perclass_comparison.png'), dpi=300, bbox_inches='tight', facecolor='white')
plt.savefig(os.path.join(RESULTS_DIR, 'scdc_perclass_comparison.pdf'), bbox_inches='tight', facecolor='white')
plt.show()

print(f"\nSaved SCDC plots to:")
print(f"  - {os.path.join(RESULTS_DIR, 'scdc_perclass_comparison.png')}")
print(f"  - {os.path.join(RESULTS_DIR, 'scdc_perclass_comparison.pdf')}")

NameError: name 'plt' is not defined

In [None]:
# Statistical summary and analysis
print("\n" + "="*60)
print("SCDC COLOR DISCRIMINATION ANALYSIS")
print("="*60)

# Overall comparison
for model in models_to_test:
    all_accs = []
    for cls in classes:
        if cls in stats_results[model]:
            all_accs.extend(stats_results[model][cls]['raw'])
    if all_accs:
        mean = np.mean(all_accs)
        std = np.std(all_accs)
        se = std / np.sqrt(len(all_accs))
        ci95 = 1.96 * se
        print(f"\n{model}:")
        print(f"  Overall: {mean:.3f} ± {ci95:.3f}")
        
        if 'cvcl' in model.lower():
            print(f"  Interpretation: CVCL struggles with color discrimination")
        elif 'clip' in model.lower():
            print(f"  Interpretation: CLIP excels at color due to text training")

# Find biggest differences
print("\n" + "-"*60)
print("CLASSES WITH LARGEST MODEL DIFFERENCES:")
print("-"*60)

differences = []
for cls in classes:
    if cls in stats_results['clip-res'] and cls in stats_results['cvcl-resnext']:
        diff = stats_results['clip-res'][cls]['mean'] - stats_results['cvcl-resnext'][cls]['mean']
        differences.append((cls, diff, 
                          stats_results['cvcl-resnext'][cls]['mean'],
                          stats_results['clip-res'][cls]['mean']))

differences.sort(key=lambda x: abs(x[1]), reverse=True)

print("\nTop 5 classes where CLIP > CVCL (best color discrimination):")
for cls, diff, cvcl_acc, clip_acc in differences[:5]:
    if diff > 0:
        print(f"  {cls:20s}: CVCL={cvcl_acc:.1%}, CLIP={clip_acc:.1%}, Diff={diff:+.1%}")

print("\nClasses where CVCL > CLIP (if any):")
cvcl_wins = [d for d in differences if d[1] < 0]
if cvcl_wins:
    for cls, diff, cvcl_acc, clip_acc in cvcl_wins[:5]:
        print(f"  {cls:20s}: CVCL={cvcl_acc:.1%}, CLIP={clip_acc:.1%}, Diff={diff:+.1%}")
else:
    print("  None - CLIP dominates color discrimination across all classes")

# Statistical test
from scipy import stats as scipy_stats
cvcl_all = []
clip_all = []
for cls in classes:
    if cls in stats_results['cvcl-resnext'] and cls in stats_results['clip-res']:
        cvcl_all.extend(stats_results['cvcl-resnext'][cls]['raw'])
        clip_all.extend(stats_results['clip-res'][cls]['raw'])

if cvcl_all and clip_all:
    t_stat, p_value = scipy_stats.ttest_ind(clip_all, cvcl_all)
    print(f"\n" + "="*60)
    print(f"Statistical Test (CLIP vs CVCL on color):")
    print(f"t-statistic: {t_stat:.3f}")
    print(f"p-value: {p_value:.6f}")
    if p_value < 0.001:
        print("Result: HIGHLY SIGNIFICANT difference (p < 0.001)")
        print("Conclusion: CLIP significantly outperforms CVCL at color discrimination")