In [1]:
import os
import glob
import torch
from tqdm import tqdm

SAVE_PATH = "/workspace/dataset/MELD/dev/pt"   # 你的单样本 .pt 存放目录
OUT_PATH  = "/workspace/dataset/MELD/dev/all_embeddings.pt"  # 合并后的总文件

# === 可选：分片保存（避免一次太大）===
ENABLE_SHARD = False     # 想分片就改 True
SHARD_SIZE   = 1000      # 每个分片包含多少个样本
SHARD_PREFIX = "/workspace/dataset/MELD/dev/all_embeddings_shard"  # 分片前缀

def load_one(sample_path):
    obj = torch.load(sample_path, map_location="cpu")
    # 基本字段校验（按你之前的保存格式）
    required_keys = ["id", "text_emb", "audio_emb", "video_emb", "meta"]
    if not all(k in obj for k in required_keys):
        raise ValueError(f"{sample_path} 缺少必要字段，实际键={list(obj.keys())}")
    # 确保张量在 CPU
    for k in ["text_emb", "audio_emb", "video_emb"]:
        if isinstance(obj[k], torch.Tensor):
            obj[k] = obj[k].cpu().contiguous()
    return obj

def merge_all():
    files = sorted(glob.glob(os.path.join(SAVE_PATH, "*.pt")))
    assert files, f"目录为空：{SAVE_PATH}"

    seen = set()
    merged = []
    shard_idx = 0

    for fp in tqdm(files, desc="Merging"):
        try:
            sample = load_one(fp)
        except Exception as e:
            print(f"[WARN] 跳过 {fp}: {e}")
            continue

        sid = sample["id"]
        if sid in seen:
            print(f"[DUP] 跳过重复 id: {sid}")
            continue
        seen.add(sid)
        merged.append(sample)

        # 分片保存（可选）
        if ENABLE_SHARD and len(merged) >= SHARD_SIZE:
            shard_path = f"{SHARD_PREFIX}_{shard_idx:03d}.pt"
            torch.save(merged, shard_path)
            print(f"[SAVE] 分片 -> {shard_path} (samples={len(merged)})")
            shard_idx += 1
            merged.clear()

    if ENABLE_SHARD:
        # 收尾分片
        if merged:
            shard_path = f"{SHARD_PREFIX}_{shard_idx:03d}.pt"
            torch.save(merged, shard_path)
            print(f"[SAVE] 分片(最后) -> {shard_path} (samples={len(merged)})")
        print("分片保存完成。")
    else:
        # 单文件保存
        torch.save(merged, OUT_PATH)
        print(f"[DONE] 合并完成：{OUT_PATH}  (总样本数={len(merged)})")

if __name__ == "__main__":
    merge_all()


Merging: 100%|██████████| 1108/1108 [00:00<00:00, 4242.73it/s]


[DONE] 合并完成：/workspace/dataset/MELD/dev/all_embeddings.pt  (总样本数=1108)


In [2]:
import torch
data = torch.load("/workspace/dataset/MELD/dev/dev_embeddings.pt", map_location="cpu")
print(len(data), data[0]["id"], data[0]["text_emb"].shape, data[0]["audio_emb"].shape, data[0]["video_emb"].shape)


1108 dia0_utt0 torch.Size([1, 1024]) torch.Size([1, 1024]) torch.Size([1, 1024])


In [3]:
import os, glob, torch
from collections import defaultdict

SAVE_PATH = "/workspace/dataset/MELD/dev/pt"

files = sorted(glob.glob(os.path.join(SAVE_PATH, "*.pt")))
print("[SCAN] files on disk:", len(files))

failed_load = []
missing_keys = []
dups = defaultdict(list)
id_to_files = defaultdict(list)
name_id_mismatch = []

for fp in files:
    try:
        o = torch.load(fp, map_location="cpu")
    except Exception as e:
        failed_load.append((fp, repr(e)))
        continue

    # 必要字段
    req = ["id", "text_emb", "audio_emb", "video_emb", "meta"]
    if not all(k in o for k in req):
        missing_keys.append((fp, list(o.keys())))
        continue

    sid = o["id"]
    id_to_files[sid].append(fp)

    # 文件名 vs 内部 id 是否一致（可帮助定位“覆盖/复用 id”的问题）
    base = os.path.splitext(os.path.basename(fp))[0]
    if base != sid:
        name_id_mismatch.append((fp, sid))

# 重复 id
dup_ids = {k: v for k, v in id_to_files.items() if len(v) > 1}

print("[OK] loadable files:", len(files) - len(failed_load))
print("[FAIL] unreadable files:", len(failed_load))
print("[MISS-KEY] missing key files:", len(missing_keys))
print("[DUP-ID] duplicate ids:", len(dup_ids))
print("[NAME≠ID] filename-id mismatch:", len(name_id_mismatch))

if failed_load:
    print("\nUnreadable examples (first 5):")
    for x in failed_load[:5]:
        print(" ", x)

if missing_keys:
    print("\nMissing-key examples (first 5):")
    for x in missing_keys[:5]:
        print(" ", x)

if dup_ids:
    print("\nDuplicate id → files (first 5 ids):")
    for i, (k, v) in enumerate(dup_ids.items()):
        if i >= 5: break
        print(" ", k, "->", v)

if name_id_mismatch:
    print("\nName vs id mismatch (first 5):")
    for x in name_id_mismatch[:5]:
        print(" ", x)

print("\n[SUMMARY] unique ids found:", len(id_to_files))


[SCAN] files on disk: 1108
[OK] loadable files: 1108
[FAIL] unreadable files: 0
[MISS-KEY] missing key files: 0
[DUP-ID] duplicate ids: 0
[NAME≠ID] filename-id mismatch: 0

[SUMMARY] unique ids found: 1108


In [4]:
import os, glob, torch
import pandas as pd

VIDEO_DIR = "/workspace/dataset/MELD/dev/dev_splits"
AUDIO_DIR = "/workspace/dataset/MELD/dev/wav"
SAVE_PATH = "/workspace/dataset/MELD/dev/pt"
CSV_PATH  = "/workspace/dataset/MELD/dev/dev_sent_emo.csv"   # 你提取时用的 df 来源

df = pd.read_csv(CSV_PATH)

def make_uid(row):
    # 按你当时的写法：注意你原始代码是把 Dialogue_ID 转成 int 的
    dia_id = int(row["Dialogue_ID"])
    utt_id = int(row["Utterance_ID"])
    return f"dia{dia_id}_utt{utt_id}"

# 只有“视频和音频都存在”的才会被保存（你有个 if not exists(video) or not exists(audio): continue）
expected = []
for _, row in df.iterrows():
    uid = make_uid(row)
    vp = os.path.join(VIDEO_DIR, f"{uid}.mp4")
    ap = os.path.join(AUDIO_DIR, f"{uid}.wav")
    if os.path.exists(vp) and os.path.exists(ap):
        expected.append(uid)

expected = set(expected)
print("[EXPECTED] samples that should be saved:", len(expected))

on_disk = { os.path.splitext(os.path.basename(p))[0] 
            for p in glob.glob(os.path.join(SAVE_PATH, "*.pt")) }

missing = sorted(expected - on_disk)
extra   = sorted(on_disk - expected)

print("[ON DISK] .pt files:", len(on_disk))
print("[MISSING] expected but not found:", len(missing))
print("[EXTRA] found but not expected:", len(extra))

if missing:
    print("\nMissing examples (first 20):")
    for x in missing[:20]:
        print(" ", x)

if extra:
    print("\nExtra examples (first 20):")
    for x in extra[:20]:
        print(" ", x)


[EXPECTED] samples that should be saved: 1108
[ON DISK] .pt files: 1108
[MISSING] expected but not found: 0
[EXTRA] found but not expected: 0


In [6]:
import torch
from pprint import pprint   # pprint 打印字典结构

path = "/workspace/dataset/MELD/dev/dev_embeddings.pt"
data = torch.load(path, map_location="cpu")

print(f"样本总数: {len(data)}\n")

first = data[0]
print("id:", first["id"])
print("text_emb shape :", tuple(first["text_emb"].shape))
print("audio_emb shape:", tuple(first["audio_emb"].shape))
print("video_emb shape:", tuple(first["video_emb"].shape))

print("\n==== meta ====")
pprint(first["meta"])  # 结构化打印 meta 字段

# 如果你还想看具体的 embedding 数值（前几个即可）
print("\n==== text_emb 前10维 ====")
print(first["text_emb"].flatten()[:10])

print("\n==== audio_emb 前10维 ====")
print(first["audio_emb"].flatten()[:10])

print("\n==== video_emb 前10维 ====")
print(first["video_emb"].flatten()[:10])


样本总数: 1108

id: dia0_utt0
text_emb shape : (1, 1024)
audio_emb shape: (1, 1024)
video_emb shape: (1, 1024)

==== meta ====
{'Emotion': 'sadness',
 'EndTime': '00:21:00,049',
 'Episode': 7,
 'Season': 4,
 'Sentiment': 'negative',
 'Speaker': 'Phoebe',
 'StartTime': '00:20:57,256'}

==== text_emb 前10维 ====
tensor([-0.4725,  4.9154,  2.1138, -0.2989, -1.2533,  2.5271,  1.7447,  1.4300,
         1.0829, -0.7339])

==== audio_emb 前10维 ====
tensor([-0.5300,  0.1690, -0.1369,  0.0985, -0.2493,  0.8640,  0.7113, -0.1096,
        -0.6751, -0.9109])

==== video_emb 前10维 ====
tensor([-0.0135,  0.0637, -0.0195, -0.0052,  0.0076,  0.0189,  0.0313, -0.0248,
        -0.0190, -0.0358])
