# CREATING ONNX FILE

In [None]:

# Step 1: Install required libraries
!pip install -q open_clip_torch onnx onnxruntime

# It's good practice to import after installation
import open_clip
import torch
import onnx
import onnxruntime
import numpy as np
import os

In [None]:
import torch
import torch.nn as nn
from transformers import CLIPModel, CLIPProcessor
from PIL import Image
import onnx
import os

# --- Step 1: Define Paths and Model ID ---
MODEL_ID = "wkcn/TinyCLIP-ViT-61M-32-Text-29M-LAION400M"
WEIGHTS_PATH = '/content/finetuned_tinyclip_multilabel.pt' # Make sure this path is correct
device = "cuda" if torch.cuda.is_available() else "cpu"

# --- Step 2: Load Model and Processor ---
print("Loading model and processor...")
model = CLIPModel.from_pretrained(MODEL_ID).to(device)
model.load_state_dict(torch.load(WEIGHTS_PATH, map_location=device))
model.eval()

processor = CLIPProcessor.from_pretrained(MODEL_ID)
print("✅ Model and processor loaded.")

# --- Step 3: Create a Wrapper Model that includes the Projection Head ---
class CompleteVisionModel(nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.vision_model = clip_model.vision_model
        self.visual_projection = clip_model.visual_projection

    def forward(self, pixel_values):
        # Get the intermediate output from the main vision model
        outputs = self.vision_model(pixel_values=pixel_values)
        pooled_output = outputs.pooler_output

        # Apply the final projection layer
        image_features = self.visual_projection(pooled_output)
        return image_features

# Instantiate the complete model
complete_vision_model_to_export = CompleteVisionModel(model).to(device)
complete_vision_model_to_export.eval()
print("✅ Created a complete vision model wrapper.")

# --- Step 4: Create a Correct Dummy Input ---
print("Creating a correctly preprocessed dummy input...")
dummy_pil_image = Image.new('RGB', (224, 224))
inputs = processor(images=dummy_pil_image, return_tensors="pt").to(device)
dummy_input = inputs['pixel_values']

# --- Step 5: Export the COMPLETE Vision Model to ONNX ---
onnx_file_path = "tinyclip_dynamic.onnx"
print("\nExporting the complete vision model to ONNX...")
torch.onnx.export(
    complete_vision_model_to_export,
    dummy_input,
    onnx_file_path,
    export_params=True,
    opset_version=14,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }
)

print(f"✅ Export complete: {onnx_file_path}")