In [9]:
import numpy as np
import pandas as pd
import pickle

import os

import torch
import matplotlib.pyplot as plt
import plotly.express as px

from sklearn.manifold import TSNE

In [22]:
save_dir = '/home/smart01/SFLAB/bonbak/output/clip'

image_embeddings = np.load(os.path.join(save_dir, 'image_embedding_clip.npy'))
img_id_decoder = pickle.load(open(os.path.join(save_dir,'image_embedding_ids.pickle'), 'rb'))
img_id_encoder = {item_id[:-2]:idx for idx, item_id in img_id_decoder.items()}

text_embeddings = np.load(os.path.join(save_dir, 'text_embedding_mclip.npy'))
txt_id_decoder = pickle.load(open(os.path.join(save_dir,'text_embedding_ids.pickle'), 'rb'))
txt_id_encoder = {item_id:idx for idx, item_id in txt_id_decoder.items()}

item_id_list = set(img_id_encoder.keys()) & set(txt_id_encoder.keys())

In [23]:
data_folder = '/home/smart01/SFLAB/DATA/mind_br_data_prepro_full/'

meta_df = pd.read_csv(data_folder+'meta_data_240312.csv', index_col=0)

In [24]:

cat_df = meta_df.loc[:,meta_df.columns.str.startswith('category')]
cat_df = pd.DataFrame(cat_df.idxmax(axis=1).apply(lambda x:x.split('_')[1]), columns=['category']).reset_index()
cat_df['image_index'] = cat_df['item_number'].apply(lambda x:img_id_encoder[x] if x in img_id_encoder else None)
cat_df['text_index'] = cat_df['item_number'].apply(lambda x:txt_id_encoder[x] if x in txt_id_encoder else None)
cat_df = cat_df.dropna().set_index('item_number')
cat_df['image_index'] = cat_df['image_index'].astype(int)
cat_df['text_index'] = cat_df['text_index'].astype(int)
# cat_df = cat_df[cat_df['category'].isin(['TS', 'PT', 'OP', 'KT', 'WS', 'CA', 'BL', 'DP'])]
cat_df = cat_df.reset_index()
cat_df

Unnamed: 0,item_number,category,image_index,text_index
0,JTBL126B,BL,15729,13805
1,JTBL226A,BL,13863,13804
2,JTBL320B,BL,997,13803
3,JTBL321A,BL,9841,13801
4,JTBL321B,BL,13344,13800
...,...,...,...,...
6845,MYTS0221,TS,5618,129
6846,MYWS0120,WS,13747,59
6847,MYWS01A1,WS,14660,57
6848,MYWS0201,WS,14380,55


In [18]:
embedding = image_embeddings[cat_df['image_index'].values]

label = cat_df['category']
x_tsne = TSNE(n_components=2, perplexity=10).fit_transform(embedding)
tsne_df = pd.DataFrame(x_tsne)
tsne_df = tsne_df.rename(columns={0:'x',1:'y'})
tsne_df = tsne_df.assign(label=label)

In [19]:
fig = px.scatter(
    tsne_df, x='x', y='y',
    color='label',
    title = 'Image Embedding Visualization'
)
fig.show()

In [25]:
embedding = text_embeddings[cat_df['text_index'].values]

label = cat_df['category']
x_tsne = TSNE(n_components=2, perplexity=10).fit_transform(embedding)
tsne_df = pd.DataFrame(x_tsne)
tsne_df = tsne_df.rename(columns={0:'x',1:'y'})
tsne_df = tsne_df.assign(label=label)

In [26]:
fig = px.scatter(
    tsne_df, x='x', y='y',
    color='label',
    title = 'Text Embedding Visualization'
)
fig.show()