In [71]:
import json
import numpy as np
from sentence_transformers import SentenceTransformer
from PIL import Image
import os


In [72]:
BASE_DIR = os.path.abspath(os.path.join(os.getcwd(), ".."))
# -----------------------------
# Load dataset
# -----------------------------
with open(os.path.join(BASE_DIR, "data/train/train_dataset.json"), "r") as f:
    train_data = json.load(f)

with open(os.path.join(BASE_DIR, "data/val/val_dataset.json"), "r") as f:
    val_data = json.load(f)

# -

In [73]:

# Initialize text embedding model
# -----------------------------
text_model = SentenceTransformer("all-MiniLM-L6-v2")

# -----------------------------
# Helper function
# -----------------------------
def get_embeddings(samples):
    # Text embeddings
    text_embeds = text_model.encode([s["user_input"] for s in samples], convert_to_numpy=True)
    
    # Image embeddings
    image_embeds = []
    for s in samples:
        img_path = os.path.join(BASE_DIR, s["image_path"])
        try:
            img = Image.open(img_path).convert("RGB")
            img = img.resize((128, 128))
            img_array = np.array(img, dtype=np.float32).flatten()
        except Exception as e:
            print(f"Error loading image {s[img_path]}: {e}")
            img_array = np.zeros((128*128*3,))
        image_embeds.append(img_array)
    
    image_embeds = np.array(image_embeds,dtype=np.float32)
    
    # Concatenate text + image embeddings
    embeddings = np.hstack([text_embeds, image_embeds]).astype(np.float32)
    return embeddings


In [74]:

# -----------------------------
# Generate embeddings
# -----------------------------
train_embeddings = get_embeddings(train_data)
val_embeddings = get_embeddings(val_data)


In [75]:

# -----------------------------
# Save embeddings
# -----------------------------
os.makedirs(os.path.join(BASE_DIR, "embeddings"), exist_ok=True)
np.save(os.path.join(BASE_DIR, "embeddings/train_embeddings.npy"), train_embeddings)
np.save(os.path.join(BASE_DIR, "embeddings/val_embeddings.npy"), val_embeddings)

print("Embeddings generated and saved in /embeddings folder")


Embeddings generated and saved in /embeddings folder
