In [1]:
from utils.CleanData import Operate, filter_language
from utils.FilterData import FilterPara

import os
import time
import io
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import dask.dataframe as dd
import torch
import clip
import dask
import faiss
import dask.array as da

from datetime import datetime
from PIL import Image
from glob import glob
from dask.distributed import Client
from dask import delayed


pd.set_option('display.max_colwidth', 1024)


# min_ratio = 0.9
# max_ratio = 1.1

# min_width, max_width = 128, 1024
# min_height, max_height = 128, 1024

# min_length, max_length = 16,1024

# desired_rows_per_partition = 6000
# batch_size = 3200

# # faiss
# faiss_k = 100  # 只搜索最近的100个邻居
# faiss_threshold = 0.98 # image feature cos 

In [2]:
def show(df, showlen = 5):
    print(f"lenth : {len(df)}")
    print(f"{df.head(showlen)}")

In [3]:

directory = '/mnt/alluxio/alluxio-fuse/user/tc_agi/klara/datasets/laion2b_en/laion2b_en_20230417112304'
# directory = '/mnt/alluxio/alluxio-fuse/user/tc_agi/klara/datasets/laion2b_multi_chinese_subset/laion2b_multi_chinese_subset'
files = glob(os.path.join(directory, '*.parquet'))
files = [f for f in files if not os.path.basename(f).startswith('.')]

In [None]:
sample_files = files[100:110]
sample_df = dd.read_parquet(sample_files)

min_ratio = 0.9
max_ratio = 1.1
sample_df['RATIO'] = sample_df['WIDTH'] / sample_df['HEIGHT']

sample_df = sample_df[(sample_df['RATIO'] >= min_ratio) & (sample_df['RATIO'] <= max_ratio)]
sample_df = sample_df.drop('RATIO', axis=1)

In [None]:
min_width, max_width = 128, 1024
min_height, max_height = 128, 1024

sample_df = sample_df[(sample_df['WIDTH'] < max_width) & (sample_df['HEIGHT'] < max_height)]
sample_df = sample_df[(sample_df['WIDTH'] > min_width) & (sample_df['HEIGHT'] > min_height)]

In [None]:
min_length, max_length = 16,1024
sample_df= sample_df[(sample_df['TEXT'].str.len() >= min_length) & (sample_df['TEXT'].str.len() <= max_length)]

In [None]:
df = sample_df
df['TEXT'] = df['TEXT'].map(Operate, meta=('TEXT', 'object'))
df = df[df['TEXT'].apply(filter_language, meta=('TEXT', 'bool'))]

In [None]:
import clip
df_clip = df.compute()

desired_rows_per_partition = 6000

npartitions = len(df_clip) // desired_rows_per_partition
df_clip = dd.from_pandas(df_clip, npartitions=npartitions)

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("/mnt/data/user/tc_agi/multi_modal/checkpoints/clip/ViT-B-16.pt", device=device)

In [None]:
def process_image(row):
    try:
        img = Image.open(io.BytesIO(row['BUFFER']))
        tensor = preprocess(img).to(device)
        return tensor, row
    except Exception as e:
        print(f"Error processing image: {e}")
        return None, None

gpu_lock = threading.Lock()

def compute_embeddings(partition, batch_size=3200):
    with gpu_lock:

        all_embeddings = []
        valid_rows = []  

        for start in range(0, len(partition), batch_size):
            end = start + batch_size
            batch = partition.iloc[start:end]
            print(len(batch))
            processed_tensors = []
            for _, row in batch.iterrows():
                tensor, valid_row = process_image(row)
                if tensor is not None:
                    processed_tensors.append(tensor)
                    valid_rows.append(valid_row)

            tensor_stack = torch.stack(processed_tensors)
            
            with torch.no_grad():
                embedding = model.encode_image(tensor_stack)
            all_embeddings.extend(embedding.cpu().numpy())

        valid_partition = pd.concat(valid_rows, axis=1).transpose()
        valid_partition['CLIP_Features'] = all_embeddings

        return valid_partition




batch_size = 3200
df_clip = df_clip.map_partitions(compute_embeddings, batch_size=batch_size, meta=df_clip._meta.assign(CLIP_Features='f8'))
result = df_clip.compute()


In [None]:
df_pandas = result
df_pandas = df_pandas.set_index('SAMPLE_ID')

In [None]:
faiss_k = 100  # 只搜索最近的100个邻居
faiss_threshold = 0.98 # image feature cos 


embeddings_matrix = np.vstack(df_pandas['CLIP_Features'].to_list()).astype('float32')  
faiss.normalize_L2(embeddings_matrix)

nlist = 100
quantizer = faiss.IndexFlatL2(embeddings_matrix.shape[1])
index_cpu = faiss.IndexIVFFlat(quantizer, embeddings_matrix.shape[1], nlist, faiss.METRIC_L2)

res = faiss.StandardGpuResources()
index = faiss.index_cpu_to_gpu(res, 0, index_cpu)

assert not index.is_trained
index.train(embeddings_matrix)
assert index.is_trained

index.add(embeddings_matrix)

k = faiss_k  # 只搜索最近的100个邻居
D, I = index.search(embeddings_matrix, k)

threshold = faiss_threshold
similar_pairs = []

for i in range(I.shape[0]):
    similarities = 1 - D[i] / 2  
    filtered_sample_ids = df_pandas.index[I[i][(similarities > threshold) & (I[i] != i)]].tolist()
    if filtered_sample_ids:
        similar_pairs.append((df_pandas.index[i], filtered_sample_ids))


In [None]:
df1 = result
df1 = df1.set_index('SAMPLE_ID')
df1['CLIP_Features'] = df1['CLIP_Features'].apply(lambda x: np.array(x).astype('float32') if isinstance(x, (list, np.ndarray)) else x)

df_to_save = df1[['CLIP_Features']]

df_to_save.to_parquet('saved_data.parquet')


In [None]:
from PIL import Image
import io
import matplotlib.pyplot as plt

def retrieve_and_show_images(sample_id1, similar_sample_ids, dataframe, max_similar=5):
    
    similar_sample_ids = [sid for sid in similar_sample_ids if sid != sample_id1]
    
    image1 = Image.open(io.BytesIO(dataframe.loc[sample_id1]['BUFFER']))
    plt.figure(figsize=(10, 10))
    plt.subplot(1, min(len(similar_sample_ids), max_similar) + 1, 1)
    plt.imshow(image1)
    plt.axis('off')
    
    for j, sample_id2 in enumerate(similar_sample_ids[:max_similar], start=2): 
        image2 = Image.open(io.BytesIO(dataframe.loc[sample_id2]['BUFFER']))
        plt.subplot(1, min(len(similar_sample_ids), max_similar) + 1, j)
        plt.imshow(image2)
        plt.axis('off') # 移除坐标轴
        
    plt.tight_layout()
    plt.show()

for idx, (sample_id1, similar_sample_ids) in enumerate(similar_pairs):
    if idx >= 40: 
        break
    retrieve_and_show_images(sample_id1, similar_sample_ids, df_pandas)


# test

In [5]:
import dask.dataframe as dd
from dask.multiprocessing import get
from PIL import Image
import io

def is_image_valid(image_data):
    try:
        with Image.open(io.BytesIO(image_data)) as img:
            img.verify()
        return True
    except:
        return False

df['is_valid'] = df['BUFFER'].map(is_image_valid, meta=('BUFFER', 'bool'))

df_remove = df[df['is_valid'] == False]
df = df[df['is_valid'] == True]

df = df.drop('is_valid', axis=1)

computed_df = df.compute(scheduler='processes')  
computed_df_remove = df_remove.compute(scheduler='processes')  


In [6]:
computed_df_remove

Unnamed: 0,SAMPLE_ID,URL,TEXT,HEIGHT,WIDTH,LICENSE,NSFW,similarity,BUFFER,IMG_TYPE,is_valid
7903,4186990000000.0,https://cdn.shopify.com/s/files/1/0016/4390/51...,Gold Pigger Ring Women,960,960,?,UNLIKELY,0.314278,b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\...,png,False


In [None]:
print(len(result))
print(len(filtered_df))

10614
10615


In [None]:
import time

def vectorized_method(partition):
    partition['Image'] = partition['BUFFER'].apply(safe_open)
    partition.compute()

def loop_based_method(partition):
    images = []
    for idx, row in partition.iterrows():
        img = safe_open(row['BUFFER'])

start_time = time.time()
vectorized_method(filtered_df) 
vectorized_duration = time.time() - start_time

start_time = time.time()
loop_based_method(filtered_df)  
loop_based_duration = time.time() - start_time

print(f"Vectorized method took: {vectorized_duration} seconds")
print(f"Loop-based method took: {loop_based_duration} seconds")

You did not provide metadata, so Dask is running your function on a small dataset to guess output types. It is possible that Dask will guess incorrectly.
To provide an explicit output types or to silence this message, please provide the `meta=` keyword, as described in the map or apply function that you are using.
  Before: .apply(func)
  After:  .apply(func, meta=('BUFFER', 'object'))



Vectorized method took: 38.581528186798096 seconds
Loop-based method took: 42.90600252151489 seconds


In [None]:
# import os
# from multiprocessing import Pool

# def parallel_preprocess(buffer):
#     try:
#         img = Image.open(io.BytesIO(buffer))
#         return preprocess(img)
#     except Exception as e:
#         return None

# def compute_embeddings(partition, batch_size=4000):
#     all_embeddings = []

#     n_processes = os.cpu_count()

#     with Pool(n_processes) as pool:
#         partition['Processed_Image'] = list(pool.map(parallel_preprocess, partition['BUFFER'].tolist()))

#     partition['Is_Exception'] = partition['Processed_Image'].isnull()
#     valid_partition = partition[~partition['Is_Exception']]
    
#     tensors = [tensor for tensor in valid_partition['Processed_Image'].tolist() if tensor is not None]
    
#     for start in range(0, len(tensors), batch_size):
#         end = start + batch_size
#         batch_tensors = tensors[start:end]

#         tensor_stack = torch.stack(batch_tensors).to(device)
#         with torch.no_grad():
#             embedding = model.encode_image(tensor_stack)
#         all_embeddings.extend(embedding.cpu().numpy())

#     valid_partition['CLIP_Features'] = all_embeddings

#     return valid_partition
