In [70]:
import pymysql
import getpass as gp
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings('ignore')

In [71]:
# 계정 정보 사전
accounts = {
    'root':   {'host': '127.0.0.1'},
    'reader': {'host': '127.0.0.1'},
    'writer': {'host': '127.0.0.1'},
    'edward': {'host': '192.168.0.27'},
    # 여기에 계정 추가 가능
}

def connect_to_db():
    user = input("ID: ").strip()
    if user not in accounts:
        raise ValueError("등록되지 않은 사용자 ID")

    password = gp.getpass("비밀번호: ")
    host = accounts[user]['host']

    # DB 이름을 반드시 입력받도록
    while True:
        db_name = input("접속할 DB 이름을 입력하세요: ").strip()
        if db_name:
            break
        print("DB 이름은 비어 있을 수 없습니다. 다시 입력해주세요.")

    conn = pymysql.connect(
        host=host,
        user=user,
        password=password,
        db=db_name,
        charset='utf8mb4'
    )
    return conn

def q(query):
    with conn.cursor() as cursor: # 커서 생성 with 구문을 사용해서 자동으로 close 하므로 메모리 누수를 방지
        cursor.execute(query)
        first = query.strip().split()[0].lower()
        if first in ['select', 'show', 'describe', 'desc', 'explain']:
            df = pd.read_sql(query, conn)
            display(df)
        else:
            conn.commit()
            print("Query OK.")


In [72]:
conn = connect_to_db()

In [73]:
q("""
CREATE TABLE IF NOT EXISTS dataset_versions (
  dataset_name   VARCHAR(255) PRIMARY KEY,
  version        VARCHAR(64) NOT NULL,
  file_format    VARCHAR(64),
  release_notes  TEXT
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
""")


Query OK.


In [74]:
q("""
CREATE TABLE IF NOT EXISTS splits (
  split_id     INT AUTO_INCREMENT PRIMARY KEY,
  dataset_name VARCHAR(255) NOT NULL,
  split_name   VARCHAR(255) NOT NULL,
  num_bytes    BIGINT,
  num_shards   INT,
  UNIQUE KEY ux_dataset_split (dataset_name, split_name),
  FOREIGN KEY (dataset_name)
    REFERENCES dataset_versions(dataset_name)
      ON DELETE CASCADE
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
""")
q("""
CREATE TABLE IF NOT EXISTS shards (
  shard_id     INT AUTO_INCREMENT PRIMARY KEY,
  split_id     INT        NOT NULL,
  shard_index  INT        NOT NULL,
  num_examples INT        NOT NULL,
  filepath     TEXT       NOT NULL,
  UNIQUE KEY ux_split_shard (split_id, shard_index),
  FOREIGN KEY (split_id)
    REFERENCES splits(split_id)
      ON DELETE CASCADE
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
""")


Query OK.
Query OK.


In [75]:
q("""
CREATE TABLE IF NOT EXISTS episodes (
  episode_id      VARCHAR(255) PRIMARY KEY,
  file_path       TEXT        NOT NULL,
  recording_path  TEXT        NOT NULL
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
""")
q("""
CREATE TABLE IF NOT EXISTS steps (
  step_id         BIGINT      PRIMARY KEY AUTO_INCREMENT,
  episode_id      VARCHAR(255) NOT NULL,
  step_index      INT         NOT NULL,
  discount        FLOAT,
  is_first        TINYINT(1),
  is_last         TINYINT(1),
  is_terminal     TINYINT(1),
  reward          FLOAT,
  lang_inst_1     TEXT,
  lang_inst_2     TEXT,
  lang_inst_3     TEXT,
  action          JSON,
  action_dict     JSON,
  obs_cart_pos    JSON,
  UNIQUE KEY ux_episode_step (episode_id, step_index),
  FOREIGN KEY (episode_id)
    REFERENCES episodes(episode_id)
      ON DELETE CASCADE
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
""")


Query OK.
Query OK.


In [76]:
import json
import tensorflow as tf
import pymysql
from tensorflow.core.example import example_pb2

In [77]:
def get_bytes(feature, key):
    if key in feature and feature[key].bytes_list.value:
        return feature[key].bytes_list.value[0].decode('utf-8')
    return None

def get_floats(feature, key):
    return list(feature[key].float_list.value) if key in feature else []

In [78]:
for step_idx, raw_record in enumerate(ds):
    ex = example_pb2.Example()
    ex.ParseFromString(raw_record.numpy())
    f = ex.features.feature

    # episode metadata
    episode_id = get_bytes(f, "episode_metadata/file_path")
    rec_path   = get_bytes(f, "episode_metadata/recording_folderpath")

    if not episode_id or not rec_path:
        print(f"스킵 (episode_id 또는 rec_path 누락) at record {step_idx}")
        continue

    # episodes 테이블에 저장
    cur.execute(
        "INSERT IGNORE INTO episodes (episode_id, file_path, recording_path) VALUES (%s, %s, %s)",
        (episode_id, shard_fp, rec_path)
    )

    # 언어 지시문
    li1 = get_bytes(f, "steps/language_instruction")
    li2 = get_bytes(f, "steps/language_instruction_2")
    li3 = get_bytes(f, "steps/language_instruction_3")

    # 기본 스칼라 값들
    discount = f["steps/discount"].float_list.value[0]      if "steps/discount" in f      else None
    is_first = int(f["steps/is_first"].int64_list.value[0]) if "steps/is_first" in f     else 0
    is_last  = int(f["steps/is_last"].int64_list.value[0])  if "steps/is_last" in f      else 0
    is_term  = int(f["steps/is_terminal"].int64_list.value[0]) if "steps/is_terminal" in f else 0
    reward   = f["steps/reward"].float_list.value[0]        if "steps/reward" in f        else None

    # 액션 벡터 (기존 `steps/action`)
    action = json.dumps(get_floats(f, "steps/action"))

    # action_dict 안의 여러 필드를 통째로 저장
    action_dict = {}
    for key in f:
        if key.startswith("steps/action_dict/"):
            subkey = key.split("/", 2)[-1]  # e.g. "cartesian_position"
            action_dict[subkey] = get_floats(f, key)
    action_dict_json = json.dumps(action_dict)

    # 관절 위치 관찰값 중 cartesian_position 하나만 예시로
    obs_cart = get_floats(f, "steps/observation/cartesian_position")
    obs_cart_json = json.dumps(obs_cart)

    # steps 테이블에 저장
    cur.execute(
        """
        INSERT INTO steps
          (episode_id, step_index, discount, is_first, is_last, is_terminal,
           reward, lang_inst_1, lang_inst_2, lang_inst_3,
           action, action_dict, obs_cart_pos)
        VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
        """,
        (episode_id,
         step_idx,
         discount, is_first, is_last, is_term,
         reward,
         li1, li2, li3,
         action, action_dict_json, obs_cart_json)
    )

conn.commit()
print("✅ 모든 레코드 적재 완료")

OperationalError: (2013, 'Lost connection to MySQL server during query')