In [3]:
import os
import json
import numpy as np
import tensorflow as tf
from tensorflow import keras

# 데이터 경로 설정
keypoint_dir = "/home/uk/Desktop/dev_ws/DeepLearning/data/label_data/SYN_word_keypoint/02_sys_word_keypoint.zip/WORD/keypoint/03"
label_dir = "/home/uk/Desktop/dev_ws/DeepLearning/data/label_data/SYN_word_keypoint/02_sys_word_keypoint.zip/WORD/morpheme/03"

# 이미지 크기 및 키포인트 개수 설정
img_size = 128

# 데이터 로드 및 전처리 함수 (Generator)
def data_generator(keypoint_dir, label_dir, img_size):
    keypoint_folders = sorted(os.listdir(keypoint_dir))
    for folder in keypoint_folders:
        folder_path = os.path.join(keypoint_dir, folder)
        if not os.path.isdir(folder_path):
            continue

        frame_files = sorted(os.listdir(folder_path))
        for frame_file in frame_files:
            if not frame_file.endswith(".json"):
                continue
            
            frame_path = os.path.join(folder_path, frame_file)
            try:
                with open(frame_path, "r") as f:
                    keypoint_data = json.load(f)
                
                if "people" in keypoint_data and keypoint_data["people"]:
                    pose_keypoints = np.array(keypoint_data["people"]["pose_keypoints_2d"]).reshape(-1, 3)
                    keypoints = pose_keypoints[:, :2].flatten()

                    img = np.zeros((img_size, img_size, 1), dtype=np.uint8)
                    for i in range(0, len(keypoints), 2):
                        x, y = int(keypoints[i] * img_size / 1920), int(keypoints[i + 1] * img_size / 1080)
                        if 0 <= x < img_size and 0 <= y < img_size:
                            img[y, x] = 255

                    img = img / 255.0  # 정규화
                    yield img  # 한 번에 하나씩 반환

            except Exception as e:
                print(f"Error processing keypoint file: {frame_path}, {e}")

# Dataset 생성 (제너레이터 사용)
dataset = tf.data.Dataset.from_generator(
    lambda: data_generator(keypoint_dir, label_dir, img_size),
    output_signature=tf.TensorSpec(shape=(img_size, img_size, 1), dtype=tf.float32)
)

# 배치 및 셔플
batch_size = 8
train_dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

# CNN 모델 정의
model = keras.Sequential([
    keras.layers.Conv2D(32, (3, 3), activation="relu", input_shape=(img_size, img_size, 1)),
    keras.layers.MaxPooling2D((2, 2)),
    keras.layers.Conv2D(64, (3, 3), activation="relu"),
    keras.layers.MaxPooling2D((2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(128, activation="relu"),
    keras.layers.Dense(10, activation="softmax")  # 클래스 수는 데이터셋 확인 후 수정
])

# 모델 컴파일
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

# 모델 학습
checkpoint_path = "model_checkpoint.weights.h5"
checkpoint_callback = keras.callbacks.ModelCheckpoint(
    checkpoint_path, save_weights_only=True, verbose=1
)

model.fit(train_dataset, epochs=10, callbacks=[checkpoint_callback])

# GPU 설정
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
    except RuntimeError as e:
        print(e)

Epoch 1/10


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


ValueError: None values not supported.

In [1]:
!nvidia-smi

Mon Mar 17 20:30:07 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.120                Driver Version: 550.120        CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3060 ...    Off |   00000000:01:00.0  On |                  N/A |
| N/A   49C    P8             13W /   70W |      58MiB /   6144MiB |     27%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                