In [1]:
import numpy as np
import pandas as pd
import pickle
import os
import torch
import clip
import transformers

from PIL import Image
from multilingual_clip import pt_multilingual_clip
from fashion_clip.fashion_clip import FashionCLIP
from tqdm.notebook import tqdm
from sklearn.metrics.pairwise import cosine_similarity

# LOAD DATA 

In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

text_folder = '/home/smart01/SFLAB/DATA/mind_br_data_prepro_full/'
text_df = pd.read_csv(os.path.join(text_folder, 'text_description.csv')).set_index('품번')

image_folder = '/home/smart01/SFLAB/DATA/mind_br_data_full_240227/images'
image_file_list = os.listdir(image_folder)

# Fashion CLIP TEXT ENCODER

In [5]:
fclip = FashionCLIP('fashion-clip')

In [None]:
text_embedding_ids = {}

idx = 0
batch_size = 256
text_embeddings = np.zeros((1,512))
for temp_df in tqdm(np.array_split(text_df, len(text_df)//batch_size)):

    for offset, item_id in enumerate(temp_df.index):
        text_embedding_ids[idx+offset] = item_id
    
    text = temp_df['설명'].tolist()
    text_embs = fclip.encode_text(text, batch_size=batch_size)
    text_embeddings = np.concatenate([text_embeddings, text_embs])

    idx = len(text_embedding_ids)

In [5]:
save_dir = '/home/smart01/SFLAB/bonbak/data/output/clip'
pickle.dump(text_embedding_ids, open(os.path.join(save_dir,'text_embedding_ids.pickle'), 'wb'))
np.save(os.path.join(save_dir, 'fclip_text_embedding.npy'), text_embeddings[1:])
text_embeddings = np.load(os.path.join(save_dir, 'fclip_text_embedding.npy'))
text_embeddings.shape

(18160, 512)

# Fashion CLIP IMAGE ENCODER

In [21]:
image_embedding_ids = {}
for idx, image_file in enumerate(image_file_list):
    item_id = image_file.split('_')[0]
    image_embedding_ids[idx] = item_id

image_file_paths = [os.path.join(image_folder, image_file) for image_file in image_file_list]
image_embeddings = fclip.encode_images(image_file_paths, batch_size=1)

100%|██████████| 1/1 [00:00<00:00,  5.52it/s]


In [8]:
save_dir = '/home/smart01/SFLAB/bonbak/data/output/clip'
pickle.dump(image_embedding_ids, open(os.path.join(save_dir,'image_embedding_ids.pickle'), 'wb'))
np.save(os.path.join(save_dir, 'fclip_image_embedding.npy'), image_embeddings)
image_embeddings = np.load(os.path.join(save_dir, 'fclip_image_embedding.npy'))
image_embeddings.shape

(15899, 512)

# Multilingual CLIP TEXT ENCODER

In [None]:
model_name = 'M-CLIP/XLM-Roberta-Large-Vit-L-14'

text_encoder = pt_multilingual_clip.MultilingualCLIP.from_pretrained(model_name).to(device)
text_tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

In [None]:
text_embedding_ids = {}

idx = 0
batch_size = 8
text_embeddings = np.zeros((1,768))
for temp_df in tqdm(np.array_split(text_df, len(text_df)//batch_size)):

    for offset, item_id in enumerate(temp_df.index):
        text_embedding_ids[idx+offset] = item_id
    
    txt = temp_df['설명'].tolist()
    # text_embs = text_encoder.forward(txt, text_tokenizer)

    txt_tok = text_tokenizer(txt, padding=True, return_tensors='pt').to(device)
    embs = text_encoder.transformer(**txt_tok)[0]
    att = txt_tok['attention_mask']
    embs = (embs * att.unsqueeze(2)).sum(dim=1) / att.sum(dim=1)[:, None]
    text_embs = text_encoder.LinearTransformation(embs)
    text_embeddings = np.concatenate([text_embeddings, text_embs.detach().cpu().numpy()])

    idx = len(text_embedding_ids)

In [None]:
save_dir = '/home/smart01/SFLAB/bonbak/data/output/clip'
pickle.dump(text_embedding_ids, open(os.path.join(save_dir,'text_embedding_ids.pickle'), 'wb'))
np.save(os.path.join(save_dir, 'mclip_text_embedding.npy'), text_embeddings[1:])
text_embeddings = np.load(os.path.join(save_dir, 'mclip_text_embedding.npy'))
text_embeddings.shape

# CLIP TEXT ENCODER

In [None]:
encoder, image_preprocess = clip.load("ViT-L/14", device=device)

In [None]:
text_embedding_ids = {}

idx = 0
batch_size = 8
text_embeddings = np.zeros((1,768))

with torch.no_grad():
    for temp_df in tqdm(np.array_split(text_df, len(text_df)//batch_size)):

        for offset, item_id in enumerate(temp_df.index):
            text_embedding_ids[idx+offset] = item_id
        
        text = temp_df['설명'].tolist()
        text = clip.tokenize(text, truncate=True).to(device)

        text_embs = encoder.encode_text(text)
        text_embeddings = np.concatenate([text_embeddings, text_embs.detach().cpu().numpy()])

        idx = len(text_embedding_ids)

In [None]:
save_dir = '/home/smart01/SFLAB/bonbak/data/output/clip'
pickle.dump(text_embedding_ids, open(os.path.join(save_dir,'text_embedding_ids.pickle'), 'wb'))
np.save(os.path.join(save_dir, 'clip_text_embedding.npy'), text_embeddings[1:])
text_embeddings = np.load(os.path.join(save_dir, 'clip_text_embedding.npy'))
text_embeddings.shape

# CLIP IMAGE ENCODER

In [None]:
image_encoder, image_preprocess = clip.load("ViT-L/14", device=device)

In [None]:
image_embeddings = np.zeros((1,768))

image_embedding_ids = {}
for idx, image_file in enumerate(tqdm(image_file_list)):
    item_id = image_file.split('_')[0]
    
    image = image_preprocess(Image.open(os.path.join(image_folder, image_file))).unsqueeze(0).to(device)
    image_embs = image_encoder.encode_image(image)

    image_embeddings = np.concatenate([image_embeddings, image_embs.detach().cpu().numpy()])
    image_embedding_ids[idx] = item_id

In [None]:
save_dir = '/home/smart01/SFLAB/bonbak/data/output/clip'
pickle.dump(image_embedding_ids, open(os.path.join(save_dir,'image_embedding_ids.pickle'), 'wb'))
np.save(os.path.join(save_dir, 'clip_image_embedding.npy'), image_embeddings[1:])
image_embeddings = np.load(os.path.join(save_dir, 'clip_image_embedding.npy'))
image_embeddings.shape

# Cosine Similarity

In [19]:
save_dir = '/home/smart01/SFLAB/bonbak/output/clip'

image_embeddings = np.load(os.path.join(save_dir, 'image_embedding_clip.npy'))
img_id_decoder = pickle.load(open(os.path.join(save_dir,'image_embedding_ids.pickle'), 'rb'))
img_id_encoder = {item_id[:-2]:idx for idx, item_id in img_id_decoder.items()}

text_embeddings = np.load(os.path.join(save_dir, 'text_embedding_mclip.npy'))
txt_id_decoder = pickle.load(open(os.path.join(save_dir,'text_embedding_ids.pickle'), 'rb'))
txt_id_encoder = {item_id:idx for idx, item_id in txt_id_decoder.items()}

item_id_list = set(img_id_encoder.keys()) & set(txt_id_encoder.keys())

In [20]:
txt_ids = []
img_ids = []
for item_id in item_id_list:
    img_ids.append(img_id_encoder[item_id])
    txt_ids.append(txt_id_encoder[item_id])

In [21]:
cossim = cosine_similarity(image_embeddings[img_ids], text_embeddings[txt_ids])
np.mean(np.diag(cossim))

0.24088835397938502