# Script Description

This script performs the following tasks:

1. **Load Libraries**:
   - Uses `json` for handling JSON files.
   - Uses `sentence-transformers` for generating text embeddings.
   - Uses `torch` and `torchvision` for generating image embeddings.
   - Uses `PIL` for image processing.
   - Uses `os` for file path operations.

2. **Load Data**:
   - Loads `initial_dataset_no_aug.json` (or the relevant json) which contains the main dataset.
   - Loads `image_metadata.json` which contains metadata about images.

3. **Initialize Models**:
   - Initializes a text embedding model (`all-MiniLM-L6-v2`) from `sentence-transformers`.
   - Loads a pre-trained ResNet-18 model from `torchvision` for image embeddings.

4. **Preprocess Images**:
   - Defines a preprocessing pipeline to resize, crop, convert to tensor, and normalize images.

5. **Generate Embeddings**:
   - For each key in the dataset, generates a text embedding.
   - For the first 5 values (keys in `image_metadata.json`), loads the corresponding image, preprocesses it, and generates an image embedding if the image file exists.

6. **Store Results**:
   - Stores the text and image embeddings in a dictionary.
   - Saves the dictionary to a new JSON file `embeddings.json`.

7. **Print Confirmation**:
   - Prints a message confirming that the embeddings have been saved.

The script ensures that the embeddings are correctly generated and stored for further use.


In [1]:
!pip install sentence-transformers torchvision pillow torch git+https://github.com/openai/CLIP.git gensim

Collecting sentence-transformers
  Downloading sentence_transformers-3.0.0-py3-none-any.whl.metadata (10 kB)
Collecting transformers<5.0.0,>=4.34.0 (from sentence-transformers)
  Downloading transformers-4.41.1-py3-none-any.whl.metadata (43 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.8/43.8 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub>=0.15.1 (from sentence-transformers)
  Downloading huggingface_hub-0.23.2-py3-none-any.whl.metadata (12 kB)
Collecting tokenizers<0.20,>=0.19 (from transformers<5.0.0,>=4.34.0->sentence-transformers)
  Downloading tokenizers-0.19.1-cp311-cp311-macosx_11_0_arm64.whl.metadata (6.7 kB)
Collecting safetensors>=0.4.1 (from transformers<5.0.0,>=4.34.0->sentence-transformers)
  Downloading safetensors-0.4.3-cp311-cp311-macosx_11_0_arm64.whl.metadata (3.8 kB)
Downloading sentence_transformers-3.0.0-py3-none-any.whl (224 kB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [3

In [1]:
import json

# Load the JSON data from the file
with open('image_metadata.json', 'r') as file:
    data = json.load(file)

# Extract filenames
filenames = [item['filename'] for item in data.values()]

# Extract unique furniture items from filenames
unique_items = set()
for filename in filenames:
    item = filename.split('.page')[0]
    unique_items.add(item.lower())

# Count the number of unique furniture items
num_unique_items = len(unique_items)

print(f'Total number of unique furniture items: {num_unique_items}')


Total number of unique furniture items: 90


In [2]:
print(unique_items)

{'mammut_1', 'lack', 'jules_1', 'smussla', 'klappsta', 'fredrik', 'applaro_2', 'lerhamn', 'leifarne', 'vadholma', 'sigurd', 'preben', 'laiva', 'hemnes', 'vaniljstang', 'nilsove', 'lantliv', 'lantliv_2', 'askholmen', 'nesna', 'vittsjo_1', 'skogsta', 'klingsbo_1', 'perjohan', 'stig', 'gladom', 'vittsjo', 'skogsta_2', 'ekenas', 'vedbo', 'reidar', 'jokkmokk', 'vesken', 'flisat', 'grubban', 'ingolf', 'kyrre', 'poang_2', 'jules_2', 'marius', 'hemnes_2', 'satsumas_2', 'sjalland', 'glenn', 'bjorkudden', 'applaro_3', 'lunnarp', 'nilsolle', 'ivar', 'tobias', 'olivblad', 'pinnig', 'alex', 'ragrund', 'applaro', 'norraryd', 'dalfred', 'tornviken', 'omtanksam', 'ingolf_2', 'vittsjo_2', 'voxlov', 'nordviken', 'norraker', 'nils', 'bernhard', 'falholmen', 'fjallbo', 'kaustby', 'stefan', 'fanbyn', 'teodores', 'trogen', 'bjorkudden_2', 'mammut_2', 'silveran', 'tjusig', 'satsumas', 'herman', 'froset', 'pahl', 'poang_1', 'ronninge', 'bekvam_3017', 'tommaryd', 'sundvik', 'agam', 'yngvar', 'borje', 'lisabo'}

## Data Structure: embeddings.json
- text_embeddings: prompt embedding
- image_embeddings_top5_idx: top 5 relevant images saved as an index of the image_embeddings_array
- image_embeddings_all: image embeddings for the entire manual 
- scores: scores for the top 5 relevant images 

In [4]:
import clip
import torch

# Run this cell if using clip
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)

In [None]:
import json
from sentence_transformers import SentenceTransformer
import torchvision.transforms as transforms
from gensim.models import KeyedVectors
from PIL import Image
import os
from tqdm import tqdm
import numpy as np
# Load the dataset
from ast import literal_eval

# image embedding types: ["resnet-18", "clip"]
IMAGE_EMBEDDING_TYPE = "clip"
# text embedding types: ["MiniLM-L6", "word2vec"]
TEXT_EMBEDDING_TYPE="word2vec"
# train test val types: ["train", "test", "val"]
TRAIN_TEST_VAL_TYPE = "train"
print("------------------------------------------------------------------------")
print(f"--------------------GENERATING EMBEDDINGS--------------------")
print(f"--------------------IMAGE EMBEDDINGS USING {IMAGE_EMBEDDING_TYPE}--------------------")
print(f"--------------------TEXT EMBEDDINGS USING {TEXT_EMBEDDING_TYPE}--------------------")
print(f"--------------------DATA TYPE: {TRAIN_TEST_VAL_TYPE}--------------------")

# Load in data and unwrap it 
# Function to convert stringified tuple keys back to tuples
def unwrap_keys(mapping):
    return {literal_eval(k): v for k, v in mapping.items()}

# Load the JSON file
with open(f'augmented_data/augmented_dataset_{TRAIN_TEST_VAL_TYPE}.json', 'r') as json_file:
    data_from_json = json.load(json_file)

# print(data_from_json)
# Unwrap the keys to their original tuple format
dataset = unwrap_keys(data_from_json)

# Load image metadata
with open('image_metadata.json', 'r') as file:
    image_metadata = json.load(file)

# Initialize text embedding model
if TEXT_EMBEDDING_TYPE == "MiniLM-L6":
    text_model = SentenceTransformer('all-MiniLM-L6-v2')
elif TEXT_EMBEDDING_TYPE == "word2vec":
    text_model = KeyedVectors.load_word2vec_format('./word2vec/GoogleNews-vectors-negative300.bin', binary=True)
else:
    raise Exception("Invalid TEXT_EMBEDDING_TYPE. Supported types are MiniLM-L6.")

# Initialize image embedding model
image_model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
image_model.eval()

# Preprocessing transformations for the images
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Function to generate word2vec embeddings 
def get_word2vec_embedding(text, model):
    words = text.split()
    valid_word_vectors = [model[word] for word in words if word in model]

    if not valid_word_vectors:
        # Handle the case where none of the words are in the model
        return np.zeros(model.vector_size)

    # Compute the average of the word vectors
    return np.mean(valid_word_vectors, axis=0).tolist()

# Function to generate image embeddings
def generate_image_embedding(image_path):
    if IMAGE_EMBEDDING_TYPE == "resnet-18":
        image = Image.open(image_path).convert('RGB')
        image_tensor = preprocess(image).unsqueeze(0)
        with torch.no_grad():
            image_embedding = image_model(image_tensor).numpy().flatten()
        return image_embedding.tolist()
    elif IMAGE_EMBEDDING_TYPE == "clip":
        image = clip_preprocess(Image.open(image_path)).unsqueeze(0).to(device)
        with torch.no_grad():
            image_features = clip_model.encode_image(image)
            # Normalize the embeddings
            image_features /= image_features.norm(dim=-1, keepdim=True)
            # Convert the tensor to a numpy array if needed
            image_embeddings = image_features.cpu().numpy()
            return image_embeddings.tolist()
    else:
        raise Exception("Invalid IMAGE_EMBEDDING_TYPE. Supported types are resnet-18, CLIP.") 

# Generate embeddings
embeddings = {}
image_embeddings = {}
for key, values in tqdm(dataset.items()):
    # Generate text embedding for the key
    if TEXT_EMBEDDING_TYPE == "MiniLM-L6":
        text_embedding = text_model.encode(key[0]).tolist()
    elif TEXT_EMBEDDING_TYPE == "word2vec":
        text_embedding = get_word2vec_embedding(key[0], text_model)
        # print(type(text_embedding))
    else:
        raise Exception("Invalid text embedding model. Supported types are word2vec, MiniLM-L6")
    # Generate image embeddings for the first 5 values
    # image_embeddings = []
    word_list = key[0].split()
    for word in word_list:
        for item in unique_items:
            if word.lower() == item:
                furniture_name = word.lower()
    # print(furniture_name)
    # furniture_name = image_metadata.get(str(values[0]))['filename'].split('.page')[0]
    # for i in range(5):
    #     if i < len(values):
    #         image_info = image_metadata.get(str(values[i]))
    #         if image_info:
    #             image_path = image_info['filename']
                
    #             full_image_path = os.path.join('data_wiki', image_path)
    #             if os.path.exists(full_image_path):
    #                 image_embedding = generate_image_embedding(full_image_path)
    #                 image_embeddings.append(image_embedding)
    #             else:
    #                 print("Can't find image: ", full_image_path)
    #         else:
    #             print("Error for prompt: ", str(key))
    
    # Add additional embeddings for all uiuds containing 'name'
    all_image_embeddings = []
    is_first_img = 0
    list_of_idxs = values["idxs_and_scores"]
    for uiud, image_info in image_metadata.items():
        # print(f"{furniture_name} - {image_info['filename'].split('.page')[0]}") 
        if furniture_name == image_info['filename'].split('.page')[0]:
            # print(furniture_name) 
            if is_first_img == 0:
                embed_idxs = [x - int(uiud) for x in list_of_idxs[:5]]
                # print(f"{furniture_name} + {int(uiud)}")
                is_first_img += 1 
            image_path = image_info['filename']
            full_image_path = os.path.join('data_wiki', image_path)
            # print(full_image_path)
            if os.path.exists(full_image_path):
                image_embedding = generate_image_embedding(full_image_path)
                all_image_embeddings.append(image_embedding)
            else:
                print("Can't find image: ", full_image_path)

    # Sanity Check - Indices Should not be out of Bounds 
    for x in embed_idxs:
        if x >= len(all_image_embeddings) or x < 0:
            print(f"Error - index out of bounds for sample {furniture_name}")
            
    # Save Embeddings
    image_scores = list_of_idxs[5:]
    embeddings[key] = {
        'text_embedding': text_embedding,
        'image_embeddings_top5_idx': embed_idxs,
        'image_embeddings_key': furniture_name,
        'scores': image_scores,
    }
    if furniture_name not in image_embeddings.keys():
        image_embeddings[furniture_name] = all_image_embeddings
        
    

# Save the embeddings to a JSON file
# print(list(embeddings.keys())[0])
# print(list(embeddings.values())[0])
print("Size of image embedding", len(list(image_embeddings['alex'][0])))
print("Size of text embedding", len(list(embeddings.values())[0]['text_embedding']))
# print(list(embeddings.values())[0]['image_embeddings_top5_idx']

def remap_keys(mapping):
    return {str(k): v for k, v in mapping.items()}
    
with open(f'embeddings/{TEXT_EMBEDDING_TYPE.lower()}_{IMAGE_EMBEDDING_TYPE.lower()}_embeddings_{TRAIN_TEST_VAL_TYPE}_aug.json', 'w') as file:
    json.dump(remap_keys(embeddings), file, indent=4)

with open(f'embeddings/{TEXT_EMBEDDING_TYPE.lower()}_{IMAGE_EMBEDDING_TYPE.lower()}_embeddings_raw_{TRAIN_TEST_VAL_TYPE}_aug.json','w') as file:
    json.dump(image_embeddings, file, indent=4)

print(f"Embeddings have been saved to 'embeddings/{TEXT_EMBEDDING_TYPE.lower()}_{IMAGE_EMBEDDING_TYPE.lower()}_embeddings_{TRAIN_TEST_VAL_TYPE}_aug.json'.")


------------------------------------------------------------------------
--------------------GENERATING EMBEDDINGS--------------------
--------------------IMAGE EMBEDDINGS USING clip--------------------
--------------------TEXT EMBEDDINGS USING word2vec--------------------
--------------------DATA TYPE: train--------------------


Using cache found in /Users/Ali/.cache/torch/hub/pytorch_vision_v0.10.0
 64%|██████████████████████████████████▎                   | 595/936 [06:16<04:01,  1.41it/s]

In [22]:
total = 936+170+170
print(total)

1276


In [36]:
print("Size of image embedding", len(list(image_embeddings['alex'][0])))
print("Size of text embedding", len(list(embeddings.values())[0]['text_embedding']))


Size of image embedding 1000
Size of text embedding (300,)
