In [16]:
#!/usr/bin/env python3
# generate_single_sts.py

import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim

# ------------ 配置区 ------------
# 待处理的单个输入 CSV
INPUT_FILE = r'E:/学习工作/PD/pks/SitToStand/Data/STS_2D_skeletons_coarsened/Pt204_C_n_301.csv'
# 生成后的保存路径（如果不填则自动在输入文件同目录下，加后缀 _synthetic）
OUTPUT_FILE = r'E:/学习工作/PD/pks/SitToStand/Data/Transformer_generate.csv'  # 默认路径
# 序列长度、关键点数量（x,y 对数）
SEQ_LEN   = 50
NUM_KPT   = 25
INPUT_DIM = NUM_KPT * 2

# 计算设备
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using Input File: {INPUT_FILE}")

# ------------ Positional Encoding ------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * -(np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        # x: (B, L, D)
        return x + self.pe[:, :x.size(1)]


# ------------ Transformer VAE ------------
class TransformerVAE(nn.Module):
    def __init__(self, input_dim, seq_len, d_model=128, nhead=4, num_layers=2, latent_dim=64):
        super().__init__()
        self.seq_len = seq_len
        self.input_linear = nn.Linear(input_dim, d_model)
        self.pos_enc      = PositionalEncoding(d_model, max_len=seq_len)
        encoder_layer     = nn.TransformerEncoderLayer(d_model, nhead)
        self.encoder      = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc_mu        = nn.Linear(d_model*seq_len, latent_dim)
        self.fc_logvar    = nn.Linear(d_model*seq_len, latent_dim)
        self.fc_latent    = nn.Linear(latent_dim, d_model*seq_len)
        decoder_layer     = nn.TransformerDecoderLayer(d_model, nhead)
        self.decoder      = nn.TransformerDecoder(decoder_layer, num_layers)
        self.output_linear= nn.Linear(d_model, input_dim)

    def encode(self, x):
        # x: (B, L, D)
        h = self.input_linear(x)
        h = self.pos_enc(h)
        out = self.encoder(h.transpose(0,1)).transpose(0,1)
        flat = out.contiguous().view(out.size(0), -1)
        return self.fc_mu(flat), self.fc_logvar(flat)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        # z: (B, latent_dim)
        x = self.fc_latent(z).view(z.size(0), self.seq_len, -1)
        tgt = self.pos_enc(x)
        memory = torch.zeros(self.seq_len, z.size(0), tgt.size(2), device=z.device)
        out = self.decoder(tgt.transpose(0,1), memory).transpose(0,1)
        return self.output_linear(out)


# ------------ 主流程 ------------
def main():
    global OUTPUT_FILE  # 声明 OUTPUT_FILE 为全局变量

    # 1. 读入 CSV
    print("开始读取 CSV 文件...")
    df = pd.read_csv(INPUT_FILE, header=2, engine='python', on_bad_lines='skip')
    print(f"CSV 文件读取完成，形状为: {df.shape}")

    # 保留前两列（一般是 frame, timestamp 或其他），后面开始是关键点
    meta = df.iloc[:, :2].reset_index(drop=True)
    data_np = df.iloc[:, 2:2+INPUT_DIM].to_numpy(dtype=np.float32)
    orig_len = data_np.shape[0]
    print(f"原始数据长度: {orig_len}")

    if orig_len < SEQ_LEN:
        pad = np.zeros((SEQ_LEN - orig_len, INPUT_DIM), dtype=np.float32)
        data_for_model = np.vstack([data_np, pad])
    else:
        data_for_model = data_np[:SEQ_LEN]

    # 2. 初始化并加载（或训练）模型
    print("初始化模型...")
    model = TransformerVAE(INPUT_DIM, SEQ_LEN).to(DEVICE)
    model.to(DEVICE).eval()

    with torch.no_grad():
        x = torch.from_numpy(data_for_model[None]).to(DEVICE)
        mu, logvar = model.encode(x)
        z = model.reparameterize(mu, logvar)
        synth = model.decode(z)[0].cpu().numpy()

    # ========== 4. 按原始帧长“反填”或截断合成结果 ==========
    if orig_len > SEQ_LEN:
        pad_back = np.zeros((orig_len - SEQ_LEN, INPUT_DIM), dtype=np.float32)
        full_synth = np.vstack([synth, pad_back])
    else:
        full_synth = synth[:orig_len]

    # ========== 5. 拼回 meta 并保存 ==========
    print(f"Meta shape: {meta.shape}, Full synth shape: {full_synth.shape}")
    if meta.shape[0] != full_synth.shape[0]:
        if meta.shape[0] > full_synth.shape[0]:
            pad_back = np.zeros((meta.shape[0] - full_synth.shape[0], INPUT_DIM), dtype=np.float32)
            full_synth = np.vstack([full_synth, pad_back])
        else:
            full_synth = full_synth[:meta.shape[0]]

    out_arr = np.hstack([meta.values, full_synth])
    out_cols = list(meta.columns) + list(df.columns[2:2+INPUT_DIM])
    out_df = pd.DataFrame(out_arr, columns=out_cols)

    if OUTPUT_FILE is None or OUTPUT_FILE == '':
        base, _ = os.path.splitext(INPUT_FILE)
        OUTPUT_FILE = f"{base}_synthetic.csv"

    out_df.to_csv(OUTPUT_FILE, header=False, index=False, float_format='%.6f')
    print(f"✔ 合成结果已保存：{OUTPUT_FILE} （shape={out_df.shape}）")

if __name__ == "__main__":
    main()

Using Input File: E:/学习工作/PD/pks/SitToStand/Data/STS_2D_skeletons_coarsened/Pt204_C_n_301.csv
开始读取 CSV 文件...
CSV 文件读取完成，形状为: (508, 52)
原始数据长度: 508
初始化模型...
Meta shape: (508, 2), Full synth shape: (508, 50)
✔ 合成结果已保存：E:/学习工作/PD/pks/SitToStand/Data/Transformer_generate.csv （shape=(508, 52)）




In [17]:
import matplotlib.pyplot as plt
import random
import numpy as np
import pandas as pd
import os

# --- Configuration for Plotting ---
# Select a few keypoints to plot (indices based on 0 to 16 for 17 keypoints)
keypoints_to_plot = [0, 8, 16]  # Example: Nose (0), MidHip (8 if it exists), Right Ankle (16)
num_samples_to_plot = 3  # How many synthetic samples to plot (will create this many separate figures/plot sets)

# Define plot styles
color_original = 'blue'
color_synthetic = 'green'
linewidth_original = 2
linestyle_synthetic = '--'

# --- Load synthetic samples for comparison ---
OUTPUT_DIR = r'E:/学习工作/PD/pks/SitToStand/Data/Transformer_generate.csv'  # Ensure this matches the output directory
INPUT_FILE = r'E:/学习工作/PD/pks/SitToStand/Data/STS_2D_skeletons_coarsened/Pt204_C_n_301.csv'

if os.path.exists(OUTPUT_DIR) and os.path.exists(INPUT_FILE):

    base_for_filter = os.path.splitext(os.path.basename(INPUT_FILE))[0]
    synthetic_file_list = [
        os.path.join(OUTPUT_DIR, f)
        for f in os.listdir(OUTPUT_DIR)
        if f.endswith('.csv') and f.startswith(base_for_filter)  # Filter for related synthetic files
    ]

    if not synthetic_file_list:
        print("No synthetic files found in the output directory for the current input file.")
    else:
        # Select random synthetic samples to plot, or plot all up to num_samples_to_plot
        paths_to_plot = random.sample(synthetic_file_list, min(num_samples_to_plot, len(synthetic_file_list)))

        # Load the original sequence
        df_original = pd.read_csv(INPUT_FILE, header=2)
        original_data = df_original.to_numpy()
        num_frames_original = original_data.shape[0]
        time_steps_original = np.arange(num_frames_original)

        print(f"\nPlotting original vs. {len(paths_to_plot)} individual synthetic samples...")

        for idx, path in enumerate(paths_to_plot):
            try:
                df_synthetic = pd.read_csv(path, header=2)
                synthetic_sequence = df_synthetic.to_numpy()
                sample_name = os.path.basename(path)
            except Exception as e:
                print(f"Error reading synthetic file {path}: {e}")
                continue

            # Create a new figure for each original vs. synthetic comparison
            num_keypoints = len(keypoints_to_plot)
            fig, axes = plt.subplots(num_keypoints, 2,
                                     figsize=(14, 3.5 * num_keypoints),
                                     sharex=True)
            if num_keypoints == 1:  # Adjust axes array if only one keypoint is plotted
                axes = np.array([axes])

            fig.suptitle(f'Original vs. Synthetic: {sample_name}', fontsize=16, y=1.02)  # Adjust y for title

            for i, kp_index in enumerate(keypoints_to_plot):
                x_index = kp_index * 2
                y_index = kp_index * 2 + 1

                if x_index >= original_data.shape[1] or y_index >= original_data.shape[1]:
                    print(f"Warning: Keypoint index {kp_index} is out of bounds for original data. Skipping.")
                    if axes.ndim > 1: axes[i, 0].set_visible(False); axes[i, 1].set_visible(False)  # Hide unused subplot
                    else: axes[0].set_visible(False); axes[1].set_visible(False)
                    continue
                if x_index >= synthetic_sequence.shape[1] or y_index >= synthetic_sequence.shape[1]:
                    print(f"Warning: Keypoint index {kp_index} is out of bounds for synthetic data {sample_name}. Skipping.")
                    if axes.ndim > 1: axes[i, 0].set_visible(False); axes[i, 1].set_visible(False)
                    else: axes[0].set_visible(False); axes[1].set_visible(False)
                    continue

                # Plot Original Data
                axes[i, 0].plot(time_steps_original, original_data[:, x_index],
                                label=f'Original KP {kp_index} (X)', color=color_original, linewidth=linewidth_original)
                axes[i, 1].plot(time_steps_original, original_data[:, y_index],
                                label=f'Original KP {kp_index} (Y)', color=color_original, linewidth=linewidth_original)

                # Plot Current Synthetic Sample
                num_synth_frames = synthetic_sequence.shape[0]
                time_steps_synthetic = np.arange(num_synth_frames)
                axes[i, 0].plot(time_steps_synthetic, synthetic_sequence[:, x_index],
                                label=f'Synth (X)', linestyle=linestyle_synthetic, color=color_synthetic, alpha=0.9)
                axes[i, 1].plot(time_steps_synthetic, synthetic_sequence[:, y_index],
                                label=f'Synth (Y)', linestyle=linestyle_synthetic, color=color_synthetic, alpha=0.9)

                axes[i, 0].set_ylabel(f'KP {kp_index} X-Coord')
                axes[i, 1].set_ylabel(f'KP {kp_index} Y-Coord')
                axes[i, 0].legend(loc='best', fontsize='x-small')
                axes[i, 1].legend(loc='best', fontsize='x-small')
                axes[i, 0].grid(True, linestyle=':', alpha=0.7)
                axes[i, 1].grid(True, linestyle=':', alpha=0.7)

            if num_keypoints > 0:  # Only set xlabel if plots were made
                axes[-1, 0].set_xlabel('Frame Number')
                axes[-1, 1].set_xlabel('Frame Number')

            plt.tight_layout(rect=[0, 0.03, 1, 0.97])  # Adjust layout
            plt.show()  # Show the figure for the current synthetic sample

        print("\nPlotting complete.")
        print("If synthetic lines are still very similar to original, consider increasing augmentation strengths (e.g., JITTER_SCALE).")
else:
    print("Error: OUTPUT_DIR or INPUT_FILE not defined or not found.")

NotADirectoryError: [WinError 267] 目录名称无效。: 'E:/学习工作/PD/pks/SitToStand/Data/Transformer_generate.csv'