In [None]:
!python -m spacy download en_core_web_sm

# Class Text-Vision Comparison - KonkLab

This notebook compares CVCL and CLIP models using text-based prototypes for class discrimination.
Instead of averaging image features to create prototypes, we use text descriptions.
The task remains 4-way forced choice classification with 4000 trials.

In [None]:
import os
import sys
import random
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from collections import defaultdict
import clip

# Path setup - Use absolute paths to avoid any confusion
REPO_ROOT = r'C:\Users\jbats\Projects\NTU-Synthetic'

# Add discover-hidden-visual-concepts to path
DISCOVER_ROOT = os.path.join(REPO_ROOT, 'discover-hidden-visual-concepts')
sys.path.insert(0, DISCOVER_ROOT)
sys.path.insert(0, REPO_ROOT)

# Import from discover-hidden-visual-concepts repo
sys.path.append(os.path.join(DISCOVER_ROOT, 'src'))
from utils.model_loader import load_model
from models.feature_extractor import FeatureExtractor

# KonkLab paths - Use absolute paths
CSV_PATH = os.path.join(REPO_ROOT, 'data', 'KonkLab', 'testdata.csv')
IMG_DIR = os.path.join(REPO_ROOT, 'data', 'KonkLab', '17-objects')
MASTER_CSV = os.path.join(REPO_ROOT, 'PatrickProject', 'Chart_Generation', 'text_vision_results.csv')

In [None]:
# Quick test to check if GPU is available and model loading speed
import torch
import time

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")
    
    # Test GPU speed
    x = torch.randn(32, 3, 224, 224).cuda()
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(10):
        _ = x * 2
    torch.cuda.synchronize()
    print(f"GPU test time: {time.time() - start:.3f}s")
else:
    print("WARNING: Running on CPU will be VERY slow!")
    print("If you have a GPU, make sure CUDA is properly installed")

In [None]:
# Dataset class with optimized loading
class ClassImageDataset(Dataset):
    def __init__(self, csv_path, img_dir, transform):
        self.df = pd.read_csv(csv_path)
        assert 'Filename' in self.df and 'Class' in self.df, \
            "CSV needs Filename and Class columns"
        self.img_dir = img_dir
        self.transform = transform
        # Pre-compute paths to avoid repeated joins
        self.paths = [os.path.join(img_dir, row['Class'], row['Filename']) 
                      for _, row in self.df.iterrows()]

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        cls = row['Class']
        # Use pre-computed path
        path = self.paths[idx]
        img = Image.open(path).convert('RGB')
        return self.transform(img), cls, idx

def collate_fn(batch):
    imgs = torch.stack([b[0] for b in batch])
    classes = [b[1] for b in batch]
    idxs = [b[2] for b in batch]
    return imgs, classes, idxs

def run_class_text_vision_test(model_name, seed=0, device=None, batch_size=16, 
                                trials_per_class=None, max_trials=4000):
    """
    Run 4-way classification test using the EXACT methodology from discover-hidden-visual-concepts predict() method.
    """
    random.seed(seed)
    torch.manual_seed(seed)
    
    # Device selection
    if device is None:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    if device == 'cuda' and not torch.cuda.is_available():
        print("[ERROR] CUDA requested but not available! Falling back to CPU.")
        device = 'cpu'
    
    if device == 'cpu':
        print("[WARNING] Running on CPU - this will be SLOW!")
        print("Reducing batch size to 4 for CPU")
        batch_size = 4
    else:
        print(f"[INFO] Using GPU: {torch.cuda.get_device_name(0)}")

    # Check if model supports text encoding
    if model_name in ['resnext', 'dino_s_resnext50']:
        print(f"[WARNING] {model_name} has no text encoder, skipping")
        return {}, 0.0

    # 1) Load model
    print(f"[INFO] Loading {model_name} on {device}...")
    import time
    start_time = time.time()
    model, transform = load_model(model_name, seed=seed, device=device)
    extractor = FeatureExtractor(model_name, model, device)
    print(f"[INFO] Model loaded in {time.time() - start_time:.1f}s")

    # 2) Load dataset
    ds = ClassImageDataset(CSV_PATH, IMG_DIR, transform)
    print(f"[INFO] Dataset: {len(ds)} images")
    
    # Use multiple workers only if not on Windows or if on Linux/Mac
    num_workers = 0 if os.name == 'nt' else 2
    
    dl = DataLoader(ds, batch_size=batch_size, shuffle=False,
                   num_workers=num_workers, collate_fn=collate_fn, 
                   pin_memory=(device=='cuda'))
    
    # 3) Extract image embeddings
    print(f"[INFO] Extracting embeddings (batch_size={batch_size})...")
    all_img_embs, all_classes, all_idxs = [], [], []
    
    from tqdm import tqdm
    start_time = time.time()
    
    with torch.no_grad():
        for imgs, classes, idxs in tqdm(dl, desc="Extracting embeddings"):
            imgs = imgs.to(device, non_blocking=True)
            feats = extractor.get_img_feature(imgs)
            feats = extractor.norm_features(feats).cpu()
            
            all_img_embs.append(feats)
            all_classes.extend(classes)
            all_idxs.extend(idxs)
                
    all_img_embs = torch.cat(all_img_embs, dim=0)
    print(f"[INFO] Extracted {len(all_idxs)} embeddings in {time.time() - start_time:.1f}s")

    # 4) Encode text labels
    unique_classes = list(set(all_classes))
    print(f"[INFO] Encoding {len(unique_classes)} class labels...")
    
    class_text_features = {}
    with torch.no_grad():
        if "clip" in model_name:
            tokens = clip.tokenize(unique_classes, truncate=True).to(device)
            txt_features = model.encode_text(tokens)
            txt_features = extractor.norm_features(txt_features).cpu()
            for i, cls in enumerate(unique_classes):
                class_text_features[cls] = txt_features[i]
        else:  # CVCL
            tokens, token_len = model.tokenize(unique_classes)
            tokens = tokens.to(device)
            if isinstance(token_len, torch.Tensor):
                token_len = token_len.to(device)
            txt_features = model.encode_text(tokens, token_len)
            txt_features = extractor.norm_features(txt_features).cpu()
            for i, cls in enumerate(unique_classes):
                class_text_features[cls] = txt_features[i]
    
    print(f"[INFO] Text encoding complete")
    
    # 5) Build mappings
    idx2class = {i:c for i,c in zip(all_idxs, all_classes)}
    idx2row = {i:r for r,i in enumerate(all_idxs)}
    class2idxs = defaultdict(list)
    for i,c in idx2class.items():
        class2idxs[c].append(i)

    # 6) Run trials - FIXED to run exactly max_trials
    class_results = {}
    total_correct = 0
    total_trials = 0
    
    # Get valid classes (those with at least 1 image)
    valid_classes = [c for c in class2idxs if len(class2idxs[c]) >= 1]
    n_classes = len(valid_classes)
    
    # Calculate trials per class if not specified
    if trials_per_class is None:
        # Distribute trials evenly across classes, with some getting extra
        trials_per_class = max_trials // n_classes
        extra_trials = max_trials % n_classes
    else:
        extra_trials = 0
    
    print(f"[INFO] Running {max_trials} total trials across {n_classes} classes")
    print(f"[INFO] Base trials per class: {trials_per_class}, extra trials for first {extra_trials} classes")
    
    # Run trials for each class
    for class_idx, cls in enumerate(valid_classes):
        if total_trials >= max_trials:
            break
            
        idxs = class2idxs[cls]
        
        # Add extra trial for first few classes to reach exactly max_trials
        current_trials = trials_per_class + (1 if class_idx < extra_trials else 0)
        # Don't exceed max_trials
        current_trials = min(current_trials, max_trials - total_trials)
        
        correct = 0
        txt_feature = class_text_features[cls].unsqueeze(0)
        
        for trial in range(current_trials):
            # Pick query and distractors
            q = random.choice(idxs)
            others = [i for i in all_idxs if idx2class[i] != cls]
            if len(others) < 3:
                continue
            distractors = random.sample(others, 3)
            
            # 4-way classification
            candidates = [q] + distractors
            cand_features = torch.stack([all_img_embs[idx2row[i]] for i in candidates])
            cand_features = cand_features.unsqueeze(0)
            
            # Compute similarity
            txt_feature_expanded = txt_feature.unsqueeze(1)
            similarity = (100.0 * cand_features @ txt_feature_expanded.transpose(-2, -1)).softmax(dim=1)
            similarity = similarity.squeeze()
            
            # Predict (query is at index 0)
            if similarity.argmax().item() == 0:
                correct += 1
                total_correct += 1
            total_trials += 1

        acc = correct / current_trials if current_trials > 0 else 0
        class_results[cls] = {'correct': correct, 'trials': current_trials, 'accuracy': acc}
        
        if total_trials % 500 == 0:
            print(f"[PROGRESS] {total_trials}/{max_trials} trials completed")

    # Final progress update
    print(f"[FINAL] Completed {total_trials} trials")
    
    overall_acc = total_correct / total_trials if total_trials else 0.0
    
    print(f"\n[RESULTS] Accuracy: {total_correct}/{total_trials} ({overall_acc:.1%})")
    
    # Save results
    summary_df = pd.DataFrame([{
        'Model': model_name,
        'Test': 'Class-TextVision-Original',
        'Dataset': 'KonkLab',
        'Correct': total_correct,
        'Trials': total_trials,
        'Accuracy': overall_acc
    }])
    
    os.makedirs(os.path.dirname(MASTER_CSV), exist_ok=True)
    if os.path.exists(MASTER_CSV):
        summary_df.to_csv(MASTER_CSV, mode='a', header=False, index=False, float_format='%.4f')
    else:
        summary_df.to_csv(MASTER_CSV, index=False, float_format='%.4f')

    return class_results, overall_acc

## CVCL Text-Vision Test

In [None]:
# Run CVCL text-vision classification
cvcl_results, cvcl_overall = run_class_text_vision_test('cvcl-resnext', max_trials=4000)

print("\nCVCL Results per Class:")
for cls, res in sorted(cvcl_results.items(), key=lambda x: x[1]['accuracy'], reverse=True)[:20]:
    print(f"{cls:20s}: {res['correct']}/{res['trials']} ({res['accuracy']:.1%})")
print(f"\nCVCL Overall Accuracy: {cvcl_overall:.1%}")

## CLIP Text-Vision Test

In [None]:
# Run CLIP text-vision classification
clip_results, clip_overall = run_class_text_vision_test('clip-resnext', max_trials=4000)

print("\nCLIP Results per Class:")
for cls, res in sorted(clip_results.items(), key=lambda x: x[1]['accuracy'], reverse=True)[:20]:
    print(f"{cls:20s}: {res['correct']}/{res['trials']} ({res['accuracy']:.1%})")
print(f"\nCLIP Overall Accuracy: {clip_overall:.1%}")

## Comparison Summary

In [None]:
print("="*60)
print("TEXT-VISION vs VISUAL PROTOTYPE COMPARISON")
print("="*60)
print(f"\nText-Vision Classification (using text descriptions as prototypes):")
print(f"  CVCL: {cvcl_overall:.1%}")
print(f"  CLIP: {clip_overall:.1%}")
print(f"  CLIP advantage: {(clip_overall - cvcl_overall)*100:.1f} percentage points")
print(f"\nNote: Compare these results with the visual prototype version")
print(f"to see how text-based prototypes perform vs image-based prototypes.")