In [None]:
# *******************************************************************************************
# *                                                                                         *
# *  Hewlett Packard Enterprise Confidential.                                               *
# *                                                                                         *
# *  This software is the proprietary information of Hewlett Packard Enterprise.            *
# *                                                                                         *
# * Author : Nishant Chanduka (nishant.chanduka@hpe.com)                                    *
# *******************************************************************************************

In [None]:
# This notebook containes sample code to 
#     1. Connect to a S3 bucket
#     2. Read images from S3 bucket
#     3. Convert images to numpy array
#     4. Load a embedding model.
#     5. Generate image embeddings. 
#     6. Connect to weaviate db instance. 
#     7. Store the image embeddings in weaviate collection. 

In [None]:
# Before you choose to run this code, create a conda env. Sample steps/commands below
# Commands:
# conda create --name weaviate-env python=3.11
# conda actiavte weaviate-env
# pip install weaviate-client boto3 numpy tensorflow Pillow
# conda install -c anaconda ipykernel

# If you do not see the env is not listed as a kernel under notebooks then close the notebook browser window and relaunch it.
# python -m ipykernel install --user --name=weaviate-env

In [None]:
import boto3
from io import BytesIO
from PIL import Image
import numpy as np
import weaviate
import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import img_to_array

In [None]:
# Configure S3 connection
object_storage_service_name = "source-images-service"
object_storage_namespace = ".ezdata-system"
resource_type = ".svc"
domain = ".cluster.local"
object_storage_port = "30000"

s3_endpoint_url = f"http://{object_storage_service_name}{object_storage_namespace}{resource_type}{domain}:{object_storage_port}"
print(f"S3 endpoint URL: {s3_endpoint_url}")

# Create S3 clients
s3_client = boto3.client('s3', endpoint_url=s3_endpoint_url)
s3_resource = boto3.resource('s3', endpoint_url=s3_endpoint_url)

# Set bucket name
bucket_name = "poc-mercedes-gp"

# Load the dataset JSON file
file_key = "training/training_dataset.json"
response = s3_client.get_object(Bucket=bucket_name, Key=file_key)
content = response["Body"].read().decode("utf-8")
dataset = json.loads(content)

# Display information about the dataset
print(f"Dataset size: {len(dataset)} images")
print("First image information:")
print(dataset[0])

# Select a subset of images for testing (last 20% of the dataset)
test_set_size = min(10, int(len(dataset) * 0.2))  # Either 10 images or 20% of dataset, whichever is smaller
test_indices = list(range(len(dataset) - test_set_size, len(dataset)))
print(f"\nSelected {test_set_size} test images (indices {test_indices[0]} to {test_indices[-1]})")

In [6]:
import weaviate, os
from weaviate.classes.init import Auth

def connect_to_weaviate():
    #getting the auth token
    secret_file_path = "/etc/secrets/ezua/.auth_token"
    with open(secret_file_path, "r") as file:
        token = file.read().strip()
    
    #Connect to Weaviate instance
    #domain = ".cluster.local"
    weaviate_http_host = "weaviate.poc-weaviate.svc.cluster.local"
    weaviate_grpc_host = "weaviate-grpc.poc-weaviate.svc" + domain
    weaviate_headers = {"x-auth-token": token}
    
    client = weaviate.connect_to_custom(
        http_host=weaviate_http_host,        # Hostname for the HTTP API connection
        http_port=80,                        # Default is 80, WCD uses 443
        http_secure=False,                   # Whether to use https (secure) for the HTTP API connection
        grpc_host=weaviate_grpc_host,        # Hostname for the gRPC API connection
        grpc_port=50051,                     # Default is 50051, WCD uses 443
        grpc_secure=False,                   # Whether to use a secure channel for the gRPC API connection
        headers=weaviate_headers,
        skip_init_checks=False
    )
    
    # Test the connection
    try:
        if client.is_ready():
            # You can now interact with your Weaviate instance
            print("Successfully connected to Weaviate with custom configuration!")
            return client
        else:
            print("Failed to connect to Weaviate.")
    except Exception as e:
        print(f"Error connecting to Weaviate: {e}")

In [13]:
def load_embedding_model():
    model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
    return model

In [7]:
def list_images_in_s3(bucket_name, prefix=''):
    response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
    return [item['Key'] for item in response.get('Contents', []) if item['Key'].lower().endswith(('png', 'jpg', 'jpeg'))]

In [8]:
def read_image_from_s3(bucket_name, object_key):
    response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
    image_content = response['Body'].read()
    image = Image.open(BytesIO(image_content))
    return image

In [9]:
def image_to_numpy_array(image, target_size=(224, 224)):
    image = image.resize(target_size)
    image = image.convert('RGB')  # Ensure image is in RGB format
    img_array = img_to_array(image)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = preprocess_input(img_array)
    return img_array

In [10]:
def get_image_embedding(model, img_array):
    embedding = model.predict(img_array)
    return embedding.flatten()

In [16]:
def store_embeddings_in_weaviate(embeddings, image_keys, client, collection):
    for key, embedding in zip(image_keys, embeddings):       
        collection.data.insert(
            properties={
                "image_key": key,
                "image_embedding": embedding.tolist()
            }
        )

In [None]:
import weaviate, os
import weaviate.classes.config as wc
from weaviate.classes.config import Configure, DataType, Multi2VecField, Property

def main():
    # Initialize Weaviate client
    client_conn = connect_to_weaviate()

    if client_conn.collections.exists("MercedesImageEmbedding"):
        client_conn.collections.delete("MercedesImageEmbedding") 
   
    client_conn.collections.create(
        name="MercedesImageEmbedding",
        properties=[
            Property(name="image_key", data_type = DataType.TEXT),
            Property(name="image_embedding", data_type = DataType.NUMBER_ARRAY),
        ],
        vectorizer_config=wc.Configure.Vectorizer.none(),
    )

    collection = client_conn.collections.get("MercedesImageEmbedding")
    
    # Load the embedding model
    embedding_model = load_embedding_model()
    
    # List images in S3
    image_keys = list_images_in_s3(bucket_name, prefix)
    #print(image_keys)

    embeddings = []
    for key in image_keys:
        # Read image from S3
        image = read_image_from_s3(bucket_name, key)
        print(image)
        
        # Convert image to NumPy array
        image_array = image_to_numpy_array(image)
        #print(image_array)

        # Get image embedding
        embedding = get_image_embedding(embedding_model, image_array)

        # Collect embeddings
        embeddings.append(embedding)
    
    # Store embeddings in Weaviate
    store_embeddings_in_weaviate(embeddings, image_keys, client_conn, collection)

    client_conn.close()
    
if __name__ == "__main__":
    main()