In [2]:
#!pip install git+https://github.com/openai/CLIP.git

In [15]:
import torch
from tqdm import tqdm
import os
import clip
from PIL import Image
import pandas as pd
import numpy as np
import pickle
#from google.colab import drive

#loading CLIP MODEL and preprocessing function
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device = device)

In [16]:
# drive.mount('/content/drive')
# DATA_PATH = "/content/drive/MyDrive"
DATA_PATH = '/home/saurav/Documents'

csv_file = DATA_PATH+'/required_dataset/styles2.csv'
df = pd.read_csv(csv_file)
df

Unnamed: 0,id,gender,masterCategory,subCategory,articleType,baseColour,season,year,usage,productDisplayName,Unnamed: 10,Unnamed: 11
0,15970,Men,Apparel,Topwear,Shirts,Navy Blue,Fall,2011,Casual,Turtle Check Men Navy Blue Shirt,,
1,39386,Men,Apparel,Bottomwear,Jeans,Blue,Summer,2012,Casual,Peter England Men Party Blue Jeans,,
2,59263,Women,Accessories,Watches,Watches,Silver,Winter,2016,Casual,Titan Women Silver Watch,,
3,21379,Men,Apparel,Bottomwear,Track Pants,Black,Fall,2011,Casual,Manchester United Men Solid Black Track Pants,,
4,53759,Men,Apparel,Topwear,Tshirts,Grey,Summer,2012,Casual,Puma Men Grey T-shirt,,
...,...,...,...,...,...,...,...,...,...,...,...,...
44442,17036,Men,Footwear,Shoes,Casual Shoes,White,Summer,2013,Casual,Gas Men Caddy Casual Shoe,,
44443,6461,Men,Footwear,Flip Flops,Flip Flops,Red,Summer,2011,Casual,Lotto Men's Soccer Track Flip Flop,,
44444,18842,Men,Apparel,Topwear,Tshirts,Blue,Fall,2011,Casual,Puma Men Graphic Stellar Blue Tshirt,,
44445,46694,Women,Personal Care,Fragrance,Perfume and Body Mist,Blue,Spring,2017,Casual,Rasasi Women Blue Lady Perfume,,


In [17]:
#defining image paths and files
image_folder = DATA_PATH+'/required_dataset/images'
embeddings = {}
batch_size = 8
total_rows = len(df)
checkpoint_file = "clip_embeddings_checkpoint.pkl"

In [18]:
#function to combine text features from multiple columns
def create_text_description(row):
  columns = [
      str(row['gender']),
      str(row['masterCategory']),
      str(row['subCategory']),
      str(row['articleType']),
      str(row['baseColour']),
      str(row['season']),
      str(row['year']),
      str(row['usage']),
      str(row['productDisplayName'])
  ]
  #concatinating all relevant columns into single description
  #as join only supports string so converted all column values to str to avoid null and integer data types
  return ' '.join(columns)

In [19]:
#load the checkpoint if exists
if os.path.exists(checkpoint_file):
    with open(checkpoint_file, 'rb') as f:
        embeddings = pickle.load(f)
    processed_ids = set(embeddings.keys())  
else:
    embeddings = {}
    processed_ids = set()

In [20]:
# Process the data in batches
for start_idx in tqdm(range(0, total_rows, batch_size)):
    end_idx = min(start_idx + batch_size, total_rows)
    batch = df.iloc[start_idx:end_idx]

    #filter out rows whose IDs have already been processed
    batch = batch[~batch['id'].astype(str).isin(processed_ids)]

    # Batch processing text descriptions
    text_descriptions = [create_text_description(row) for _, row in batch.iterrows()]
    text_inputs = clip.tokenize(text_descriptions).to(device)

    with torch.no_grad():
        text_embeddings = model.encode_text(text_inputs).cpu().numpy()

    # Batch processing images
    image_embeddings = []
    for _, row in batch.iterrows():
        image_id = str(row['id'])
        image_path = os.path.join(image_folder, f"{image_id}.jpg")

        if os.path.exists(image_path):
            image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
            with torch.no_grad():
                image_embedding = model.encode_image(image).cpu().numpy()
            image_embeddings.append(image_embedding)
        else:
            print(f"Image {image_id}.jpg not found, skipping image embedding.")
            image_embeddings.append(None)

    # Store embeddings for the current batch
    for i, (_,row) in enumerate(batch.iterrows()):
        image_id = str(row['id'])
        embeddings[image_id] = {
            "text_embedding": text_embeddings[i],
            "image_embedding": image_embeddings[i]
        }
        processed_ids.add(image_id)   #mark the image ID as processed

# Save embeddings to a pickle file
output_file = "clip_embeddings_batch.pkl"
with open(output_file, 'wb') as f:
    pickle.dump(embeddings, f)

print(f"Embeddings have been saved to {output_file}")

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5556/5556 [01:46<00:00, 52.14it/s]


Embeddings have been saved to clip_embeddings_batch.pkl


In [14]:
output_file = "clip_embeddings_checkpoint.pkl"
with open(output_file, 'wb') as f:
    pickle.dump(embeddings, f)

In [13]:
processed_ids

{'25556',
 '9777',
 '15883',
 '10664',
 '35539',
 '27773',
 '29608',
 '41244',
 '16206',
 '14584',
 '38440',
 '35462',
 '58482',
 '59160',
 '58563',
 '52887',
 '57767',
 '13934',
 '41356',
 '31846',
 '39183',
 '17264',
 '3402',
 '35021',
 '38156',
 '17403',
 '33848',
 '4609',
 '24408',
 '48237',
 '10507',
 '35997',
 '8694',
 '27737',
 '3615',
 '37786',
 '14639',
 '18892',
 '36053',
 '54168',
 '40424',
 '6681',
 '32353',
 '22844',
 '58734',
 '14845',
 '30745',
 '24688',
 '6171',
 '7432',
 '36747',
 '5702',
 '33037',
 '38502',
 '39939',
 '44149',
 '53832',
 '25107',
 '17614',
 '59639',
 '26598',
 '13811',
 '39960',
 '43153',
 '33090',
 '26040',
 '18619',
 '2608',
 '42651',
 '44782',
 '49039',
 '55112',
 '12186',
 '58529',
 '59962',
 '57506',
 '23165',
 '19301',
 '24604',
 '32845',
 '8333',
 '48183',
 '12653',
 '48189',
 '20080',
 '47326',
 '7544',
 '25333',
 '45723',
 '57625',
 '32971',
 '9545',
 '46139',
 '13233',
 '7498',
 '53424',
 '21422',
 '44167',
 '57881',
 '33171',
 '51159',
 '55

In [33]:
len(embeddings)

44447

In [None]:
# #iterating through CSV rows to generate embeddings for each ID
# for index, row in tqdm(df.iterrows()):
#   image_id = str(row['id'])
#   #combine multiple columns into one descriptive text
#   text_description = create_text_description(row)

#   #process the text to get text embeddings
#   text_input = clip.tokenize([text_description]).to(device)
#   with torch.no_grad():
#     text_embedding = model.encode_text(text_input).cpu().numpy()

#   #load and preprocess image corresponding to 'id'
#   image_path = os.path.join(image_folder, f"{image_id}.jpg")
#   if os.path.exists(image_path):
#     image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
#     with torch.no_grad():
#       image_embedding = model.encode_image(image).cpu().numpy()

#   else:
#     print(f"Image not found for ID: {image_id}")
#     image_embedding = None

#   #store embeddings in a dictionary
#   embeddings[image_id] = {
#       'text_embedding': text_embedding,
#       'image_embedding': image_embedding
#   }
