# ELEC 475 Lab 4 - CLIP Visualization Generator (FIXED)

This notebook generates all required visualizations for Lab 4 Section 2.4:
1. Textâ†’Image retrieval (including 'sport' and 'a dog playing')
2. Zero-shot image classification

**Features:**
- Clones Lab4 code from GitHub
- Downloads ONLY val2014 images (~6GB)
- **FIXED: Handles text_encoder key mismatch automatically**
- Handles different model architectures (base, batchnorm, dropout)
- Saves outputs to Google Drive

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install -q transformers torch torchvision tqdm pillow matplotlib

In [None]:
import os

if not os.path.exists('475_ML-CV_Labs'):
    !git clone https://github.com/Jcub05/475_ML-CV_Labs.git

os.chdir('/content/475_ML-CV_Labs/Lab4')
print(f"âœ“ Current directory: {os.getcwd()}")

In [None]:
import urllib.request
import zipfile
from pathlib import Path

data_dir = Path('/content/coco_data')
data_dir.mkdir(exist_ok=True)
val_dir = data_dir / 'val2014'

if not val_dir.exists() or len(list(val_dir.glob('*.jpg'))) == 0:
    val_url = 'http://images.cocodataset.org/zips/val2014.zip'
    val_zip = data_dir / 'val2014.zip'
    urllib.request.urlretrieve(val_url, val_zip)
    with zipfile.ZipFile(val_zip, 'r') as z:
        z.extractall(data_dir)
    val_zip.unlink()

print(f"âœ“ Found {len(list(val_dir.glob('*.jpg')))} validation images")

In [None]:
from google.colab import files

# CONFIGURE MODEL TYPE
MODEL_TYPE = 'batchnorm'  # Options: 'base', 'batchnorm', 'dropout', 'batchnorm_dropout'

# UPLOAD OR SPECIFY MODEL PATH
# Option 1: Upload
# uploaded = files.upload()
# model_checkpoint_path = list(uploaded.keys())[0]

# Option 2: Google Drive path
model_checkpoint_path = '/content/drive/MyDrive/elec475_lab4/models/best_model_batch_norm.pth'

print(f"Model type: {MODEL_TYPE}")
print(f"Checkpoint: {model_checkpoint_path}")

In [None]:
import torch
from transformers import CLIPProcessor, CLIPModel
from torchvision import transforms
from PIL import Image

from model import CLIPFineTuneModel
from model_modified import CLIPImageEncoderModified, CLIPFineTuneModelModified
from visualize import visualize_text_to_image_retrieval, zero_shot_classification, create_retrieval_grid

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
output_dir = Path('/content/Visualizations')
output_dir.mkdir(exist_ok=True)

print(f"Device: {device}")

In [None]:
def get_clip_transform():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.48145466, 0.4578275, 0.40821073),
            std=(0.26862954, 0.26130258, 0.27577711)
        )
    ])

def load_model_with_architecture(model_path, model_type, device):
    """
    Load model with AUTOMATIC KEY FIXING for text_encoder mismatch.
    """
    print(f"\nLoading {model_type} model from {model_path}...")
    
    MODEL_CONFIGS = {
        'base': {'use_batchnorm': False, 'use_dropout': False, 'deeper_projection': False},
        'batchnorm': {'use_batchnorm': True, 'use_dropout': False, 'deeper_projection': False},
        'dropout': {'use_batchnorm': False, 'use_dropout': True, 'dropout_rate': 0.1, 'deeper_projection': False},
        'batchnorm_dropout': {'use_batchnorm': True, 'use_dropout': True, 'dropout_rate': 0.1, 'deeper_projection': False},
    }
    
    # Create model
    if model_type == 'base':
        model = CLIPFineTuneModel(
            embed_dim=512,
            pretrained_resnet=True,
            clip_model_name="openai/clip-vit-base-patch32",
            freeze_text_encoder=True
        ).to(device)
    else:
        config = MODEL_CONFIGS[model_type]
        image_encoder = CLIPImageEncoderModified(embed_dim=512, **config)
        clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        model = CLIPFineTuneModelModified(
            image_encoder=image_encoder,
            text_encoder=clip_model.text_model,
            tokenizer=None
        ).to(device)
    
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    state_dict = checkpoint.get('model_state_dict', checkpoint)
    
    # ðŸ”§ FIX: Remove extra 'text_model' prefix from text_encoder keys
    fixed_state_dict = {}
    num_fixed = 0
    for key, value in state_dict.items():
        if key.startswith('text_encoder.text_model.'):
            new_key = key.replace('text_encoder.text_model.', 'text_encoder.')
            fixed_state_dict[new_key] = value
            num_fixed += 1
        else:
            fixed_state_dict[key] = value
    
    if num_fixed > 0:
        print(f"âœ“ Fixed {num_fixed} text_encoder keys")
    
    # Load weights
    missing, unexpected = model.load_state_dict(fixed_state_dict, strict=False)
    
    if missing:
        print(f"âš  Missing keys: {len(missing)} (expected for frozen layers)")
    if unexpected:
        print(f"âš  Unexpected keys: {len(unexpected)}")
    
    model.eval()
    print(f"âœ“ Model loaded successfully\n")
    return model

def precompute_image_embeddings(model, image_paths, transform, device, batch_size=32):
    print(f"Precomputing embeddings for {len(image_paths)} images...")
    all_embeds = []
    model.eval()
    
    with torch.no_grad():
        for i in range(0, len(image_paths), batch_size):
            batch_paths = image_paths[i:i+batch_size]
            images = [transform(Image.open(p).convert('RGB')) for p in batch_paths]
            images = torch.stack(images).to(device)
            embeds = model.encode_image(images).cpu()
            all_embeds.append(embeds)
            
            if (i // batch_size + 1) % 10 == 0:
                print(f"  {i+len(batch_paths)}/{len(image_paths)}")
    
    all_embeds = torch.cat(all_embeds, dim=0)
    print(f"âœ“ Embeddings: {all_embeds.shape}")
    return all_embeds

class ModifiedModelWrapper:
    def __init__(self, model, processor):
        self.model = model
        self.processor = processor
        
    def eval(self):
        self.model.eval()
        return self
    
    def encode_text(self, input_ids, attention_mask):
        with torch.no_grad():
            if hasattr(self.model, 'text_encoder'):
                outputs = self.model.text_encoder(
                    input_ids=input_ids,
                    attention_mask=attention_mask
                )
                import torch.nn.functional as F
                return F.normalize(outputs.pooler_output, p=2, dim=-1)
            else:
                return self.model.encode_text(input_ids, attention_mask)
    
    def encode_image(self, images):
        return self.model.encode_image(images)

def generate_visualizations(model, model_name, image_paths, image_embeds, processor, transform, device, output_dir):
    print(f"\nGenerating visualizations for: {model_name}")
    model_output_dir = output_dir / model_name
    model_output_dir.mkdir(parents=True, exist_ok=True)
    
    wrapped_model = ModifiedModelWrapper(model, processor) if isinstance(model, CLIPFineTuneModelModified) else model
    
    # Text-to-Image
    text_queries = ["sport", "a dog playing", "a person eating", "a beautiful sunset", "a cat on a couch"]
    
    for query in text_queries:
        print(f"  Query: '{query}'")
        visualize_text_to_image_retrieval(
            query_text=query,
            model=wrapped_model,
            image_paths=image_paths,
            image_embeds=image_embeds,
            clip_processor=processor,
            device=device,
            top_k=5,
            save_path=model_output_dir / f"text2img_{query.replace(' ', '_')}.png"
        )
    
    # Grid
    create_retrieval_grid(
        queries=text_queries[:4],
        model=wrapped_model,
        image_paths=image_paths,
        image_embeds=image_embeds,
        clip_processor=processor,
        device=device,
        images_per_query=5,
        save_path=model_output_dir / "text2img_grid.png"
    )
    
    # Zero-shot classification
    class_labels = ['a person', 'an animal', 'a landscape']
    for idx, img_path in enumerate(image_paths[:5]):
        print(f"  Classifying image {idx+1}/5")
        predicted_class, confidence = zero_shot_classification(
            query_image_path=img_path,
            class_labels=class_labels,
            model=wrapped_model,
            clip_processor=processor,
            transform=transform,
            device=device,
            save_path=model_output_dir / f"classification_example_{idx+1}.png"
        )
        print(f"    â†’ {predicted_class} ({confidence*100:.1f}%)")
    
    print(f"âœ“ Saved to: {model_output_dir}")

print("âœ“ Functions loaded")

In [None]:
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
transform = get_clip_transform()
all_image_paths = sorted(list(val_dir.glob("*.jpg")))[:1000]
print(f"âœ“ Using {len(all_image_paths)} images")

In [None]:
# LOAD MODEL (with automatic key fixing)
model = load_model_with_architecture(model_checkpoint_path, MODEL_TYPE, device)

# Generate embeddings
image_embeds = precompute_image_embeddings(model, all_image_paths, transform, device)

# Generate visualizations
generate_visualizations(
    model=model,
    model_name=MODEL_TYPE,
    image_paths=all_image_paths,
    image_embeds=image_embeds,
    processor=processor,
    transform=transform,
    device=device,
    output_dir=output_dir
)

del model, image_embeds
torch.cuda.empty_cache()
print("\nâœ… Complete!")

In [None]:
# Download results
import shutil
archive_name = f'/content/Lab4_Visualizations_{MODEL_TYPE}'
shutil.make_archive(archive_name, 'zip', output_dir)
files.download(f'{archive_name}.zip')
print("âœ“ Download started!")

In [None]:
# Copy to Drive (optional)
import shutil
drive_dir = f'/content/drive/MyDrive/Lab4_Visualizations_{MODEL_TYPE}'
if os.path.exists(drive_dir):
    shutil.rmtree(drive_dir)
shutil.copytree(output_dir, drive_dir)
print(f"âœ“ Copied to: {drive_dir}")