In [1]:
#测试 DummyDataset 是否能正确读取 .pt
from train import DummyDataset

# 尝试读取一个 .pt 数据样本
dataset = DummyDataset("./data/RAVDESS/train")  # or ./data_hdtf_train
sample = dataset[0]  # 取第0个样本

print("✅ Sample loaded:")
for k, v in sample.items():
    print(f"{k}: {type(v)}, shape: {v.shape if hasattr(v, 'shape') else v}")

  from .autonotebook import tqdm as notebook_tqdm


✅ Sample loaded:
audio: <class 'torch.Tensor'>, shape: torch.Size([67925])
blendshape: <class 'torch.Tensor'>, shape: torch.Size([127, 52])
level: <class 'torch.Tensor'>, shape: torch.Size([])
person: <class 'torch.Tensor'>, shape: torch.Size([])


In [None]:
from model import EmoTalk
from loss_v0002 import EmoTalkLoss
from train import DummyDataset
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm
from types import SimpleNamespace
import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# === 配置 ===
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EPOCHS = 10
LR = 1e-4
SAVE_PATH = 'emotalk_v1_retrained.pth'
BATCH_SIZE = 1

args = SimpleNamespace(
    feature_dim=1024,
    bs_dim=52,
    device=DEVICE,
    batch_size=BATCH_SIZE,
    max_seq_len=512,
    period=20,
    emotion_dim=256,
    emo_gru_hidden=128,
    emo_gru_layers=2,
    transformer_layers=4,
    transformer_heads=8,
    transformer_dim=512,
    num_emotions=2,   # 注意！模型中只用了2维情绪one-hot
    num_person=24     # 模型中 one_hot_person 是 24维
)

# === 模型与优化器 ===
model = EmoTalk(args).to(DEVICE)
loss_fn = EmoTalkLoss(region_weighted=True)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

# === 数据准备（你需要自己确保 DummyDataset 返回的格式正确）===
train_dataset = DummyDataset('./data/RAVDESS/train')
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

# === 开始训练 ===
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0
    logs_accum = {"main": 0, "smooth": 0, "vel": 0, "total": 0}

    for raw in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        # 构造模型输入
        data = {
            "input12": raw["audio"].to(DEVICE),
            "input21": raw["audio"].to(DEVICE),
            "target11": raw["blendshape"].to(DEVICE),
            "target12": raw["blendshape"].to(DEVICE),
            "level": raw["level"].item(),
            "person": raw["person"].item()
        }

        output1, output2, _ = model(data)

        loss1, logs1 = loss_fn(output1, data["target11"])
        loss2, logs2 = loss_fn(output2, data["target12"])
        loss = (loss1 + loss2) / 2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        for k in logs1:
            logs_accum[k] += (logs1[k] + logs2[k]) / 2

    avg_loss = total_loss / len(train_loader)
    print(f"\n[Epoch {epoch+1}] Train Loss: {avg_loss:.4f}")
    print(f"  -> main: {logs_accum['main']:.4f} | smooth: {logs_accum['smooth']:.4f} | vel: {logs_accum['vel']:.4f}")

    torch.save(model.state_dict(), SAVE_PATH)
    print(f"✅ Model saved to {SAVE_PATH}")
