In [1]:
import torch
import pprint  # 用来格式化输出

# 文件路径
pt_path = "/workspace/dataset/S5/features/Ses05M_impro06.pt"

# 读取文件
data = torch.load(pt_path, map_location="cpu")

print(f"Loaded {len(data)} utterances from {pt_path}")

# 查看前几个 key（每个 utterance 的 ID）
print("Utterance IDs:", list(data.keys())[:5])

# 取出一个样本看看
utt_id = list(data.keys())[0]
entry = data[utt_id]

print(f"\n=== Example: {utt_id} ===")
pprint.pprint({
    "start_time": entry.get("start_time"),
    "end_time": entry.get("end_time"),
    "text": entry.get("text"),
    "text_emb shape": entry["text_emb"].shape,
    "audio_emb shape": entry["audio_emb"].shape,
    "vision_emb shape": entry["vision_emb"].shape,
})


Loaded 34 utterances from /workspace/dataset/S5/features/Ses05M_impro06.pt
Utterance IDs: ['Ses05M_impro06_F000', 'Ses05M_impro06_M000', 'Ses05M_impro06_F001', 'Ses05M_impro06_M001', 'Ses05M_impro06_F002']

=== Example: Ses05M_impro06_F000 ===
{'audio_emb shape': torch.Size([1, 1024]),
 'end_time': 6.57,
 'start_time': 3.67,
 'text': "Ryan, what's wrong?",
 'text_emb shape': torch.Size([1, 1024]),
 'vision_emb shape': torch.Size([1, 1024])}


In [2]:
# 拼接所有文本 embedding
text_embs = torch.cat([v["text_emb"] for v in data.values()], dim=0)
audio_embs = torch.cat([v["audio_emb"] for v in data.values()], dim=0)
vision_embs = torch.cat([v["vision_emb"] for v in data.values()], dim=0)

In [3]:
text_embs

tensor([[-1.7944, -2.4691, -0.6758,  ..., -3.4899, -1.0594,  1.9511],
        [ 2.3022,  3.2462,  0.9003,  ..., -0.4022, -0.2051,  0.9829],
        [-0.7184, -5.3655, -0.0762,  ...,  2.2394,  1.6025, -5.4964],
        ...,
        [ 1.2145, -0.2882, -0.4721,  ...,  2.2102, -1.5978,  0.2008],
        [ 1.7834,  1.2176, -0.7155,  ..., -1.5025, -3.6112, -0.1797],
        [-0.3329, -4.4933, -2.5126,  ...,  1.9726, -0.3362, -4.8012]])

In [4]:
text_embs.shape

torch.Size([34, 1024])

In [10]:
import os
import torch

def load_all_embeddings(base_dir="/workspace/dataset", sessions=["S1", "S2", "S3", "S4", "S5"]):
    all_data = []
    skipped_files = []
    loaded_files = 0

    for s in sessions:
        feature_dir = os.path.join(base_dir, s, "features")
        if not os.path.exists(feature_dir):
            print(f"[WARN] Missing feature dir: {feature_dir}")
            continue

        for fname in os.listdir(feature_dir):
            #  只加载包含 with_emo 的文件
            if not fname.endswith("_with_emo.pt"):
                continue

            fpath = os.path.join(feature_dir, fname)
            try:
                data = torch.load(fpath, map_location="cpu")
            except Exception as e:
                print(f"[ERROR] Failed to load {fname}: {e}")
                skipped_files.append(fname)
                continue

            loaded_files += 1

            for utt_id, v in data.items():
                # 确保有情感字段才保存
                if "C-E1" not in v or v["C-E1"] is None:
                    continue

                all_data.append({
                    "session": s,
                    "video": fname.replace("_with_emo.pt", ""),
                    "utt_id": utt_id,
                    "text_emb": v["text_emb"].squeeze(0),
                    "audio_emb": v["audio_emb"].squeeze(0),
                    "vision_emb": v["vision_emb"].squeeze(0),
                    "start_time": v.get("start_time"),
                    "end_time": v.get("end_time"),
                    "VAD": v.get("VAD"),
                    "C-E1": v.get("C-E1"),
                    "C-E2": v.get("C-E2"),
                    "C-E4": v.get("C-E4"),
                })

    #print(f"Loaded {loaded_files} files with emotion labels.")
    #print(f"Total utterances collected: {len(all_data)}")
    #if skipped_files:
    #    print(f"[WARN] Skipped {len(skipped_files)} files: {skipped_files[:5]}")

    return all_data



In [11]:
import torch.nn as nn

class MultiModalBiGRU(nn.Module):
    def __init__(self, input_dim=1024*3, hidden_dim=512, num_layers=1, bidirectional=True):
        super().__init__()
        self.bigru = nn.GRU(
            input_dim,
            hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=bidirectional
        )

    def forward(self, x):
        # x: [batch_size, seq_len, input_dim]
        out, _ = self.bigru(x)
        return out


In [12]:
from itertools import groupby
from operator import itemgetter

all_data = load_all_embeddings()
# 假设每个视频是一组对话


# 按视频分组
videos = {}
for item in all_data:
    key = (item["session"], item["video"])
    videos.setdefault(key, []).append(item)

# 初始化 BiGRU
model = MultiModalBiGRU().to("cuda")

for (session, video), utterances in videos.items():
    utterances = sorted(utterances, key=lambda x: x["start_time"] or 0.0)
    features = []
    for u in utterances:
        feat = torch.cat([u["text_emb"], u["audio_emb"], u["vision_emb"]], dim=-1)
        features.append(feat)
    x = torch.stack(features).unsqueeze(0).to("cuda")  # [1, seq_len, 3072]
    with torch.no_grad():
        output = model(x)
    print(f"{session}/{video}: GRU output shape = {output.shape}")


S1/Ses01F_impro02: GRU output shape = torch.Size([1, 38, 1024])
S1/Ses01F_impro03: GRU output shape = torch.Size([1, 52, 1024])
S1/Ses01F_impro04: GRU output shape = torch.Size([1, 71, 1024])
S1/Ses01F_impro05: GRU output shape = torch.Size([1, 67, 1024])
S1/Ses01F_impro06: GRU output shape = torch.Size([1, 47, 1024])
S1/Ses01F_impro07: GRU output shape = torch.Size([1, 37, 1024])
S1/Ses01F_script01_1: GRU output shape = torch.Size([1, 89, 1024])
S1/Ses01F_script01_2: GRU output shape = torch.Size([1, 33, 1024])
S1/Ses01F_script01_3: GRU output shape = torch.Size([1, 75, 1024])
S1/Ses01F_script03_1: GRU output shape = torch.Size([1, 71, 1024])
S1/Ses01F_script03_2: GRU output shape = torch.Size([1, 82, 1024])
S1/Ses01M_impro05: GRU output shape = torch.Size([1, 70, 1024])
S1/Ses01M_impro02: GRU output shape = torch.Size([1, 48, 1024])
S1/Ses01M_impro03: GRU output shape = torch.Size([1, 53, 1024])
S1/Ses01M_impro04: GRU output shape = torch.Size([1, 52, 1024])
S1/Ses01M_impro01: GRU ou

In [9]:
import torch
import pprint  # 用来格式化输出

# 文件路径
pt_path = "/workspace/dataset/S5/features/Ses05M_impro06_with_emo.pt"

# 读取文件
data = torch.load(pt_path, map_location="cpu")

print(f"Loaded {len(data)} utterances from {pt_path}")

# 查看前几个 key（每个 utterance 的 ID）
print("Utterance IDs:", list(data.keys())[:5])

# 取出一个样本看看
utt_id = list(data.keys())[0]
entry = data[utt_id]

print(f"\n=== Example: {utt_id} ===")
pprint.pprint({
    "start_time": entry.get("start_time"),
    "end_time": entry.get("end_time"),
    "text": entry.get("text"),
    "text_emb shape": entry["text_emb"].shape,
    "audio_emb shape": entry["audio_emb"].shape,
    "vision_emb shape": entry["vision_emb"].shape,
    "VAD": entry.get("VAD"),
    "C-E1": entry.get("C-E1"),
    "C-E2": entry.get("C-E2"),
    "C-E4": entry.get("C-E4"),
})


Loaded 34 utterances from /workspace/dataset/S5/features/Ses05M_impro06_with_emo.pt
Utterance IDs: ['Ses05M_impro06_F000', 'Ses05M_impro06_M000', 'Ses05M_impro06_F001', 'Ses05M_impro06_M001', 'Ses05M_impro06_F002']

=== Example: Ses05M_impro06_F000 ===
{'C-E1': 'Sadness',
 'C-E2': 'Neutral',
 'C-E4': 'Other',
 'VAD': [2.0, 2.0, 1.5],
 'audio_emb shape': torch.Size([1, 1024]),
 'end_time': 6.57,
 'start_time': 3.67,
 'text': "Ryan, what's wrong?",
 'text_emb shape': torch.Size([1, 1024]),
 'vision_emb shape': torch.Size([1, 1024])}
