# Same Class Different Color (SCDC) Text-Vision Comparison Test

This notebook tests model performance on distinguishing objects of the SAME class that have DIFFERENT colors using text-vision zero-shot classification.

Test format:
- 4-way forced choice
- Query: Image with specific class and color (e.g., red apple)
- Text prompts: Include color + class (e.g., "red apple", "green apple", "yellow apple", "blue apple")
- Distractors: Same class but different colors
- 4000 trials total per model

This is typically harder than DCDC since only color differs, not class.

In [1]:
# Imports
import os
import sys
import pandas as pd
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
import random
from datetime import datetime
import clip
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader

# Path setup - Use absolute paths to avoid any confusion
REPO_ROOT = r'C:\Users\jbats\Projects\NTU-Synthetic'

# Add discover-hidden-visual-concepts to path
DISCOVER_ROOT = os.path.join(REPO_ROOT, 'discover-hidden-visual-concepts')
sys.path.insert(0, DISCOVER_ROOT)
sys.path.insert(0, REPO_ROOT)

# Import from discover-hidden-visual-concepts repo
sys.path.append(os.path.join(DISCOVER_ROOT, 'src'))
from utils.model_loader import load_model
from models.feature_extractor import FeatureExtractor

# Paths
DATA_PATH = os.path.join(REPO_ROOT, 'data', 'SyntheticKonkle_224', 'SyntheticKonkle')
IMG_PATH = os.path.join(DATA_PATH, )
METADATA_PATH = os.path.join(REPO_ROOT, 'data', 'SyntheticKonkle', 'master_labels.csv')
RESULTS_PATH = os.path.join(REPO_ROOT, 'PatrickProject', 'Chart_Generation', 'text_vision_results.csv')

print(f"Data path: {DATA_PATH}")
print(f"Image path: {IMG_PATH}")
print(f"Metadata path: {METADATA_PATH}")
print(f"Results will be saved to: {RESULTS_PATH}")

  from pkg_resources import packaging


Data path: C:\Users\jbats\Projects\NTU-Synthetic\data\SyntheticKonkle_224\SyntheticKonkle
Image path: C:\Users\jbats\Projects\NTU-Synthetic\data\SyntheticKonkle_224\SyntheticKonkle
Metadata path: C:\Users\jbats\Projects\NTU-Synthetic\data\SyntheticKonkle\master_labels.csv
Results will be saved to: C:\Users\jbats\Projects\NTU-Synthetic\PatrickProject\Chart_Generation\text_vision_results.csv


In [2]:
# Load and prepare data
def load_konklab_data():
    """Load KonkLab dataset with metadata"""
    # Read metadata
    df = pd.read_csv(METADATA_PATH)
    
    # Standardize column names (handle both 'color' and 'colour')
    # SyntheticKonkle already has lowercase 'color' column
    
    # Build full paths
    df['image_path'] = df.apply(lambda row: os.path.join(DATA_PATH, row['folder'], row['filename']), axis=1)
    
    # Filter to only entries with valid color information
    df = df[df['color'].notna() & (df['color'] != '')].copy()
    
    # Standardize color names (lowercase)
    df['color'] = df['color'].str.lower().str.strip()
    
    # Create class-color combination column
    df['class_color'] = df['class'] + '_' + df['color']
    
    print(f"Loaded {len(df)} images with color annotations")
    print(f"Unique classes: {df['class'].nunique()}")
    print(f"Unique colors: {df['color'].nunique()}")
    print(f"Unique class-color combinations: {df['class_color'].nunique()}")
    
    # Find classes that have multiple colors (needed for SCDC test)
    class_color_counts = df.groupby('class')['color'].nunique()
    multi_color_classes = class_color_counts[class_color_counts >= 4].index.tolist()
    
    print(f"\nClasses with 4+ colors (suitable for SCDC): {len(multi_color_classes)}")
    if len(multi_color_classes) > 0:
        print(f"Examples: {multi_color_classes[:5]}")
    
    return df, multi_color_classes

# Load data
data_df, multi_color_classes = load_konklab_data()
print("\nSample data:")
print(data_df[['class', 'color', 'class_color']].head())

Loaded 7882 images with color annotations
Unique classes: 67
Unique colors: 12
Unique class-color combinations: 671

Classes with 4+ colors (suitable for SCDC): 67
Examples: ['abacus', 'apple', 'axe', 'babushkadolls', 'bagel']

Sample data:
    class   color    class_color
0  abacus     red     abacus_red
1  abacus   green   abacus_green
2  abacus    blue    abacus_blue
3  abacus  yellow  abacus_yellow
4  abacus  orange  abacus_orange


In [3]:
def run_scdc_text_vision_test(model_name='cvcl-resnext', seed=0, device=None, num_trials=4000):
    """Run Same Class Different Color text-vision test
    
    Args:
        model_name: Model to test ('cvcl-resnext' or 'clip-res')
        seed: Random seed for reproducibility (matches original Class test)
        device: Device to use (None for auto-detect)
        num_trials: Total number of trials to run
    """
    # Set seeds to match original test methodology
    random.seed(seed)
    torch.manual_seed(seed)
    
    print(f"\n{'='*60}")
    print(f"Running SCDC Text-Vision Test with {model_name}")
    print(f"{'='*60}")
    
    # Device selection
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    if device == 'cuda' and not torch.cuda.is_available():
        print("[ERROR] CUDA requested but not available! Falling back to CPU.")
        device = 'cpu'
    
    print(f"Using device: {device}")
    
    # Load model using same loader as original
    print(f"[INFO] Loading {model_name} on {device}...")
    model, transform = load_model(model_name, seed=seed, device=device)
    extractor = FeatureExtractor(model_name, model, device)
    model.eval()
    
    # Load and prepare data
    df = pd.read_csv(METADATA_PATH)
    
    # Standardize column names (handle both 'color' and 'colour')
    # SyntheticKonkle already has lowercase 'color' column
    
    # Build full paths
    df['image_path'] = df.apply(lambda row: os.path.join(DATA_PATH, row['folder'], row['filename']), axis=1)
    
    # Filter to only entries with valid color information
    df = df[df['color'].notna() & (df['color'] != '')].copy()
    
    # Standardize color names (lowercase)
    df['color'] = df['color'].str.lower().str.strip()
    
    # Create class-color combination column
    df['class_color'] = df['class'] + '_' + df['color']
    
    # Find classes that have at least 4 different colors (needed for 4-way choice)
    class_colors = df.groupby('class')['color'].unique()
    valid_classes = [cls for cls, colors in class_colors.items() if len(colors) >= 4]
    
    if len(valid_classes) == 0:
        print("ERROR: No classes have 4+ different colors. Cannot run SCDC test.")
        return [], 0.0
    
    print(f"\nFound {len(valid_classes)} classes with 4+ colors")
    print(f"Examples: {valid_classes[:5]}")
    
    # Group data by class-color combinations for valid classes
    df_valid = df[df['class'].isin(valid_classes)]
    grouped = df_valid.groupby('class_color').agg({
        'image_path': list,
        'class': 'first',
        'color': 'first'
    }).reset_index()
    
    print(f"\nUsing {len(grouped)} class-color combinations from {len(valid_classes)} classes")
    
    # Pre-compute all image embeddings for efficiency
    print("\nExtracting image embeddings...")
    image_embeddings = {}
    skipped_images = []
    
    # Process in batches
    all_image_paths = [img for imgs in grouped['image_path'] for img in imgs]
    batch_size = 16
    
    for i in tqdm(range(0, len(all_image_paths), batch_size), desc="Extracting embeddings"):
        batch_paths = all_image_paths[i:i+batch_size]
        batch_images = []
        
        for img_path in batch_paths:
            if img_path not in image_embeddings:  # Skip if already processed
                try:
                    img = Image.open(img_path).convert('RGB')
                    img_processed = transform(img).unsqueeze(0).to(device)
                    batch_images.append((img_path, img_processed))
        
                except Exception as e:
                    # Skip corrupted/invalid images
                    skipped_images.append(img_path)
                    continue
        if batch_images:
            # Stack batch
            paths = [p for p, _ in batch_images]
            imgs = torch.cat([img for _, img in batch_images], dim=0)
            
            # Get embeddings
            with torch.no_grad():
                embeddings = extractor.get_img_feature(imgs)
                embeddings = extractor.norm_features(embeddings)
            
            # Store
            for path, emb in zip(paths, embeddings):
                image_embeddings[path] = emb.cpu().float()  # Ensure float32
    
    print(f"Extracted embeddings for {len(image_embeddings)} images")
    if skipped_images:
        print(f"Skipped {len(skipped_images)} corrupted/invalid images")
    
    # Prepare for trials
    correct_count = 0
    trial_results = []
    
    # Calculate trials per class to get exactly num_trials
    trials_per_class = num_trials // len(valid_classes)
    remaining_trials = num_trials % len(valid_classes)
    
    print(f"\nRunning {num_trials} trials across {len(valid_classes)} classes...")
    print(f"Trials per class: {trials_per_class}, with {remaining_trials} classes getting 1 extra")
    
    # Run trials for each valid class
    for class_idx, target_class in enumerate(tqdm(valid_classes, desc="Processing classes")):
        # Get all color combinations for this class
        class_combos = grouped[grouped['class'] == target_class]
        available_colors = class_combos['color'].unique()
        
        if len(available_colors) < 4:
            continue  # Skip if not enough colors
        
        # Determine number of trials for this class
        n_trials = trials_per_class + (1 if class_idx < remaining_trials else 0)
        
        for trial in range(n_trials):
            # Select 4 different colors for this class
            selected_colors = random.sample(list(available_colors), 4)
            
            # First color is the query
            query_color = selected_colors[0]
            query_combo = f"{target_class}_{query_color}"
            query_data = class_combos[class_combos['class_color'] == query_combo].iloc[0]
            # Select random query image from valid images
            valid_query_paths = [p for p in query_data['image_path'] if p in image_embeddings]
            if not valid_query_paths:
                continue
            query_img_path = random.choice(valid_query_paths)
            
            # Create candidate list (all 4 colors of same class)
            candidate_combos = [f"{target_class}_{color}" for color in selected_colors]
            random.shuffle(candidate_combos)
            
            # Get correct index
            correct_idx = candidate_combos.index(query_combo)
            
            # Create text prompts for each candidate
            candidate_texts = []
            for combo in candidate_combos:
                combo_data = class_combos[class_combos['class_color'] == combo].iloc[0]
                # Create text prompt with color + class
                text_prompt = f"{combo_data['color']} {combo_data['class'].lower()}"
                candidate_texts.append(text_prompt)
            
            # Encode text prompts (batch encoding)
            with torch.no_grad():
                if "clip" in model_name:
                    tokens = clip.tokenize(candidate_texts, truncate=True).to(device)
                    txt_features = model.encode_text(tokens)
                    txt_features = extractor.norm_features(txt_features)
                else:  # CVCL
                    tokens, token_len = model.tokenize(candidate_texts)
                    tokens = tokens.to(device)
                    if isinstance(token_len, torch.Tensor):
                        token_len = token_len.to(device)
                    txt_features = model.encode_text(tokens, token_len)
                    txt_features = extractor.norm_features(txt_features)
            
            # Get query image embedding
            query_embedding = image_embeddings[query_img_path].unsqueeze(0).to(device)
            
            # Calculate similarity following original predict() method
            # Ensure both are float32
            query_embedding = query_embedding.float()
            txt_features = txt_features.float()
            
            similarity = (100.0 * query_embedding @ txt_features.transpose(-2, -1)).softmax(dim=1)
            
            # Get prediction
            pred_idx = similarity.argmax(dim=1).item()
            
            # Check if correct
            is_correct = (pred_idx == correct_idx)
            if is_correct:
                correct_count += 1
            
            # Store trial result
            trial_results.append({
                'trial': len(trial_results) + 1,
                'query_class': target_class,
                'query_color': query_color,
                'query_img': os.path.basename(query_img_path),
                'correct_idx': correct_idx,
                'predicted_idx': pred_idx,
                'correct': is_correct,
                'candidate_texts': candidate_texts,
                'similarity_scores': similarity.cpu().numpy().tolist()
            })
    
    # Calculate accuracy
    accuracy = correct_count / len(trial_results) if trial_results else 0
    
    print(f"\n{'='*60}")
    print(f"Results for {model_name} - SCDC Text-Vision Test:")
    print(f"Total trials: {len(trial_results)}")
    print(f"Correct: {correct_count}")
    print(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"{'='*60}")
    
    # Save results to CSV
    results_row = {
        'Model': model_name,
        'Test': 'SCDC-TextVision',
        'Dataset': 'SyntheticKonkle',
        'Correct': correct_count,
        'Trials': len(trial_results),
        'Accuracy': accuracy
    }
    
    # Append to results file
    os.makedirs(os.path.dirname(RESULTS_PATH), exist_ok=True)
    if os.path.exists(RESULTS_PATH):
        results_df = pd.read_csv(RESULTS_PATH)
    else:
        results_df = pd.DataFrame()
    
    results_df = pd.concat([results_df, pd.DataFrame([results_row])], ignore_index=True)
    results_df.to_csv(RESULTS_PATH, index=False, float_format='%.4f')
    print(f"\nResults saved to {RESULTS_PATH}")
    
    return trial_results, accuracy

## Run CVCL SCDC Text-Vision Test

In [4]:
# Run CVCL test with seed=0 (matching original Class test)
cvcl_trials, cvcl_accuracy = run_scdc_text_vision_test('cvcl-resnext', seed=0, num_trials=4000)


Running SCDC Text-Vision Test with cvcl-resnext
Using device: cuda
[INFO] Loading cvcl-resnext on cuda...
Loading checkpoint from C:\Users\jbats\.cache\huggingface\hub\models--wkvong--cvcl_s_dino_resnext50_embedding\snapshots\f50eaa0c50a6076a5190b1dd52aeeb6c3e747045\cvcl_s_dino_resnext50_embedding.ckpt


Lightning automatically upgraded your loaded checkpoint from v1.5.8 to v2.5.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint C:\Users\jbats\.cache\huggingface\hub\models--wkvong--cvcl_s_dino_resnext50_embedding\snapshots\f50eaa0c50a6076a5190b1dd52aeeb6c3e747045\cvcl_s_dino_resnext50_embedding.ckpt`



Found 67 classes with 4+ colors
Examples: ['abacus', 'apple', 'axe', 'babushkadolls', 'bagel']

Using 671 class-color combinations from 67 classes

Extracting image embeddings...


Extracting embeddings: 100%|██████████| 493/493 [00:24<00:00, 19.98it/s]


Extracted embeddings for 7841 images
Skipped 32 corrupted/invalid images

Running 4000 trials across 67 classes...
Trials per class: 59, with 47 classes getting 1 extra


Processing classes: 100%|██████████| 67/67 [00:40<00:00,  1.63it/s]


Results for cvcl-resnext - SCDC Text-Vision Test:
Total trials: 4000
Correct: 982
Accuracy: 0.2455 (24.55%)

Results saved to C:\Users\jbats\Projects\NTU-Synthetic\PatrickProject\Chart_Generation\text_vision_results.csv





## Run CLIP SCDC Text-Vision Test

In [5]:
# Run CLIP test with seed=0 (matching original Class test)
clip_trials, clip_accuracy = run_scdc_text_vision_test('clip-resnext', seed=0, num_trials=4000)


Running SCDC Text-Vision Test with clip-resnext
Using device: cuda
[INFO] Loading clip-resnext on cuda...

Found 67 classes with 4+ colors
Examples: ['abacus', 'apple', 'axe', 'babushkadolls', 'bagel']

Using 671 class-color combinations from 67 classes

Extracting image embeddings...


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
Extracting embeddings: 100%|██████████| 493/493 [00:17<00:00, 27.79it/s]


Extracted embeddings for 7841 images
Skipped 32 corrupted/invalid images

Running 4000 trials across 67 classes...
Trials per class: 59, with 47 classes getting 1 extra


Processing classes: 100%|██████████| 67/67 [00:24<00:00,  2.76it/s]


Results for clip-resnext - SCDC Text-Vision Test:
Total trials: 4000
Correct: 3927
Accuracy: 0.9818 (98.17%)

Results saved to C:\Users\jbats\Projects\NTU-Synthetic\PatrickProject\Chart_Generation\text_vision_results.csv





## Compare Results

In [6]:
# Display comparison
print("\n" + "="*60)
print("SCDC TEXT-VISION TEST COMPARISON")
print("="*60)
print(f"\nTest: Same Class Different Color (4-way forced choice)")
print(f"\nResults:")
print(f"  CVCL Accuracy: {cvcl_accuracy:.4f} ({cvcl_accuracy*100:.2f}%)")
print(f"  CLIP Accuracy: {clip_accuracy:.4f} ({clip_accuracy*100:.2f}%)")
print(f"\nDifference: {abs(cvcl_accuracy - clip_accuracy):.4f} ({abs(cvcl_accuracy - clip_accuracy)*100:.2f}%)")
if cvcl_accuracy > clip_accuracy:
    print(f"CVCL performs better by {(cvcl_accuracy - clip_accuracy)*100:.2f}%")
elif clip_accuracy > cvcl_accuracy:
    print(f"CLIP performs better by {(clip_accuracy - cvcl_accuracy)*100:.2f}%")
else:
    print("Both models perform equally")

print("\n" + "="*60)
print("\nAnalysis:")
print("- SCDC is typically harder than DCDC since only color differs")
print("- Models must rely on color understanding (both visual and textual)")
print("- Lower accuracy than DCDC would confirm color discrimination is challenging")


SCDC TEXT-VISION TEST COMPARISON

Test: Same Class Different Color (4-way forced choice)

Results:
  CVCL Accuracy: 0.2455 (24.55%)
  CLIP Accuracy: 0.9818 (98.17%)

Difference: 0.7363 (73.62%)
CLIP performs better by 73.62%


Analysis:
- SCDC is typically harder than DCDC since only color differs
- Models must rely on color understanding (both visual and textual)
- Lower accuracy than DCDC would confirm color discrimination is challenging


## Analysis Notes

### SCDC Text-Vision Test Characteristics:
- Tests discrimination when class is SAME but colors DIFFER
- All 4 candidates are the same object class (e.g., all apples)
- Text prompts include color information (e.g., "red apple", "green apple")
- Generally harder than DCDC since class provides no discriminative signal
- Pure test of color understanding (both visual and textual)

