# DCDC Text-Vision Test - CVCL Training Classes Only\n\nDifferent Class Different Color\n\n**This version only tests on the 25 classes that appear in CVCL's training data.**

In [1]:
# Imports
import os
import sys
import pandas as pd
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
import random
from datetime import datetime
import clip
from collections import defaultdict
from torch.utils.data import Dataset, DataLoader

# Path setup
REPO_ROOT = r'C:\Users\jbats\Projects\NTU-Synthetic'

# Add discover-hidden-visual-concepts to path
DISCOVER_ROOT = os.path.join(REPO_ROOT, 'discover-hidden-visual-concepts')
sys.path.insert(0, DISCOVER_ROOT)
sys.path.insert(0, REPO_ROOT)

# Import from discover-hidden-visual-concepts repo
sys.path.append(os.path.join(DISCOVER_ROOT, 'src'))
from utils.model_loader import load_model
from models.feature_extractor import FeatureExtractor

# Paths
CVCL_CLASSES_PATH = os.path.join(REPO_ROOT, 'data', 'CVCL_Konkle_Overlap', 'CVCLKonkMatches.csv')
DATA_PATH = os.path.join(REPO_ROOT, 'data', 'SyntheticKonkle_224', 'SyntheticKonkle')
METADATA_PATH = os.path.join(REPO_ROOT, 'data', 'SyntheticKonkle', 'master_labels.csv')
RESULTS_PATH = os.path.join(REPO_ROOT, 'PatrickProject', 'Chart_Generation', 'cvcl_training_text_vision_results.csv')

print(f"Data path: {DATA_PATH}")
print(f"CVCL classes file: {CVCL_CLASSES_PATH}")
print(f"Results will be saved to: {RESULTS_PATH}")

  from pkg_resources import packaging


Data path: C:\Users\jbats\Projects\NTU-Synthetic\data\SyntheticKonkle_224\SyntheticKonkle
CVCL classes file: C:\Users\jbats\Projects\NTU-Synthetic\data\CVCL_Konkle_Overlap\CVCLKonkMatches.csv
Results will be saved to: C:\Users\jbats\Projects\NTU-Synthetic\PatrickProject\Chart_Generation\cvcl_training_text_vision_results.csv


In [2]:
# Load CVCL training classes
cvcl_df = pd.read_csv(CVCL_CLASSES_PATH)
CVCL_TRAINING_CLASSES = cvcl_df['Class'].str.strip().tolist()

print(f"CVCL Training Classes ({len(CVCL_TRAINING_CLASSES)}):")
for cls in CVCL_TRAINING_CLASSES:
    print(f"  {cls}")

CVCL Training Classes (24):
  ball
  butterfly
  phone
  bagel
  basket
  bell
  fan
  seashell
  bird
  stool
  train
  ring
  tricycle
  toothpaste
  pen
  tree
  apple
  cookie
  bread
  pumpkin
  camera
  rabbit
  pillow
  horse


In [3]:
# Load and prepare data - FILTERED TO CVCL TRAINING CLASSES
def load_cvcl_synthetickonkle_data():
    """Load SyntheticKonkle dataset filtered to CVCL training classes"""
    # Read metadata
    df = pd.read_csv(METADATA_PATH)
    
    # Filter to only CVCL training classes
    df = df[df['class'].isin(CVCL_TRAINING_CLASSES)].copy()
    
    # Handle missing ball and bread
    missing_classes = set(CVCL_TRAINING_CLASSES) - set(df['class'].unique())
    if missing_classes:
        print(f"Adding missing classes from folders: {missing_classes}")
        for cls in missing_classes:
            folder = f"{cls}_color"
            folder_path = os.path.join(DATA_PATH, folder)
            if os.path.exists(folder_path):
                image_files = [f for f in os.listdir(folder_path) if f.endswith('.png')]
                for img_file in image_files:
                    # Parse filename to extract metadata
                    parts = img_file.replace('.png', '').split('_')
                    if len(parts) >= 5:
                        new_row = {
                            'folder': folder,
                            'filename': img_file,
                            'class': cls,
                            'color': '_'.join(parts[4:]),
                            'size': parts[1],
                            'texture': parts[2],
                            'variant': parts[3]
                        }
                        df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
    
    # Build full paths
    df['image_path'] = df.apply(lambda row: os.path.join(DATA_PATH, row['folder'], row['filename']), axis=1)
    
    # Filter to only entries with valid metadata
    df = df[df['color'].notna() & (df['color'] != '')].copy()
    df = df[df['size'].notna() & (df['size'] != '')].copy()
    df = df[df['texture'].notna() & (df['texture'] != '')].copy()
    
    # Standardize names (lowercase)
    df['color'] = df['color'].str.lower().str.strip()
    df['size'] = df['size'].str.lower().str.strip()
    df['texture'] = df['texture'].str.lower().str.strip()
    
    print(f"Loaded {len(df)} images from {df['class'].nunique()} CVCL training classes")
    print(f"Classes: {sorted(df['class'].unique())}")
    print(f"Unique colors: {df['color'].nunique()}")
    print(f"Unique sizes: {df['size'].nunique()}")
    print(f"Unique textures: {df['texture'].nunique()}")
    
    return df

# Load data
data_df = load_cvcl_synthetickonkle_data()
print("\nSample data:")
print(data_df[['class', 'color', 'size', 'texture']].head())

Adding missing classes from folders: {'ball'}
Loaded 2832 images from 24 CVCL training classes
Classes: ['apple', 'bagel', 'ball', 'basket', 'bell', 'bird', 'bread', 'butterfly', 'camera', 'cookie', 'fan', 'horse', 'pen', 'phone', 'pillow', 'pumpkin', 'rabbit', 'ring', 'seashell', 'stool', 'toothpaste', 'train', 'tree', 'tricycle']
Unique colors: 12
Unique sizes: 4
Unique textures: 4

Sample data:
   class   color   size texture
0  apple     red  large   bumpy
1  apple   green  large   bumpy
2  apple    blue  large   bumpy
3  apple  yellow  large   bumpy
4  apple  orange  large   bumpy


In [4]:
def run_dcdc_test(model_name='cvcl-resnext', seed=0, device=None, num_trials=4000):
    """Run DCDC text-vision test on CVCL training classes only
    
    Args:
        model_name: Model to test ('cvcl-resnext' or 'clip-resnext')
        seed: Random seed for reproducibility
        device: Device to use (None for auto-detect)
        num_trials: Total number of trials to run
    """
    # Set seeds
    random.seed(seed)
    torch.manual_seed(seed)
    
    print(f"\n{'='*60}")
    print(f"Running DCDC Text-Vision Test with {model_name}")
    print(f"CVCL Training Classes Only")
    print(f"{'='*60}")
    
    # Device selection
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    print(f"Using device: {device}")
    
    # Load model
    print(f"[INFO] Loading {model_name} on {device}...")
    model, transform = load_model(model_name, seed=seed, device=device)
    extractor = FeatureExtractor(model_name, model, device)
    model.eval()
    
    # Load and prepare data (already filtered to CVCL training classes)
    df = load_cvcl_synthetickonkle_data()
    
    # Create class-color combination column
    df['class_color'] = df['class'] + '_' + df['color']
    
    print(f"Loaded {len(df)} images with color annotations")
    print(f"Unique classes: {df['class'].nunique()}")
    print(f"Unique colors: {df['color'].nunique()}")
    print(f"Unique class-color combinations: {df['class_color'].nunique()}")
    
    # Group data by class-color combinations
    grouped = df.groupby('class_color').agg({
        'image_path': list,
        'class': 'first',
        'color': 'first'
    }).reset_index()
    
    # Filter to combinations with at least 1 image
    grouped = grouped[grouped['image_path'].apply(len) >= 1].reset_index(drop=True)
    
    print(f"\nUsing {len(grouped)} unique class-color combinations")
    
    # Pre-compute all image embeddings for efficiency
    print("\nExtracting image embeddings...")
    image_embeddings = {}
    skipped_images = []
    
    # Process in batches
    all_image_paths = [img for imgs in grouped['image_path'] for img in imgs]
    batch_size = 16
    
    for i in tqdm(range(0, len(all_image_paths), batch_size), desc="Extracting embeddings"):
        batch_paths = all_image_paths[i:i+batch_size]
        batch_images = []
        
        for img_path in batch_paths:
            if img_path not in image_embeddings:  # Skip if already processed
                try:
                    img = Image.open(img_path).convert('RGB')
                    img_processed = transform(img).unsqueeze(0).to(device)
                    batch_images.append((img_path, img_processed))
                except Exception as e:
                    # Skip corrupted/invalid images
                    skipped_images.append(img_path)
                    continue
        
        if batch_images:
            # Stack batch
            paths = [p for p, _ in batch_images]
            imgs = torch.cat([img for _, img in batch_images], dim=0)
            
            # Get embeddings
            with torch.no_grad():
                embeddings = extractor.get_img_feature(imgs)
                embeddings = extractor.norm_features(embeddings)
            
            # Store
            for path, emb in zip(paths, embeddings):
                image_embeddings[path] = emb.cpu().float()  # Ensure float32
    
    print(f"Extracted embeddings for {len(image_embeddings)} images")
    if skipped_images:
        print(f"Skipped {len(skipped_images)} corrupted/invalid images")
    
    # Update grouped data to only include valid images
    for idx, row in grouped.iterrows():
        valid_paths = [p for p in row['image_path'] if p in image_embeddings]
        grouped.at[idx, 'image_path'] = valid_paths
    
    # Filter out combinations with no valid images
    grouped = grouped[grouped['image_path'].apply(len) > 0].reset_index(drop=True)
    
    # Prepare for trials
    correct_count = 0
    trial_results = []
    
    # Calculate trials per combination (ensure we get exactly num_trials)
    combinations_list = grouped['class_color'].tolist()
    trials_per_combo = num_trials // len(combinations_list)
    remaining_trials = num_trials % len(combinations_list)
    
    # Create trial distribution
    trial_distribution = []
    for i, combo in enumerate(combinations_list):
        n_trials = trials_per_combo + (1 if i < remaining_trials else 0)
        trial_distribution.extend([combo] * n_trials)
    
    # Shuffle trials
    random.shuffle(trial_distribution)
    
    print(f"\nRunning {len(trial_distribution)} trials...")
    
    # Run trials
    for trial_idx in tqdm(range(len(trial_distribution)), desc="Trials"):
        # Get query class-color
        query_combo = trial_distribution[trial_idx]
        query_data = grouped[grouped['class_color'] == query_combo].iloc[0]
        query_class = query_data['class']
        query_color = query_data['color']
        
        # Select random query image from valid images
        valid_query_paths = [p for p in query_data['image_path'] if p in image_embeddings]
        if not valid_query_paths:
            continue
        query_img_path = random.choice(valid_query_paths)
        
        # Select 3 distractors (different class AND different color)
        valid_distractors = grouped[
            (grouped['class'] != query_class) & 
            (grouped['color'] != query_color)
        ]
        
        # Filter distractors to those with valid images
        valid_distractors = valid_distractors[valid_distractors['image_path'].apply(len) > 0]
        
        if len(valid_distractors) < 3:
            continue  # Skip if not enough valid distractors
        
        distractor_combos = valid_distractors.sample(n=3)['class_color'].tolist()
        
        # Create candidate list (query + distractors)
        all_combos = [query_combo] + distractor_combos
        random.shuffle(all_combos)
        
        # Get correct index
        correct_idx = all_combos.index(query_combo)
        
        # Select random images for each candidate
        candidate_imgs = []
        candidate_texts = []
        skip_trial = False
        
        for combo in all_combos:
            combo_data = grouped[grouped['class_color'] == combo].iloc[0]
            valid_paths = [p for p in combo_data['image_path'] if p in image_embeddings]
            if not valid_paths:
                skip_trial = True
                break
            img_path = random.choice(valid_paths)
            candidate_imgs.append(img_path)
            
            # Create text prompt with color + class
            text_prompt = f"{combo_data['color']} {combo_data['class'].lower()}"
            candidate_texts.append(text_prompt)
        
        if skip_trial:
            continue
        
        # Encode text prompts (batch encoding)
        with torch.no_grad():
            if "clip" in model_name:
                tokens = clip.tokenize(candidate_texts, truncate=True).to(device)
                txt_features = model.encode_text(tokens)
                txt_features = extractor.norm_features(txt_features)
            else:  # CVCL
                tokens, token_len = model.tokenize(candidate_texts)
                tokens = tokens.to(device)
                if isinstance(token_len, torch.Tensor):
                    token_len = token_len.to(device)
                txt_features = model.encode_text(tokens, token_len)
                txt_features = extractor.norm_features(txt_features)
        
        # Get query image embedding
        query_embedding = image_embeddings[query_img_path].unsqueeze(0).to(device)
        
        # Calculate similarity following original predict() method
        # Ensure both are float32
        query_embedding = query_embedding.float()
        txt_features = txt_features.float()
        
        similarity = (100.0 * query_embedding @ txt_features.transpose(-2, -1)).softmax(dim=1)
        
        # Get prediction
        pred_idx = similarity.argmax(dim=1).item()
        
        # Check if correct
        is_correct = (pred_idx == correct_idx)
        if is_correct:
            correct_count += 1
        
        # Store trial result
        trial_results.append({
            'trial': trial_idx + 1,
            'query_class': query_class,
            'query_color': query_color,
            'query_img': os.path.basename(query_img_path),
            'correct_idx': correct_idx,
            'predicted_idx': pred_idx,
            'correct': is_correct,
            'candidate_texts': candidate_texts,
            'similarity_scores': similarity.cpu().numpy().tolist()
        })
    
    # Calculate accuracy
    accuracy = correct_count / len(trial_results) if trial_results else 0
    
    print(f"\n{'='*60}")
    print(f"Results for {model_name} - DCDC Text-Vision Test:")
    print(f"Total trials: {len(trial_results)}")
    print(f"Correct: {correct_count}")
    print(f"Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"{'='*60}")
    
    # Save results
    results_row = {
        'Model': model_name,
        'Test': 'DCDC-TextVision-CVCLTraining',
        'Dataset': 'SyntheticKonkle_224',
        'Correct': correct_count,
        'Trials': len(trial_results),
        'Accuracy': accuracy
    }
    
    # Append to results file
    os.makedirs(os.path.dirname(RESULTS_PATH), exist_ok=True)
    if os.path.exists(RESULTS_PATH):
        results_df = pd.read_csv(RESULTS_PATH)
    else:
        results_df = pd.DataFrame()
    
    results_df = pd.concat([results_df, pd.DataFrame([results_row])], ignore_index=True)
    results_df.to_csv(RESULTS_PATH, index=False, float_format='%.4f')
    print(f"\nResults saved to {RESULTS_PATH}")
    
    return trial_results, accuracy

## Run CVCL DCDC Text-Vision Test

In [5]:
# Run CVCL test
cvcl_trials, cvcl_accuracy = run_dcdc_test('cvcl-resnext', seed=0, num_trials=4000)


Running DCDC Text-Vision Test with cvcl-resnext
CVCL Training Classes Only
Using device: cuda
[INFO] Loading cvcl-resnext on cuda...
Loading checkpoint from C:\Users\jbats\.cache\huggingface\hub\models--wkvong--cvcl_s_dino_resnext50_embedding\snapshots\f50eaa0c50a6076a5190b1dd52aeeb6c3e747045\cvcl_s_dino_resnext50_embedding.ckpt


Lightning automatically upgraded your loaded checkpoint from v1.5.8 to v2.5.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint C:\Users\jbats\.cache\huggingface\hub\models--wkvong--cvcl_s_dino_resnext50_embedding\snapshots\f50eaa0c50a6076a5190b1dd52aeeb6c3e747045\cvcl_s_dino_resnext50_embedding.ckpt`


Adding missing classes from folders: {'ball'}
Loaded 2832 images from 24 CVCL training classes
Classes: ['apple', 'bagel', 'ball', 'basket', 'bell', 'bird', 'bread', 'butterfly', 'camera', 'cookie', 'fan', 'horse', 'pen', 'phone', 'pillow', 'pumpkin', 'rabbit', 'ring', 'seashell', 'stool', 'toothpaste', 'train', 'tree', 'tricycle']
Unique colors: 12
Unique sizes: 4
Unique textures: 4
Loaded 2832 images with color annotations
Unique classes: 24
Unique colors: 12
Unique class-color combinations: 243

Using 243 unique class-color combinations

Extracting image embeddings...


Extracting embeddings: 100%|██████████| 177/177 [00:08<00:00, 21.47it/s]


Extracted embeddings for 2806 images
Skipped 23 corrupted/invalid images

Running 4000 trials...


Trials: 100%|██████████| 4000/4000 [00:42<00:00, 93.05it/s] 


Results for cvcl-resnext - DCDC Text-Vision Test:
Total trials: 4000
Correct: 1197
Accuracy: 0.2993 (29.93%)

Results saved to C:\Users\jbats\Projects\NTU-Synthetic\PatrickProject\Chart_Generation\cvcl_training_text_vision_results.csv





## Run CLIP DCDC Text-Vision Test

In [6]:
# Run CLIP test
clip_trials, clip_accuracy = run_dcdc_test('clip-resnext', seed=0, num_trials=4000)


Running DCDC Text-Vision Test with clip-resnext
CVCL Training Classes Only
Using device: cuda
[INFO] Loading clip-resnext on cuda...
Adding missing classes from folders: {'ball'}
Loaded 2832 images from 24 CVCL training classes
Classes: ['apple', 'bagel', 'ball', 'basket', 'bell', 'bird', 'bread', 'butterfly', 'camera', 'cookie', 'fan', 'horse', 'pen', 'phone', 'pillow', 'pumpkin', 'rabbit', 'ring', 'seashell', 'stool', 'toothpaste', 'train', 'tree', 'tricycle']
Unique colors: 12
Unique sizes: 4
Unique textures: 4
Loaded 2832 images with color annotations
Unique classes: 24
Unique colors: 12
Unique class-color combinations: 243

Using 243 unique class-color combinations

Extracting image embeddings...


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
Extracting embeddings: 100%|██████████| 177/177 [00:06<00:00, 27.29it/s]


Extracted embeddings for 2806 images
Skipped 23 corrupted/invalid images

Running 4000 trials...


Trials: 100%|██████████| 4000/4000 [00:26<00:00, 149.66it/s]


Results for clip-resnext - DCDC Text-Vision Test:
Total trials: 4000
Correct: 3957
Accuracy: 0.9892 (98.92%)

Results saved to C:\Users\jbats\Projects\NTU-Synthetic\PatrickProject\Chart_Generation\cvcl_training_text_vision_results.csv





## Compare Results

In [None]:
# Display comparison
print("\n" + "="*60)
print("DCDC TEXT-VISION TEST COMPARISON - CVCL TRAINING CLASSES")
print("="*60)
print(f"\nResults:")
print(f"  CVCL Accuracy: {cvcl_accuracy:.4f} ({cvcl_accuracy*100:.2f}%)")
print(f"  CLIP Accuracy: {clip_accuracy:.4f} ({clip_accuracy*100:.2f}%)")
print(f"\nDifference: {abs(cvcl_accuracy - clip_accuracy):.4f} ({abs(cvcl_accuracy - clip_accuracy)*100:.2f}%)")
if cvcl_accuracy > clip_accuracy:
    print(f"CVCL performs better by {(cvcl_accuracy - clip_accuracy)*100:.2f}% on its training classes")
elif clip_accuracy > cvcl_accuracy:
    print(f"CLIP performs better by {(clip_accuracy - cvcl_accuracy)*100:.2f}% even on CVCL's training classes")
else:
    print("Both models perform equally")