In [1]:
import os
import json
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import StandardScaler

In [2]:
# ============================================================
# 1️⃣ 與訓練時完全相同的模型定義
# ============================================================
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)

    def forward(self, x):
        # x: (batch, latent_dim)
        distances = (
            torch.sum(x**2, dim=1, keepdim=True)
            + torch.sum(self.embedding.weight**2, dim=1)
            - 2 * torch.matmul(x, self.embedding.weight.t())
        )
        encoding_indices = torch.argmin(distances, dim=1)
        quantized = self.embedding(encoding_indices)
        return quantized, encoding_indices


class VQVAE(nn.Module):
    def __init__(self, input_dim=99, hidden_dim=128, latent_dim=32, num_embeddings=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, latent_dim)
        )
        self.vq = VectorQuantizer(num_embeddings, latent_dim)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

    def forward(self, x):
        z = self.encoder(x)
        z_q, indices = self.vq(z)
        x_recon = self.decoder(z_q)
        return x_recon, indices, z, z_q

In [3]:
# ============================================================
# 2️⃣ 載入訓練好的模型
# ============================================================
device = "cuda" if torch.cuda.is_available() else "cpu"
model = VQVAE().to(device)
state_dict = torch.load("vqvae_model.pth", map_location=device)
model.load_state_dict(state_dict)
model.eval()
print("✅ 模型已成功載入")

✅ 模型已成功載入


  state_dict = torch.load("vqvae_model.pth", map_location=device)


In [4]:
# ============================================================
# 3️⃣ 定義 CSV → 骨架陣列轉換函式
# ============================================================
def parse_point(s):
    s = s.strip("()")
    parts = [float(p.strip(" '")) for p in s.split(",")]
    return parts

def csv_to_pose_array(csv_path):
    df = pd.read_csv(csv_path)
    pose_cols = [c for c in df.columns if c != "frame"]
    poses = []
    for _, row in df.iterrows():
        pose = []
        for c in pose_cols:
            pose += parse_point(row[c])
        poses.append(pose)
    return np.array(poses)  # (N_frames, 99)

In [5]:
# ============================================================
# 4️⃣ 處理資料夾中所有 CSV，轉成符號序列
# ============================================================
folder_path = "./dance_csv/"
output_json = "symbol_sequences.json"

symbol_dict = {}
scaler = StandardScaler()

for file_name in os.listdir(folder_path):
    if file_name.endswith(".csv"):
        file_path = os.path.join(folder_path, file_name)
        print(f"📄 處理中：{file_name}")

        # 讀取骨架資料
        poses = csv_to_pose_array(file_path)
        poses = scaler.fit_transform(poses)

        data = torch.tensor(poses, dtype=torch.float32).to(device)

        with torch.no_grad():
            _, indices, _, _ = model(data)
            indices = indices.cpu().numpy()

        # 轉成字母序列（A~Z 重複）
        symbols = [chr(65 + (i % 26)) for i in indices]
        symbol_seq = "".join(symbols)

        symbol_dict[file_name] = {
            "length": len(symbols),
            "symbols": symbol_seq
        }

📄 處理中：Ballet_1.csv
📄 處理中：Ballet_10.csv
📄 處理中：Ballet_11.csv
📄 處理中：Ballet_12.csv
📄 處理中：Ballet_13.csv
📄 處理中：Ballet_14.csv
📄 處理中：Ballet_15.csv
📄 處理中：Ballet_16.csv
📄 處理中：Ballet_17.csv
📄 處理中：Ballet_18.csv
📄 處理中：Ballet_19.csv
📄 處理中：Ballet_2.csv
📄 處理中：Ballet_20.csv
📄 處理中：Ballet_21.csv
📄 處理中：Ballet_22.csv
📄 處理中：Ballet_23.csv
📄 處理中：Ballet_24.csv
📄 處理中：Ballet_25.csv
📄 處理中：Ballet_26.csv
📄 處理中：Ballet_27.csv
📄 處理中：Ballet_28.csv
📄 處理中：Ballet_29.csv
📄 處理中：Ballet_3.csv
📄 處理中：Ballet_30.csv
📄 處理中：Ballet_31.csv
📄 處理中：Ballet_32.csv
📄 處理中：Ballet_33.csv
📄 處理中：Ballet_34.csv
📄 處理中：Ballet_35.csv
📄 處理中：Ballet_36.csv
📄 處理中：Ballet_37.csv
📄 處理中：Ballet_38.csv
📄 處理中：Ballet_39.csv
📄 處理中：Ballet_4.csv
📄 處理中：Ballet_40.csv
📄 處理中：Ballet_41.csv
📄 處理中：Ballet_42.csv
📄 處理中：Ballet_43.csv
📄 處理中：Ballet_44.csv
📄 處理中：Ballet_45.csv
📄 處理中：Ballet_46.csv
📄 處理中：Ballet_47.csv
📄 處理中：Ballet_48.csv
📄 處理中：Ballet_5.csv
📄 處理中：Ballet_6.csv
📄 處理中：Ballet_7.csv
📄 處理中：Ballet_8.csv
📄 處理中：Ballet_9.csv
📄 處理中：JiaJiangDance_49.csv
📄 處理中：JiaJiangDance_50

In [6]:
# ============================================================
# 5️⃣ 儲存成 JSON
# ============================================================
with open(output_json, "w", encoding="utf-8") as f:
    json.dump(symbol_dict, f, indent=4, ensure_ascii=False)

print(f"🎉 已完成！符號序列儲存於：{output_json}")

🎉 已完成！符號序列儲存於：symbol_sequences.json
