In [74]:
from pymilvus import (
    connections,
    utility,
    FieldSchema,
    CollectionSchema,
    DataType,
    Collection,
    db,
    MilvusClient
)
import clip
import torch
import numpy as np
from PIL import Image
from torch.nn.functional import cosine_similarity
import os

In [2]:
host = "localhost"
port = 19530 # Mapping for 19530 (default Milvus port)

# Connect to Milvus
client = connections.connect("default", host=host, port=port)

# Check if the connection is established
print("Is Milvus connected:", connections.has_connection("default"))

# Optional: List collections to confirm the connection
from pymilvus import utility
print("Collections:", utility.list_collections())

Is Milvus connected: True
Collections: ['product_collection', 'products', 'images']


In [4]:
# database = db.create_database("Building_Designs")

In [5]:
db.using_database("Building_Designs")

In [86]:
collections = utility.list_collections()
print(collections)

['designs']


In [109]:
fields = [
    FieldSchema(name="Building_id", dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name="Image_Name", dtype=DataType.VARCHAR, max_length = 1000),
    FieldSchema(name="building_generated_image_embedding", dtype=DataType.FLOAT_VECTOR, dim=768),  
    # FieldSchema(name="generated_building_image", dtype=DataType.FLOAT_VECTOR, dim=768),
    FieldSchema(name="building_sketch_embedding", dtype=DataType.FLOAT_VECTOR, dim=768)
]

In [110]:
# buildings_collection = Collection('designs')
# buildings_collection.release()
# buildings_collection.drop()

In [111]:
design_schema = CollectionSchema(fields, description="Architectural Designs collection")
buildings_collection = Collection(name = 'designs', schema=design_schema)

In [112]:
index_params = {
    "metric_type": "COSINE",
    "index_type": "IVF_FLAT",
    "params": {"nlist": 256}
}

In [113]:
buildings_collection.create_index(field_name="building_sketch_embedding", index_params = index_params)
buildings_collection.create_index(field_name="building_generated_image_embedding", index_params = index_params)
buildings_collection.load()

In [102]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
clip_model, preprocess = clip.load("ViT-L/14@336px", device=device)

In [103]:
def generate_image_embedding(image):
    # Check if the input is a path (string) or a NumPy array
    if isinstance(image, str):
        # Load the image from the file path
        image_pil = Image.open(image)
    elif isinstance(image, np.ndarray):
        # Convert the NumPy array to a PIL Image
        image_pil = Image.fromarray(image)
    else:
        raise ValueError("The input must be a file path or a NumPy array")

    # Preprocess the image using the provided preprocessing function
    image_tensor = preprocess(image_pil).unsqueeze(0).to(device)

    with torch.no_grad():
        image_embedding = clip_model.encode_image(image_tensor)
        image_embedding /= image_embedding.norm(dim=-1, keepdim=True)
    return image_embedding.cpu().numpy()[0]


In [104]:
SKETCHES_IMAGES_DIR = 'sketches_images'
GENERATED_IMAGES_DIR = 'generated_images_512'

In [105]:
sketches_images = os.listdir(SKETCHES_IMAGES_DIR)

In [114]:
image_names = []
generated_images_embeddings = []
sketch_embeddings = []

# Batch size
batch_size = 100

# Iterate through your images
for i, image in enumerate(sketches_images):
    # Extract and prepare data
    sketch_image_path = os.path.join(SKETCHES_IMAGES_DIR, image)
    image_name = image[:-11]
    generated_image_name = image_name + '_generated.jpg'
    generated_image_path = os.path.join(GENERATED_IMAGES_DIR, generated_image_name)
    
    sketch_image_embedding = generate_image_embedding(sketch_image_path)
    generated_image_embedding = generate_image_embedding(generated_image_path)
    
    # Append data to the lists
    image_names.append(image_name)
    generated_images_embeddings.append(generated_image_embedding)
    sketch_embeddings.append(sketch_image_embedding)
    
    # When batch size is reached, insert into Milvus
    if (i + 1) % batch_size == 0 or (i + 1) == len(sketches_images):
        entities = [
            image_names,
            generated_images_embeddings,
            sketch_embeddings
        ]
        
        # Insert into Milvus
        buildings_collection.insert(entities)
        
        # Clear the lists after batch insert
        image_names.clear()
        generated_images_embeddings.clear()
        sketch_embeddings.clear()

# Flush to ensure all data is written
buildings_collection.flush()