In [None]:
!pip install open_clip_torch torch torchvision pandas

In [None]:
!mkdir /content/validation
!unzip /content/valid.zip -d /content/validation/

# HEATMAPS


In [None]:
import torch
from PIL import Image
from transformers import CLIPModel, CLIPProcessor
import json
import os
import warnings
import numpy as np
import matplotlib.pyplot as plt
import time # Import time to measure performance

# Suppress specific warnings from PIL about large images
warnings.filterwarnings("ignore", category=UserWarning, module='PIL')

# --- 1. CONFIGURATION ---
MODEL_ID = "wkcn/TinyCLIP-ViT-61M-32-Text-29M-LAION400M"
SAVED_MODEL_PATH = '/content/finetuned_tinyclip_multilabel.pt'
VALIDATION_DIR = '/content/validation/valid'
CLASS_LABELS = ["calyx", "fruitlet", "peduncle", "negative"]
CONFIDENCE_THRESHOLD = 0.5

# --- NEW: BATCH_SIZE for processing patches to prevent out-of-memory errors ---
BATCH_SIZE = 64

# Sliding Window Parameters
PATCH_SIZE = 224
STRIDE = 112

# --- 2. MODEL AND PROCESSOR SETUP ---
print("--- Loading FINE-TUNED Model for Stage 2 ---")
if not os.path.exists(SAVED_MODEL_PATH):
    raise FileNotFoundError(f"Fine-tuned model not found at {SAVED_MODEL_PATH}.")

processor = CLIPProcessor.from_pretrained(MODEL_ID)
model = CLIPModel.from_pretrained(MODEL_ID)
model.load_state_dict(torch.load(SAVED_MODEL_PATH, map_location=torch.device('cpu')))
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
text_prompts = [f"a photo of a {label}" for label in CLASS_LABELS]
print(f"Using device: {device}\n")


# --- 3. LOAD AND PROCESS GROUND TRUTH ---
print("--- Loading and Processing COCO Ground Truth ---")
annotation_file_path = os.path.join(VALIDATION_DIR, '_annotations.coco.json')
if not os.path.exists(annotation_file_path):
    raise FileNotFoundError(f"Annotation file not found at {annotation_file_path}.")

with open(annotation_file_path, 'r') as f:
    coco_data = json.load(f)

# Create mappings and load all annotations
coco_id_to_name = {cat['id']: cat['name'] for cat in coco_data['categories']}
train_class_labels = [lbl for lbl in CLASS_LABELS if lbl != "negative"]
name_to_class_idx = {name: i for i, name in enumerate(train_class_labels)}
coco_id_to_class_idx = {
    coco_id: name_to_class_idx.get(name) for coco_id, name in coco_id_to_name.items() if name in train_class_labels
}

image_id_to_filename = {img['id']: img['file_name'] for img in coco_data['images']}
image_id_to_filename_subset = dict(list(image_id_to_filename.items())[:])
image_id_to_annotations = {img_id: [] for img_id in image_id_to_filename}
for ann in coco_data['annotations']:
    image_id_to_annotations[ann['image_id']].append(ann)
print(f"Processed ground truth for {len(image_id_to_filename)} images.\n")


# --- 4. HELPER FUNCTION (Unchanged) ---
def get_patch_ground_truth(patch_box, image_annotations, overlap_threshold=0.1):
    px1, py1, px2, py2 = patch_box
    patch_area = (px2 - px1) * (py2 - py1)
    patch_truth = [0] * len(train_class_labels)
    for ann in image_annotations:
        bbox = ann['bbox']
        bx1, by1, bw, bh = bbox; bx2, by2 = bx1 + bw, by1 + bh
        ix1, iy1 = max(px1, bx1), max(py1, by1)
        ix2, iy2 = min(px2, bx2), min(py2, by2)
        inter_area = max(0, ix2 - ix1) * max(0, iy2 - iy1)
        if (inter_area / patch_area) > overlap_threshold:
            class_idx = coco_id_to_class_idx.get(ann['category_id'])
            if class_idx is not None:
                patch_truth[class_idx] = 1
    is_negative = 1 if sum(patch_truth) == 0 else 0
    return patch_truth + [is_negative]


# --- 5. MAIN LOOP: OPTIMIZED WITH BATCH PROCESSING ---
print(f"--- Running OPTIMIZED Sliding Window Analysis on All {len(image_id_to_filename_subset)} Images ---")

image_processing_times = []

for image_id, filename in list(image_id_to_filename_subset.items())[:15]:
    print(f"\n\n=========================================================")
    print(f"Processing Image: {filename}")
    print(f"=============================================================")
    start_time = time.time() # Start timer

    image_path = os.path.join(VALIDATION_DIR, filename)
    if not os.path.exists(image_path):
        print(f"--> SKIPPING: File not found at {image_path}")
        continue

    image = Image.open(image_path).convert("RGB")
    image_width, image_height = image.size
    current_image_annotations = image_id_to_annotations[image_id]

    # --- OPTIMIZATION 1: Extract all patches into a list first ---
    patches = []
    num_patches_y = (image_height - PATCH_SIZE) // STRIDE + 1
    num_patches_x = (image_width - PATCH_SIZE) // STRIDE + 1

    for y in range(0, image_height - PATCH_SIZE + 1, STRIDE):
        for x in range(0, image_width - PATCH_SIZE + 1, STRIDE):
            patch = image.crop((x, y, x + PATCH_SIZE, y + PATCH_SIZE))
            patches.append(patch)

    if not patches:
        print("--> SKIPPING: No patches were generated for this image.")
        continue

    print(f"Extracted {len(patches)} patches. Processing them in batches of {BATCH_SIZE}...")

    all_probs = []
    image_patch_predictions = []

    # --- OPTIMIZATION 2: Process the list of patches in batches ---
    with torch.no_grad():
        for i in range(0, len(patches), BATCH_SIZE):
            batch = patches[i:i + BATCH_SIZE]

            # The processor naturally handles a list of PIL Images
            inputs = processor(text=text_prompts, images=batch, return_tensors="pt", padding=True).to(device)
            outputs = model(**inputs)

            # Get probabilities and predictions for the current batch
            probs = outputs.logits_per_image.sigmoid()
            predictions = (probs > CONFIDENCE_THRESHOLD).int()

            all_probs.append(probs.cpu())
            image_patch_predictions.append(predictions.cpu())

    # Concatenate results from all batches
    full_probs_tensor = torch.cat(all_probs)
    full_predictions_tensor = torch.cat(image_patch_predictions)

    # --- OPTIMIZATION 3: Reshape the results to form the heatmap ---
    # The output order is preserved, so we can directly reshape.
    # Reshape from (total_patches, num_classes) to (y_dim, x_dim, num_classes)
    heatmap_tensor = full_probs_tensor.view(num_patches_y, num_patches_x, len(CLASS_LABELS))
    # Permute to get (num_classes, y_dim, x_dim) for easy plotting
    heatmap = heatmap_tensor.permute(2, 0, 1)

    end_time = time.time()
    elapsed_time = end_time - start_time
    image_processing_times.append(elapsed_time)
    print(f"--> Image processing finished in {end_time - start_time:.2f} seconds.")

    # --- Display Per-Image Report (Logic is the same, just using the batched results) ---
    print(f"\nImage Report Card for: {filename}")
    print("-" * 60)
    print(f"{'Class':<12} | {'Ground Truth Count':<20} | {'Predicted Patch Count':<22}")
    print("-" * 60)
    true_counts = [0] * len(train_class_labels)
    for ann in current_image_annotations:
        class_idx = coco_id_to_class_idx.get(ann['category_id'])
        if class_idx is not None:
            true_counts[class_idx] += 1

    # Sum the predictions from the batched tensor result
    predicted_counts = torch.sum(full_predictions_tensor[:, :len(train_class_labels)], axis=0).numpy()
    for i, label in enumerate(train_class_labels):
        print(f"{label:<12} | {true_counts[i]:<20} | {predicted_counts[i]:<22}")
    print("-" * 60)

    # --- Display Per-Image Heatmaps (Unchanged) ---
    print(f"\nHeatmap Visualizations for: {filename}")
    heatmap_np = heatmap.numpy()
    fig, axes = plt.subplots(1, len(CLASS_LABELS) + 1, figsize=(20, 5))

    axes[0].imshow(image)
    axes[0].set_title("Original Image")
    axes[0].axis('off')

    for i, class_name in enumerate(CLASS_LABELS):
        im = axes[i+1].imshow(heatmap_np[i], cmap='viridis', interpolation='nearest')
        axes[i+1].set_title(f"Heatmap for '{class_name}'")
        axes[i+1].axis('off')

    fig.colorbar(im, ax=axes.ravel().tolist())
    plt.tight_layout()
    plt.show()

print("\n\n--- Analysis of all validation images complete. ---")
print(f"Average time per image: {np.mean(image_processing_times):.2f} seconds")