In [2]:
#!pip install git+https://github.com/openai/CLIP.git

In [3]:
import torch
from tqdm import tqdm
import os
import clip
from PIL import Image
import pandas as pd
import numpy as np
#from google.colab import drive

#loading CLIP MODEL and preprocessing function
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device = device)



In [4]:
# drive.mount('/content/drive')
# DATA_PATH = "/content/drive/MyDrive"
DATA_PATH = '/home/saurav/Documents'

csv_file = DATA_PATH+'/required_dataset/styles2.csv'
df = pd.read_csv(csv_file)
df

Unnamed: 0,id,gender,masterCategory,subCategory,articleType,baseColour,season,year,usage,productDisplayName,Unnamed: 10,Unnamed: 11
0,15970,Men,Apparel,Topwear,Shirts,Navy Blue,Fall,2011,Casual,Turtle Check Men Navy Blue Shirt,,
1,39386,Men,Apparel,Bottomwear,Jeans,Blue,Summer,2012,Casual,Peter England Men Party Blue Jeans,,
2,59263,Women,Accessories,Watches,Watches,Silver,Winter,2016,Casual,Titan Women Silver Watch,,
3,21379,Men,Apparel,Bottomwear,Track Pants,Black,Fall,2011,Casual,Manchester United Men Solid Black Track Pants,,
4,53759,Men,Apparel,Topwear,Tshirts,Grey,Summer,2012,Casual,Puma Men Grey T-shirt,,
...,...,...,...,...,...,...,...,...,...,...,...,...
44442,17036,Men,Footwear,Shoes,Casual Shoes,White,Summer,2013,Casual,Gas Men Caddy Casual Shoe,,
44443,6461,Men,Footwear,Flip Flops,Flip Flops,Red,Summer,2011,Casual,Lotto Men's Soccer Track Flip Flop,,
44444,18842,Men,Apparel,Topwear,Tshirts,Blue,Fall,2011,Casual,Puma Men Graphic Stellar Blue Tshirt,,
44445,46694,Women,Personal Care,Fragrance,Perfume and Body Mist,Blue,Spring,2017,Casual,Rasasi Women Blue Lady Perfume,,


In [5]:
#defining image paths and files
image_folder = DATA_PATH+'/required_dataset/images'
embeddings = {}
batch_size = 8
total_rows = len(df)
checkpoint_file = "clip_embeddings_checkpoint.pkl"

In [6]:
#function to combine text features from multiple columns
def create_text_description(row):
  columns = [
      str(row['gender']),
      str(row['masterCategory']),
      str(row['subCategory']),
      str(row['articleType']),
      str(row['baseColour']),
      str(row['season']),
      str(row['year']),
      str(row['usage']),
      str(row['productDisplayName'])
  ]
  #concatinating all relevant columns into single description
  #as join only supports string so converted all column values to str to avoid null and integer data types
  return ' '.join(columns)

In [None]:
#iterating through CSV rows to generate embeddings for each ID
for index, row in tqdm(df.iterrows()):
  image_id = str(row['id'])
  #combine multiple columns into one descriptive text
  text_description = create_text_description(row)

  #process the text to get text embeddings
  text_input = clip.tokenize([text_description]).to(device)
  with torch.no_grad():
    text_embedding = model.encode_text(text_input).cpu().numpy()

  #load and preprocess image corresponding to 'id'
  image_path = os.path.join(image_folder, f"{image_id}.jpg")
  if os.path.exists(image_path):
    image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
    with torch.no_grad():
      image_embedding = model.encode_image(image).cpu().numpy()

  else:
    print(f"Image not found for ID: {image_id}")
    image_embedding = None

  #store embeddings in a dictionary
  embeddings[image_id] = {
      'text_embedding': text_embedding,
      'image_embedding': image_embedding
  }


In [None]:
# Process the data in batches
for start_idx in tqdm(range(0, total_rows, batch_size)):
    end_idx = min(start_idx + batch_size, total_rows)
    batch = df.iloc[start_idx:end_idx]

    # Batch processing text descriptions
    text_descriptions = [create_text_description(row) for _, row in batch.iterrows()]
    text_inputs = clip.tokenize(text_descriptions).to(device)

    with torch.no_grad():
        text_embeddings = model.encode_text(text_inputs).cpu().numpy()

    # Batch processing images
    image_embeddings = []
    for _, row in batch.iterrows():
        image_id = str(row['id'])
        image_path = os.path.join(image_folder, f"{image_id}.jpg")

        if os.path.exists(image_path):
            image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
            with torch.no_grad():
                image_embedding = model.encode_image(image).cpu().numpy()
            image_embeddings.append(image_embedding)
        else:
            print(f"Image {image_id}.jpg not found, skipping image embedding.")
            image_embeddings.append(None)

    # Store embeddings for the current batch
    for i, (_,row) in enumerate(batch.iterrows()):
        image_id = str(row['id'])
        embeddings[image_id] = {
            "text_embedding": text_embeddings[i],
            "image_embedding": image_embeddings[i]
        }

# Save embeddings to a pickle file
output_file = "clip_embeddings_batch.pkl"
with open(output_file, 'wb') as f:
    pickle.dump(embeddings, f)

print(f"Embeddings have been saved to {output_file}")

 15%|██████████████████████▎                                                                                                                             | 837/5556 [22:18<2:02:58,  1.56s/it]

Image 39403.jpg not found, skipping image embedding.


 23%|██████████████████████████████████▍                                                                                                                | 1303/5556 [34:38<1:50:46,  1.56s/it]

Image l.jpg not found, skipping image embedding.


 36%|█████████████████████████████████████████████████████▌                                                                                             | 2026/5556 [54:10<1:32:29,  1.57s/it]

Image 39410.jpg not found, skipping image embedding.


 73%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                        | 4040/5556 [1:47:05<39:23,  1.56s/it]

Image 39401.jpg not found, skipping image embedding.


 82%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                          | 4550/5556 [2:00:30<25:34,  1.53s/it]

Image 39425.jpg not found, skipping image embedding.


 90%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎              | 5003/5556 [2:12:29<13:44,  1.49s/it]

Image 12347.jpg not found, skipping image embedding.


 94%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌         | 5200/5556 [2:17:39<09:10,  1.55s/it]

In [None]:
#preprocessing the image
image_path = (DATA_PATH+'/dataset/random_test_images/lehenga.png')

#.to(device) method moves the image tensor to the specified computing device either CPU or GPU
image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)

In [None]:
#Encode the  image into CLIP's embedding space
with torch.no_grad():
  image_features = model.encode_image(image)

#normalize the calculated embeddings
image_features /=  image_features.norm(dim = -1, keepdim = True)