In [5]:
import pandas as pd
import pyarrow.parquet as pq
from PIL import Image
import io
import torch
from tqdm import tqdm
import clip
import pickle

In [6]:
device = torch.device("cpu")
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

## 加载yolo数据

In [7]:
import json

# JSON文件路径
json_file_path = "/data/lab/STA-AS2/big_yolo_train.json"

# 打开并读取JSON文件
with open(json_file_path, 'r') as json_file:
    xuanzi = json.load(json_file)

## 分块跑代码保存数组

In [None]:
all_embeddings = []
all_captions = []
for q in range(20,38):
    path = f"/data/lab/STA-AS2/Orogindata/train-{q}.parquet"
    table = pq.read_table(path)
    df = table.to_pandas()
    for m in range(3):
        for i in tqdm(range(1000*m,1000*(m+1)), desc="Processing images"):
            if(i==2981):
                break
            row = df.loc[i]
            image_data_dict = row["image"] 
            image_data = image_data_dict.get("bytes", None)  # 获取二进制图像数据
            image = Image.open(io.BytesIO(image_data))
            image = preprocess(image).unsqueeze(0).to(device)
            with torch.no_grad():
                prefix = clip_model.encode_image(image).to(device)

            id00 = row["sentids"][0]
            imageid = row["cocoid"]
            caption = row ["sentences_raw"][0]
            a = {}
            a["imageid"]=imageid
            a["id"] =id00
            a["caption"]=caption
            a["clip_embedding"] = i

            more = xuanzi[str(imageid)]
            text = clip.tokenize(more).to(device)
            text_features = clip_model.encode_text(text)

            prefix_more = torch.cat((prefix,text_features),dim=1)
            all_embeddings.append(prefix_more)
            all_captions.append(a)


        torch.save(all_embeddings, f'/data/lab/STA-AS2/Formerge/{q}-{m}.pt')
        with open(f'/data/lab/STA-AS2/Formerge/{q}-{m}.pkl', 'wb') as pkl_file:
            pickle.dump(all_captions, pkl_file)
        all_embeddings.clear()
        all_captions.clear()
    print(f'{q}-Done')

## 合并数组

In [9]:
merge_tensor = []
merge_dict = []
for i in tqdm(range(3), desc="总进度"):
    for o in range(3):
        pklpath = f"/data/lab/STA-AS2/littlemerge/{i}-{o}.pkl"
        ptpath = f"/data/lab/STA-AS2/littlemerge/{i}-{o}.pt"
        
        with open(pklpath, 'rb') as pkl_file:
            loaded_data = pickle.load(pkl_file)
        merge_dict.extend(loaded_data)
        
        tmp = torch.load(ptpath)
        merge_tensor.extend(tmp)

总进度: 100%|██████████| 3/3 [00:00<00:00,  7.92it/s]


## 输出训练数据格式

In [10]:
out_path = "/data/lab/STA-AS2/oscar_split_ViT-B_train4.pkl"
with open(out_path, 'wb') as f:
    pickle.dump({"clip_embedding": torch.cat(merge_tensor, dim=0), "captions": merge_dict}, f)