# 训练你自己的动作

## 准备工作
在这一部分，你将导入必要的库并定义几个函数来预处理训练图像，将其转换为包含关键点坐标和真实标签的CSV文件。

In [None]:
import csv
import cv2
import itertools
import numpy as np
import pandas as pd
import os
import sys
import tempfile
import tqdm

from matplotlib import pyplot as plt
from matplotlib.collections import LineCollection

import tensorflow as tf
import tensorflow_hub as hub
from tensorflow import keras

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

### 用MoveNet执行姿势估计的代码

In [None]:
pose_sample_rpi_path = os.path.join(os.getcwd(), 'examples/lite/examples/pose_estimation/raspberry_pi')
sys.path.append(pose_sample_rpi_path)

# 加载MoveNet Thunder模型
import utils
from data import BodyPart
from ml import Movenet
movenet = Movenet('movenet_thunder')

# 定义函数以使用MoveNet Thunder进行姿势估计
def detect(input_tensor, inference_count=3):
  image_height, image_width, channel = input_tensor.shape
  movenet.detect(input_tensor.numpy(), reset_crop_region=True)
  for _ in range(inference_count - 1):
    person = movenet.detect(input_tensor.numpy(), reset_crop_region=False)
  return person

### 可视化姿势估计结果的函数

In [None]:
def draw_prediction_on_image(image, person, crop_region=None, close_figure=True, keep_input_size=False):
  image_np = utils.visualize(image, [person])
  height, width, channel = image.shape
  aspect_ratio = float(width) / height
  fig, ax = plt.subplots(figsize=(12 * aspect_ratio, 12))
  im = ax.imshow(image_np)
  if close_figure:
    plt.close(fig)
  if not keep_input_size:
    image_np = utils.keep_aspect_ratio_resizer(image_np, (512, 512))
  return image_np

### 加载图像，检测姿势关键点并保存到CSV文件中的代码

In [None]:
class MoveNetPreprocessor(object):
  def __init__(self, images_in_folder, images_out_folder, csvs_out_path):
    self._images_in_folder = images_in_folder
    self._images_out_folder = images_out_folder
    self._csvs_out_path = csvs_out_path
    self._messages = []
    self._csvs_out_folder_per_class = tempfile.mkdtemp()
    self._pose_class_names = sorted([n for n in os.listdir(self._images_in_folder) if not n.startswith('.')])

  def process(self, per_pose_class_limit=None, detection_threshold=0.1):
    for pose_class_name in self._pose_class_names:
      print('Preprocessing', pose_class_name, file=sys.stderr)
      images_in_folder = os.path.join(self._images_in_folder, pose_class_name)
      images_out_folder = os.path.join(self._images_out_folder, pose_class_name)
      csv_out_path = os.path.join(self._csvs_out_folder_per_class, pose_class_name + '.csv')
      if not os.path.exists(images_out_folder):
        os.makedirs(images_out_folder)
      with open(csv_out_path, 'w') as csv_out_file:
        csv_out_writer = csv.writer(csv_out_file, delimiter=',', quoting=csv.QUOTE_MINIMAL)
        image_names = sorted([n for n in os.listdir(images_in_folder) if not n.startswith('.')])
        if per_pose_class_limit is not None:
          image_names = image_names[:per_pose_class_limit]
        valid_image_count = 0
        for image_name in tqdm.tqdm(image_names):
          image_path = os.path.join(images_in_folder, image_name)
          try:
            image = tf.io.read_file(image_path)
            image = tf.io.decode_jpeg(image)
          except:
            self._messages.append('Skipped ' + image_path + '. Invalid image.')
            continue
          else:
            image = tf.io.read_file(image_path)
            image = tf.io.decode_jpeg(image)
            image_height, image_width, channel = image.shape
          if channel != 3:
            self._messages.append('Skipped ' + image_path + '. Image isn\'t in RGB format.')
            continue
          person = detect(image)
          min_landmark_score = min([keypoint.score for keypoint in person.keypoints])
          should_keep_image = min_landmark_score >= detection_threshold
          if not should_keep_image:
            self._messages.append('Skipped ' + image_path + '. No pose was confidently detected.')
            continue
          valid_image_count += 1
          output_overlay = draw_prediction_on_image(image.numpy().astype(np.uint8), person, close_figure=True, keep_input_size=True)
          output_frame = cv2.cvtColor(output_overlay, cv2.COLOR_RGB2BGR)
          cv2.imwrite(os.path.join(images_out_folder, image_name), output_frame)
          pose_landmarks = np.array([[keypoint.coordinate.x, keypoint.coordinate.y, keypoint.score] for keypoint in person.keypoints], dtype=np.float32)
          coordinates = pose_landmarks.flatten().astype(str).tolist()
          csv_out_writer.writerow([image_name] + coordinates)
        if not valid_image_count:
          raise RuntimeError('No valid images found for the "{}" class.'.format(pose_class_name))
    print('\n'.join(self._messages))
    all_landmarks_df = self._all_landmarks_as_dataframe()
    all_landmarks_df.to_csv(self._csvs_out_path, index=False)

  def class_names(self):
    return self._pose_class_names

  def _all_landmarks_as_dataframe(self):
    total_df = None
    for class_index, class_name in enumerate(self._pose_class_names):
      csv_out_path = os.path.join(self._csvs_out_folder_per_class, class_name + '.csv')
      per_class_df = pd.read_csv(csv_out_path, header=None)
      per_class_df['class_no'] = [class_index]*len(per_class_df)
      per_class_df['class_name'] = [class_name]*len(per_class_df)
      per_class_df[per_class_df.columns[0]] = (os.path.join(class_name, '') + per_class_df[per_class_df.columns[0]].astype(str))
      if total_df is None:
        total_df = per_class_df
      else:
        total_df = pd.concat([total_df, per_class_df], axis=0)
    list_name = [[bodypart.name + '_x', bodypart.name + '_y', bodypart.name + '_score'] for bodypart in BodyPart]
    header_name = []
    for columns_name in list_name:
      header_name += columns_name
    header_name = ['file_name'] + header_name
    header_map = {total_df.columns[i]: header_name[i] for i in range(len(header_name))}
    total_df.rename(header_map, axis=1, inplace=True)
    return total_df

## 第1部分：预处理输入图像

由于我们的姿势分类器的输入是MoveNet模型的输出关键点，我们需要通过MoveNet运行已标记的图像，并将所有关键点数据和真实标签捕获到CSV文件中来生成我们的训练数据集。

我们为本教程提供的数据集是一个CG生成的瑜伽姿势数据集。它包含多个CG生成的模型在做5种不同瑜伽姿势的图像。目录已经分为`train`数据集和`test`数据集。

因此，在本节中，我们将下载瑜伽数据集并通过MoveNet运行，以便我们可以将所有关键点捕获到CSV文件中... **然而，将我们上千张文件的数据集输入MoveNet并生成这个CSV文件大约需要10分钟**。所以作为替代方案，你可以通过设置下面的`is_skip_step_1`参数为**True**，下载一个现成的瑜伽数据集的CSV文件。这样你将跳过此步骤，而是下载将在此预处理步骤中创建的相同CSV文件。

想用自己的图像数据集训练姿势分类器，你需要上传你的图像并运行此预处理步骤按照以下说明上传你自己的姿势数据集。

1. 准备一个包含你的图像数据集文件夹的压缩文件(ZIP、TAR或其他)。如果你的数据集尚未拆分，将根据指定的拆分比例进行拆分。也就是说，你上传的图像文件夹应如下所示：
  ```
  yoga_poses/
  |__ downdog/
      |______ 00000128.jpg
      |______ 00000181.jpg
      |______ ...
  |__ goddess/
      |______ 00000243.jpg
      |______ 00000306.jpg
      |______ ...
  ```
1. 选择你的压缩文件并等待上传完成后再继续。
2. 编辑以下代码块以指定你的压缩文件和图像目录的名称（默认情况下，我们期望一个ZIP文件，因此如果你的压缩文件是其他格式，也需要修改该部分）。
3. 现在运行其余的笔记本。