### 환경 설정

In [None]:
# [1] 런타임 체크
import torch, platform, sys, os, subprocess, textwrap, random, numpy as np
print("Python:", sys.version)
print("CUDA available:", torch.cuda.is_available())
print("GPU:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)

# 재현성
seed=42
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False

In [None]:
# [2] 라이브러리
!pip -q install pretty_midi miditoolkit music21 datasets --progress-bar off
# (선택) 시각화/로그: wandb or tensorboard 원하면 추가

In [None]:
### 음악 샘플 테스트
!apt-get -q install -y fluidsynth
!pip install midi2audio

In [None]:
# [3] 드라이브 마운트 & 프로젝트 폴더
from google.colab import drive; drive.mount('/content/drive')

In [None]:
# 이미 Drive 마운트했다고 가정 (안했다면: from google.colab import drive; drive.mount('/content/drive'))

PROJ = "/content/drive/MyDrive/Deep_Learning_project/original_token"
print("PROJ:", PROJ)

In [None]:
import os, glob, json, pandas as pd

# PROJ 하위에서 메타데이터 파일 찾기 (파일명이 다르면 패턴만 바꾸면 됨)
candidates = glob.glob(os.path.join(PROJ, "**", "metadata.json"), recursive=True)
if not candidates:
    raise FileNotFoundError("metadata.json 을 찾지 못했습니다. 위치를 확인하세요.")
META_JSON = candidates[0]
print("META_JSON:", META_JSON)

# JSON 로드
with open(META_JSON, "r") as f:
    meta_raw = json.load(f)

# JSON 구조가 {id: {metadata: {...}, audio_scores: {...}}} 형태라고 가정하고
# 'metadata' 안의 내용을 추출하여 DataFrame 생성, JSON id를 index로 사용
data_dict = {}
for id, item in meta_raw.items():
    if 'metadata' in item:
        metadata = item['metadata']
        row = {}
        # 필요한 컬럼들을 직접 매핑
        row['file_path'] = metadata.get('file_path')
        row['split'] = metadata.get('split')
        row['composer'] = metadata.get('composer')
        row['music_period'] = metadata.get('music_period')
        row['difficulty'] = metadata.get('difficulty')
        row['genre'] = metadata.get('genre')
        row['opus'] = metadata.get('opus')
        # audio_scores가 top level이 아니라 metadata 안에 있는 경우
        if 'audio_score' in metadata:
             row['audio_score'] = metadata.get('audio_score')
        # split_ratio가 top level이 아니라 metadata 안에 있는 경우
        if 'split_ratio' in metadata:
             row['split_ratio'] = metadata.get('split_ratio')

        # 예시: audio_scores나 split_ratio가 metadata 레벨이 아닌 다른 곳에 있다면 아래와 같이 접근
        # if 'audio_scores' in item:
        #     row['audio_score'] = item['audio_scores'].get('some_score_key') # replace 'some_score_key'
        # if 'split_ratio' in item:
        #     row['split_ratio'] = item.get('split_ratio')

        data_dict[id] = row

meta_df = pd.DataFrame.from_dict(data_dict, orient="index")


# 중요한 컬럼만 보기 좋게 정렬(필요시 수정)
cols = [c for c in ["file_path","split","composer","music_period","difficulty","genre","audio_score","opus","split_ratio"] if c in meta_df.columns]
meta_df = meta_df[cols]
display(meta_df.head())

In [None]:
import os, pandas as pd
from pathlib import Path
from collections import defaultdict
import re
import random # Import random for shuffling

# NOTE: Make sure PROJ_Mozart is set correctly to the directory containing your MIDI files
# PROJ_Mozart = "/content/drive/MyDrive/Deep_Learning_project/original_token/mozart_midis" # Example path
if 'PROJ_Mozart' not in locals() and 'PROJ_Mozart' not in globals():
    # Define a default or raise an error if PROJ_Mozart is not defined
    # For now, using a placeholder - PLEASE UPDATE THIS TO YOUR ACTUAL MIDI DIRECTORY IF NEEDED
    PROJ_Mozart = "/content/drive/MyDrive/Deep_Learning_project/original_token/mozart_midis"
    print(f"PROJ_Mozart was not defined, using default: {PROJ_Mozart}")


# PROJ 아래의 .mid/.midi 재귀 인덱싱
all_midis = [p for p in Path(PROJ_Mozart).rglob("*") if p.suffix.lower() in [".mid", ".midi"]]
print("Indexed MIDI files:", len(all_midis))

# 실제 파일명을 소문자로 정규화하여 인덱싱
by_name = defaultdict(list)
for p in all_midis:
    by_name[p.name.lower()].append(str(p))

# print("Sample keys in by_name:", list(by_name.keys())[:10]) # Debugging: print sample keys from the indexed files

def find_midi_for_meta(meta_index: str):
    """
    메타데이터 DataFrame의 인덱스(JSON 키)를 사용하여 실제 MIDI 파일 찾기 시도.
    파일 이름 패턴이 '{index}_0.mid' 또는 '{index}_0.midi' 형태라고 가정.
    """
    # print(f"Attempting to match index: {meta_index}") # Debugging: print the current index being processed

    if not isinstance(meta_index, str):
        # print("Index is not a string, skipping.") # Debugging
        return None

    # 예상 파일 이름 패턴 생성 (예: '4' -> '000004_0.mid')
    # JSON 키가 숫자로 변환될 수 있다고 가정
    try:
        num = int(meta_index)
        base_name_mid = f"{str(num).zfill(6)}_0.mid"
        base_name_midi = f"{str(num).zfill(6)}_0.midi"
        # print(f"Expected filenames: {base_name_mid}, {base_name_midi}") # Debugging
    except ValueError:
        # 인덱스가 숫자가 아니면 매칭 불가
        # print(f"Index '{meta_index}' is not a number, skipping.") # Debugging
        return None

    # 인덱싱된 파일 목록에서 찾아보기
    hits_mid = by_name.get(base_name_mid.lower(), [])
    if hits_mid:
        # print(f"Match found for {base_name_mid.lower()}: {hits_mid[0]}") # Debugging
        return hits_mid[0]

    hits_midi = by_name.get(base_name_midi.lower(), [])
    if hits_midi:
        # print(f"Match found for {base_name_midi.lower()}: {hits_midi[0]}") # Debugging
        return hits_midi[0]

    # print(f"No match found for index: {meta_index}") # Debugging
    return None


# 매칭 실행: file_path 대신 DataFrame index 사용
print("\n[EDA 실행 전 파일 목록 검토]")
if 'meta_df' not in locals():
    print("meta_df DataFrame이 존재하지 않습니다. 이전 셀을 실행해주세요.")
    # EDA 관련 변수들을 빈 리스트로 초기화하여 오류 방지
    train_files, val_files, test_files = [], [], []
else:
    # DataFrame의 인덱스를 사용하여 파일 경로 매칭
    meta_df["full_path"] = meta_df.index.map(find_midi_for_meta)

    # Remove rows where full_path is None before splitting
    matched_df = meta_df.dropna(subset=["full_path"]).copy()

    # Randomly split the matched files (60% train, 20% val, 20% test)
    matched_files_list = matched_df["full_path"].tolist()
    random.shuffle(matched_files_list)

    total_matched = len(matched_files_list)
    train_size = int(0.6 * total_matched)
    val_size = int(0.2 * total_matched)
    # test_size = total_matched - train_size - val_size # remaining

    train_files = matched_files_list[:train_size]
    val_files = matched_files_list[train_size:train_size + val_size]
    test_files = matched_files_list[train_size + val_size:] # Use remaining for test

    # Update the 'split' column in the DataFrame based on the new random split
    matched_df['split'] = None # Reset split column
    matched_df.loc[matched_df['full_path'].isin(train_files), 'split'] = 'train'
    matched_df.loc[matched_df['full_path'].isin(val_files), 'split'] = 'val'
    matched_df.loc[matched_df['full_path'].isin(test_files), 'split'] = 'test'

    # Now create the split dataframes from the matched_df with the new splits
    train_df = matched_df[matched_df["split"] == "train"]
    val_df   = matched_df[matched_df["split"] == "val"]
    test_df  = matched_df[matched_df["split"] == "test"]


    # Result check
    matched_count = len(matched_df) # Now this is the count of files that were successfully matched AND split
    total_count = len(meta_df)
    print(f"원본 메타데이터 총 개수: {total_count}")
    print(f"매칭 및 분할된 파일 수: {matched_count}")


    print("train:", len(train_files), " | val:", len(val_files), " | test:", len(test_files))
    display(train_df.head())

### EDA
	•	이 데이터로 LSTM을 안정적으로 학습시킬 수 있는가?
	•	토큰화 규칙(시간 분할, 벨로시티 bin, max_len)을 어떻게 정할 것인가?
	•	학습 전 배제해야 할 샘플(너무 짧음/깨짐/이상치)은 있는가?

  	•	TIME_SHIFT 분할: IOI 분포 기반 32 또는 64 결정
	•	VEL bin 개수: 벨로시티 분포 기반 8/16 중 택1
	•	max_len/TBPTT: 길이 P95 기반(예: 512)
	•	폴리포니 처리: 동시 발음 분포에 맞춰 간단/확장 설계
	•	배제 규칙: 짧은 곡/무음/깨짐 사례 기준치 확정
	•	샘플링 제약: 반복률 높으면 반복 페널티/노리핏 n-gram 도입

In [None]:
import pretty_midi
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def analyze_midi(path):
    try:
        pm = pretty_midi.PrettyMIDI(path)
    except Exception as e:
        print("Parse error:", path, e)
        return None

    # (B) 길이/밀도
    duration = pm.get_end_time()
    events = sum(len(inst.notes) for inst in pm.instruments)
    density = events / duration if duration > 0 else 0

    # (E) 리듬(IOI 분포)
    iois = []
    for inst in pm.instruments:
        starts = sorted([n.start for n in inst.notes])
        iois += np.diff(starts).tolist() if len(starts) > 1 else []
    iois = np.array(iois)

    # (D) 다이내믹스(벨로시티 분포)
    velocities = [n.velocity for inst in pm.instruments for n in inst.notes]

    # (F) 폴리포니(동시 발음 수)
    note_times = []
    for inst in pm.instruments:
        for n in inst.notes:
            note_times.append((n.start, +1))  # note_on
            note_times.append((n.end, -1))   # note_off
    note_times.sort()
    active, max_poly, poly_hist = 0, 0, []
    for t, ev in note_times:
        active += ev
        max_poly = max(max_poly, active)
        poly_hist.append(active)

    return {
        "file": path,
        "duration": duration,
        "events": events,
        "density": density,
        "iois": iois,
        "velocities": velocities,
        "max_poly": max_poly,
        "poly_hist": poly_hist,
    }

In [None]:
import pandas as pd
from tqdm import tqdm

all_files = train_files + val_files + test_files
results = []
ioi_all, vel_all, poly_all = [], [], []

for f in tqdm(all_files[:100]):  # 처음엔 일부만 테스트 (예: 100개)
    r = analyze_midi(f)
    if r:
        results.append({
            "file": r["file"],
            "duration": r["duration"],
            "events": r["events"],
            "density": r["density"],
            "max_poly": r["max_poly"]
        })
        ioi_all += r["iois"].tolist()
        vel_all += r["velocities"]
        poly_all += r["poly_hist"]

eda_df = pd.DataFrame(results)
eda_df.head()

In [None]:
# 필요 라이브러리
import pretty_midi, numpy as np, math
from collections import Counter, defaultdict
from statistics import median

# ---- 공통 유틸 ----
def safe_load_midi(path):
    try:
        return pretty_midi.PrettyMIDI(path)
    except Exception as e:
        return None

def median_beat_period(pm: pretty_midi.PrettyMIDI):
    """곡의 박(quarter-note) 길이(초) 추정: beat 간격의 중앙값 사용."""
    beats = pm.get_beats()  # tempo 변화 반영됨
    if len(beats) >= 2:
        return float(np.median(np.diff(beats)))
    # 예외: 비트 추정 실패 → 템포 변화에서 근사
    times, tempi = pm.get_tempo_changes()
    if len(tempi) > 0:
        return float(np.median(60.0 / tempi))
    # 최후의 기본값(120bpm)
    return 0.5

def all_notes(pm, include_drums=False):
    notes = []
    for inst in pm.instruments:
        if (not include_drums) and inst.is_drum:
            continue
        notes.extend(inst.notes)
    # 시작시각 기준 정렬
    notes.sort(key=lambda n: (n.start, n.end))
    return notes

# ---- (1) TIME_SHIFT 분해능 평가: IOI 스냅 오차 ----
def ioi_snap_report(files, limit=None):
    """
    IOI(인접 note-on 간 시간차)를 박 격자(32/64분할)에 스냅했을 때
    스텝 대비 오차의 P95를 계산해 추천 분해능을 반환.
    """
    errs = {32: [], 64: []}
    n_files = 0
    for i, path in enumerate(files):
        if limit and i >= limit:
            break
        pm = safe_load_midi(path)
        if pm is None:
            continue
        n_files += 1
        bp = median_beat_period(pm)  # one beat (quarter-note) seconds
        starts = [n.start for n in all_notes(pm)]
        if len(starts) < 2:
            continue
        iois = np.diff(sorted(starts))
        for div in (32, 64):
            step = bp / div  # seconds per sub-beat
            x = iois / step  # 스텝 단위로 표시
            frac_err = np.abs(x - np.round(x))  # 최근접 스텝과의 차이(스텝 단위)
            errs[div].extend(frac_err.tolist())

    rep = {}
    for div in (32, 64):
        if len(errs[div]) == 0:
            rep[div] = np.nan
        else:
            rep[div] = float(np.percentile(errs[div], 95))  # P95 (스텝 단위, 0~0.5)
    # 추천 규칙: P95 < 0.25 이면 해당 격자 OK. 둘 다 OK면 더 단순한 32를 채택.
    if math.isnan(rep[32]) and math.isnan(rep[64]):
        choice = None
    elif (not math.isnan(rep[32]) and rep[32] <= 0.25) and (not math.isnan(rep[64]) and rep[64] <= 0.25):
        choice = 32
    elif (not math.isnan(rep[64]) and rep[64] <= 0.25):
        choice = 64
    else:
        # 둘 다 크면 64가 상대적으로 유리(더 촘촘)
        choice = 64

    print(f"[IOI 스냅 오차] P95(스텝 단위) → 32분할:{rep[32]:.3f}, 64분할:{rep[64]:.3f}  | 추천:{choice}")
    return {"p95_32": rep[32], "p95_64": rep[64], "choice": choice}

# ---- (2) Velocity 분포 요약: IQR 기반 bin 추천 ----
def velocity_report(files, limit=None):
    vels = []
    parsed = 0
    for i, path in enumerate(files):
        if limit and i >= limit:
            break
        pm = safe_load_midi(path)
        if pm is None:
            continue
        parsed += 1
        for n in all_notes(pm):
            vels.append(n.velocity)
    if len(vels) == 0:
        print("[Velocity] 수집된 벨로시티가 없습니다.")
        return {"iqr": None, "choice": None}
    v = np.array(vels, dtype=float)
    q25, q75 = np.percentile(v, [25, 75])
    iqr = float(q75 - q25)
    # 추천 규칙: IQR < 20 → 8bin, 그 외엔 16bin
    choice = 8 if iqr < 20 else 16
    print(f"[Velocity] IQR={iqr:.1f}  (Q25={q25:.1f}, Q75={q75:.1f})  | 추천 bin={choice}")
    return {"iqr": iqr, "q25": float(q25), "q75": float(q75), "choice": choice}

# ---- (3) 반복률: 3–5그램 상위 점유율 & no-repeat 권고 ----
def ngram_repetition_report(files, n_vals=(3,4,5), limit=None):
    """
    간단 토큰열: (pitch, duration_bin) 시퀀스.
    duration_bin은 한 박을 8등분한 스텝으로 반올림하여 사용.
    전역 n-gram 카운트를 모아 상위 n-gram의 점유율을 계산.
    """
    global_counts = {n: Counter() for n in n_vals}
    total = {n: 0 for n in n_vals}

    for i, path in enumerate(files):
        if limit and i >= limit:
            break
        pm = safe_load_midi(path)
        if pm is None:
            continue
        bp = median_beat_period(pm)
        step = bp / 8.0  # 프레이즈 거칠게 보기: 1박 8분할
        seq = []
        for n in all_notes(pm):
            dur_bin = int(round(max((n.end - n.start) / step, 0)))
            dur_bin = min(dur_bin, 31)  # 과도한 길이는 클램프
            seq.append((n.pitch, dur_bin))
        if len(seq) == 0:
            continue
        # n-gram 생성
        for n in n_vals:
            if len(seq) < n:
                continue
            for j in range(len(seq) - n + 1):
                tup = tuple(seq[j:j+n])
                global_counts[n][tup] += 1
                total[n] += 1

    report = {}
    for n in n_vals:
        if total[n] == 0 or len(global_counts[n]) == 0:
            report[n] = {"top_ratio": None, "suggest": None}
            print(f"[n={n}-gram] 데이터 부족")
            continue
        top_ng, top_ct = global_counts[n].most_common(1)[0]
        top_ratio = top_ct / total[n]
        # 권고: 상위 n-gram 점유율이 2.5% 이상이면 no-repeat n-gram 적용
        suggest = (f"no_repeat_ngram_size={n}" if top_ratio >= 0.025 else "optional")
        print(f"[n={n}-gram] 상위 점유율={top_ratio*100:.2f}%  | 권고: {suggest}")
        report[n] = {"top_ratio": top_ratio, "suggest": suggest}
    return report

# ---- 실행: 파일 목록을 넣어 한 번에 리포트 ----
def run_added_eda(train_files, val_files, test_files, limit_per_split=None):
    files = []
    for L in (train_files, val_files, test_files):
        files.extend(L[:limit_per_split] if limit_per_split else L)

    print("파일 수:", len(files))
    out = {}
    out["ioi"] = ioi_snap_report(files)
    out["vel"] = velocity_report(files)
    out["rep"] = ngram_repetition_report(files)
    print("\n요약:")
    print(" - TIME_SHIFT 추천:", out['ioi']['choice'])
    print(" - Velocity bin 추천:", out['vel']['choice'])
    best_rep = max((v["top_ratio"] for v in out["rep"].values() if v["top_ratio"] is not None), default=None)
    if best_rep is not None and best_rep >= 0.025:
        # 가장 강한 n-gram의 n을 찾아서 표시
        pick_n = max(out["rep"], key=lambda k: (out["rep"][k]["top_ratio"] or -1))
        print(f" - 반복 억제 권고: {out['rep'][pick_n]['suggest']} (상위 점유율={out['rep'][pick_n]['top_ratio']*100:.2f}%)")
    else:
        print(" - 반복 억제: optional (강한 반복 패턴 증거 약함)")
    return out

# 사용 예:
# result = run_added_eda(train_files, val_files, test_files, limit_per_split=50)

In [None]:
result = run_added_eda(train_files, val_files, test_files, limit_per_split=50)

### 토큰화
- 이벤트 토큰 방식

In [None]:
# !pip -q install pretty_midi

import math, os, json, hashlib
import numpy as np
import pretty_midi
from collections import defaultdict

# =======================
# 0) 설정
# =======================
TS_DIV = 64          # 1박(quarter) 64분할
VEL_BINS = 16        # velocity 0~127 → 16개 bin
TS_MAX = 16          # 하나의 TS 토큰이 표현하는 최대 스텝(긴 간격은 여러 개로 쪼갬)
PROGRAM = 0          # 피아노 (Acoustic Grand)

# 특별 토큰 ID 고정
PAD_ID = 0
BOS_ID = 1
EOS_ID = 2

# NOTE/VEL/TS 토큰의 id 공간 시작점 (충돌 없게 순차 배치)
# [PAD, BOS, EOS] 3개 예약 → id=3부터 본 토큰 시작
VEL_BASE = 3                     # VEL_1 .. VEL_16 → 16개
TS_BASE  = VEL_BASE + VEL_BINS   # TS_1 .. TS_16   → TS_MAX개
NON_BASE = TS_BASE  + TS_MAX     # NOTE_ON_0 .. NOTE_ON_127
NOFF_BASE= NON_BASE + 128        # NOTE_OFF_0 .. NOTE_OFF_127
VOCAB_SIZE = NOFF_BASE + 128

def vel_to_bin(v: int, bins: int = VEL_BINS):
    # 0~127 → 1..bins (0은 거의 없다고 가정)
    v = max(1, min(127, int(v)))
    step = 128 / bins
    b = int(math.ceil(v / step))
    return max(1, min(bins, b))

def bin_to_vel(b: int, bins: int = VEL_BINS):
    # bin 번호의 "중앙값"으로 되돌림
    step = 128 / bins
    lo = int((b - 1) * step)
    hi = int(b * step - 1)
    return int((lo + hi) // 2)

def beat_period_seconds(pm: pretty_midi.PrettyMIDI):
    """한 박(quarter-note) 길이(초). 비트가 있으면 beat 간격의 중앙값, 아니면 템포 기반."""
    beats = pm.get_beats()
    if len(beats) >= 2:
        return float(np.median(np.diff(beats)))
    t, tempi = pm.get_tempo_changes()
    if len(tempi):
        return float(np.median(60.0 / tempi))
    return 0.5  # fallback = 120bpm

def _id_vel(b):          return VEL_BASE + (b - 1)          # 1..16 → VEL_BASE..VEL_BASE+15
def _id_ts(k):           return TS_BASE  + (k - 1)          # 1..TS_MAX
def _id_non(pitch):      return NON_BASE + pitch            # 0..127
def _id_noff(pitch):     return NOFF_BASE + pitch           # 0..127
def _is_vel(tok):        return VEL_BASE <= tok < VEL_BASE + VEL_BINS
def _is_ts(tok):         return TS_BASE  <= tok < TS_BASE  + TS_MAX
def _is_non(tok):        return NON_BASE <= tok < NON_BASE + 128
def _is_noff(tok):       return NOFF_BASE<= tok < NOFF_BASE+ 128

# =======================
# 1) 토큰화 (MIDI → ids)
# =======================
def tokenize_midi(path, ts_div=TS_DIV, vel_bins=VEL_BINS, ts_max=TS_MAX):
    """
    - 1박을 ts_div로 등분 (64)
    - velocity를 vel_bins로 양자화 (16)
    - 같은 타임스텝에 여러 NOTE_ON/NOTE_OFF 가능(폴리포니)
    - 이벤트 순서: (필요한 TS들) → [VEL → NOTE_ON]* → [NOTE_OFF]*
    반환: ids(list[int]), aux(dict: step_sec 등)
    """
    pm = pretty_midi.PrettyMIDI(path)
    bp = beat_period_seconds(pm)
    step_sec = bp / ts_div

    # 모든 노트 수집 (드럼 제외)
    notes = []
    for inst in pm.instruments:
        if inst.is_drum:
            continue
        for n in inst.notes:
            notes.append(n)
    # 시작·끝 스냅(반올림)
    for n in notes:
        n._grid_start = int(round(n.start / step_sec))
        n._grid_end   = max(n._grid_start + 1, int(round(n.end / step_sec)))  # 최소 1스텝은 유지

    # 타임스텝별 버킷팅
    bucket_on  = defaultdict(list)  # timestep → [(pitch, vel_bin), ...]
    bucket_off = defaultdict(list)  # timestep → [pitch, ...]
    for n in notes:
        vb = vel_to_bin(n.velocity, vel_bins)
        bucket_on[n._grid_start].append((n.pitch, vb))
        bucket_off[n._grid_end].append(n.pitch)

    # 시간 진행
    tokens = [BOS_ID]
    cur_t = 0
    timeline = sorted(set(list(bucket_on.keys()) + list(bucket_off.keys())))
    for t in timeline:
        if t < cur_t:
            continue
        gap = t - cur_t
        # gap을 TS 토큰으로 분해 (예: 17 → 16 + 1)
        while gap > 0:
            step = min(ts_max, gap)
            tokens.append(_id_ts(step))
            gap -= step
        cur_t = t

        # 동시 발음: pitch 오름차순, "VEL → NOTE_ON" 반복
        if t in bucket_on:
            for pitch, vb in sorted(bucket_on[t], key=lambda x: x[0]):
                tokens.append(_id_vel(vb))
                tokens.append(_id_non(pitch))

        # NOTE_OFF는 같은 타임스텝에서 한 번에 방출 (pitch 오름차순)
        if t in bucket_off:
            for pitch in sorted(bucket_off[t]):
                tokens.append(_id_noff(pitch))

    tokens.append(EOS_ID)
    aux = {"step_sec": step_sec, "program": PROGRAM, "ts_div": ts_div, "vel_bins": vel_bins, "ts_max": ts_max}
    return tokens, aux

# =======================
# 2) 디토큰화 (ids → MIDI)
# =======================
def detokenize_to_pretty_midi(tokens, aux):
    """
    - TS_k로 그리드 인덱스를 전진
    - VEL_b → NOTE_ON_p 순서로 온음 생성 (현재 velocity 상태 반영)
    - NOTE_OFF_p에서 해당 음 종료
    - 남은 온음은 마지막 그리드 시점에서 정리
    """
    step_sec = float(aux.get("step_sec", 0.5/TS_DIV))  # 기본 120bpm
    program  = int(aux.get("program", PROGRAM))

    pm = pretty_midi.PrettyMIDI()
    inst = pretty_midi.Instrument(program=program, is_drum=False)
    pm.instruments.append(inst)

    cur_grid = 0
    current_vel_bin = vel_to_bin(64, VEL_BINS)  # 초기값(중간 세기)
    open_notes = {}  # pitch → start_time_sec

    def grid_to_time(g): return g * step_sec

    i = 0
    N = len(tokens)
    while i < N:
        tok = tokens[i]
        i += 1
        if tok == BOS_ID:
            continue
        if tok == EOS_ID:
            break

        if _is_ts(tok):
            k = (tok - TS_BASE) + 1
            cur_grid += k
            continue

        if _is_vel(tok):
            current_vel_bin = (tok - VEL_BASE) + 1
            continue

        if _is_non(tok):
            pitch = (tok - NON_BASE)
            # 같은 그리드에서 NOTE_ON이 연달아 나와도 허용
            start = grid_to_time(cur_grid)
            vel = bin_to_vel(current_vel_bin, VEL_BINS)
            if pitch in open_notes:
                # 이미 열려 있으면, 일단 닫고 다시 시작(비정상 케이스 방지)
                inst.notes.append(pretty_midi.Note(
                    velocity=bin_to_vel(current_vel_bin, VEL_BINS), pitch=pitch,
                    start=open_notes[pitch], end=start + step_sec
                ))
            open_notes[pitch] = start
            continue

        if _is_noff(tok):
            pitch = (tok - NOFF_BASE)
            if pitch in open_notes:
                start = open_notes.pop(pitch)
                end = max(start + 1e-3, grid_to_time(cur_grid))  # 최소 길이 확보
                vel = bin_to_vel(current_vel_bin, VEL_BINS)
                inst.notes.append(pretty_midi.Note(velocity=vel, pitch=pitch, start=start, end=end))
            # 열리지 않은 음에 대한 off는 무시
            continue

        # 알 수 없는 토큰은 무시

    # 남은 음들 정리: 마지막 그리드 시점에서 최소 길이만큼 닫기
    end_time = grid_to_time(cur_grid)
    for pitch, st in list(open_notes.items()):
        inst.notes.append(pretty_midi.Note(
            velocity=bin_to_vel(current_vel_bin, VEL_BINS), pitch=pitch,
            start=st, end=max(st + step_sec, end_time)
        ))
    return pm

def detokenize_to_midi_file(tokens, aux, out_path):
    pm = detokenize_to_pretty_midi(tokens, aux)
    pm.write(out_path)
    return out_path

# =======================
# 3) 간단 라운드트립 테스트/리포트
# =======================
def tokenize_and_reconstruct(path, out_midi_path=None):
    """
    1) 토큰화 → 2) 디토큰 → 3) 원본/복원 길이, 이벤트 수 비교 리포트
    """
    toks, aux = tokenize_midi(path)
    pm_orig = pretty_midi.PrettyMIDI(path)
    pm_recon = detokenize_to_pretty_midi(toks, aux)

    dur_o = pm_orig.get_end_time()
    dur_r = pm_recon.get_end_time()
    cnt_o = sum(len(inst.notes) for inst in pm_orig.instruments if not inst.is_drum)
    cnt_r = sum(len(inst.notes) for inst in pm_recon.instruments if not inst.is_drum)

    report = {
        "tokens": len(toks),
        "orig_duration": dur_o,
        "recon_duration": dur_r,
        "dur_rel_err_%": (abs(dur_o - dur_r) / max(1e-6, dur_o)) * 100.0,
        "orig_events": cnt_o,
        "recon_events": cnt_r,
        "evt_rel_err_%": (abs(cnt_o - cnt_r) / max(1, cnt_o)) * 100.0
    }

    if out_midi_path:
        pm_recon.write(out_midi_path)
        report["saved"] = out_midi_path
    return toks, aux, report

print("VOCAB_SIZE:", VOCAB_SIZE, "| TS_DIV:", TS_DIV, "| VEL_BINS:", VEL_BINS)

In [None]:
# 예: train_files[0] 를 하나 집어 왕복 검증
sample_path = train_files[0]
tokens, aux, rep = tokenize_and_reconstruct(sample_path, out_midi_path=None)
print(rep)

### Dataset/DataLoader

In [None]:
# 전제: 앞서 정의한 tokenize_midi()가 이미 세션에 존재한다고 가정합니다.
# 필요시: from your_module import tokenize_midi

import os, csv, json, math, hashlib, random, time
from pathlib import Path
from typing import List, Dict, Any, Tuple
import numpy as np
import pretty_midi
import torch
from torch.utils.data import Dataset, DataLoader

# =========================
# 설정(필요시 수정)
# =========================
PAD_ID = 0
BOS_ID = 1
EOS_ID = 2

# 토큰화 규칙 버전(캐시 구분용): 토큰 규칙 바뀌면 꼭 바꿔주세요.
TOKEN_RULE_VERSION = "evt_ts64_vel16_tsm16_v1"

# 필터 규칙(권장 기본)
MIN_EVENTS = 200          # 너무 짧은 곡 제외
MAX_DENSITY = 10.0        # events/sec 상한
MIN_DURATION_SEC = 30.0   # 30초 미만 제외

# 캐시/로그 디렉토리
PROJ = "/content/drive/MyDrive/Deep_Learning_project/original_token"
CACHE_DIR = f"{PROJ}/data/processed"
LOG_DIR   = f"{PROJ}/logs"
REPORT_DIR= f"{PROJ}/reports"
os.makedirs(CACHE_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)
os.makedirs(REPORT_DIR, exist_ok=True)

FILTER_REPORT_CSV = f"{LOG_DIR}/filter_report.csv"

# =========================
# 유틸
# =========================
def sha1_text(s: str) -> str:
    return hashlib.sha1(s.encode("utf-8")).hexdigest()

def safe_midi_stats(path: str) -> Dict[str, Any]:
    """
    필터 판단을 위한 빠른 통계: duration, events(#notes), density.
    """
    try:
        pm = pretty_midi.PrettyMIDI(path)
        duration = pm.get_end_time()
        events = sum(len(inst.notes) for inst in pm.instruments if not inst.is_drum)
        density = (events / max(1e-6, duration)) if duration > 0 else float("inf")
        return {"ok": True, "duration": duration, "events": events, "density": density}
    except Exception as e:
        return {"ok": False, "error": str(e)}

def cache_paths_for(midipath: str) -> Tuple[str, str]:
    """
    캐시 파일 경로: .npy(토큰), .json(aux)
    캐시 키 = sha1(토큰규칙버전 + 절대경로 + 최종수정시각)
    """
    p = Path(midipath)
    stat = p.stat()
    key_src = f"{TOKEN_RULE_VERSION}|{str(p.resolve())}|{stat.st_mtime_ns}"
    key = sha1_text(key_src)
    npy_path = os.path.join(CACHE_DIR, f"{key}.npy")
    js_path  = os.path.join(CACHE_DIR, f"{key}.json")
    return npy_path, js_path

def load_or_tokenize(midipath: str):
    """
    캐시가 있으면 로드, 없으면 토큰화 후 저장.
    반환: (tokens: np.ndarray[int], aux: dict)
    """
    npy_path, js_path = cache_paths_for(midipath)
    if os.path.exists(npy_path) and os.path.exists(js_path):
        toks = np.load(npy_path)
        with open(js_path, "r") as f:
            aux = json.load(f)
        return toks, aux, True  # from_cache=True

    # 캐시 없음 → 토큰화
    tokens, aux = tokenize_midi(midipath)  # 앞서 제공한 함수 사용
    tokens = np.asarray(tokens, dtype=np.int32)

    # 저장
    np.save(npy_path, tokens)
    with open(js_path, "w") as f:
        json.dump(aux, f)

    return tokens, aux, False

def append_filter_report(rows: List[Dict[str, Any]]):
    """
    제외된 파일/사유를 CSV로 기록.
    """
    write_header = not os.path.exists(FILTER_REPORT_CSV)
    with open(FILTER_REPORT_CSV, "a", newline="") as f:
        w = csv.DictWriter(f, fieldnames=["path", "reason", "duration", "events", "density", "error"])
        if write_header:
            w.writeheader()
        for r in rows:
            w.writerow(r)

# =========================
# Dataset
# =========================
class MidiTokenDataset(Dataset):
    def __init__(self,
                 paths: List[str],
                 max_len: int = 512,
                 pad_id: int = PAD_ID,
                 apply_filters: bool = True,
                 seed: int = 42):
        """
        paths: MIDI 파일 경로 리스트
        max_len: Truncated BPTT 창 길이
        """
        super().__init__()
        self.paths_all = list(paths)
        self.max_len = int(max_len)
        self.pad_id = int(pad_id)
        self.rng = random.Random(seed)

        # 1) 필터(선택)
        self.paths = []
        filtered_rows = []
        if apply_filters:
            for p in self.paths_all:
                st = safe_midi_stats(p)
                if not st["ok"]:
                    filtered_rows.append({"path": p, "reason": "parse_error",
                                          "duration": None, "events": None, "density": None,
                                          "error": st.get("error")})
                    continue
                if st["events"] < MIN_EVENTS:
                    filtered_rows.append({"path": p, "reason": "too_few_events",
                                          "duration": st["duration"], "events": st["events"],
                                          "density": st["density"], "error": None})
                    continue
                if st["duration"] < MIN_DURATION_SEC:
                    filtered_rows.append({"path": p, "reason": "too_short_duration",
                                          "duration": st["duration"], "events": st["events"],
                                          "density": st["density"], "error": None})
                    continue
                if st["density"] > MAX_DENSITY:
                    filtered_rows.append({"path": p, "reason": "too_high_density",
                                          "duration": st["duration"], "events": st["events"],
                                          "density": st["density"], "error": None})
                    continue
                # 통과
                self.paths.append(p)
        else:
            self.paths = self.paths_all

        # 로그 기록
        if filtered_rows:
            append_filter_report(filtered_rows)

        if len(self.paths) == 0:
            raise RuntimeError("유효한 학습 샘플이 없습니다. 필터 기준을 조정하세요.")

        # 2) 각 파일의 토큰 길이 메타(빠른 슬라이싱을 위해)
        #    (캐시가 없으면 생성하면서 길이 파악)
        self.lengths = []
        t0 = time.time()
        for p in self.paths:
            toks, aux, from_cache = load_or_tokenize(p)
            self.lengths.append(int(len(toks)))
        t1 = time.time()
        print(f"[MidiTokenDataset] 캐시/토큰 길이 준비 완료: {len(self.paths)}개, {t1-t0:.1f}s")

    def __len__(self):
        # 샘플 = "파일 단위"가 아니라 "슬라이스 단위"로 보려면 IterableDataset 설계가 필요하지만,
        # 여기선 간단히 파일 단위로 두고, __getitem__에서 무작위 슬라이스를 뽑습니다.
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        toks, aux, _ = load_or_tokenize(path)
        L = len(toks)

        # 최소 길이: BOS, ..., EOS → 학습 시 x=t[:-1], y=t[1:]
        if L < 2:
            # 빈에 가까운 곡이면, 아주 짧은 더미 반환(필터에서 걸러지는 게 일반적)
            x = np.array([BOS_ID], dtype=np.int64)
            y = np.array([EOS_ID], dtype=np.int64)
            mask = np.array([1], dtype=np.int64)
            return torch.from_numpy(x), torch.from_numpy(y), torch.from_numpy(mask)

        # Truncated BPTT: max_len+1 창을 랜덤 연속 슬라이스로 선택
        T = self.max_len + 1
        if L <= T:
            slice_tokens = toks  # 짧으면 전체 사용
        else:
            start = self.rng.randint(0, L - T)
            slice_tokens = toks[start:start+T]

        # x, y 분리
        x = slice_tokens[:-1].astype(np.int64)
        y = slice_tokens[1: ].astype(np.int64)

        # 아직 패딩 전(개별 시퀀스) → collate에서 패딩
        return torch.from_numpy(x), torch.from_numpy(y), torch.tensor(1, dtype=torch.int64)  # dummy mask flag

In [None]:
def collate_pad(batch, pad_id: int = PAD_ID):
    """
    batch: list of (x, y, _)
    - 동적 패딩: 배치 내 최장 길이에 맞춰 PAD 채움
    - 마스크: (x != PAD)
    반환: xpad, ypad, mask  (shape: [B, T])
    """
    xs, ys, _ = zip(*batch)
    lens = [len(x) for x in xs]
    T = max(lens)
    B = len(xs)

    xpad = torch.full((B, T), pad_id, dtype=torch.long)
    ypad = torch.full((B, T), pad_id, dtype=torch.long)

    for i, (x, y) in enumerate(zip(xs, ys)):
        t = len(x)
        xpad[i, :t] = x
        ypad[i, :t] = y

    mask = (xpad != pad_id).to(torch.bool)
    return xpad, ypad, mask

In [None]:
from torch.utils.data import DataLoader

# 예: 앞서 만든 리스트
# train_files, val_files, test_files

max_len = 512
batch_size = 32

train_ds = MidiTokenDataset(train_files, max_len=max_len, apply_filters=True, seed=42)
val_ds   = MidiTokenDataset(val_files,   max_len=max_len, apply_filters=True, seed=43)

train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                      num_workers=2, pin_memory=True, collate_fn=lambda b: collate_pad(b, PAD_ID))

val_dl   = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                      num_workers=2, pin_memory=True, collate_fn=lambda b: collate_pad(b, PAD_ID))

xb, yb, mb = next(iter(train_dl))
xb.shape, yb.shape, mb.shape, xb.dtype

In [None]:
import os, json, numpy as np, pandas as pd, torch, time
from collections import Counter

# 0) 환경 요약
print("CUDA:", torch.cuda.is_available(), "| device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu")

# 1) split 파일/경로 확인
print("\n[split CSV 확인]")
for name, df in [("train", train_df), ("val", val_df), ("test", test_df)]:
    print(f"{name:<5} rows={len(df):4d},  has_full_path={df['full_path'].notna().sum():4d}")

# 2) 캐시 디렉토리 요약
print("\n[캐시 파일 요약]")
cache_dir = CACHE_DIR  # 당신이 설정한 CACHE_DIR 사용
n_npy = len([f for f in os.listdir(cache_dir) if f.endswith(".npy")])
n_js  = len([f for f in os.listdir(cache_dir) if f.endswith(".json")])
print("CACHE_DIR:", cache_dir)
print("npy:", n_npy, "json:", n_js)

# 3) 토큰 길이 통계(샘플 200곡)
print("\n[토큰 길이 통계 (샘플)]")
sample_files = (train_files[:100] + val_files[:50] + test_files[:50])[:200]
lens = []
t0=time.time()
for p in sample_files:
    toks, aux, from_cache = load_or_tokenize(p)
    lens.append(len(toks))
print(f"샘플 {len(lens)}개, 로드 {time.time()-t0:.1f}s  |  P50={np.percentile(lens,50):.0f}, P95={np.percentile(lens,95):.0f}, MAX={np.max(lens)}")

# 4) DataLoader 배치 무결성
print("\n[DataLoader 배치 무결성]")
xb, yb, mb = next(iter(train_dl))
print("xb/yb/mb shapes:", xb.shape, yb.shape, mb.shape, "| dtype:", xb.dtype)
assert xb.shape == yb.shape == mb.shape, "배치 텐서 shape 불일치"
assert xb.dtype == torch.long, "토큰 dtype은 torch.long이어야 합니다"
pad_id = PAD_ID
pad_frac = (xb==pad_id).float().mean().item()
print(f"PAD 비율(배치 평균): {pad_frac*100:.1f}%")

# 5) 라운드트립(토큰화↔복원) 빠른 점검 3곡
print("\n[라운드트립 테스트 3곡]")
from random import sample
for p in sample(train_files, k=min(3,len(train_files))):
    _, aux, rep = tokenize_and_reconstruct(p, out_midi_path=None)
    print(os.path.basename(p), "| dur_err%={:.2f}, evt_err%={:.2f}, tokens={}".format(
        rep["dur_rel_err_%"], rep["evt_rel_err_%"], rep["tokens"]))

# 6) 간단 손실계산 드라이런(모델 없이 mask/ignore 논리 확인)
print("\n[ignore_index 논리 점검]")
ce = torch.nn.CrossEntropyLoss(ignore_index=pad_id, reduction="mean")
# vocab_size는 당신의 VOCAB_SIZE 변수 사용
vocab_size = VOCAB_SIZE
with torch.no_grad():
    # 가짜 로짓: [B,T,V]
    logits = torch.randn(xb.size(0), xb.size(1), vocab_size)
    loss = ce(logits.view(-1, vocab_size), yb.reshape(-1))
print("dummy CE(loss, ignore PAD) =", float(loss))

print("\n[요약]")
print("- split CSV/파일 경로 OK")
print("- 캐시 파일 개수:", n_npy, "(npy) /", n_js, "(json)")
print("- DataLoader 배치/마스크 OK, PAD ignore CE OK")
print("- 라운드트립 dur/evt 오차가 매우 크면 토큰화 규칙 재점검 필요")

## 모델링 (1)

In [None]:
# ===== 하이퍼파라미터/경로 =====
import os, math, csv, time, random
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler

# 기본 경로 (필요시 수정)
PROJ = "/content/drive/MyDrive/Deep_Learning_project/original_token"
CKPT_DIR   = f"{PROJ}/ckpt"
LOG_DIR    = f"{PROJ}/logs"
SAMPLES_DIR= f"{PROJ}/samples"
os.makedirs(CKPT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)
os.makedirs(SAMPLES_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 필수 토큰 상수 (없을 때의 안전 장치)
try: PAD_ID
except NameError: PAD_ID = 0
try: BOS_ID
except NameError: BOS_ID = 1
try: EOS_ID
except NameError: EOS_ID = 2
try: VOCAB_SIZE
except NameError: VOCAB_SIZE = 3 + 16 + 16 + 128 + 128  # 폴백

@dataclass
class TrainConfig:
    vocab_size: int = VOCAB_SIZE
    pad_id: int = PAD_ID
    d_model: int = 512
    lstm_hidden: int = 768
    lstm_layers: int = 2
    dropout: float = 0.25
    max_len: int = 512
    lr: float = 3e-4
    weight_decay: float = 0.01
    grad_clip: float = 1.0
    epochs: int = 100
    log_every: int = 200
    val_every: int = 1000
    amp: bool = True
    seed: int = 42
cfg = TrainConfig()

# ===== 모델 정의 =====
class EventLSTM(nn.Module):
    def __init__(self, vocab_size, d_model=512, hidden=768, layers=2, dropout=0.2, pad_id=0):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.dropout = nn.Dropout(dropout)
        self.lstm = nn.LSTM(
            input_size=d_model,
            hidden_size=hidden,
            num_layers=layers,
            batch_first=True,
            dropout=dropout
        )
        self.norm = nn.LayerNorm(hidden)
        self.head = nn.Linear(hidden, vocab_size)

    def forward(self, x, hidden=None):
        x = self.embed(x)
        x = self.dropout(x)
        out, hidden = self.lstm(x, hidden)
        out = self.norm(out)
        logits = self.head(out)
        return logits, hidden

# ===== 유틸 =====
def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def perplexity(nll):
    try: return math.exp(nll)
    except OverflowError: return float("inf")

def save_checkpoint(model, opt, scaler, step, path):
    torch.save({
        "model": model.state_dict(),
        "opt": opt.state_dict(),
        "scaler": scaler.state_dict() if scaler is not None else None,
        "step": step
    }, path)

def sample_and_save(model, start_token=BOS_ID, max_tokens=1024, temperature=1.0, top_k: Optional[int]=50,
                    out_midi_path: Optional[str]=None, aux_for_detok: Optional[dict]=None):
    model.eval()
    toks = [start_token]
    hidden = None
    with torch.no_grad():
        for _ in range(max_tokens-1):
            x = torch.tensor(toks[-cfg.max_len:], dtype=torch.long, device=device).unsqueeze(0)
            logits, hidden = model(x, hidden=None)
            logits = logits[:, -1, :] / max(1e-6, temperature)
            if top_k is not None and top_k > 0:
                topv, topi = torch.topk(logits, k=min(top_k, logits.size(-1)), dim=-1)
                probs = torch.softmax(topv, dim=-1)
                idx = topi.gather(-1, torch.multinomial(probs, num_samples=1))
                next_id = int(idx.item())
            else:
                probs = torch.softmax(logits, dim=-1)
                next_id = int(torch.multinomial(probs, num_samples=1).item())
            toks.append(next_id)
            if next_id == EOS_ID:
                break
    if out_midi_path is not None and aux_for_detok is not None:
        try: detokenize_to_midi_file(toks, aux_for_detok, out_midi_path)
        except Exception as e: print("[WARN] detokenize failed:", e)
    return toks

# ===== 학습/검증 루프 =====
def train_one_epoch(model, dl, opt, scheduler, scaler, ce, step0=0,
                    log_path=f"{LOG_DIR}/train_val_curve.csv"):
    model.train()  # 매 epoch 시작 시 학습 모드 보장
    running_loss, step = 0.0, step0
    t0 = time.time()

    for xb, yb, mb in dl:
        xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)

        if cfg.amp:
            with autocast():
                logits, _ = model(xb)
                loss = ce(logits.reshape(-1, cfg.vocab_size), yb.reshape(-1))
            # backward
            scaler.scale(loss).backward()
            if cfg.grad_clip is not None:
                scaler.unscale_(opt)
                nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            scaler.step(opt)
            scaler.update()
        else:
            logits, _ = model(xb)
            loss = ce(logits.reshape(-1, cfg.vocab_size), yb.reshape(-1))
            loss.backward()
            if cfg.grad_clip is not None:
                nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            opt.step()

        # ✅ 스케줄러는 항상 optimizer.step() 이후
        if scheduler is not None:
            scheduler.step()

        step += 1
        running_loss += float(loss.item())

        # 로깅
        if step % cfg.log_every == 0:
            avg_nll = running_loss / cfg.log_every
            ppl = perplexity(avg_nll)
            print(f"[train] step {step}  nll={avg_nll:.3f}  ppl={ppl:.1f}  ({time.time()-t0:.1f}s)")
            append_log(log_path, {"step": step, "split": "train", "nll": avg_nll, "ppl": ppl})
            running_loss = 0.0

        # 중간 검증
        if step % cfg.val_every == 0:
            nll, ppl = evaluate(model, val_dl)  # evaluate 내부에서 model.eval()
            print(f"[val]   step {step}  nll={nll:.3f}  ppl={ppl:.1f}")
            append_log(log_path, {"step": step, "split": "val", "nll": nll, "ppl": ppl})

            # (옵션) 샘플 저장
            aux = {"step_sec": 0.5/64, "program": 0}
            out_mid = os.path.join(SAMPLES_DIR, f"step{step}_sample.mid")
            _ = sample_and_save(model, out_midi_path=out_mid, aux_for_detok=aux)

            model.train()  # 🔑 아주 중요: 평가 후 반드시 학습 모드로 복귀
    return step

@torch.no_grad()
def evaluate(model, dl):
    model.eval()
    ce = nn.CrossEntropyLoss(ignore_index=cfg.pad_id, reduction="mean")
    total_loss, total_tokens = 0.0, 0
    for xb, yb, mb in dl:
        xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)
        logits, _ = model(xb)
        loss = ce(logits.reshape(-1, cfg.vocab_size), yb.reshape(-1))
        total_loss += float(loss.item()) * xb.size(0)
        total_tokens += xb.size(0)
    nll = total_loss / max(1, total_tokens)
    ppl = perplexity(nll)
    return nll, ppl

def append_log(csv_path, row: dict):
    write_header = not os.path.exists(csv_path)
    with open(csv_path, "a", newline="") as f:
        w = csv.DictWriter(f, fieldnames=["step", "split", "nll", "ppl"])
        if write_header: w.writeheader()
        w.writerow(row)

In [None]:
# ===== 실제 실행 =====
torch.manual_seed(cfg.seed)
random.seed(cfg.seed)

model = EventLSTM(
    vocab_size=cfg.vocab_size,
    d_model=cfg.d_model,
    hidden=cfg.lstm_hidden,
    layers=cfg.lstm_layers,
    dropout=cfg.dropout,
    pad_id=cfg.pad_id
).to(device)

opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

# ✅ 스케줄러 추가 (Cosine, T_max는 "총 스텝 수" 기준)
steps_per_epoch = len(train_dl)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=max(1, cfg.epochs * steps_per_epoch))

# 기존 AMP도 그대로 사용 가능 (경고는 떠도 동작 OK)
scaler = GradScaler(enabled=cfg.amp)
ce = nn.CrossEntropyLoss(ignore_index=cfg.pad_id, reduction="mean")

print(f"Model params: {count_params(model)/1e6:.2f}M")
print(f"steps/epoch: {steps_per_epoch}")

global_step = 0
for ep in range(1, cfg.epochs+1):
    print(f"\n=== EPOCH {ep}/{cfg.epochs} ===")
    # ✅ 스케줄러를 train_one_epoch에 인자로 전달
    global_step = train_one_epoch(model, train_dl, opt, scheduler, scaler, ce, step0=global_step)

    # 에폭 끝 검증 + 체크포인트
    nll, ppl = evaluate(model, val_dl)
    print(f"[val @epoch{ep}] nll={nll:.3f}  ppl={ppl:.1f}")
    append_log(f"{LOG_DIR}/train_val_curve.csv",
               {"step": global_step, "split": f"val_ep{ep}", "nll": nll, "ppl": ppl})
    save_checkpoint(model, opt, scaler, global_step,
                    os.path.join(CKPT_DIR, f"lstm_ep{ep}_step{global_step}.pt"))

### 결과

In [None]:
# 2) 샘플 토큰 → MIDI → WAV → 재생
from IPython.display import Audio, display
import tempfile, os
from midi2audio import FluidSynth

toks = sample_and_save(
    model,
    start_token=BOS_ID,
    max_tokens=512,
    temperature=1.0,
    top_k=50,
    out_midi_path=None, # 여기서는 저장하지 않고 토큰만 받음
    aux_for_detok={"step_sec": 0.5/64, "program": 0}
)

# 저장할 파일 경로 지정 (SAMPLES_DIR 사용)
sample_idx = int(time.time()) # 파일명 충돌 방지를 위해 타임스탬프 사용
out_mid_path = os.path.join(SAMPLES_DIR, f"generated_sample_{sample_idx}.mid")
out_wav_path = out_mid_path.replace(".mid", ".wav")


detokenize_to_midi_file(toks, {"step_sec": 0.5/64, "program": 0}, out_mid_path)

sf2_path = "/usr/share/sounds/sf2/FluidR3_GM.sf2"
fs = FluidSynth(sf2_path)
fs.midi_to_audio(out_mid_path, out_wav_path)

display(Audio(out_wav_path))
print("MIDI:", out_mid_path, "| WAV:", out_wav_path)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os

plt.rcParams['font.family'] = 'DejaVu Sans'
plt.rcParams['axes.unicode_minus'] = False

def _read_log_safe(log_path):
    if not os.path.exists(log_path):
        raise FileNotFoundError(f"로그 파일을 찾을 수 없습니다: {log_path}")
    # 잘못된 행 스킵, 파이썬 엔진 사용
    df = pd.read_csv(
        log_path,
        engine="python",
        on_bad_lines="skip"
    )
    # 필요한 컬럼만 남기기 (여분 헤더/깨진 행 제거)
    keep = [c for c in ["step", "split", "nll", "ppl"] if c in df.columns]
    df = df[keep]
    # 숫자형 강제 변환
    if "step" in df.columns:
        df["step"] = pd.to_numeric(df["step"], errors="coerce")
    if "nll" in df.columns:
        df["nll"] = pd.to_numeric(df["nll"], errors="coerce")
    if "ppl" in df.columns:
        df["ppl"] = pd.to_numeric(df["ppl"], errors="coerce")
    # 유효한 행만
    df = df.dropna(subset=["step", "split", "nll", "ppl"])
    # step 정렬
    df = df.sort_values("step").reset_index(drop=True)
    return df

def plot_training_curves(log_path=f"{LOG_DIR}/train_val_curve.csv", save_path=None):
    df = _read_log_safe(log_path)

    train_data = df[df['split'] == 'train']
    val_data   = df[df['split'].str.contains('val', na=False)]

    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('LSTM Music Generation Training Results', fontsize=16, fontweight='bold')

    # 1) NLL
    ax1 = axes[0, 0]
    if len(train_data): ax1.plot(train_data['step'], train_data['nll'], 'b-', label='Train NLL', alpha=0.7, linewidth=1)
    if len(val_data):   ax1.plot(val_data['step'],   val_data['nll'],   'r-', label='Validation NLL', linewidth=2)
    ax1.set_xlabel('Training Steps'); ax1.set_ylabel('NLL'); ax1.set_title('Loss Curve (NLL)')
    ax1.legend(); ax1.grid(True, alpha=0.3)

    # 2) PPL
    ax2 = axes[0, 1]
    if len(train_data): ax2.plot(train_data['step'], train_data['ppl'], 'b-', label='Train PPL', alpha=0.7, linewidth=1)
    if len(val_data):   ax2.plot(val_data['step'],   val_data['ppl'],   'r-', label='Validation PPL', linewidth=2)
    ax2.set_xlabel('Training Steps'); ax2.set_ylabel('Perplexity'); ax2.set_title('Perplexity Curve')
    ax2.legend(); ax2.grid(True, alpha=0.3); ax2.set_yscale('log')

    # 3) Epoch-wise validation (raw string으로 정규식)
    ax3 = axes[1, 0]
    epoch_val_data = df[df['split'].str.contains('val_ep', na=False)].copy()
    if len(epoch_val_data):
        epoch_val_data['epoch'] = epoch_val_data['split'].str.extract(r'val_ep(\d+)').astype(int)
        epoch_val_data = epoch_val_data.sort_values('epoch')
        ax3.plot(epoch_val_data['epoch'], epoch_val_data['nll'], 'ro-', label='Val NLL', linewidth=2, markersize=6)
        ax3_twin = ax3.twinx()
        ax3_twin.plot(epoch_val_data['epoch'], epoch_val_data['ppl'], 'go-', label='Val PPL', linewidth=2, markersize=6)
        ax3.set_xlabel('Epoch'); ax3.set_ylabel('Validation NLL', color='red')
        ax3_twin.set_ylabel('Validation Perplexity', color='green')
        ax3.set_title('Epoch-wise Validation Performance'); ax3.grid(True, alpha=0.3)

    # 4) 최근 10 에폭 PPL
    ax4 = axes[1, 1]
    if len(epoch_val_data):
        recent = epoch_val_data.tail(10)
        x = np.arange(len(recent))
        bars = ax4.bar(x, recent['ppl'], alpha=0.7, color='skyblue', edgecolor='navy')
        ax4.set_xlabel('Recent Epochs'); ax4.set_ylabel('Perplexity')
        ax4.set_title('Recent Validation Perplexity (Last 10 Epochs)')
        ax4.set_xticks(x); ax4.set_xticklabels(recent['epoch'], rotation=45)
        for bar, v in zip(bars, recent['ppl']):
            ax4.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1, f'{v:.1f}', ha='center', va='bottom', fontsize=8)

    plt.tight_layout()
    if save_path is None:
        save_path = os.path.join(LOG_DIR, 'training_curves.png')
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    print(f"그래프가 저장되었습니다: {save_path}")
    plt.show()

def plot_detailed_analysis(log_path=f"{LOG_DIR}/train_val_curve.csv"):
    df = _read_log_safe(log_path)

    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle('Detailed Training Analysis', fontsize=16, fontweight='bold')

    # 1) 이동평균 손실
    ax1 = axes[0, 0]
    train_data = df[df['split'] == 'train']
    val_data   = df[df['split'].str.contains('val', na=False)]
    if len(train_data):
        window = max(1, len(train_data) // 20)
        smooth = train_data['nll'].rolling(window=window, center=True).mean()
        ax1.plot(train_data['step'], smooth, 'b-', linewidth=2, label=f'Train NLL (MA-{window})')
        ax1.plot(train_data['step'], train_data['nll'], 'b-', alpha=0.3, linewidth=0.5)
    if len(val_data):
        ax1.plot(val_data['step'], val_data['nll'], 'r-', linewidth=2, label='Validation NLL')
    ax1.set_xlabel('Training Steps'); ax1.set_ylabel('NLL'); ax1.set_title('Loss Convergence (with Moving Average)')
    ax1.legend(); ax1.grid(True, alpha=0.3)

    # 2) Val PPL 분포
    ax2 = axes[0, 1]
    val_only = df[df['split'].str.contains('val', na=False)]
    if len(val_only):
        ax2.hist(val_only['ppl'], bins=20, alpha=0.7, color='red', edgecolor='black')
        ax2.axvline(val_only['ppl'].mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: {val_only["ppl"].mean():.1f}')
        ax2.axvline(val_only['ppl'].median(), color='orange', linestyle='--', linewidth=2, label=f'Median: {val_only["ppl"].median():.1f}')
    ax2.set_xlabel('Perplexity'); ax2.set_ylabel('Frequency'); ax2.set_title('Validation Perplexity Distribution')
    ax2.legend(); ax2.grid(True, alpha=0.3)

    # 3) 에폭별 개선률
    ax3 = axes[1, 0]
    epoch_val = df[df['split'].str.contains('val_ep', na=False)].copy()
    if len(epoch_val):
        epoch_val['epoch'] = epoch_val['split'].str.extract(r'val_ep(\d+)').astype(int)
        epoch_val = epoch_val.sort_values('epoch')
        epoch_val['ppl_impr'] = epoch_val['ppl'].diff()
        ax3.bar(epoch_val['epoch'][1:], epoch_val['ppl_impr'][1:], alpha=0.7, color='green')
        ax3.axhline(0, color='black', linestyle='-', alpha=0.5)
        ax3.set_xlabel('Epoch'); ax3.set_ylabel('Perplexity Change'); ax3.set_title('Epoch-wise Perplexity Improvement')
        ax3.grid(True, alpha=0.3)

    # 4) 학습 안정성
    ax4 = axes[1, 1]
    if len(train_data) > 10:
        group_size = max(1, len(train_data) // 10)
        groups = [train_data['nll'].iloc[i:i+group_size] for i in range(0, len(train_data), group_size)]
        stds = [g.std() for g in groups if len(g) > 1]
        xs = [i * group_size for i in range(len(stds))]
        ax4.plot(xs, stds, 'bo-', linewidth=2, markersize=6)
        ax4.set_xlabel('Training Steps (Grouped)'); ax4.set_ylabel('NLL Std'); ax4.set_title('Training Stability (Loss Variance)')
        ax4.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

print("=== 학습 곡선 시각화 ===")
plot_training_curves()

print("\n=== 상세 분석 ===")
plot_detailed_analysis()

## 모델링 (2)
- 어텐션 (o)
- 계층적 구조 (x)

In [None]:
# ===== 개선버전: LSTM(+Attention) 실험 프레임 (Colab-ready) =====
import os, math, csv, time, json, random, itertools
from dataclasses import dataclass
from typing import Optional, Dict, Any, List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler

# ========== 경로/디바이스 ==========
PROJ = "/content/drive/MyDrive/Deep_Learning_project/original_token"
CKPT_DIR    = f"{PROJ}/ckpt"
LOG_DIR     = f"{PROJ}/logs"
SAMPLES_DIR = f"{PROJ}/samples"
RESULTS_DIR = f"{PROJ}/results"
for d in [CKPT_DIR, LOG_DIR, SAMPLES_DIR, RESULTS_DIR]:
    os.makedirs(d, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ===== 필수 토큰 상수 (폴백) =====
try: PAD_ID
except NameError: PAD_ID = 0
try: BOS_ID
except NameError: BOS_ID = 1
try: EOS_ID
except NameError: EOS_ID = 2
try: VOCAB_SIZE
except NameError: VOCAB_SIZE = 3 + 16 + 16 + 128 + 128  # PAD/BOS/EOS + VEL(16) + TS(16) + ON(128) + OFF(128)

# ========== 설정 ==========
@dataclass
class TrainConfig:
    vocab_size: int = VOCAB_SIZE
    pad_id: int = PAD_ID
    d_model: int = 512
    lstm_hidden: int = 768
    lstm_layers: int = 2
    max_len: int = 512

    # lr/opt
    lr: float = 3e-4
    weight_decay: float = 0.01
    epochs: int = 50
    grad_clip: float = 1.0
    amp: bool = True

    # logging
    log_every: int = 200
    val_every: int = 1000
    seed: int = 42

    # dropout (분리)
    dropout_emb: float = 0.1
    dropout_lstm: float = 0.25
    dropout_attn: float = 0.1
    dropout_ffn: float = 0.1

    # Attention 옵션
    use_attention: bool = False
    num_attention_heads: int = 8

    # 랜덤 서치
    n_trials: int = 6   # 실험 개수 (시간 절약용 소수)
    # 샘플링
    default_temperature: float = 1.0
    default_top_k: int = 50
    default_top_p: float = 0.0
    default_no_repeat_ngram: int = 0

cfg = TrainConfig()

def set_seed(seed: int):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

set_seed(cfg.seed)

# ========== 모델 ==========
class EventLSTM(nn.Module):
    def __init__(self, vocab_size, d_model=512, hidden=768, layers=2,
                 dropout_emb=0.1, dropout_lstm=0.25, pad_id=0):
        super().__init__()
        self.use_attention = False
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.drop_emb = nn.Dropout(dropout_emb)
        self.lstm = nn.LSTM(
            input_size=d_model, hidden_size=hidden, num_layers=layers,
            batch_first=True, dropout=dropout_lstm if layers > 1 else 0.0
        )
        self.norm = nn.LayerNorm(hidden)
        self.drop_out = nn.Dropout(dropout_lstm)
        self.head = nn.Linear(hidden, vocab_size)

    def forward(self, x: torch.Tensor, hidden: Optional[Tuple[torch.Tensor,torch.Tensor]]=None):
        # x: (B, T)
        x = self.drop_emb(self.embed(x))      # (B,T,d_model)
        out, hidden = self.lstm(x, hidden)    # (B,T,H)
        out = self.drop_out(self.norm(out))   # (B,T,H)
        logits = self.head(out)               # (B,T,V)
        return logits, hidden

class EventLSTMWithAttention(nn.Module):
    def __init__(self, vocab_size, d_model=512, hidden=768, layers=2,
                 dropout_emb=0.1, dropout_lstm=0.25,
                 num_heads=8, dropout_attn=0.1, dropout_ffn=0.1, pad_id=0):
        super().__init__()
        self.use_attention = True
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.drop_emb = nn.Dropout(dropout_emb)

        self.lstm = nn.LSTM(
            input_size=d_model, hidden_size=hidden, num_layers=layers,
            batch_first=True, dropout=dropout_lstm if layers > 1 else 0.0
        )
        self.attn = nn.MultiheadAttention(embed_dim=hidden, num_heads=num_heads,
                                          dropout=dropout_attn, batch_first=True)
        self.dropout_attn = nn.Dropout(dropout_attn)
        self.dropout_ffn  = nn.Dropout(dropout_ffn)

        # Transformer-style: (Attn -> Residual+Norm) then (FFN -> Residual+Norm)
        self.norm1 = nn.LayerNorm(hidden)
        self.norm2 = nn.LayerNorm(hidden)
        self.ffn = nn.Sequential(
            nn.Linear(hidden, 4*hidden),
            nn.GELU(),
            nn.Linear(4*hidden, hidden)
        )
        self.head = nn.Linear(hidden, vocab_size)

    def _causal_mask(self, T: int, device):
        # True=mask (차단). 2D (T,T) boolean mask
        return torch.triu(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=1)

    def forward(self, x: torch.Tensor, hidden: Optional[Tuple[torch.Tensor,torch.Tensor]]=None):
        # x: (B,T)
        x = self.drop_emb(self.embed(x))            # (B,T,d_model)
        lstm_out, hidden = self.lstm(x, hidden)     # (B,T,H)

        T = lstm_out.size(1)
        attn_mask = self._causal_mask(T, lstm_out.device)  # (T,T) boolean

        attn_out, _ = self.attn(lstm_out, lstm_out, lstm_out, attn_mask=attn_mask)  # (B,T,H)
        y = self.norm1(lstm_out + self.dropout_attn(attn_out))                      # (B,T,H)

        ffn_out = self.ffn(y)                                   # (B,T,H)
        y = self.norm2(y + self.dropout_ffn(ffn_out))           # (B,T,H)

        logits = self.head(y)                                   # (B,T,V)
        return logits, hidden

# ========== 유틸 ==========
def count_params(model): return sum(p.numel() for p in model.parameters() if p.requires_grad)

def perplexity(nll):
    try:
        return math.exp(nll)
    except OverflowError:
        return float("inf")

def save_checkpoint(model, opt, scaler, step, path):
    torch.save({
        "model": model.state_dict(),
        "opt": opt.state_dict(),
        "scaler": scaler.state_dict() if scaler is not None else None,
        "step": step
    }, path)

# nucleus(top-p) & top-k & no-repeat-ngram
def _apply_sampling_filters(logits: torch.Tensor, top_k: int=0, top_p: float=0.0) -> torch.Tensor:
    # logits: (V,)
    probs = torch.softmax(logits, dim=-1)

    # top-k
    if top_k and top_k > 0 and top_k < probs.numel():
        topv, topi = torch.topk(probs, k=top_k)
        filtered = torch.full_like(probs, 0.0)
        filtered.scatter_(0, topi, topv)
        probs = filtered

    # top-p (nucleus)
    if top_p and 0.0 < top_p < 1.0:
        sorted_probs, sorted_idx = torch.sort(probs, descending=True)
        cumsum = torch.cumsum(sorted_probs, dim=-1)
        mask = cumsum > top_p
        # keep first element even if > top_p (to ensure at least one token)
        mask[0] = False
        sorted_probs[mask] = 0.0
        probs = torch.zeros_like(probs)
        probs.scatter_(0, sorted_idx, sorted_probs)

    # renormalize
    s = probs.sum()
    if s.item() > 0:
        probs = probs / s
    else:
        # fallback to uniform if all zero
        probs = torch.full_like(probs, 1.0 / probs.numel())

    return probs

def _forbidden_next_tokens_by_ngram(prefix: List[int], n: int) -> set:
    """prefix 내에서 길이 n의 n-gram이 이미 나타난 경우,
       현재 마지막 (n-1)-gram 뒤에 올 수 있는 금지 토큰 집합을 반환"""
    if n <= 1 or len(prefix) < n-1:
        return set()
    mapping = {}
    for i in range(len(prefix)-n+1):
        key = tuple(prefix[i:i+n-1])
        nxt = prefix[i+n-1]
        mapping.setdefault(key, set()).add(nxt)
    key = tuple(prefix[-(n-1):])
    return mapping.get(key, set())

@torch.no_grad()
def sample_and_save(model, start_token=BOS_ID, max_tokens=1024,
                    temperature: float=1.0, top_k: Optional[int]=50, top_p: float=0.0,
                    no_repeat_ngram_size: int=0,
                    out_midi_path: Optional[str]=None, aux_for_detok: Optional[dict]=None):
    model.eval()
    toks: List[int] = [start_token]
    hidden = None

    for _ in range(max_tokens-1):
        if getattr(model, "use_attention", False):
            # 어텐션 사용하는 경우: 프리픽스 전체 재계산(정확성 우선)
            x = torch.tensor(toks[-cfg.max_len:], dtype=torch.long, device=device).unsqueeze(0)
            logits, hidden = model(x, hidden=None)
            last = logits[:, -1, :].squeeze(0)  # (V,)
        else:
            # LSTM-only: 1토큰씩 hidden carry (효율)
            x = torch.tensor([[toks[-1]]], dtype=torch.long, device=device)
            logits, hidden = model(x, hidden)
            last = logits.squeeze(0).squeeze(0)  # (V,)

        last = last / max(1e-6, temperature)

        # n-gram 금지 토큰 마스킹
        if no_repeat_ngram_size and no_repeat_ngram_size > 1:
            forbids = _forbidden_next_tokens_by_ngram(toks, no_repeat_ngram_size)
        else:
            forbids = set()

        # 필터링(top-k, top-p)
        probs = _apply_sampling_filters(last, top_k=top_k or 0, top_p=top_p)

        if forbids:
            probs[list(forbids)] = 0.0
            s = probs.sum()
            probs = probs / s if s.item() > 0 else torch.full_like(probs, 1.0 / probs.numel())

        next_id = int(torch.multinomial(probs, num_samples=1).item())
        toks.append(next_id)
        if next_id == EOS_ID:
            break

    if out_midi_path is not None and aux_for_detok is not None:
        try:
            detokenize_to_midi_file(toks, aux_for_detok, out_midi_path)
        except Exception as e:
            print("[WARN] detokenize failed:", e)

    return toks

# ========== 학습/검증 루프 ==========
def append_log(csv_path, row: Dict[str, Any]):
    write_header = not os.path.exists(csv_path)
    with open(csv_path, "a", newline="") as f:
        w = csv.DictWriter(f, fieldnames=["step", "split", "nll", "ppl"])
        if write_header: w.writeheader()
        w.writerow(row)

def train_one_epoch(model, dl, opt, scheduler, scaler, ce, step0=0, log_path=f"{LOG_DIR}/train_val_curve.csv",
                    val_dl=None):
    model.train()
    running_loss, step = 0.0, step0
    t0 = time.time()

    for xb, yb, mb in dl:
        xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)
        opt.zero_grad(set_to_none=True)

        if cfg.amp:
            with autocast():
                logits, _ = model(xb)
                loss = ce(logits.reshape(-1, cfg.vocab_size), yb.reshape(-1))
            scaler.scale(loss).backward()
            if cfg.grad_clip is not None:
                scaler.unscale_(opt)
                nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            scaler.step(opt)
            scaler.update()
        else:
            logits, _ = model(xb)
            loss = ce(logits.reshape(-1, cfg.vocab_size), yb.reshape(-1))
            loss.backward()
            if cfg.grad_clip is not None:
                nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
            opt.step()

        if scheduler is not None:
            scheduler.step()

        step += 1
        running_loss += float(loss.item())

        if step % cfg.log_every == 0:
            avg_nll = running_loss / cfg.log_every
            ppl = perplexity(avg_nll)
            print(f"[train] step {step}  nll={avg_nll:.3f}  ppl={ppl:.1f}  ({time.time()-t0:.1f}s)")
            append_log(log_path, {"step": step, "split": "train", "nll": avg_nll, "ppl": ppl})
            running_loss = 0.0

        if (val_dl is not None) and (step % cfg.val_every == 0):
            nll, ppl = evaluate(model, val_dl)
            print(f"[val]   step {step}  nll={nll:.3f}  ppl={ppl:.1f}")
            append_log(log_path, {"step": step, "split": "val", "nll": nll, "ppl": ppl})
            model.train()
    return step

@torch.no_grad()
def evaluate(model, dl):
    model.eval()
    ce = nn.CrossEntropyLoss(ignore_index=cfg.pad_id, reduction="none")
    total_loss, total_tokens = 0.0, 0

    for xb, yb, mb in dl:
        xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)
        logits, _ = model(xb)  # (B,T,V)
        V = cfg.vocab_size
        loss_vec = ce(logits.reshape(-1, V), yb.reshape(-1))  # (B*T,)
        mask = (yb.reshape(-1) != cfg.pad_id)
        total_loss += float(loss_vec[mask].sum().item())
        total_tokens += int(mask.sum().item())

    nll = total_loss / max(1, total_tokens)
    ppl = perplexity(nll)
    return nll, ppl

# ========== 랜덤 서치 ==========
def sample_config_space() -> Dict[str, Any]:
    return {
        "use_attention": random.choice([False, True]),
        "num_attention_heads": random.choice([4, 8, 16]),
        "lr": random.choice([1e-4, 3e-4, 1e-3]),
        "weight_decay": random.choice([0.01, 0.1, 0.5]),
        "dropout_emb": random.choice([0.05, 0.1, 0.2]),
        "dropout_lstm": random.choice([0.2, 0.25, 0.3]),
        "dropout_attn": random.choice([0.05, 0.1, 0.2]),
        "dropout_ffn": random.choice([0.05, 0.1, 0.2]),
    }

def build_model_from_cfg() -> nn.Module:
    if cfg.use_attention:
        model = EventLSTMWithAttention(
            vocab_size=cfg.vocab_size,
            d_model=cfg.d_model,
            hidden=cfg.lstm_hidden,
            layers=cfg.lstm_layers,
            dropout_emb=cfg.dropout_emb,
            dropout_lstm=cfg.dropout_lstm,
            num_heads=cfg.num_attention_heads,
            dropout_attn=cfg.dropout_attn,
            dropout_ffn=cfg.dropout_ffn,
            pad_id=cfg.pad_id
        ).to(device)
    else:
        model = EventLSTM(
            vocab_size=cfg.vocab_size,
            d_model=cfg.d_model,
            hidden=cfg.lstm_hidden,
            layers=cfg.lstm_layers,
            dropout_emb=cfg.dropout_emb,
            dropout_lstm=cfg.dropout_lstm,
            pad_id=cfg.pad_id
        ).to(device)
    return model

def run_random_search(train_dl, val_dl, n_trials: int = None):
    n_trials = n_trials or cfg.n_trials
    results = []
    best_perplexity = float("inf")
    best_config = None

    print(f"총 {n_trials}개 랜덤 서치를 실행합니다...")

    for i in range(1, n_trials+1):
        # 샘플 설정
        conf = sample_config_space()
        # 어텐션 off일 때도 헤드/attn 드롭아웃 값은 유지하지만 사용 안 함(무시)
        cfg.use_attention     = conf["use_attention"]
        cfg.num_attention_heads = conf["num_attention_heads"]
        cfg.lr               = conf["lr"]
        cfg.weight_decay     = conf["weight_decay"]
        cfg.dropout_emb      = conf["dropout_emb"]
        cfg.dropout_lstm     = conf["dropout_lstm"]
        cfg.dropout_attn     = conf["dropout_attn"]
        cfg.dropout_ffn      = conf["dropout_ffn"]

        print(f"\n=== 실험 {i}/{n_trials} ===")
        print(f"설정: {conf}")

        set_seed(cfg.seed + i)  # 실험별 시드 변이
        model = build_model_from_cfg()
        print(f"모델 파라미터: {count_params(model)/1e6:.2f}M")

        opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
        steps_per_epoch = max(1, len(train_dl))
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=max(1, cfg.epochs * steps_per_epoch))
        scaler = GradScaler(enabled=cfg.amp)
        ce = nn.CrossEntropyLoss(ignore_index=cfg.pad_id, reduction="mean")

        global_step = 0
        best_val_ppl = float("inf")

        for ep in range(1, cfg.epochs + 1):
            global_step = train_one_epoch(model, train_dl, opt, scheduler, scaler, ce, step0=global_step, val_dl=val_dl)
            nll, ppl = evaluate(model, val_dl)
            print(f"[val @epoch{ep}] nll={nll:.3f}  ppl={ppl:.2f}")
            if ppl < best_val_ppl:
                best_val_ppl = ppl

        results.append({
            "config": conf,
            "best_val_perplexity": float(best_val_ppl),
            "model_params": int(count_params(model)),
            "experiment_id": i
        })
        print(f"최고 검증 Perplexity: {best_val_ppl:.2f}")

        if best_val_ppl < best_perplexity:
            best_perplexity = best_val_ppl
            best_config = conf
            print(f"★ 새로운 최고 성능! Perplexity: {best_perplexity:.2f}")

    results_file = f"{RESULTS_DIR}/random_search_results.json"
    with open(results_file, "w") as f:
        json.dump(results, f, indent=2)

    print("\n=== 랜덤 서치 완료 ===")
    print(f"총 {len(results)}개 실험 완료")
    print(f"최고 성능: Perplexity {best_perplexity:.2f}")
    print(f"최고 설정: {best_config}")
    print(f"결과 저장: {results_file}")

    analyze_results(results)
    return results, best_config

def analyze_results(results: List[Dict[str,Any]]):
    print("\n=== 결과 분석 ===")
    attn = [r for r in results if r["config"]["use_attention"]]
    base = [r for r in results if not r["config"]["use_attention"]]
    if attn and base:
        avg_attn = sum(r["best_val_perplexity"] for r in attn) / len(attn)
        avg_base = sum(r["best_val_perplexity"] for r in base) / len(base)
        imp = (avg_base - avg_attn) / max(1e-9, avg_base) * 100.0
        print(f"Baseline 평균 PPL: {avg_base:.2f}")
        print(f"Attention 평균 PPL: {avg_attn:.2f}")
        print(f"개선율: {imp:.1f}%")

    top5 = sorted(results, key=lambda x: x["best_val_perplexity"])[:5]
    print("\n상위 5개 결과:")
    for i, r in enumerate(top5, 1):
        print(f"{i}. PPL {r['best_val_perplexity']:.2f}  | 설정: {r['config']}")

# ========== 메인 ==========
if __name__ == "__main__":
    print("🔥 LSTM 업그레이드 1차: Attention + 랜덤 서치")
    print(f"디바이스: {device}")
    print(f"VOCAB_SIZE: {VOCAB_SIZE}")
    # NOTE: train_dl / val_dl 은 사전에 준비되어 있어야 합니다.
    # 예) train_dl, val_dl = ...
    results, best_config = run_random_search(train_dl, val_dl, n_trials=cfg.n_trials)

    # 샘플 생성 예시
    # sample_and_save(model, start_token=BOS_ID, max_tokens=1024,
    #                 temperature=cfg.default_temperature, top_k=cfg.default_top_k,
    #                 top_p=cfg.default_top_p, no_repeat_ngram_size=cfg.default_no_repeat_ngram,
    #                 out_midi_path=f"{SAMPLES_DIR}/sample.mid", aux_for_detok=aux_dict)

In [None]:
# === 베스트 설정으로 단일 학습 → 베스트 가중치 저장/로드 → 샘플 생성/재생 (Colab-ready) ===
import os, json, time
from IPython.display import Audio, display
from midi2audio import FluidSynth
import torch
from torch.cuda.amp import GradScaler
import torch.nn as nn

# 경로
PROJ = "/content/drive/MyDrive/Deep_Learning_project/original_token"
CKPT_DIR = f"{PROJ}/ckpt"
RESULTS_JSON = f"{PROJ}/results/random_search_results.json"
SAMPLES_DIR = f"{PROJ}/samples"


BEST_CKPT_PATH = f"{CKPT_DIR}/best_model.pt"


# 0) best_config 확보: 우선 변수, 없으면 results.json에서 최저 PPL 항목 자동 선택
try:
    best_config  # 이미 변수로 존재하는 경우
except NameError:
    with open(RESULTS_JSON, "r") as f:
        results = json.load(f)
    best_row = min(results, key=lambda r: r["best_val_perplexity"])
    best_config = best_row["config"]
print("[best_config]", best_config)

# 1) cfg 적용
cfg.use_attention       = best_config['use_attention']
cfg.num_attention_heads = best_config['num_attention_heads']
cfg.lr                  = best_config['lr']
cfg.weight_decay        = best_config['weight_decay']
cfg.dropout_emb         = best_config['dropout_emb']
cfg.dropout_lstm        = best_config['dropout_lstm']
cfg.dropout_attn        = best_config['dropout_attn']
cfg.dropout_ffn         = best_config['dropout_ffn']

# 2) 모델 생성
best_model = build_model_from_cfg()
opt = torch.optim.AdamW(best_model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
steps_per_epoch = max(1, len(train_dl))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=max(1, cfg.epochs * steps_per_epoch))
scaler = GradScaler(enabled=cfg.amp)
ce = nn.CrossEntropyLoss(ignore_index=cfg.pad_id, reduction="mean")

# 3) 단일 학습 루프 + 최고 PPL 가중치만 저장
best_val_ppl = float("inf")
global_step = 0
for ep in range(1, cfg.epochs + 1):
    global_step = train_one_epoch(best_model, train_dl, opt, scheduler, scaler, ce, step0=global_step, val_dl=val_dl)
    nll, ppl = evaluate(best_model, val_dl)
    print(f"[single-train @epoch{ep}] nll={nll:.3f}  ppl={ppl:.2f}")
    if ppl < best_val_ppl:
        best_val_ppl = ppl
        torch.save(best_model.state_dict(), BEST_CKPT_PATH)
        print(f"✔ 베스트 갱신: PPL={best_val_ppl:.2f}  → 저장: {BEST_CKPT_PATH}")

# 4) 가장 좋은 가중치 로드(안전)
best_model.load_state_dict(torch.load(BEST_CKPT_PATH, map_location=device))
best_model.eval()
print("✅ 베스트 모델 준비 완료")

# 5) 토큰 생성 → MIDI/WAV 저장 → 재생 (여러 샘플)
import random

# 생성 설정
N_SAMPLES = 3
temps = [0.9, 1.0, 1.1]              # 다양성 확보용 (개수 < N이면 순환 사용)
topks = [50, 64, 32]                 # top-k도 살짝 바꿔서 컬러 다변화
program = 0                          # GM Acoustic Grand Piano
aux_base = {"step_sec": 0.5/64, "program": program}  # 학습 설정과 일치 필요
sf2 = "/usr/share/sounds/sf2/FluidR3_GM.sf2"        # Colab 기본 SoundFont 경로

os.makedirs(SAMPLES_DIR, exist_ok=True)

def generate_one(idx, temperature, top_k, seed=None):
    if seed is None:
        seed = int(time.time()) + idx * 1234567
    torch.manual_seed(seed)
    random.seed(seed)

    # 5-1) 토큰 샘플링
    toks = sample_and_save(
        best_model,
        start_token=BOS_ID,
        max_tokens=512,
        temperature=temperature,
        top_k=top_k,
        out_midi_path=None,                 # 직접 detokenize 할거라 None
        aux_for_detok=aux_base
    )

    # 5-2) 파일 경로 준비
    tag = f"t{temperature:.1f}_k{top_k}_s{seed % 100000}"
    mid = os.path.join(SAMPLES_DIR, f"best_{int(time.time())}_{idx+1}_{tag}.mid")
    wav = mid.replace(".mid", ".wav")

    # 5-3) detokenize (BOS/PAD 제거)
    toks = [t for t in toks if t not in (PAD_ID, BOS_ID)]
    detokenize_to_midi_file(toks, aux_base, mid)

    # 5-4) WAV 변환 & 재생
    FluidSynth(sf2).midi_to_audio(mid, wav)
    display(Audio(wav))
    print(f"[{idx+1}/{N_SAMPLES}] TEMP={temperature}, TOPK={top_k}, SEED={seed}")
    print("MIDI:", mid, "| WAV:", wav, "\n")

# === 생성 실행 ===
for i in range(N_SAMPLES):
    temperature = temps[i % len(temps)]
    top_k = topks[i % len(topks)]
    generate_one(i, temperature, top_k)

### 결과

In [None]:
# 5) 토큰 생성 → MIDI/WAV 저장 → 재생 (최소 옵션)
toks = sample_and_save(
    best_model,
    start_token=BOS_ID,
    max_tokens=1024,
    temperature=1.0,
    top_k=50,
    out_midi_path=None,
    aux_for_detok={"step_sec": 0.5/64, "program": 0}  # TS_DIV=64 가정(학습 설정과 일치 필요)
)

# 파일 경로
os.makedirs(SAMPLES_DIR, exist_ok=True)
mid = os.path.join(SAMPLES_DIR, f"best_{int(time.time())}.mid")
wav = mid.replace(".mid", ".wav")

# detokenize (BOS/PAD 최소 정리)
toks = [t for t in toks if t not in (PAD_ID, BOS_ID)]
detokenize_to_midi_file(toks, {"step_sec": 0.5/64, "program": 0}, mid)

# WAV 변환 & 재생 (SoundFont 경로만 확인)
sf2 = "/usr/share/sounds/sf2/FluidR3_GM.sf2"  # Colab 기본 경로(없으면 apt 설치 필요)
FluidSynth(sf2).midi_to_audio(mid, wav)

display(Audio(wav))
print("MIDI:", mid, "| WAV:", wav)