# Step 1: Import Libraries and Load the CLIP Model

In [9]:
import torch
import os
import clip
from PIL import Image, UnidentifiedImageError
import requests
from io import BytesIO
import mysql.connector
import pandas as pd
import time
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from dotenv import load_dotenv

In [None]:
# Load the CLIP model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# Step 2: Connect to TiDB

In [None]:
def reconnect():
    try:
        connection = mysql.connector.connect(
            host=os.getenv('MYSQL_HOST'),
            user=os.getenv('MYSQL_USER'),
            password=os.getenv('MYSQL_PASSWORD'),
            database=os.getenv('MYSQL_DATABASE'),
            port=int(os.getenv('MYSQL_PORT'))
        )
        if connection.is_connected():
            print("Successfully connected to TiDB")
            return connection
    except mysql.connector.Error as e:
        print(f"Error: {e}")
        return None

# Establish the initial connection
connection = reconnect()
cursor = connection.cursor(dictionary=True)

# Step 3: Retrieve Data from TiDB

In [None]:
# Query the existing data
query = """SELECT id, motorbike_image_path
FROM detected_motorbikes_and_plates
WHERE motorbike_image_embedding IS NULL"""
cursor.execute(query)
data = pd.DataFrame(cursor.fetchall())

# Display the retrieved data
print(data.head())

# Step 4: Generate Image Embeddings Using CLIP

In [None]:
# Function to generate image embeddings using CLIP
def get_image_embedding(image_url):
    try:
        response = requests.get(image_url)
        img = Image.open(BytesIO(response.content))

        # Preprocess the image for the CLIP model
        img_preprocessed = preprocess(img).unsqueeze(0).to(device)

        # Generate embedding using the CLIP model
        with torch.no_grad():
            embedding = model.encode_image(img_preprocessed).cpu().numpy().flatten()

        return embedding
    except UnidentifiedImageError:
        print(f"Cannot identify image from URL: {image_url}")
        return None
    except Exception as e:
        print(f"Error processing image from URL: {image_url}, Error: {e}")
        return None

# Generate embeddings for all images, skipping those that cause errors
data['embedding'] = data['motorbike_image_path'].apply(get_image_embedding)

# Check the DataFrame to ensure embeddings are created
print(data[['id', 'embedding']].head())