In [None]:
import torch
from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer
from PIL import Image
import json
import numpy as np
from pathlib import Path
from tqdm import tqdm
import gc

# ============================================================================
# ULTRA LIGHTWEIGHT CONFIG otherwise local memory overload
# ============================================================================

CONFIG = {
    'n_samples': 100, 
    'batch_size': 10,
    
    # Paths
    'data_dir': './vqa_data',
    'vqa_json': 'vqa_ccs_object_detection.json',
    'image_dir': 'train2014',
    'cache_dir': './hidden_states_cache',
    
    # Model - LIGHTWEIGHT! ~500MB instead of ~14GB
    'model_name': 'nlpconnect/vit-gpt2-image-captioning',
    'layer': -1,
}


def load_lightweight_samples(config):
    """Load minimal samples"""
    print("LOADING DATA")
    
    data_path = Path(config['data_dir']) / config['vqa_json']
    with open(data_path, 'r') as f:
        vqa_data = json.load(f)
    
    # Take first N samples
    samples = vqa_data[:config['n_samples']]
    
    print(f"Using {len(samples)} samples")
    
    # Create contrast pairs
    pairs = []
    for item in samples:
        q = item['question'].rstrip('?')
        pairs.append({
            'image_id': item['image_id'],
            'pos_text': f"{q}? Yes",
            'neg_text': f"{q}? No",
            'label': item['label']
        })
    
    return pairs


def extract_in_batches(pairs, config):
    """Extract hidden states in small batches to avoid OOM"""
    print("EXTRACTING HIDDEN STATES (BATCH MODE)")
    
    cache_dir = Path(config['cache_dir'])
    cache_dir.mkdir(exist_ok=True)
    
    # Check if already cached
    cache_file = cache_dir / 'cached_hiddens_vit.npz'
    if cache_file.exists():
        print("\nFound cached hidden states!")
        print("Loading from cache...")
        data = np.load(cache_file)
        return data['pos_hiddens'], data['neg_hiddens'], data['labels']
    
    print("\nNo cache found. Extracting...")
    print(f"Processing {len(pairs)} samples in batches of {config['batch_size']}")
    
    # Load lightweight model
    print("\nLoading lightweight model (ViT-GPT2, ~500MB)...")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Device: {device}")
    
    model = VisionEncoderDecoderModel.from_pretrained(config['model_name'])
    feature_extractor = ViTImageProcessor.from_pretrained(config['model_name'])
    tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
    
    model = model.to(device)
    model.eval()
    
    print("Model loaded (much faster than LLaVA!)\n")
    
    # Extract in batches
    all_pos_hiddens = []
    all_neg_hiddens = []
    all_labels = []
    
    image_dir = Path(config['data_dir']) / config['image_dir']
    
    for i in tqdm(range(0, len(pairs), config['batch_size']), desc="Batches"):
        batch_pairs = pairs[i:i + config['batch_size']]
        
        for pair in batch_pairs:
            image_id = pair['image_id']
            image_path = image_dir / f"COCO_train2014_{image_id:012d}.jpg"
            
            if not image_path.exists():
                print(f"Skipping {image_id}")
                continue
            
            try:
                image = Image.open(image_path).convert('RGB')
                
                # Extract positive
                pos_h = extract_one(model, feature_extractor, tokenizer, image, pair['pos_text'], device)
                
                # Extract negative
                neg_h = extract_one(model, feature_extractor, tokenizer, image, pair['neg_text'], device)
                
                all_pos_hiddens.append(pos_h)
                all_neg_hiddens.append(neg_h)
                all_labels.append(pair['label'])
                
            except Exception as e:
                print(f"Error {image_id}: {e}")
                continue
        
        # Clear cache after each batch
        if device == "cuda":
            torch.cuda.empty_cache()
        gc.collect()
    
    # Unload model to free memory
    del model
    del feature_extractor
    del tokenizer
    if device == "cuda":
        torch.cuda.empty_cache()
    gc.collect()
    
    print("\nModel unloaded to free memory")
    
    # Convert to arrays
    pos_hiddens = np.array(all_pos_hiddens)
    neg_hiddens = np.array(all_neg_hiddens)
    labels = np.array(all_labels)
    
    print(f"\nExtracted: {pos_hiddens.shape}")
    
    # Cache for future runs
    np.savez(cache_file, pos_hiddens=pos_hiddens, neg_hiddens=neg_hiddens, labels=labels)
    print(f"💾 Cached to: {cache_file}")
    
    return pos_hiddens, neg_hiddens, labels


def extract_one(model, feature_extractor, tokenizer, image, text, device):
    """Extract single hidden state from ViT-GPT2 model"""
    # Process image
    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values.to(device)
    
    # Tokenize text
    text_inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
    
    with torch.no_grad():
        # Get encoder outputs (vision)
        encoder_outputs = model.encoder(pixel_values, output_hidden_states=True)
        
        # Get decoder outputs (text, conditioned on image)
        decoder_outputs = model.decoder(
            input_ids=text_inputs['input_ids'],
            encoder_hidden_states=encoder_outputs.last_hidden_state,
            output_hidden_states=True
        )
        
        # Use last decoder hidden state, averaged across tokens
        hidden = decoder_outputs.hidden_states[-1].mean(dim=1).squeeze(0)
    
    return hidden.cpu().numpy()


def train_simple_ccs(pos_hiddens, neg_hiddens, labels):
    """Train CCS probe (simplified)"""
    print("TRAINING CCS PROBE")
    
    import torch.nn as nn
    import torch.optim as optim
    
    # Normalize
    pos_hiddens = torch.FloatTensor(pos_hiddens)
    neg_hiddens = torch.FloatTensor(neg_hiddens)
    
    pos_hiddens = pos_hiddens - pos_hiddens.mean(dim=0)
    neg_hiddens = neg_hiddens - neg_hiddens.mean(dim=0)
    
    # Simple 70/30 split
    n = len(labels)
    n_train = int(0.7 * n)
    
    pos_train = pos_hiddens[:n_train]
    neg_train = neg_hiddens[:n_train]
    pos_test = pos_hiddens[n_train:]
    neg_test = neg_hiddens[n_train:]
    labels_test = labels[n_train:]
    
    print(f"Train: {n_train}, Test: {n - n_train}")
    
    # Simple probe
    class Probe(nn.Module):
        def __init__(self, d):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(d, 100),
                nn.ReLU(),
                nn.Linear(100, 1),
                nn.Sigmoid()
            )
        def forward(self, x):
            return self.net(x)
    
    probe = Probe(pos_hiddens.shape[1])
    opt = optim.Adam(probe.parameters(), lr=1e-3)
    
    print("\nTraining...")
    for epoch in range(100):
        p_pos = probe(pos_train)
        p_neg = probe(neg_train)
        
        loss = ((p_pos - (1 - p_neg))**2).mean() + (torch.min(p_pos, p_neg)**2).mean()
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        if (epoch + 1) % 25 == 0:
            print(f"Epoch {epoch+1}/100, Loss: {loss.item():.6f}")
    
    # Test
    probe.eval()
    with torch.no_grad():
        p_pos = probe(pos_test).squeeze()
        p_neg = probe(neg_test).squeeze()
        probs = 0.5 * (p_pos + (1 - p_neg))
        preds = (probs > 0.5).numpy()
    
    acc = (preds == labels_test).mean()
    
    print(f"\nTest Accuracy: {acc:.1%} ({(preds == labels_test).sum()}/{len(labels_test)})")
    
    return acc


def main():
    # Load data
    pairs = load_lightweight_samples(CONFIG)
    
    # Extract (or load from cache)
    pos_h, neg_h, labels = extract_in_batches(pairs, CONFIG)
    
    # Train CCS
    acc = train_simple_ccs(pos_h, neg_h, labels)
    
    print("COMPLETE!")
    print(f"Final accuracy: {acc:.1%}")
    print("\nNext: Increase n_samples to 50-100 for better results")


if __name__ == "__main__":
    main()


CCS ULTRA-LIGHTWEIGHT PIPELINE

🚀 OPTIMIZED VERSION - Using ViT-GPT2 (~500MB)
   Instead of LLaVA (~14GB) - 28x smaller!

- Processes in small batches
- Caches results
- Unloads model after extraction

LOADING DATA
Using 100 samples

EXTRACTING HIDDEN STATES (BATCH MODE)

✅ Found cached hidden states!
Loading from cache...

TRAINING CCS PROBE
Train: 5, Test: 3

Training...
Epoch 25/100, Loss: 0.192710
Epoch 50/100, Loss: 0.179391
Epoch 75/100, Loss: 0.156329
Epoch 100/100, Loss: 0.116346

✅ Test Accuracy: 33.3% (1/3)

✅ COMPLETE!
Final accuracy: 33.3%

Next: Increase n_samples to 50-100 for better results
Epoch 100/100, Loss: 0.116346

✅ Test Accuracy: 33.3% (1/3)

✅ COMPLETE!
Final accuracy: 33.3%

Next: Increase n_samples to 50-100 for better results
