# 1) In this project, we will learn to use the Meta AI's CLIP model for image retrieval task. 
# 2) We will also learn how to fine tune the CLIP model on the MNIST dataset to improve it's performance.

### 1. Importing the required packages

In [None]:
# --- Core ML & Data ---
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
import clip
import numpy as np
from PIL import Image, ImageOps

# --- Plotting & Visualization ---
import matplotlib.pyplot as plt
import seaborn as sns
import umap

# --- App & Utilities ---
import gradio as gr
from tqdm import tqdm

### 2. Load Base CLIP Model & Print Info

In [None]:
print("Loading original pre-trained CLIP model (ViT-B/32)...")
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the model onto the correct device
model, preprocess = clip.load("ViT-B/32", device=device)
# torch.save(model, 'clip_vit_b32.pth') # For saving the model locally.

# Set to evaluation mode
model.eval()

# --- Print Model Details ---
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print(f"Model parameters:   {np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print(f"Input resolution:   {input_resolution}x{input_resolution}")
print(f"Context length:     {context_length} tokens")
print(f"Vocabulary size:    {vocab_size}")

### 3. Loading the MNIST dataset and creating a dataloader

In [None]:
# Load the 10,000-image TEST set (train=False is correct for testing)
# 'transform=preprocess' automatically applies CLIP's preprocessing to each image
mnist_data = datasets.MNIST(root='./data', train=False, download=True, transform=preprocess)

# Set up the DataLoader to process the images in efficient batches
batch_size = 128
mnist_loader = DataLoader(mnist_data, batch_size=batch_size, shuffle=False)

### 4. About the MNIST dataset (Sample Visualizations)

In [None]:
# This gives us the original PIL images for plotting
original_mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=None)

print("These are grayscale images of handwritten digits (0-9).")
print(f"Number of samples in MNIST test set: {len(original_mnist_test)}")
# Get the first image to show its type (it's a PIL Image, not a tensor)
image, label = original_mnist_test[0]
print(f"Original image size: {image.size}")

# Show some sample images from the *original* dataset
fig, axes = plt.subplots(1, 5, figsize=(15, 3))
fig.suptitle('Sample Images from MNIST Test Set', fontsize=16) # Add a main title

for i in range(5):
    # Get the original PIL image
    image, label = original_mnist_test[i]
    
    # imshow can plot PIL images directly
    axes[i].imshow(image, cmap='gray') 
    axes[i].set_title(f"Label: {label}")
    axes[i].axis('off')

plt.show()

### 5. Calculate CLIP image embeddings

In [None]:
print("Starting embedding generation for 10,000 test images...")

all_embeddings_list = []
all_labels_list = []

# Use torch.no_grad() to save memory and compute (disables gradient calculation)
with torch.no_grad():
    # Loop through the DataLoader, which yields preprocessed batches
    for images, labels in tqdm(mnist_loader):
        # Move the image batch to the GPU
        images = images.to(device)
        
        # Calculate embeddings for the entire batch at once
        batch_embeddings = model.encode_image(images)
        
        # Move embeddings and labels back to the CPU for storage
        all_embeddings_list.append(batch_embeddings.cpu())
        all_labels_list.append(labels.cpu())

print("Embedding generation complete.")

# --- Concatenate and Save ---
print("Concatenating all batches...")

# 'torch.cat' stacks the list of batch tensors into one big tensor
final_embeddings = torch.cat(all_embeddings_list)
final_labels = torch.cat(all_labels_list)

print(f"Final embeddings shape: {final_embeddings.shape}")
print(f"Final labels shape: {final_labels.shape}")

# Save embeddings and labels for future use in the app
# (Make sure the 'checkpoints' directory exists!)
torch.save(final_embeddings, 'checkpoints/mnist_clip_image_embeddings_test.pt')
torch.save(final_labels, 'checkpoints/mnist_labels_test.pt')

print("Embeddings and labels saved to 'checkpoints/'.")

# TESTING TIME.....

### 6. Loading embeddings and reducing dimensionality using UMAP

In [None]:
print("Loading pre-computed test set embeddings...")
# 1) Load saved embeddings and labels (created from the test set)
loaded_embeddings = torch.load('checkpoints/mnist_clip_image_embeddings_test.pt', map_location=device)
loaded_labels = torch.load('checkpoints/mnist_labels_test.pt', map_location=device)

# 2) Load original MNIST *TEST* dataset (for displaying results)
# --- THIS IS THE CRITICAL FIX ---
original_mnist = datasets.MNIST(root='./data', train=False, download=True)

# --- 3) Prepare data for UMAP (on CPU) ---
print("Pre-computing UMAP projection for 10,000 test images...")
embeddings_np = loaded_embeddings.cpu().numpy()
labels_np = loaded_labels.cpu().numpy()

umap_embeddings_np = embeddings_np
umap_labels_np = labels_np

# --- 4) Fit UMAP Reducer ---
# This 'reducer' object will be re-used to transform new queries
reducer = umap.UMAP(
    n_neighbors=15,
    min_dist=0.1,
    n_components=2,
    metric='cosine',
    random_state=42,
    verbose=False
)
print("Fitting UMAP reducer...")
umap_embeddings_2d = reducer.fit_transform(umap_embeddings_np)
print("UMAP pre-computation complete.")

# --- 5) Prepare Embeddings for Search (on GPU) ---
# Create a copy on the GPU for fast similarity calculations in the app
all_features_gpu = loaded_embeddings.to(device)

print("App setup is complete.")

### 7. Defining a function to calculate similarity

In [None]:
# Helper function to calculate cosine similarity
def calculate_similarity(feature1, feature2):
    # Normalize the first set of features (unit vectors)
    feature1_norm = feature1 / feature1.norm(dim=-1, keepdim=True)
    # Normalize the second set of features (unit vectors)
    feature2_norm = feature2 / feature2.norm(dim=-1, keepdim=True)
    
    # Calculate the dot product between all vectors
    # For unit vectors, dot product == cosine similarity
    similarity = (feature1_norm @ feature2_norm.T) 
    
    return similarity

### 8. Function for Text-to-Image Search üîé

In [None]:
def text_search_and_plot(digit_text, top_k=3):
    if not digit_text:
        print("Please provide a digit text (0-9).")
        return [], None
    
    # 1. Create a mapping from digit character to word
    digit_map = {
        "0": "zero", "1": "one", "2": "two", "3": "three", "4": "four",
        "5": "five", "6": "six", "7": "seven", "8": "eight", "9": "nine"
    }
    
    # 2. Get the word for the corresponding digit
    digit_word = digit_map.get(digit_text, "") # Get the word, or empty string if not found

    if not digit_word: # Check if the map lookup failed
        print(f"Invalid digit text: {digit_text}")
        return [], None
    
    print(f"Searching for text: A handwritten digit {digit_word}")

    # 3. Encoding the text prompt using the word
    prompt = f"A handwritten digit {digit_word}"
    text_token = clip.tokenize([prompt]).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text_token)

    # 4. Calculating similarity
    similarities = calculate_similarity(text_features, all_features_gpu)

    # 5. Retrieving top-k similar images
    _, top_indices = torch.topk(similarities.squeeze(0), top_k)
    top_indices = top_indices.cpu().numpy().tolist()
    print(f"Top {top_k} similar images indices: {top_indices}")
    
    retrieved_labels = [labels_np[idx].item() for idx in top_indices]
    print(f"Labels for top indices: {retrieved_labels}")

    # 6. Retrieving original images
    result_images = []
    for idx in top_indices:
        image, _ = original_mnist[idx] # Uses 'original_mnist' (test set)
        result_images.append(image)

    # --- 7. Generate Plot ---
    text_features_cpu = text_features.cpu().numpy()
    text_embedding_2d = reducer.transform(text_features_cpu) # Uses 'reducer'

    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111)

    sns.scatterplot(
        x=umap_embeddings_2d[:, 0], # Uses 'umap_embeddings_2d'
        y=umap_embeddings_2d[:, 1],
        hue=umap_labels_np,        # Uses 'umap_labels_np'
        palette=sns.color_palette("tab10", 10),
        s=10, alpha=0.6,
        ax=ax
    )

    # Plot the text embedding point
    ax.scatter(
        text_embedding_2d[0, 0], text_embedding_2d[0, 1],
        marker='d',         
        color='black',
        s=100,                   
        label=f'Text: "{prompt}"'
    )

    ax.set_title('UMAP Projection (Image Embeddings + Text Query)')
    ax.set_xlabel('UMAP Component 1')
    ax.set_ylabel('UMAP Component 2')
    ax.grid(True, linestyle='--', alpha=0.5)
    
    # Legend marker size fix
    handles, labels = ax.get_legend_handles_labels()
    dot_marker_size_scale = 6
    for i in range(min(10, len(handles) -1)):
         if hasattr(handles[i], 'set_sizes'):
             current_size = handles[i].get_sizes()[0]
             handles[i].set_sizes([current_size * dot_marker_size_scale])
    ax.legend(handles, labels, title='Digit Label / Text', loc='center left', bbox_to_anchor=(1.05, 0.5))
    
    plt.tight_layout(rect=[0, 0, 0.85, 1])
    plt.close(fig)

    return result_images, fig

### 9. Function for Sketch-to-Image Search ‚úèÔ∏è

In [None]:
# --- Function for Sketch Preprocessing ---
def preprocess_sketch(sketch_input_data):
    """
    Processes a user sketch (from Gradio) to resemble an MNIST digit.
    Args:
        sketch_input_data (np.ndarray or dict): The sketch data from Gradio.
    Returns:
        PIL.Image or None: A preprocessed PIL Image, ready for CLIP's preprocess, or None if invalid.
    """
    sketch_array = None # Initialize sketch_array

    if sketch_input_data is None:
        print("Sketch input is None.")
        return None

    # Check if input is a dictionary and extract the numpy array
    if isinstance(sketch_input_data, dict):
        print("Sketch input is a dict. Keys:", sketch_input_data.keys()) # DEBUG: See keys
        if 'image' in sketch_input_data and isinstance(sketch_input_data['image'], np.ndarray):
             sketch_array = sketch_input_data['image']
        elif 'composite' in sketch_input_data and isinstance(sketch_input_data['composite'], np.ndarray):
             sketch_array = sketch_input_data['composite']
        else:
             print("Could not find a valid NumPy array in the sketch dictionary.")
             return None # Cannot proceed
    elif isinstance(sketch_input_data, np.ndarray):
        sketch_array = sketch_input_data # Input is already the array
    else:
        print(f"Unexpected sketch input type: {type(sketch_input_data)}")
        return None

    if sketch_array is None:
         print("Failed to extract sketch array.")
         return None
    if sketch_array.size == 0:
         print("Received empty sketch array.")
         return None

    try:
        pil_image = Image.fromarray(sketch_array).convert("L")
    except Exception as e:
        print(f"Error in Image.fromarray: {e}")
        print(f"Sketch array shape: {sketch_array.shape}, dtype: {sketch_array.dtype}")
        return None

    pil_image = ImageOps.invert(pil_image)

    bbox = pil_image.getbbox()
    if bbox is None:
        print("No content found in the sketch (bounding box is None).")
        return Image.new('L', (28, 28), color=0)

    pil_image = pil_image.crop(bbox)

    width, height = pil_image.size
    target_size = max(width, height)
    padding_left = (target_size - width) // 2
    padding_top = (target_size - height) // 2
    padding_right = target_size - width - padding_left
    padding_bottom = target_size - height - padding_top
    padding = (padding_left, padding_top, padding_right, padding_bottom)
    pil_image = ImageOps.expand(pil_image, padding, fill=0)

    pil_image = pil_image.resize((28, 28), Image.Resampling.LANCZOS)

    return pil_image

# --- Function for Sketch Search AND Plotting ---
def sketch_search_and_plot(sketch_image, top_k=3):
    """
    Finds top_k MNIST images similar to the user's sketch.
    Args:
        sketch_image (np.ndarray or dict): Sketch data from Gradio.
        top_k (int): Number of results.
    Returns:
        tuple: (List of result PIL Images, Matplotlib Figure object or None)
    """
    # Check 1: Input is None (handled by preprocess_sketch now, but good practice)
    if sketch_image is None:
        print("Please provide a sketch image.")
        return [], None

    print("Processing sketch...")
    # --- 1. Preprocess the Sketch ---
    # Use the sketch_input_data variable name consistently
    preprocessed_sketch_pil = preprocess_sketch(sketch_image)

    # Check 2: Preprocessing failed
    if preprocessed_sketch_pil is None:
        print("Sketch preprocessing failed.")
        return [], None

    # --- 2. Encode the Processed Sketch using CLIP ---
    # Apply CLIP's standard preprocessing AFTER our custom preprocessing
    # Ensure 'preprocess' and 'device' are available in this scope
    clip_input_tensor = preprocess(preprocessed_sketch_pil).unsqueeze(0).to(device)
    with torch.no_grad():
        # Ensure 'model' is available in this scope
        sketch_features = model.encode_image(clip_input_tensor)

    # --- 3. Search (using all_features_gpu) ---
    print("Calculating similarities...")
    # Ensure 'calculate_similarity' and 'all_features_gpu' are available
    similarities = calculate_similarity(sketch_features, all_features_gpu)
    _, top_indices = torch.topk(similarities.squeeze(0), top_k)
    top_indices = top_indices.cpu().numpy().tolist()
    print(f"Top sketch indices: {top_indices}")

    # Ensure 'original_mnist' is available
    result_images = [original_mnist[idx][0] for idx in top_indices]

    # --- 4. Generate Plot ---
    try:
        sketch_features_cpu = sketch_features.cpu().numpy()
        # Ensure 'reducer' is available (the fitted UMAP object)
        sketch_embedding_2d = reducer.transform(sketch_features_cpu)

        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111)

        # Ensure 'umap_embeddings_2d' and 'umap_labels_np' are available
        sns.scatterplot(
            x=umap_embeddings_2d[:, 0], y=umap_embeddings_2d[:, 1],
            hue=umap_labels_np, palette=sns.color_palette("tab10", 10),
            s=10, alpha=0.6, ax=ax
        )
        ax.scatter(
            sketch_embedding_2d[0, 0], sketch_embedding_2d[0, 1],
            marker='d', color='black', s=100, label='Your Sketch'
        )
        ax.set_title('UMAP Projection (Image Embeddings + Sketch Query)')
        ax.set_xlabel('UMAP Component 1')
        ax.set_ylabel('UMAP Component 2')
        ax.grid(True, linestyle='--', alpha=0.5)

        handles, labels = ax.get_legend_handles_labels()
        dot_marker_size_scale = 6
        for i in range(min(10, len(handles) -1)):
             if hasattr(handles[i], 'set_sizes'):
                 current_size = handles[i].get_sizes()[0]
                 handles[i].set_sizes([current_size * dot_marker_size_scale])
        ax.legend(handles, labels, title='Digit Label / Sketch', loc='center left', bbox_to_anchor=(1.05, 0.5))

        plt.tight_layout(rect=[0, 0, 0.85, 1])
        plt.close(fig)
        plot_output = fig
    except Exception as e:
        print(f"Error generating plot for sketch: {e}")
        plot_output = None # Return None if plotting fails

    return result_images, plot_output

## üöÄ Launching the App!

In [None]:
def clear_sketch():
    return None

# --- Gradio Interface ---

with gr.Blocks(theme=gr.themes.Soft()) as app_interface:
    gr.Markdown(
        """
        # CLIP MNIST Search Engine üß†üñºÔ∏è‚úèÔ∏è
        Use CLIP embeddings to search the MNIST dataset using text prompts or sketches.
        """
    )

    with gr.Tabs():
        # --- Text Search Tab ---
        with gr.TabItem("Text Search"):
            with gr.Row():
                with gr.Column(scale=1):
                    text_input = gr.Radio(
                        choices=[str(i) for i in range(10)],
                        label="Select a Digit",
                        info="Choose the digit you want to search for."
                    )
                    submit_btn_text = gr.Button("Search Text", variant="primary")
                with gr.Column(scale=3):
                    gallery_text = gr.Gallery(
                        label="Search Results", columns=3, object_fit="contain",
                        height=250, preview=True
                    )
                    plot_text = gr.Plot(label="UMAP Visualization (Text Query)")

        # --- Sketch Search Tab ---
        with gr.TabItem("Sketch Search"):
            with gr.Row():
                with gr.Column(scale=1):
                    # Add Sketchpad input
                    sketch_input = gr.Sketchpad(
                        label="Draw a Digit (0-9)",
                        type="numpy", # Output as numpy array
                     )
                    submit_btn_sketch = gr.Button("Search Sketch", variant="primary")
                    clear_btn_sketch = gr.Button("Clear Sketch") # Add a clear button
                with gr.Column(scale=3):
                     gallery_sketch = gr.Gallery(
                        label="Search Results", columns=3, object_fit="contain",
                        height=250, preview=True
                    )
                     plot_sketch = gr.Plot(label="UMAP Visualization (Sketch Query)")


    # --- Event Handlers ---
    submit_btn_text.click(
        fn=text_search_and_plot,
        inputs=text_input,
        outputs=[gallery_text, plot_text],
        show_progress="full"
    )
    submit_btn_sketch.click(
        fn=sketch_search_and_plot,
        inputs=sketch_input,
        outputs=[gallery_sketch, plot_sketch],
        show_progress="full"
    )

    # Link clear button to sketchpad input
    clear_btn_sketch.click(fn=clear_sketch, inputs=None, outputs=sketch_input)
    
# Launch the app
if __name__ == "__main__":
    app_interface.launch()

# <span style="color:green;">Fine-tuning CLIP on MNIST dataset</span> üî•

### 1. Importing the required libraries

In [None]:
# --- Core ML & Data ---
import torch
from torchvision import datasets
from torch.utils.data import DataLoader
import clip
import numpy as np
from PIL import Image, ImageOps

# --- Plotting & Visualization ---
import matplotlib.pyplot as plt
import seaborn as sns
import umap

# --- App & Utilities ---
import gradio as gr
from tqdm import tqdm

### 2. Preparing custom image-text paired dataset

In [None]:
# CLIP needs (image, text) pairs for fine-tuning.
# A custom Dataset class is created here to wrap the standard MNIST dataset, which only provides (image, label).
# This class will generate the required text prompts on the fly.

class PairedMNISTDataset(Dataset):
    """
    A custom dataset that takes an MNIST dataset and the
    CLIP preprocessor, returning (preprocessed_image, tokenized_text) pairs.
    """
    def __init__(self, mnist_dataset, clip_preprocess):
        self.mnist = mnist_dataset
        self.preprocess = clip_preprocess
        
        # Maps numeric labels (e.g., 0) to their text words (e.g., "zero")
        # to create meaningful text prompts.
        self.digit_map = {
            0: "zero", 1: "one", 2: "two", 3: "three", 4: "four",
            5: "five", 6: "six", 7: "seven", 8: "eight", 9: "nine"
        }

    def __len__(self):
        # The length of the dataset is the length of the original MNIST set.
        return len(self.mnist)

    def __getitem__(self, idx):
        # This method is called by the DataLoader for each item.
        
        # 1. Get the original PIL image and its numeric label.
        image, label = self.mnist[idx]
        
        # 2. Apply the CLIP preprocessing to the image.
        #    The original 'image' is a PIL Image, so the preprocess function
        #    (handling resize, 3-channel conversion, normalization) works directly.
        preprocessed_image = self.preprocess(image)

        # 3. Create the text prompt.
        #    Maps the label (e.g., 7) to its word ("seven").
        digit_word = self.digit_map[label]
        prompt = f"a handwritten digit {digit_word}"
        
        # 4. Tokenize the text prompt.
        #    clip.tokenize() returns a batch, so the first [0]
        #    item is selected.
        tokenized_text = clip.tokenize([prompt])[0]
        
        # Return the (image, text) pair.
        return preprocessed_image, tokenized_text

### 3. Define the CLIP Loss Function

In [None]:
def calculate_clip_loss(logits_per_image, logits_per_text, device):
    """
    Calculates the symmetric CLIP loss (contrastive loss).
    """
    # Create the "ground truth" labels. For a batch of size N,
    # the i-th image matches the i-th text, so the labels are [0, 1, 2, ..., N-1].
    batch_size = logits_per_image.shape[0]
    ground_truth = torch.arange(batch_size, dtype=torch.long, device=device)
    
    # Calculate cross-entropy loss in both directions
    loss_img = F.cross_entropy(logits_per_image, ground_truth)
    loss_txt = F.cross_entropy(logits_per_text, ground_truth)
    
    # The final loss is the average of the two
    total_loss = (loss_img + loss_txt) / 2
    return total_loss

### 4. Loading model, mnist data, and optimizer

In [None]:
# --- Hyperparameters ---
device = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
EPOCHS = 20  # Set desired number of epochs
LR = 1e-5    # Using the stabler 1e-5 learning rate. 1e-4 is very high.

# --- Load Your Fine-Tuned Model ---
# This loads the model you already trained (continuing from a checkpoint)
print("Loading fine-tuned model from checkpoint...")
model_path = "checkpoints/finetuned_mnist_clip.pt"

# 1. Load the original "ViT-B/32" architecture
#    (jit=False is required to load a state_dict)
model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

# 2. Load your fine-tuned weights into that architecture
model.load_state_dict(torch.load(model_path, map_location=device))
print(f"Successfully loaded weights from {model_path}")

# --- Load MNIST Training Data ---
# We need the original PIL Images (transform=None) for our custom dataset
mnist_train = MNIST(root=".", train=True, download=True, transform=None)

# --- Create Paired Dataset and DataLoader ---
# This uses the class from Cell 2
paired_dataset = PairedMNISTDataset(mnist_dataset=mnist_train, clip_preprocess=preprocess)

train_loader = DataLoader(
    paired_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True  # Shuffle=True is crucial for training
)

# --- Define the Optimizer (This was missing!) ---
print("Freezing logit_scale and creating SGD optimizer...")
# 1. Freeze the logit_scale to prevent training instability (from our debugging)
model.logit_scale.requires_grad = False

# 2. Create the stable SGD optimizer
#    We filter to ensure the optimizer only sees trainable parameters
optimizer = torch.optim.SGD(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LR,
    momentum=0.9 
)

print("Setup complete. Ready for training.")

### 5. Running the Fine-Tuning Loop üöÄ

In [None]:
# --- 1. FREEZE LOGIT_SCALE (Run before optimizer) ---
print("Freezing logit_scale parameter...")
model.logit_scale.requires_grad = False

# --- 2. THE FINAL FIX: Switch Optimizer to SGD ---
# Adam is too unstable. SGD is simpler and more stable.
LR = 1e-4 # A good, safe starting learning rate for SGD
print(f"Using SGD optimizer with LR={LR}")
optimizer = torch.optim.SGD(
    filter(lambda p: p.requires_grad, model.parameters()), # Filtered list
    lr=LR,
    momentum=0.9 # Standard momentum
)

# --- 3. RUN THE TRAINING LOOP ---
print("Starting fine-tuning with SGD...")
model.train() # Set model to training mode

for epoch in range(EPOCHS):
    print(f"--- Epoch {epoch + 1}/{EPOCHS} ---")
    
    for (images, texts) in tqdm(train_loader):
        images = images.to(device)
        texts = texts.to(device)
        
        # We'll keep forcing float32 for safety
        with torch.amp.autocast(device_type="cuda", dtype=torch.float32):
            logits_per_image, logits_per_text = model(images, texts)
            loss = calculate_clip_loss(logits_per_image, logits_per_text, device)
        
        if torch.isnan(loss):
            print("!!! Loss became NAN. Stopping training. !!!")
            break

        # --- Backward Pass & Optimize ---
        optimizer.zero_grad()
        loss.backward()
        # Gradient clipping is still a good idea
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
    if 'loss' in locals() and torch.isnan(loss):
        print("Epoch stopped early due to nan loss.")
        break
        
    print(f"Epoch {epoch + 1} finished. Last batch loss: {loss.item():.4f}")

print("Fine-tuning complete.")

### 5. Saving the fine-tuned model

In [None]:
# --- Save the fine-tuned model's weights ---
# We'll save it to a new file
save_path = "checkpoints/finetuned_mnist_clip.pt"
torch.save(model.state_dict(), save_path)
print(f"Fine-tuned model saved to {save_path}")

### 6. Running the app with the fine-tuned model üî•

In [None]:
# --- 1. Setup ---
print("Setting up model and loading data...")
device = "cuda" if torch.cuda.is_available() else "cpu"
finetuned_model_path = "checkpoints/finetuned_mnist_clip.pt"

# --- 2. Load Your Fine-Tuned Model ---
# First, load the base model architecture (on CPU, with jit=False)
model, preprocess = clip.load("ViT-B/32", device="cpu", jit=False)
# Now, load your fine-tuned weights into that architecture
model.load_state_dict(torch.load(finetuned_model_path, map_location="cpu"))
# Move the fully-loaded model to the GPU and set to eval mode
model = model.to(device)
model.eval()
print(f"Successfully loaded fine-tuned model from {finetuned_model_path}")

# --- 3. Load MNIST Test Data ---
# Load the test set (train=False) with the CLIP preprocessor
# We'll use this to generate new embeddings
mnist_test_transformed = MNIST(root=".", train=False, download=True, transform=preprocess)
# Also load the original test set (for displaying results)
original_mnist_test = MNIST(root=".", train=False, download=True)

# --- 4. Generate NEW Embeddings from Fine-Tuned Model ---
# We must re-create the embeddings using the *new* model
print("Generating new embeddings for the 10,000 test images...")
batch_size = 128
data_loader = DataLoader(mnist_test_transformed, batch_size=batch_size, shuffle=False)

all_embeddings_list = []
all_labels_list = []

with torch.no_grad():
    for images, labels in tqdm(data_loader):
        images = images.to(device)
        batch_embeddings = model.encode_image(images)
        all_embeddings_list.append(batch_embeddings.cpu())
        all_labels_list.append(labels.cpu())

# Concatenate all batches into single tensors
all_features_cpu = torch.cat(all_embeddings_list)
all_labels_cpu = torch.cat(all_labels_list)
all_features_gpu = all_features_cpu.to(device) # For fast search
print("New test embeddings generated.")

# --- 5. Pre-compute UMAP (Run Once at Startup) ---
print("Pre-computing UMAP projection...")
embeddings_np = all_features_cpu.numpy()
labels_np = all_labels_cpu.numpy()
umap_embeddings_np = embeddings_np # Use all 10k points
umap_labels_np = labels_np

reducer = umap.UMAP(
    n_neighbors=15, min_dist=0.1, n_components=2,
    metric='cosine', random_state=42, verbose=False
)
print("Fitting UMAP reducer...")
umap_embeddings_2d = reducer.fit_transform(umap_embeddings_np)
print("UMAP pre-computation complete.")

# --- 6. Helper & Search Functions ---

def calculate_similarity(features1, features2):
    features1 = features1 / features1.norm(dim=-1, keepdim=True)
    features2 = features2 / features2.norm(dim=-1, keepdim=True)
    similarity = features1 @ features2.T
    return similarity

def clear_sketch():
    return None

# --- Text Search ---
def text_search_and_plot(digit_text, top_k=3):
    if not digit_text:
        return [], None
    
    digit_map = {
        "0": "zero", "1": "one", "2": "two", "3": "three", "4": "four",
        "5": "five", "6": "six", "7": "seven", "8": "eight", "9": "nine"
    }
    digit_word = digit_map.get(digit_text, "")
    if not digit_word:
        return [], None

    print(f"Searching for text: A handwritten digit {digit_word}")
    prompt = f"A handwritten digit {digit_word}"
    text_token = clip.tokenize([prompt]).to(device)
    
    with torch.no_grad():
        text_features = model.encode_text(text_token)

    similarities = calculate_similarity(text_features, all_features_gpu)
    _, top_indices = torch.topk(similarities.squeeze(0), top_k)
    top_indices = top_indices.cpu().numpy().tolist()
    
    retrieved_labels = [all_labels_cpu[idx].item() for idx in top_indices]
    print(f"Top indices: {top_indices}, Labels: {retrieved_labels}")

    result_images = [original_mnist_test[idx][0] for idx in top_indices]

    # Generate Plot
    text_features_cpu = text_features.cpu().numpy()
    text_embedding_2d = reducer.transform(text_features_cpu)
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111)
    sns.scatterplot(
        x=umap_embeddings_2d[:, 0], y=umap_embeddings_2d[:, 1],
        hue=umap_labels_np, palette=sns.color_palette("tab10", 10),
        s=10, alpha=0.6, ax=ax
    )
    ax.scatter(
        text_embedding_2d[0, 0], text_embedding_2d[0, 1],
        marker='s', color='black', s=100, label=f'Text: "{prompt}"'
    )
    ax.set_title('UMAP Projection (Image Embeddings + Text Query)')
    ax.set_xlabel('UMAP Component 1')
    ax.set_ylabel('UMAP Component 2')
    ax.grid(True, linestyle='--', alpha=0.5)

    handles, labels = ax.get_legend_handles_labels()
    dot_marker_size_scale = 6
    for i in range(min(10, len(handles) -1)):
         if hasattr(handles[i], 'set_sizes'):
             current_size = handles[i].get_sizes()[0]
             handles[i].set_sizes([current_size * dot_marker_size_scale])
    ax.legend(handles, labels, title='Digit Label / Text', loc='center left', bbox_to_anchor=(1.05, 0.5))
    
    plt.tight_layout(rect=[0, 0, 0.85, 1])
    plt.close(fig)

    return result_images, fig

# --- Sketch Preprocessing ---
def preprocess_sketch(sketch_input_data):
    sketch_array = None
    if sketch_input_data is None: return None
    if isinstance(sketch_input_data, dict):
        if 'image' in sketch_input_data and isinstance(sketch_input_data['image'], np.ndarray):
             sketch_array = sketch_input_data['image']
        elif 'composite' in sketch_input_data and isinstance(sketch_input_data['composite'], np.ndarray):
             sketch_array = sketch_input_data['composite']
        else: return None
    elif isinstance(sketch_input_data, np.ndarray):
        sketch_array = sketch_input_data
    else: return None
    if sketch_array is None or sketch_array.size == 0: return None

    try:
        pil_image = Image.fromarray(sketch_array).convert("L")
    except Exception: return None

    pil_image = ImageOps.invert(pil_image)
    bbox = pil_image.getbbox()
    if bbox is None: return Image.new('L', (28, 28), color=0)
    pil_image = pil_image.crop(bbox)
    width, height = pil_image.size
    target_size = max(width, height)
    padding = ((target_size - width) // 2, (target_size - height) // 2,
               target_size - width - (target_size - width) // 2,
               target_size - height - (target_size - height) // 2)
    pil_image = ImageOps.expand(pil_image, padding, fill=0)
    pil_image = pil_image.resize((28, 28), Image.Resampling.LANCZOS)
    return pil_image

# --- Sketch Search ---
def sketch_search_and_plot(sketch_image, top_k=3):
    if sketch_image is None:
        return [], None

    print("Processing sketch...")
    preprocessed_sketch_pil = preprocess_sketch(sketch_image)
    if preprocessed_sketch_pil is None:
        print("Sketch preprocessing failed.")
        return [], None

    clip_input_tensor = preprocess(preprocessed_sketch_pil).unsqueeze(0).to(device)
    with torch.no_grad():
        sketch_features = model.encode_image(clip_input_tensor)

    print("Calculating similarities...")
    similarities = calculate_similarity(sketch_features, all_features_gpu)
    _, top_indices = torch.topk(similarities.squeeze(0), top_k)
    top_indices = top_indices.cpu().numpy().tolist()
    print(f"Top sketch indices: {top_indices}")
    
    retrieved_labels = [all_labels_cpu[idx].item() for idx in top_indices]
    print(f"Labels for top indices: {retrieved_labels}")

    result_images = [original_mnist_test[idx][0] for idx in top_indices]

    # Generate Plot
    try:
        sketch_features_cpu = sketch_features.cpu().numpy()
        sketch_embedding_2d = reducer.transform(sketch_features_cpu)
        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111)
        sns.scatterplot(
            x=umap_embeddings_2d[:, 0], y=umap_embeddings_2d[:, 1],
            hue=umap_labels_np, palette=sns.color_palette("tab10", 10),
            s=10, alpha=0.6, ax=ax
        )
        ax.scatter(
            sketch_embedding_2d[0, 0], sketch_embedding_2d[0, 1],
            marker='s', color='black', s=100, label='Your Sketch'
        )
        ax.set_title('UMAP Projection (Image Embeddings + Sketch Query)')
        ax.set_xlabel('UMAP Component 1')
        ax.set_ylabel('UMAP Component 2')
        ax.grid(True, linestyle='--', alpha=0.5)

        handles, labels = ax.get_legend_handles_labels()
        dot_marker_size_scale = 6
        for i in range(min(10, len(handles) -1)):
             if hasattr(handles[i], 'set_sizes'):
                 current_size = handles[i].get_sizes()[0]
                 handles[i].set_sizes([current_size * dot_marker_size_scale])
        ax.legend(handles, labels, title='Digit Label / Sketch', loc='center left', bbox_to_anchor=(1.05, 0.5))
        
        plt.tight_layout(rect=[0, 0, 0.85, 1])
        plt.close(fig)
        plot_output = fig
    except Exception as e:
        print(f"Error generating plot for sketch: {e}")
        plot_output = None

    return result_images, plot_output

# --- 7. Gradio Interface ---

with gr.Blocks(theme=gr.themes.Soft()) as app_interface:
    gr.Markdown(
        """
        # CLIP MNIST Search Engine üß†üñºÔ∏è‚úèÔ∏è
        Use CLIP embeddings to search the MNIST dataset using text prompts or sketches.
        **Model: Fine-Tuned on MNIST Training Set**
        """
    )

    with gr.Tabs():
        # --- Text Search Tab ---
        with gr.TabItem("Text Search"):
            with gr.Row():
                with gr.Column(scale=1):
                    text_input = gr.Radio(
                        choices=[str(i) for i in range(10)],
                        label="Select a Digit",
                        info="Choose the digit you want to search for."
                    )
                    submit_btn_text = gr.Button("Search Text", variant="primary")
                with gr.Column(scale=3):
                    gallery_text = gr.Gallery(
                        label="Search Results", columns=3, object_fit="contain",
                        height=250, preview=True
                    )
                    plot_text = gr.Plot(label="UMAP Visualization (Text Query)")

        # --- Sketch Search Tab ---
        with gr.TabItem("Sketch Search"):
            with gr.Row():
                with gr.Column(scale=1):
                    sketch_input = gr.Sketchpad(
                        label="Draw a Digit (0-9)",
                        type="numpy",
                     )
                    submit_btn_sketch = gr.Button("Search Sketch", variant="primary")
                    clear_btn_sketch = gr.Button("Clear Sketch")
                with gr.Column(scale=3):
                     gallery_sketch = gr.Gallery(
                        label="Search Results", columns=3, object_fit="contain",
                        height=250, preview=True
                    )
                     plot_sketch = gr.Plot(label="UMAP Visualization (Sketch Query)")

    # --- Event Handlers ---
    submit_btn_text.click(
        fn=text_search_and_plot,
        inputs=text_input,
        outputs=[gallery_text, plot_text]
    )
    submit_btn_sketch.click(
        fn=sketch_search_and_plot,
        inputs=sketch_input,
        outputs=[gallery_sketch, plot_sketch]
    )
    clear_btn_sketch.click(fn=clear_sketch, inputs=None, outputs=sketch_input)

# --- Launch the app ---
if __name__ == "__main__":
    app_interface.launch()