<a href="https://colab.research.google.com/github/Adithyan773/IKEA_recomendation_system/blob/main/IKEA_final_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q faiss-cpu

In [None]:
import numpy as np
import tensorflow as tf
from transformers import DistilBertTokenizer, TFDistilBertModel
from tensorflow.keras.layers import Dense, Input, Layer
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import faiss
from IPython.display import Image, display

# Custom DistilBERT embedding layer
class DistilBertEmbeddingLayer(Layer):
    def __init__(self, distilbert_model, **kwargs):
        super(DistilBertEmbeddingLayer, self).__init__(**kwargs)
        self.distilbert_model = distilbert_model

    def call(self, inputs):
        input_ids, attention_mask = inputs
        outputs = self.distilbert_model(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state[:, 0, :]

    def get_config(self):
        config = super(DistilBertEmbeddingLayer, self).get_config()
        return config

# Load FAISS indices and metadata
image_index = faiss.read_index('/content/image_embeddings_fixed.faiss')
text_index = faiss.read_index('/content/text_embeddings_fixed.faiss')
image_paths = np.load('/content/image_paths_fixed.npy')
full_metadata = np.load('/content/full_metadata_fixed.npy', allow_pickle=True)

# Define columns (consistent with text embedding code)
all_columns = ['name', 'category', 'short_description', 'designer', 'depth', 'height', 'width', 'price', 'old_price', 'image_description']
image_embeddings = np.array([image_index.reconstruct(i) for i in range(image_index.ntotal)])
text_embeddings = np.array([text_index.reconstruct(i) for i in range(text_index.ntotal)])

# Ensure consistent sample size
n_samples = min(len(image_embeddings), len(text_embeddings), len(image_paths), len(full_metadata))
image_embeddings = image_embeddings[:n_samples]
text_embeddings = text_embeddings[:n_samples]
image_paths = image_paths[:n_samples]
full_metadata = full_metadata[:n_samples]

# Create projection model
def create_projection_model(input_dim=256, output_dim=256):
    input_layer = Input(shape=(input_dim,))
    x = Dense(512, activation='relu')(input_layer)
    x = Dense(output_dim, activation='relu')(x)
    return Model(input_layer, x)

text_projection = create_projection_model()
image_projection = create_projection_model()

# Training parameters
batch_size = 32
num_epochs = 20
optimizer = Adam(learning_rate=0.0001)
margin = 1.0

# Train Siamese network with triplet loss
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    epoch_losses = []
    permutation = np.random.permutation(n_samples)
    for start in range(0, n_samples, batch_size):
        end = min(start + batch_size, n_samples)
        batch_indices = permutation[start:end]

        text_batch = text_embeddings[batch_indices]
        image_batch = image_embeddings[batch_indices]
        neg_indices = np.random.choice(n_samples, len(batch_indices))

        with tf.GradientTape() as tape:
            text_proj = text_projection(text_batch)
            image_proj = image_projection(image_batch)

            neg_image_batch = image_embeddings[neg_indices]
            neg_text_batch = text_embeddings[neg_indices]
            neg_image_proj = image_projection(neg_image_batch)
            neg_text_proj = text_projection(neg_text_batch)

            pos_dist_text = tf.norm(text_proj - image_proj, axis=-1)
            neg_dist_text = tf.norm(text_proj - neg_image_proj, axis=-1)
            loss_text = tf.reduce_mean(tf.maximum(pos_dist_text - neg_dist_text + margin, 0.0))

            pos_dist_image = tf.norm(image_proj - text_proj, axis=-1)
            neg_dist_image = tf.norm(image_proj - neg_text_proj, axis=-1)
            loss_image = tf.reduce_mean(tf.maximum(pos_dist_image - neg_dist_image + margin, 0.0))

            loss = loss_text + loss_image
            epoch_losses.append(loss.numpy())

        grads = tape.gradient(loss, text_projection.trainable_weights + image_projection.trainable_weights)
        optimizer.apply_gradients(zip(grads, text_projection.trainable_weights + image_projection.trainable_weights))

    avg_epoch_loss = np.mean(epoch_losses)
    print(f"Epoch {epoch+1}  Loss: {avg_epoch_loss:.4f}")

# Project text embeddings
projected_text_embeddings = text_projection.predict(text_embeddings, batch_size=32)
projected_text_embeddings = projected_text_embeddings / np.linalg.norm(projected_text_embeddings, axis=1, keepdims=True)

# Create new FAISS index for projected text embeddings
d = 256
new_text_index = faiss.IndexFlatIP(d)
new_text_index.add(projected_text_embeddings)

# Load query model
tokenizer = DistilBertTokenizer.from_pretrained('/content/drive/MyDrive/distilbert_v3')
distilbert_base = TFDistilBertModel.from_pretrained('/content/drive/MyDrive/distilbert_v3', from_pt=True)
query_input_ids = Input(shape=(None,), dtype=tf.int32, name='input_ids')
query_attention_mask = Input(shape=(None,), dtype=tf.int32, name='attention_mask')
distilbert_layer = DistilBertEmbeddingLayer(distilbert_base)([query_input_ids, query_attention_mask])
query_embedding = Dense(256, activation='relu', name='text_embedding')(distilbert_layer)
query_model = Model([query_input_ids, query_attention_mask], query_embedding)
query_model.load_weights('/content/text_model_weights_fixed.weights.h5')

# Function to get projected query embedding
def get_projected_query_embedding(query):
    # Optional: Align with text embedding code by repeating query three times
    enhanced_query = query + ' ' + query + ' ' + query
    inputs = tokenizer(enhanced_query, return_tensors='tf', padding=True, truncation=True, max_length=128)
    initial_embedding = query_model.predict([inputs['input_ids'], inputs['attention_mask']], verbose=0)
    projected_embedding = text_projection.predict(initial_embedding, verbose=0)
    projected_embedding = projected_embedding / np.linalg.norm(projected_embedding, axis=1, keepdims=True)
    return projected_embedding[0]

# Function to get top-k products (simplified, no material filtering)
def get_top_k_products(query, k=5):
    query_embedding = get_projected_query_embedding(query).reshape(1, -1)
    similarities, indices = new_text_index.search(query_embedding, k)
    top_k_indices = indices[0]
    top_k_similarities = similarities[0]
    top_k_images = [image_paths[idx] for idx in top_k_indices]
    top_k_metadata = [dict(zip(all_columns, full_metadata[idx])) for idx in top_k_indices]
    return top_k_similarities, top_k_images, top_k_metadata

In [None]:
# Test the query
query = "small chair"
top_k_similarities, top_k_images, top_k_metadata = get_top_k_products(query, k=5)

print(f"Matches for query: '{query}':")
for i, (sim, img_path, meta) in enumerate(zip(top_k_similarities, top_k_images, top_k_metadata), 1):
    print(f"{i}. Similarity: {sim:.4f}")
    print(f"   Image Path: {img_path}")
    try:
        display(Image(filename=img_path, width=200, height=200))
    except Exception as e:
        print(f"   [Error displaying image: {e}]")
    for col, value in meta.items():
        print(f"   {col.capitalize()}: {value}")
    print()