# FINE TUNING TINYCLIP

In [None]:
!pip install open_clip_torch torch torchvision pandas

In [None]:
!mkdir /content/train

In [None]:

!unzip /content/crops.zip -d /content/train/

## CREATE CSV DATA FILE: IMG_PATH, CAPTION TEXT PROMPT

In [None]:
import os
import pandas as pd

# --- CONFIGURATION ---
# The main directory where your class folders are located
CROPS_DIR = '/content/train/crops'
# The final CSV file to be created
OUTPUT_CSV_PATH = 'training_data.csv'
# The class labels, in the desired order for the vector
CLASS_LABELS = ["calyx", "fruitlet", "peduncle", "negative"]

# --- SCRIPT LOGIC ---
print(f"--- Generating '{OUTPUT_CSV_PATH}' from '{CROPS_DIR}' ---")

# A list to hold all the data rows
training_data = []

# Walk through the directory
for class_name in os.listdir(CROPS_DIR):
    class_dir = os.path.join(CROPS_DIR, class_name)
    if not os.path.isdir(class_dir) or class_name not in CLASS_LABELS:
        continue

    print(f"Processing folder: {class_name}")

    # Get the index for the current class to create the one-hot vector
    true_class_index = CLASS_LABELS.index(class_name)

    for filename in os.listdir(class_dir):
        if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            image_path = os.path.join(class_dir, filename)

            # Create the multi-hot (in this case, one-hot) label vector
            label_vector = [0] * len(CLASS_LABELS)
            label_vector[true_class_index] = 1

            # Add data to our list
            row = {'image_path': image_path}
            row.update({label: vec_val for label, vec_val in zip(CLASS_LABELS, label_vector)})
            training_data.append(row)

# Convert the list of data to a pandas DataFrame
df = pd.DataFrame(training_data)

# Save the DataFrame to a CSV file
df.to_csv(OUTPUT_CSV_PATH, index=False)

print(f"\nSuccessfully created '{OUTPUT_CSV_PATH}' with {len(df)} entries. ✅")

## TRAINING


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
import pandas as pd
from tqdm.auto import tqdm

## 1. Configuration
EPOCHS = 5
BATCH_SIZE = 32
LEARNING_RATE = 1e-5
MODEL_ID = "wkcn/TinyCLIP-ViT-61M-32-Text-29M-LAION400M"
CSV_PATH = 'training_data.csv'  # The CSV you just created
SAVE_PATH = '/content/finetuned_tinyclip_multilabel.pt'
CLASS_LABELS = ["calyx", "fruitlet", "peduncle", "negative"]

## 2. Custom Dataset
class MultiLabelDataset(Dataset):
    def __init__(self, csv_path, class_labels):
        self.data = pd.read_csv(csv_path)
        self.class_labels = class_labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        image_path = row['image_path']
        labels = torch.tensor([row[label] for label in self.class_labels], dtype=torch.float32)
        return {"image_path": image_path, "labels": labels}

## 3. Main Training Logic
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model = CLIPModel.from_pretrained(MODEL_ID).to(device)
processor = CLIPProcessor.from_pretrained(MODEL_ID)
print("Model loaded successfully. ✅")

dataset = MultiLabelDataset(CSV_PATH, CLASS_LABELS)
text_prompts = [f"a photo of a {label}" for label in CLASS_LABELS]
text_inputs = processor(text=text_prompts, return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
    text_embeds = model.get_text_features(**text_inputs)

def collate_fn(batch):
    image_paths = [item['image_path'] for item in batch]
    labels = torch.stack([item['labels'] for item in batch])
    try:
        images = [Image.open(path).convert("RGB") for path in image_paths]
    except FileNotFoundError as e:
        print(f"Error loading image: {e}")
        return None
    image_inputs = processor(images=images, return_tensors="pt", padding=True)
    return {"pixel_values": image_inputs.pixel_values, "labels": labels}

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
loss_fn = torch.nn.BCEWithLogitsLoss()

print("Starting multi-label fine-tuning...")
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    pbar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{EPOCHS}")
    for batch in pbar:
        if batch is None: continue

        pixel_values = batch['pixel_values'].to(device)
        labels = batch['labels'].to(device)
        optimizer.zero_grad()

        image_embeds = model.get_image_features(pixel_values=pixel_values)
        logits = torch.matmul(image_embeds, text_embeds.t())
        loss = loss_fn(logits, labels)

        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pbar.set_postfix({"Loss": loss.item()})

    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch + 1} finished with average loss: {avg_loss:.4f}")

torch.save(model.state_dict(), SAVE_PATH)
print(f"\nTraining complete! 🎉 Multi-label model saved to {SAVE_PATH}")

## CREATE VALIDATION DIR

In [None]:
!mkdir /content/validation

In [None]:
!unzip /content/valid.zip -d /content/validation/

## **EVALUATION SCRIPT**


In [None]:
import torch
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
import os
from tqdm.auto import tqdm
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_curve, ConfusionMatrixDisplay
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import LabelBinarizer

## 1. Configuration
MODEL_ID = "wkcn/TinyCLIP-ViT-61M-32-Text-29M-LAION400M"
SAVED_MODEL_PATH = '/content/finetuned_tinyclip_multilabel.pt'
VALIDATION_DIR = '/content/validation/valid/crops'
CLASS_LABELS = ["calyx", "fruitlet", "peduncle", "negative"]


## 2. Load Model and Processor
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

model = CLIPModel.from_pretrained(MODEL_ID)
model.load_state_dict(torch.load(SAVED_MODEL_PATH, map_location=device))
model.to(device)
model.eval()

processor = CLIPProcessor.from_pretrained(MODEL_ID)


## 3. Evaluation
text_prompts = [f"a photo of a {label}" for label in CLASS_LABELS]
print(f"Testing against prompts: {text_prompts}")

all_true_labels = []
all_predicted_labels = []
all_pred_scores = []

if not os.path.isdir(VALIDATION_DIR):
    print(f"Error: Validation directory '{VALIDATION_DIR}' not found.")
else:
    pbar = tqdm(os.walk(VALIDATION_DIR), desc="Evaluating")
    for root, _, files in pbar:
        for file in files:
            if not file.lower().endswith(('.png', '.jpg', '.jpeg')):
                continue

            image_path = os.path.join(root, file)
            true_label = os.path.basename(root)

            image = Image.open(image_path).convert("RGB")
            inputs = processor(
                text=text_prompts,
                images=image,
                return_tensors="pt",
                padding=True
            ).to(device)

            with torch.no_grad():
                outputs = model(**inputs)

            logits_per_image = outputs.logits_per_image
            predicted_index = logits_per_image.argmax().item()
            predicted_label = CLASS_LABELS[predicted_index]

            all_true_labels.append(true_label)
            all_predicted_labels.append(predicted_label)
            probs = logits_per_image.softmax(dim=1).cpu().numpy()
            all_pred_scores.append(probs[0])

            pbar.set_postfix({"Correct": f"{len([i for i, j in zip(all_predicted_labels, all_true_labels) if i == j])}/{len(all_true_labels)}"})

    if all_true_labels:
        print("\n--- Evaluation Complete ---")
        report = classification_report(all_true_labels, all_predicted_labels, labels=CLASS_LABELS, digits=4)
        print("--- Classification Report ---")
        print(report)

        print("\n--- Confusion Matrix ---")
        cm = confusion_matrix(all_true_labels, all_predicted_labels, labels=CLASS_LABELS)

        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=CLASS_LABELS, yticklabels=CLASS_LABELS)
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.title('Confusion Matrix')
        plt.show()

        # --- Added Normalized Confusion Matrix ---
        print("\n--- Normalized Confusion Matrix ---")
        cm_normalized = confusion_matrix(all_true_labels, all_predicted_labels, labels=CLASS_LABELS, normalize='true')

        plt.figure(figsize=(8, 6))
        sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues',
                    xticklabels=CLASS_LABELS, yticklabels=CLASS_LABELS)
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.title('Normalized Confusion Matrix')
        plt.show()

        # Binarize labels and convert scores for plotting
        lb = LabelBinarizer()
        lb.fit(CLASS_LABELS)
        y_true_bin = lb.transform(all_true_labels)
        y_scores = np.array(all_pred_scores)

        # --- FIGURE 1: All P-R curves on one plot with an inset ---
        print("\n--- Combined Precision-Recall Curves ---")
        fig, ax = plt.subplots(figsize=(10, 8))
        for i, class_name in enumerate(CLASS_LABELS):
            precision, recall, _ = precision_recall_curve(y_true_bin[:, i], y_scores[:, i])
            # --- MODIFIED: Linewidth is thicker only for 'calyx' ---
            linewidth = 4 if class_name == 'calyx' else 2
            ax.plot(recall, precision, lw=linewidth, label=f'P-R curve for {class_name}')

        ax.set_xlabel("Recall")
        ax.set_ylabel("Precision")
        ax.set_title("Precision-Recall Curves for the TinyCLIP Classifier")
        ax.legend(loc="best")
        ax.grid(alpha=0.4)
        ax.set_xlim(0.75, 1.05)
        ax.set_ylim(0.75, 1.05)

        # --- ADD INSET PLOT FOR CONTEXT ---
        ax_inset = fig.add_axes([0.18, 0.18, 0.4, 0.4])
        for i, class_name in enumerate(CLASS_LABELS):
            precision, recall, _ = precision_recall_curve(y_true_bin[:, i], y_scores[:, i])
            # --- MODIFIED: Inset linewidth is also thicker only for 'calyx' ---
            linewidth = 4 if class_name == 'calyx' else 2
            ax_inset.plot(recall, precision, lw=linewidth)
        ax_inset.set_title("Full Range (0-1)")
        ax_inset.set_xlabel("Recall")
        ax_inset.set_ylabel("Precision")
        ax_inset.grid(alpha=0.4)

        plt.show()

        # --- FIGURE 2: P-R curve for each class as a separate subplot with insets ---
        print("\n--- Per-Class Precision-Recall Curves ---")
        fig, axes = plt.subplots(1, len(CLASS_LABELS), figsize=(20, 6), sharey=True)
        fig.suptitle('Per-Class Precision-Recall Curves', fontsize=16)

        for i, class_name in enumerate(CLASS_LABELS):
            precision, recall, _ = precision_recall_curve(y_true_bin[:, i], y_scores[:, i])
            # --- REVERTED: All lines have the same standard thickness ---
            axes[i].plot(recall, precision, lw=2, label=f'P-R for {class_name}')
            axes[i].set_title(f'Class: {class_name}')
            axes[i].set_xlabel('Recall')
            axes[i].set_ylabel('Precision')
            axes[i].grid(alpha=0.4)
            axes[i].legend()
            axes[i].set_xlim(0.75, 1.05)
            axes[i].set_ylim(0.75, 1.05)

            # --- ADD INSET PLOT FOR EACH SUBPLOT ---
            ax_inset = axes[i].inset_axes([0.1, 0.1, 0.5, 0.5])
            ax_inset.plot(recall, precision, lw=2)
            ax_inset.set_title("Full Range")
            ax_inset.set_xlabel("R", fontsize=8)
            ax_inset.set_ylabel("P", fontsize=8)
            ax_inset.grid(alpha=0.4)


        plt.tight_layout(rect=[0, 0, 1, 0.96])
        plt.show()

    else:
        print("No images found in the validation directory.")

In [None]:
import torch
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
import os
import numpy as np
from tqdm.auto import tqdm

## 1. Configuration
SAVED_MODEL_PATH = '/content/finetuned_tinyclip_multilabel.pt'
VALIDATION_DIR = '/content/validation/valid/crops'
MODEL_ID = "wkcn/TinyCLIP-ViT-61M-32-Text-29M-LAION400M"

# --- Check if GPU is available ---
if not torch.cuda.is_available():
    print("Error: A CUDA-enabled GPU is required for accurate measurements.")
else:
    device = torch.device("cuda")

    ## 2. Load Model and Data Paths
    model = CLIPModel.from_pretrained(MODEL_ID)
    model.load_state_dict(torch.load(SAVED_MODEL_PATH, map_location=device))
    model.to(device)
    model.eval()

    processor = CLIPProcessor.from_pretrained(MODEL_ID)

    # Get list of image paths for evaluation
    image_paths = []
    for root, _, files in os.walk(VALIDATION_DIR):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                image_paths.append(os.path.join(root, file))

    num_images = len(image_paths)
    print(f"Found {num_images} images for evaluation.")

    # --- Warm-up the GPU ---
    print("Warming up the GPU...")
    for _ in range(10):
        dummy_image = Image.new('RGB', (224, 224))
        dummy_text = ["a photo"]
        inputs = processor(text=dummy_text, images=dummy_image, return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            _ = model(**inputs)

    torch.cuda.synchronize()

    # ===================================================================
    # METRIC 1: MODEL SIZE
    # ===================================================================
    model_size_bytes = os.path.getsize(SAVED_MODEL_PATH)
    model_size_mb = model_size_bytes / (1024 * 1024)
    print(f"\n--- 1. Model Size ---")
    print(f"✅ Model Size: {model_size_mb:.2f} MB")

    # ===================================================================
    # METRIC 2: INFERENCE LATENCY
    # ===================================================================
    print("\n--- 2. Measuring Inference Latency ---")
    latencies = []
    for image_path in tqdm(image_paths, desc="Measuring Latency"):
        image = Image.open(image_path).convert("RGB")
        text_prompts = ["a photo"] # Only need one prompt for timing
        inputs = processor(text=text_prompts, images=image, return_tensors="pt", padding=True).to(device)

        # Time the inference call
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)

        start_event.record()
        with torch.no_grad():
            _ = model(**inputs)
        end_event.record()

        torch.cuda.synchronize()

        latency_ms = start_event.elapsed_time(end_event)
        latencies.append(latency_ms)

    average_latency_ms = np.mean(latencies)
    print(f"✅ Average Latency: {average_latency_ms:.2f} ms per image")


    # ===================================================================
    # METRIC 3: PEAK MEMORY USAGE
    # ===================================================================
    print("\n--- 3. Measuring Peak Memory Usage ---")
    torch.cuda.reset_peak_memory_stats(device)

    # Run evaluation on all images to find the peak memory
    for image_path in tqdm(image_paths, desc="Measuring Memory"):
        image = Image.open(image_path).convert("RGB")
        text_prompts = ["a photo"]
        inputs = processor(text=text_prompts, images=image, return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            _ = model(**inputs)

    peak_memory_bytes = torch.cuda.max_memory_allocated(device)
    peak_memory_mb = peak_memory_bytes / (1024 * 1024)
    print(f"✅ Peak Memory Usage: {peak_memory_mb:.2f} MB")

## MISCLASSIFICATION

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
import torch
from transformers import CLIPModel, CLIPProcessor
from PIL import Image

# --- 1. Configuration ---
# Model and Data Paths
MODEL_ID = "wkcn/TinyCLIP-ViT-61M-32-Text-29M-LAION400M"
SAVED_MODEL_PATH = '/content/finetuned_tinyclip_multilabel.pt'
VALIDATION_CSV_PATH = 'validation_data.csv'

# Image and Class Configuration
VALIDATION_ROOT_DIR = '/content/validation/valid'
ORIGINAL_IMAGE_BASE_FILENAME = 'IMG_2021_JPEG.rf.d08beab9aa2e94e562953893e89bdc15.jpg'
IMAGE_IDENTIFIER = 'IMG_2021' # The common identifier for the image and its crops
CLASS_LABELS = ["calyx", "fruitlet", "peduncle", "negative"]

# --- 2. Load Model and Data ---
# Load fine-tuned model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
try:
    model = CLIPModel.from_pretrained(MODEL_ID)
    model.load_state_dict(torch.load(SAVED_MODEL_PATH, map_location=device))
    model.to(device)
    model.eval()
    processor = CLIPProcessor.from_pretrained(MODEL_ID)
    text_prompts = [f"a photo of a {label}" for label in CLASS_LABELS]
    print("Successfully loaded model and processor.")
except Exception as e:
    print(f"Error loading model: {e}")
    exit()

# Load the dataset CSV
try:
    df = pd.read_csv(VALIDATION_CSV_PATH)
    print(f"Successfully loaded data from '{VALIDATION_CSV_PATH}'.")
except FileNotFoundError:
    print(f"Error: Validation CSV '{VALIDATION_CSV_PATH}' not found.")
    exit()

# --- 3. Identify Target Images ---
# Path to the original, un-cropped image
original_image_path = os.path.join(VALIDATION_ROOT_DIR, ORIGINAL_IMAGE_BASE_FILENAME)

# Find all rows in the CSV related to the target image (including positive and negative crops)
cropped_images_info = df[df['image_path'].str.contains(IMAGE_IDENTIFIER, na=False)]
cropped_image_paths = cropped_images_info['image_path'].tolist()

if not cropped_image_paths:
    print(f"No cropped images found for identifier '{IMAGE_IDENTIFIER}'. Check your CSV and identifier.")
    exit()

# --- 4. Create Display with Predictions ---
# Arrange plots in a grid
num_cropped = len(cropped_image_paths)
total_plots = 1 + num_cropped
num_cols = 3
num_rows = (total_plots + num_cols - 1) // num_cols

fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, num_rows * 6))
axes = axes.flatten()
for ax in axes:
    ax.axis('off')

# --- Plot the Original Image (First Slot) ---
ax_idx = 0
axes[ax_idx].set_title(f"Original:\n{os.path.basename(original_image_path)}", fontsize=10)
if os.path.exists(original_image_path):
    try:
        img = mpimg.imread(original_image_path)
        axes[ax_idx].imshow(img)
    except Exception as e:
        axes[ax_idx].text(0.5, 0.5, f'Error loading\n{e}', ha='center', color='red')
else:
    axes[ax_idx].text(0.5, 0.5, 'Original Image Not Found', ha='center', color='red')

# --- Plot the Cropped Images with Predictions ---
for i, path in enumerate(cropped_image_paths):
    ax_idx = i + 1
    if ax_idx >= len(axes):
        break # Stop if we run out of subplot axes

    ax = axes[ax_idx]
    if not os.path.exists(path):
        ax.text(0.5, 0.5, 'Image File Not Found', ha='center', color='red')
        ax.set_title(f"Crop:\n{os.path.basename(path)}", fontsize=10)
        continue

    try:
        # Get the ground truth label from the DataFrame
        image_data = cropped_images_info[cropped_images_info['image_path'] == path].iloc[0]
        true_label = 'unknown'
        for label in CLASS_LABELS:
            if image_data[label] == 1:
                true_label = label
                break

        # Get the model's prediction
        image = Image.open(path).convert("RGB")
        inputs = processor(text=text_prompts, images=image, return_tensors="pt", padding=True).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
        logits = outputs.logits_per_image
        pred_index = logits.argmax().item()
        predicted_label = CLASS_LABELS[pred_index]

        # Display the image
        ax.imshow(mpimg.imread(path))

        # Set the title with the prediction results and appropriate color
        if predicted_label == true_label:
            title_color = 'green'
        else:
            title_color = 'red'

        ax.set_title(f"True: '{true_label}'\nPred: '{predicted_label}'", color=title_color, fontsize=12)

    except Exception as e:
        ax.text(0.5, 0.5, f'Error processing\n{e}', ha='center', color='red')
        print(f"Failed to process {path}: {e}")

# Main title for the entire figure
plt.suptitle(f'Model Predictions for {IMAGE_IDENTIFIER}', fontsize=20, y=1.0)
plt.tight_layout(pad=3.0)
plt.savefig(IMAGE_IDENTIFIER + ".png")
# plt.show() # Uncomment to display in an interactive environment

print(f"\nGenerated plot and saved to {IMAGE_IDENTIFIER}.png")