In [1]:
import os
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # For 3D scatter plot

从您的项目中导入 Dataset 类和骨架定义
如果 AMASSSubsetDataset 依赖其他自定义模块，请确保它们也能被正确导入

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation
from IPython.display import HTML # For displaying animation in Jupyter
import os
import sys
import torch # Added for torch.manual_seed (optional, for reproducibility)
import traceback # Added for more detailed error messages

# --- Setup Project Paths ---
# This assumes your notebook is in the 'Manifold' project root directory.
module_path = os.path.abspath(os.path.join('..')) 
if module_path not in sys.path:
    sys.path.insert(0, module_path) # Prepend to sys.path
    print(f"Added '{module_path}' to sys.path")
else:
    print(f"'{module_path}' is already in sys.path")

try:
    from src.datasets.amass_dataset import AMASSSubsetDataset
    from src.kinematics.skeleton_utils import get_skeleton_parents, get_num_joints
    print("Custom modules imported successfully.")
except ImportError as e:
    print(f"ImportError: {e}")
    print(f"Please ensure your Jupyter Notebook can find the 'src' directory from '{os.getcwd()}'")
    raise

# --- 1. Define Path to Your PROCESSED AMASS Data ---
# Adjust subject and file name as needed
processed_npz_path = os.path.join("..","data", "CMU", "00", "00_01_poses.npz") 

if not os.path.exists(processed_npz_path):
    print(f"ERROR: Processed data file not found at '{processed_npz_path}'")
    print(f"Please check the path. Current working directory is: {os.getcwd()}")
    # sys.exit() # Or raise an error if you want to stop
else:
    print(f"Using processed AMASS file at: {processed_npz_path}")

# --- 1. Define Path to Your PROCESSED AMASS Data ---
# Adjust to your data path structure
# Example: if project root is 'Manifold', data is in 'Manifold/data/processed/...'
data_base_dir = os.path.join(project_root, "data")
processed_npz_path = os.path.join(data_base_dir, "CMU", "01", "01_01_poses.npz") # More common CMU sequence


if not os.path.exists(processed_npz_path):
    print(f"ERROR: Data file not found at '{processed_npz_path}' and dummy creation failed.")
    sys.exit() # Stop if no data
else:
    print(f"Using AMASS file at: {processed_npz_path}")


# --- 2. Define Parameters for Dataset Loading ---
skeleton_type = 'smpl_24'
window_size = 150  # Number of frames for animation, ensure sequence in file is long enough
center_around_root = False # Set to False to see absolute noise effect
# *** MODIFICATION: Set noise_std to a non-zero value ***
noise_std = 0.03  # e.g., 3cm noise if coordinates are in meters

# --- 3. Load the Dataset ---

dataset = None
all_frames_clean_np = None
all_frames_noisy_np = None

# Set random seed for reproducible noise (optional)
torch.manual_seed(42)
np.random.seed(42)

if os.path.exists(processed_npz_path):
    try:
        # Assuming AMASSSubsetDataset from snippet_8 (does not take 'is_train')
        # It adds noise if noise_std > 0
        dataset = AMASSSubsetDataset(
            data_paths=[processed_npz_path],
            window_size=window_size,
            skeleton_type=skeleton_type,
            noise_std=noise_std,
            is_train=False, 
            center_around_root=center_around_root,
            joint_selector_indices=None # IMPORTANT: Set to None for pre-processed 24-joint data
        )
        print("AMASSSubsetDataset initialized.")
    except Exception as e:
        print(f"Error initializing AMASSSubsetDataset: {e}")
        traceback.print_exc()

    if dataset and len(dataset) > 0:
        print(f"Dataset loaded successfully with {len(dataset)} windows (note: only first window used for this animation script).")
        try:
            # Assuming your dataset returns (noisy_full_window, clean_full_window, some_metadata)
            # This matches your animation script's expectation for clean_window_torch
            # If AMASSSubsetDataset (snippet_8) is used, clean_window_torch will be the target_center_pose (single frame).
            # For full animation of clean sequence, we might need to load it directly or ensure dataset provides it.
            # For this script, we'll assume your previous animation script's data extraction was correct for your Dataset version.

            # Fetching the first (and only, due to how windows are typically generated from a full sequence)
            # item that corresponds to a window starting at the beginning of the sequence.
            # The `window_size` parameter in the Dataset now defines the length of these sequences.
            sample_idx = 0 # Get the first window available
            if len(dataset) == 0:
                 raise ValueError("Dataset is empty after initialization. Check sequence length vs window_size.")

            noisy_window_torch, clean_window_torch, _metadata = dataset[sample_idx]
            print(f"Successfully got item [{sample_idx}] from dataset.")

            all_frames_clean_np = clean_window_torch.numpy()
            all_frames_noisy_np = noisy_window_torch.numpy()

            # --- SHAPE VALIDATION ---
            # Based on your animation script, clean_window_torch was (window_size, 24, 3)
            # Noisy_window_torch from AMASSSubsetDataset (snippet_8) is (window_size, N_joints, 3)
            # Target_clean_pose (snippet_8) is (N_joints, 3)
            # This part requires your AMASSSubsetDataset to provide the full clean window as the second element.
            # If clean_window_torch is just a single frame, this animation will not work as intended for side-by-side.
            # For now, proceeding with your script's implication that clean_window_torch is a full window.

            if all_frames_clean_np.ndim != 3 or all_frames_clean_np.shape[0] != window_size:
                print(f"WARNING: Expected clean_frames shape ({window_size}, N_joints, 3), but got {all_frames_clean_np.shape}")
                print("This might indicate that clean_window_torch is not the full clean window as expected by original animation script.")
                print("If it's a single frame, side-by-side animation of full sequence needs different data loading for clean part.")
                # If clean part is single frame, we can choose to animate only noisy or tile the clean frame (not ideal)
                # For now, we'll let it proceed and it might fail in animation or look static for clean part.

            if all_frames_noisy_np.ndim != 3 or all_frames_noisy_np.shape[0] != window_size:
                 print(f"ERROR: Expected noisy_frames shape ({window_size}, N_joints, 3), but got {all_frames_noisy_np.shape}")
                 all_frames_noisy_np = None # Prevent animation if shape is wrong


            print(f"Shape of clean animation data: {all_frames_clean_np.shape if all_frames_clean_np is not None else 'None'}")
            print(f"Shape of noisy animation data: {all_frames_noisy_np.shape if all_frames_noisy_np is not None else 'None'}")


            if all_frames_clean_np is not None and (np.isnan(all_frames_clean_np).any() or np.isinf(all_frames_clean_np).any()):
                print("WARNING: NaN or Inf values found in clean animation data!")
            if all_frames_noisy_np is not None and (np.isnan(all_frames_noisy_np).any() or np.isinf(all_frames_noisy_np).any()):
                print("WARNING: NaN or Inf values found in noisy animation data!")
            else:
                print("Animation data appears to have valid numbers (no NaN/Inf checked).")

        except Exception as e:
            print(f"Error getting item from dataset or processing it: {e}")
            traceback.print_exc()
            all_frames_clean_np = None
            all_frames_noisy_np = None
    else:
        print("Dataset could not be loaded or is empty.")
else:
    print(f"Skipping dataset loading as file was not found: {processed_npz_path}")


# --- 4. Setup Animation (only if data is loaded) ---
html_output = None
if all_frames_clean_np is not None and all_frames_noisy_np is not None and \
   all_frames_clean_np.shape == all_frames_noisy_np.shape and \
   all_frames_clean_np.ndim == 3 and all_frames_clean_np.shape[0] > 0 :

    num_animation_frames = all_frames_clean_np.shape[0]
    num_joints_to_plot = all_frames_clean_np.shape[1]
    skeleton_parents = get_skeleton_parents(skeleton_type)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8), subplot_kw={'projection': '3d'})
    fig.suptitle(f"Clean vs. Noisy (Noise Std: {noise_std}) Animation", fontsize=16)

    # Determine fixed axis limits based on the clean sequence for consistency
    # (or use combined range if noisy version might go significantly out of bounds)
    combined_data_for_limits = np.concatenate((all_frames_clean_np, all_frames_noisy_np), axis=0)
    x_min, x_max = combined_data_for_limits[..., 0].min(), combined_data_for_limits[..., 0].max()
    y_min, y_max = combined_data_for_limits[..., 1].min(), combined_data_for_limits[..., 1].max()
    z_min, z_max = combined_data_for_limits[..., 2].min(), combined_data_for_limits[..., 2].max()
    margin = 0.2

    axes_list = [ax1, ax2]
    plot_titles_str = ['Clean Pose', 'Noisy Pose']
    data_list_for_anim = [all_frames_clean_np, all_frames_noisy_np]
    colors_list = ['green', 'red']
    bone_colors_list = ['lime', 'lightcoral']

    scatter_plots = []
    bone_lines_list = []
    titles_obj_list = []

    for i, ax in enumerate(axes_list):
        ax.set_xlim(x_min - margin, x_max + margin)
        ax.set_ylim(y_min - margin, y_max + margin)
        ax.set_zlim(z_min - margin, z_max + margin)
        
        current_max_range = max(x_max-x_min, y_max-y_min, z_max-z_min)
        if current_max_range == 0 : current_max_range = 1
        try:
            ax.set_box_aspect((x_max-x_min if (x_max-x_min)>0 else 1,
                               y_max-y_min if (y_max-y_min)>0 else 1,
                               z_max-z_min if (z_max-z_min)>0 else 1))
        except AttributeError: # Older matplotlib
            ax.set_aspect('auto')


        ax.set_xlabel('X'); ax.set_ylabel('Y'); ax.set_zlabel('Z')
        title_obj = ax.set_title(f'{plot_titles_str[i]} - Frame 0 / {num_animation_frames-1}')
        titles_obj_list.append(title_obj)

        initial_pose = data_list_for_anim[i][0]
        sc_plot = ax.scatter(initial_pose[:, 0], initial_pose[:, 1], initial_pose[:, 2],
                               c=colors_list[i], marker='o', s=50, edgecolors='black', linewidth=0.3, depthshade=True)
        scatter_plots.append(sc_plot)

        current_bone_lines = []
        for j_idx, parent_idx in enumerate(skeleton_parents):
            if parent_idx != -1:
                line, = ax.plot([initial_pose[j_idx, 0], initial_pose[parent_idx, 0]],
                                [initial_pose[j_idx, 1], initial_pose[parent_idx, 1]],
                                [initial_pose[j_idx, 2], initial_pose[parent_idx, 2]],
                                color=bone_colors_list[i], linewidth=2)
                current_bone_lines.append(line)
        bone_lines_list.append(current_bone_lines)
        ax.view_init(elev=15., azim=-75)


    def update_dual_plot(frame_num, clean_data, noisy_data, scatter_plots_list, bones_list, skeleton_parents, titles_list):
        all_artists = []
        current_poses = [clean_data[frame_num], noisy_data[frame_num]]

        for plot_idx in range(2): # 0 for clean, 1 for noisy
            pose_to_update = current_poses[plot_idx]
            scatter_to_update = scatter_plots_list[plot_idx]
            bones_to_update = bones_list[plot_idx]
            title_to_update = titles_list[plot_idx]

            scatter_to_update._offsets3d = (pose_to_update[:, 0], pose_to_update[:, 1], pose_to_update[:, 2])
            all_artists.append(scatter_to_update)

            line_idx = 0
            for i, parent_idx in enumerate(skeleton_parents):
                if parent_idx != -1:
                    bones_to_update[line_idx].set_data(
                        [pose_to_update[i, 0], pose_to_update[parent_idx, 0]],
                        [pose_to_update[i, 1], pose_to_update[parent_idx, 1]])
                    bones_to_update[line_idx].set_3d_properties(
                        [pose_to_update[i, 2], pose_to_update[parent_idx, 2]])
                    all_artists.append(bones_to_update[line_idx])
                    line_idx += 1
            
            base_title = 'Clean Pose' if plot_idx == 0 else 'Noisy Pose'
            title_to_update.set_text(f'{base_title} - Frame {frame_num} / {num_animation_frames-1}')
            all_artists.append(title_to_update)
        
        return all_artists

    print("Creating animation... This might take a moment.")
    try:
        anim = FuncAnimation(fig, update_dual_plot, frames=num_animation_frames,
                             fargs=(all_frames_clean_np, all_frames_noisy_np, scatter_plots, bone_lines_list, skeleton_parents, titles_obj_list),
                             interval=max(20, 1000 // 30), blit=False, repeat=True) # Target ~30 FPS, min 20ms interval

        html_output = HTML(anim.to_jshtml(fps=30)) # Specify FPS for jshtml output
        print("Animation created. If it doesn't display below, ensure your Jupyter environment supports jshtml.")
    except Exception as e:
        print(f"Error during animation creation or HTML conversion: {e}")
        traceback.print_exc()
        html_output = "Error creating animation."

    plt.close(fig)
else:
    html_output_message = "Data not loaded or shapes are mismatched, cannot create animation."
    print(html_output_message)
    # Create a simple HTML object to display the message if in Jupyter
    if 'get_ipython' in globals() and get_ipython() is not None:
        html_output = HTML(f"<p>{html_output_message}</p>")
    else: # if not in IPython (e.g. running as script)
        html_output = None # Or handle differently

# Display the animation (or error message)
if html_output:
    display(html_output)
elif html_output is None and not (all_frames_clean_np is not None and all_frames_noisy_np is not None):
    pass # Message already printed if data was not loaded

'D:\git_repo_tidy\ESE6500\FinalProj\Manifold' is already in sys.path
Custom modules imported successfully.
Using processed AMASS file at: ..\data\CMU\00\00_01_poses.npz
Using AMASS file at: D:\git_repo_tidy\ESE6500\FinalProj\Manifold\data\CMU\01\01_01_poses.npz
Loading AMASS data from 1 path(s)...
Successfully loaded 1 sequences, processed for 24 joints.
Created 2602 windows.
AMASSSubsetDataset initialized.
Dataset loaded successfully with 2602 windows (note: only first window used for this animation script).
Successfully got item [0] from dataset.
Shape of clean animation data: (150, 24, 3)
Shape of noisy animation data: (150, 24, 3)
Animation data appears to have valid numbers (no NaN/Inf checked).
Creating animation... This might take a moment.
