In [None]:
# === mozart만 6:2:2로 분리 + flattened_metadata_with_split.json 생성 ===
from pathlib import Path
import os, json, re, math, shutil, glob, random

# --- 경로 설정 --- # 🙂🙂 개별 환경따라 변경 🙂🙂
# Google Drive가 마운트된 경로 및 프로젝트 폴더 설정
PROJ = Path("/content/drive/MyDrive/Deep_Learning_project")

# metadata.json 파일 경로 설정
META_JSON = PROJ / "metadata.json"

# MIDI 파일들이 있는 데이터 루트 폴더 설정
DATA_ROOT = PROJ / "original_token" / "mozart_midis"


SPLIT_ROOT = PROJ / "mozart_dataset_TF"
SPLIT_ROOT.mkdir(parents=True, exist_ok=True)

# 여러 번 실행 시, 아래를 True로 두면 train/validation/test 폴더를 비우고 시작합니다.
CLEAN_SPLIT = True
if CLEAN_SPLIT:
    for sub in ["train", "validation", "test"]:
        shutil.rmtree(SPLIT_ROOT / sub, ignore_errors=True)

for sub in ["train", "validation", "test"]:
    (SPLIT_ROOT / sub).mkdir(parents=True, exist_ok=True)

# --- 유틸 ---
def parse_filename(fp: Path):
    m = re.match(r"^(\d{6})_(\d+)\.mid$", fp.name)
    if not m:
        return None, None
    id_num = int(m.group(1))
    take = m.group(2)
    return str(id_num), take  # metadata.json의 키는 '4' 같은 형태

def pick_audio_score(meta_entry: dict, take: str):
    aud = meta_entry.get("audio_scores", {})
    if isinstance(aud, dict) and aud:
        if take in aud:
            return aud[take]
        try:
            return next(iter(aud.values()))
        except StopIteration:
            return None
    return None

def is_mozart(composer_val):  # 🙂🙂 작곡가명따라 변경 🙂🙂
    if composer_val is None:
        return False
    return "mozart" in str(composer_val).lower()  # 🙂🙂 작곡가명따라 변경 🙂🙂

# --- 메타 로드 ---
assert META_JSON.exists(), f"metadata.json not found: {META_JSON}"
with open(META_JSON, "r") as f:
    meta_raw = json.load(f)

# --- 데이터 스캔: 모든 .mid 파일 ---
all_mid_paths = [Path(p) for p in glob.glob(str(DATA_ROOT / "**" / "*.mid"), recursive=True)]

# --- mozart만 필터링 ---  # 🙂🙂 작곡가명따라 변경 🙂🙂
beet_items = []
skipped_no_meta = 0
skipped_bad_name = 0

for fp in all_mid_paths:
    id_str, take = parse_filename(fp)
    if not id_str:
        skipped_bad_name += 1
        continue
    entry = meta_raw.get(id_str)
    if not entry:
        skipped_no_meta += 1
        continue

    md = entry.get("metadata", {})
    if not is_mozart(md.get("composer")):  # 🙂🙂 작곡가명따라 변경 🙂🙂
        continue

    beet_items.append((fp, id_str, take, entry))

print(f"총 MIDI: {len(all_mid_paths)}개")
print(f"mozart 후보: {len(beet_items)}개")  # 🙂🙂 작곡가명따라 변경 🙂🙂
print(f"메타 없음으로 스킵: {skipped_no_meta}개, 파일명 규칙 불일치 스킵: {skipped_bad_name}개")

# --- 6:2:2 분할 ---
SEED = 42
random.Random(SEED).shuffle(beet_items)

N = len(beet_items)
n_train = math.floor(N * 0.6)
n_val   = math.floor(N * 0.2)
n_test  = N - n_train - n_val

splits = (
    [("train", 0.6)] * n_train +
    [("validation", 0.2)] * n_val +
    [("test", 0.2)] * n_test
)

# --- 복사 & 플래튼 메타 구성 ---
flat_meta = {}  # key: 파일명, val: 메타 dict
missing_optionals = {"music_period": 0, "difficulty": 0, "genre": 0, "opus": 0}

for (item, (split_name, split_ratio)) in zip(beet_items, splits):
    fp, id_str, take, entry = item
    md = entry.get("metadata", {})

    basename = fp.name
    audio_score = pick_audio_score(entry, take)

    music_period = md.get("music_period")
    difficulty   = md.get("difficulty")
    genre        = md.get("genre")
    opus         = md.get("opus")

    if music_period is None: missing_optionals["music_period"] += 1
    if difficulty   is None: missing_optionals["difficulty"]   += 1
    if genre        is None: missing_optionals["genre"]        += 1
    if opus         is None: missing_optionals["opus"]         += 1

    dst = SPLIT_ROOT / split_name / basename
    shutil.copy2(fp, dst)  # 같은 이름이면 덮어씀

    flat_meta[basename] = {
        "file_path": basename,       # ex) 000004_0.mid
        "split": split_name,         # train / validation / test
        "composer": md.get("composer"),
        "music_period": music_period,
        "difficulty": difficulty,
        "genre": genre,
        "audio_score": audio_score,
        "opus": opus,
        "split_ratio": split_ratio,
    }

# --- JSON/CSV 저장: 쓰기 가능한 SPLIT_ROOT에 저장 ---
OUT_JSON = SPLIT_ROOT / "flattened_metadata_with_split.json"
OUT_CSV  = SPLIT_ROOT / "flattened_metadata_with_split.csv"

with open(OUT_JSON, "w") as f:
    json.dump(flat_meta, f, ensure_ascii=False, indent=2)

# CSV도 같이 저장(편의)
import pandas as pd
pd.DataFrame.from_dict(flat_meta, orient="index").reset_index(drop=True).to_csv(OUT_CSV, index=False)

print("\n=== 분할 결과 ===")
print(f"train: {n_train}, validation: {n_val}, test: {n_test}")
print(f"저장(JSON): {OUT_JSON}")
print(f"저장(CSV):  {OUT_CSV}")
print(f"출력 폴더: {SPLIT_ROOT} (train/ validation/ test)")
print("\n(참고) optional 필드 결측 개수 →", {k:v for k,v in missing_optionals.items() if v>0})

총 MIDI: 536개
mozart 후보: 536개
메타 없음으로 스킵: 0개, 파일명 규칙 불일치 스킵: 0개

=== 분할 결과 ===
train: 321, validation: 107, test: 108
저장(JSON): /content/drive/MyDrive/Deep_Learning_project/mozart_dataset_TF/flattened_metadata_with_split.json
저장(CSV):  /content/drive/MyDrive/Deep_Learning_project/mozart_dataset_TF/flattened_metadata_with_split.csv
출력 폴더: /content/drive/MyDrive/Deep_Learning_project/mozart_dataset_TF (train/ validation/ test)

(참고) optional 필드 결측 개수 → {'music_period': 37, 'difficulty': 393, 'opus': 27}


In [None]:
# [1] 런타임 체크 (MPS + CPU)
import torch, platform, sys, os, subprocess, textwrap, random, numpy as np

print("Python:", sys.version)
print("MPS available:", torch.backends.mps.is_available())

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

# 재현성
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if device.type == "mps":
    pass  # MPS는 별도의 manual_seed_all 없음
else:
    torch.cuda.manual_seed_all(seed)  # 혹시 cuda fallback 될 경우만

Python: 3.12.11 (main, Jun  4 2025, 08:56:18) [GCC 11.4.0]
MPS available: False
Using device: cpu


In [None]:
!{sys.executable} -m pip install pretty_midi mido einops pyfluidsynth music21


Collecting pretty_midi
  Downloading pretty_midi-0.2.10.tar.gz (5.6 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/5.6 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━[0m [32m3.1/5.6 MB[0m [31m168.6 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━[0m [32m4.2/5.6 MB[0m [31m64.9 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━[0m [32m5.4/5.6 MB[0m [31m54.0 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m5.6/5.6 MB[0m [31m43.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m36.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mido
  Downloading mido-1.3.3-py3-none-any.whl.metadata (6.4 kB)
Collecting pyfluidsyn

In [None]:
from pathlib import Path
import json, random, math, os
import numpy as np
import pretty_midi as pm
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from einops import rearrange

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SPLIT_META_JSON = SPLIT_ROOT / "flattened_metadata_with_split.json"
assert SPLIT_META_JSON.exists(), f"split 메타가 없습니다: {SPLIT_META_JSON}"
with open(SPLIT_META_JSON, "r") as f:
    META = json.load(f)

# 토큰/캐시 경로
VOCAB_JSON   = SPLIT_ROOT / "vocab.json"
TOK_CACHE_DIR= SPLIT_ROOT / "tok_cache"
TOK_CACHE_DIR.mkdir(parents=True, exist_ok=True)

### 토큰화

In [None]:
from music21 import converter

# 음이름 표기: MIDI pitch class → 'C','C#',...,'B'
PC2NAME = ['C','C#','D','Eb','E','F','F#','G','Ab','A','Bb','B']

def detect_key_with_music21(midi_path: Path):
    """music21로 전체 곡의 조성(장/단조)을 추정해 KEY 토큰을 돌려줍니다."""
    try:
        s = converter.parse(str(midi_path))
        k = s.analyze('key')
        tonic_name = k.tonic.name  # e.g., 'C', 'E-'
        # music21의 E- 같은 표기를 좀 더 일반적으로 변환
        tonic_name = tonic_name.replace('-','b')  # E- → Eb
        mode = 'maj' if k.mode.lower().startswith('maj') else 'min'
        return f"{tonic_name}{'maj' if mode=='maj' else 'min'}"  # 예: Cmaj, Amin, Ebmaj
    except Exception:
        return None

In [None]:
import numpy as np

# 기본 화음 템플릿(루트=0 기준의 pitch-class 집합)
CHORD_TEMPLATES = {
    'maj'       : {0,4,7},
    'min'       : {0,3,7},
    'dim'       : {0,3,6},
    'aug'       : {0,4,8},
    'dom7'      : {0,4,7,10},
    'maj7'      : {0,4,7,11},
    'min7'      : {0,3,7,10},
    'halfdim7'  : {0,3,6,10},
    'dim7'      : {0,3,6,9},
}

def best_chord_label(pitches):
    """
    pitches: 리스트/셋 (MIDI pitch들) → 최적의 (root, quality) 반환. 없으면 None.
    평가 기준: 템플릿 커버 비율 + 사이즈 근사성.
    """
    if not pitches:
        return None
    pcs = sorted({p % 12 for p in pitches})
    if not pcs:
        return None

    best = None
    best_score = -1e9
    for root in range(12):
        shifted = {(pc - root) % 12 for pc in pcs}
        for qual, templ in CHORD_TEMPLATES.items():
            inter = len(shifted & templ)
            # 템플릿 포함 비율 & 여분 패널티
            cover = inter / max(1, len(templ))
            extra_penalty = -0.15 * max(0, len(shifted - templ))
            score = cover + extra_penalty
            if score > best_score and inter >= 2:  # 최소 2음 이상 맞아야 화음으로 인정
                best_score = score
                best = (root, qual)
    if best is None:
        return None
    root_name = PC2NAME[best[0]]
    return f"{root_name}:{best[1]}"  # 예: C:maj, A:min, G:dom7


In [None]:
TS_DIV = 16
MIN_VEL, MAX_VEL = 20, 100
MIN_DUR, MAX_DUR = 1, 16

def quantize_time(pm_obj, ts_div=TS_DIV):
    ts = pm_obj.time_signature_changes or [pm.TimeSignature(4,4,0.0)]
    tempo_times, tempi = pm_obj.get_tempo_changes()
    tempo = float(tempi[0]) if len(tempi) else 120.0
    num, den = ts[0].numerator, ts[0].denominator
    beat_len = 60.0 / tempo
    bar_sec = (4/den) * num * beat_len

    notes=[]
    for inst in pm_obj.instruments:
        for n in inst.notes:
            bar = int(n.start // bar_sec)
            pos = int(((n.start - bar*bar_sec)/bar_sec)*ts_div); pos = max(0,min(ts_div-1,pos))
            dur = int(((n.end - n.start)/bar_sec)*ts_div); dur = max(MIN_DUR, min(MAX_DUR, dur))
            vel = int(np.clip(n.velocity, MIN_VEL, MAX_VEL))
            notes.append((bar,pos,n.pitch,dur,vel))
    notes.sort(key=lambda x:(x[0],x[1],x[2]))
    return notes, (num,den), int(tempo), bar_sec

def encode_remi_harmony(midi_path: Path, add_chords=True, chord_every='pos'):
    """
    chord_every: 'bar' → 바 당 1개, 'beat' → 박자 단위(근사), 'pos' → POS 단위(onset 기준).
    """
    pm_obj = pm.PrettyMIDI(str(midi_path))
    notes,(num,den),tempo,bar_sec = quantize_time(pm_obj)

    # 0) KEY(장/단조) 토큰
    key_token = detect_key_with_music21(midi_path)  # 예: 'Cmaj', 'Amin', None

    toks = []
    toks.append(f"TSig_{num}_{den}")
    toks.append(f"TEMPO_{tempo}")
    if key_token:
        toks.append(f"KEY_{key_token}")  # 예: KEY_Cmaj

    # 1) 화음 라벨링을 위한 그룹핑
    #    - chord_every == 'bar' : 같은 bar 내 모든 노트 onsets
    #    - chord_every == 'pos' : (bar,pos) 단위 노트 onsets
    #    - chord_every == 'beat': bar 내 beat 경계 근사 (num 개)
    from collections import defaultdict
    onset_map = defaultdict(list)  # key: (bar, pos or beatIndex) → pitches

    if chord_every == 'bar':
        for (bar,pos,pitch,dur,vel) in notes:
            onset_map[(bar, -1)].append(pitch)

    elif chord_every == 'beat':
        # 박자 경계(대략)로 pos→beat index 매핑 (TS_DIV를 num로 나눔)
        step_per_beat = max(1, TS_DIV // num)
        for (bar,pos,pitch,dur,vel) in notes:
            beat_idx = pos // step_per_beat
            onset_map[(bar, beat_idx)].append(pitch)

    else:  # 'pos'
        for (bar,pos,pitch,dur,vel) in notes:
            onset_map[(bar, pos)].append(pitch)

    # 2) 토큰 시퀀스 생성
    cur_bar = -1
    last_chord = None
    for (bar,pos,pitch,dur,vel) in notes:
        # BAR 토큰
        while cur_bar < bar:
            toks.append("BAR")
            cur_bar += 1
            last_chord = None  # 새 마디에서 화음 새로 판단

        # POS 토큰
        toks.append(f"POS_{pos}")

        # 2-a) 화음 토큰(선택)
        if add_chords:
            key = (bar, -1) if chord_every=='bar' else ((bar, pos) if chord_every=='pos' else (bar, pos // max(1, TS_DIV // num)))
            chord_label = best_chord_label(onset_map.get(key, []))
            if chord_label and chord_label != last_chord:
                toks.append(f"CHORD_{chord_label}")   # 예: CHORD_C:maj, CHORD_A:min
                last_chord = chord_label

        # 2-b) 음표 이벤트
        toks += [f"NOTE_ON_{pitch}", f"DUR_{dur}", f"VEL_{vel}"]

    return toks


In [None]:
def cond_tokens(meta, midi_path_for_key: Path = None):
    t = ["COMPOSER_Beethoven"]
    if meta.get("music_period"): t.append(f"PERIOD_{meta['music_period']}")
    if meta.get("genre"):        t.append(f"GENRE_{meta['genre']}")
    if meta.get("difficulty"):   t.append(f"DIFF_{meta['difficulty']}")
    if meta.get("opus"):         t.append(f"OPUS_{str(meta['opus']).replace(' ','_')}")
    q=meta.get("audio_score")
    if q is not None:
        t.append(f"QUALITY_{'High' if q>=0.8 else 'Med' if q>=0.5 else 'Low'}")
    # KEY (메타/파일 기반)
    if midi_path_for_key is not None:
        k = detect_key_with_music21(midi_path_for_key)
        if k:
            t.append(f"KEY_{k}")  # KEY_Cmaj / KEY_Amin ...
    return t

In [None]:
def build_vocab():
    if VOCAB_JSON.exists(): return json.load(open(VOCAB_JSON))
    vocab={"[PAD]":0,"[BOS]":1,"[EOS]":2,"[UNK]":3}
    idx=len(vocab)
    for fname,meta in META.items():
        p = SPLIT_ROOT/meta["split"]/meta["file_path"]
        if not p.exists(): continue
        # vocab 빌드 단계
        toks = ["[BOS]"] + cond_tokens(meta, midi_path_for_key=p) + encode_remi_harmony(p, add_chords=True, chord_every='pos')+ ["[EOS]"]

        json.dump(toks, open(TOK_CACHE_DIR/(fname+".json"),"w"))
        for t in toks:
            if t not in vocab:
                vocab[t]=idx; idx+=1
    json.dump(vocab, open(VOCAB_JSON,"w"))
    return vocab

VOCAB = build_vocab()
IVOCAB= {i:t for t,i in VOCAB.items()}
VOCAB_SIZE=len(VOCAB)
print("VOCAB_SIZE:", VOCAB_SIZE)

VOCAB_SIZE: 471


In [None]:
SEQ_LEN=512

def toks_to_ids(toks): return [VOCAB.get(t, VOCAB["[UNK]"]) for t in toks]

class BeethovenMIDIDataset(Dataset):
    def __init__(self, split):
        self.items=[(k,v) for k,v in META.items() if v["split"]==split]
        random.Random(42).shuffle(self.items)
    def __len__(self): return len(self.items)
    def __getitem__(self,i):
        fname,meta = self.items[i]
        cache=TOK_CACHE_DIR/(fname+".json")
        if cache.exists(): toks=json.load(open(cache))
        else:
            p=SPLIT_ROOT/meta["split"]/meta["file_path"]
            toks=["[BOS]"]+cond_tokens(meta)+encode_remi_lite(p)+["[EOS]"]
        ids=toks_to_ids(toks)
        if len(ids)>=SEQ_LEN:
            st=random.randint(0, len(ids)-SEQ_LEN)
            seq=ids[st:st+SEQ_LEN]
        else:
            seq=ids+[VOCAB["[PAD]"]] * (SEQ_LEN-len(ids))
        x=torch.tensor(seq[:-1],dtype=torch.long)
        y=torch.tensor(seq[1:], dtype=torch.long)
        return x,y

train_dl=DataLoader(BeethovenMIDIDataset("train"), batch_size=16, shuffle=True, drop_last=True)
val_dl  =DataLoader(BeethovenMIDIDataset("validation"), batch_size=16, shuffle=False, drop_last=False)


In [None]:
import math, torch, torch.nn as nn
from einops import rearrange

# ===== 1) RoPE (Rotary Positional Embedding) =====
def apply_rope(q,k):
    B,H,T,D = q.shape
    pos = torch.arange(T, device=q.device).float()
    inv = 1.0/(10000**(torch.arange(0,D,2,device=q.device).float()/D))
    ang = torch.einsum('t,d->td', pos, inv)
    sin,cos = ang.sin()[None,None], ang.cos()[None,None]
    def rot(x):
        x1,x2=x[...,::2],x[...,1::2]
        return torch.stack([x1*cos - x2*sin, x1*sin + x2*cos], dim=-1).flatten(-2)
    return rot(q), rot(k)

# ===== 2) Self-Attention =====
class CausalSelfAttn(nn.Module):
    def __init__(self,d_model=512,n_head=8,p=0.1):
        super().__init__()
        assert d_model% n_head==0
        self.nh=n_head; self.dk=d_model//n_head
        self.qkv=nn.Linear(d_model, d_model*3)
        self.proj=nn.Linear(d_model, d_model)
        self.drop=nn.Dropout(p)
    def forward(self,x):
        B,T,C=x.shape
        q,k,v = self.qkv(x).chunk(3,-1)
        q=rearrange(q,'b t (h d)->b h t d',h=self.nh)
        k=rearrange(k,'b t (h d)->b h t d',h=self.nh)
        v=rearrange(v,'b t (h d)->b h t d',h=self.nh)
        q,k=apply_rope(q,k)
        att=(q@k.transpose(-1,-2))/math.sqrt(self.dk)
        mask=torch.triu(torch.ones(T,T,device=x.device),1).bool()
        att=att.masked_fill(mask,float('-inf')).softmax(-1)
        att=self.drop(att)
        y=att@v
        y=rearrange(y,'b h t d->b t (h d)')
        return self.proj(y)

# ===== 3) Transformer Block =====
class Block(nn.Module):
    def __init__(self,d=512,h=8,p=0.1,mlp=4):
        super().__init__()
        self.ln1=nn.LayerNorm(d); self.att=CausalSelfAttn(d,h,p)
        self.ln2=nn.LayerNorm(d)
        self.mlp=nn.Sequential(
            nn.Linear(d,d*mlp), nn.GELU(), nn.Dropout(p), nn.Linear(d*mlp,d)
        )
        self.drop=nn.Dropout(p)
    def forward(self,x):
        x=x+self.drop(self.att(self.ln1(x)))
        x=x+self.drop(self.mlp(self.ln2(x)))
        return x

# ===== 4) MiniGPT Model =====
class MiniGPT(nn.Module):
    def __init__(self,vocab_size, d=512, L=8, H=8, p=0.1):
        super().__init__()
        self.emb=nn.Embedding(vocab_size, d)
        self.pos=nn.Parameter(torch.zeros(1, SEQ_LEN-1, d))  # 절대 위치 (RoPE와 병용)
        self.blocks=nn.ModuleList([Block(d,H,p) for _ in range(L)])
        self.ln=nn.LayerNorm(d)
        self.head=nn.Linear(d, vocab_size, bias=False)
    def forward(self,idx):
        x=self.emb(idx) + self.pos[:, :idx.size(1), :]
        for b in self.blocks: x=b(x)
        return self.head(self.ln(x))

# ===== 5) 모델 초기화 =====
model = MiniGPT(VOCAB_SIZE).to(DEVICE)
opt   = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9,0.95), weight_decay=0.1)
criterion = nn.CrossEntropyLoss(ignore_index=VOCAB["[PAD]"], label_smoothing=0.1)


In [None]:
import pretty_midi as pm
from pathlib import Path

def detokenize_to_midi(tokens, out_path, ts_div=16, default_time_sig=(4,4), default_tempo=120,
                       program=0, safe_pitch=(36,96)):
    """
    REMI류 토큰 시퀀스를 MIDI로 복원합니다.
    - tokens 예시: ["TSig_4_4","TEMPO_112","BAR","POS_0","NOTE_ON_60","DUR_4","VEL_80", ...]
    - KEY_*, CHORD_* 토큰은 MIDI에 직접 반영하지 않고 건너뜁니다.
    - ts_div는 토큰화 때 쓴 TS_DIV와 동일해야 합니다(기본 16).
    """
    out_path = Path(out_path) if isinstance(out_path, str) else out_path

    # 초기 메타
    num, den = default_time_sig
    tempo = default_tempo
    beat_len = 60.0 / tempo
    bar_sec = (4/den) * num * beat_len

    cur_bar = -1
    pos = 0

    pm_out = pm.PrettyMIDI()
    inst = pm.Instrument(program=program)
    SAFE_LOW, SAFE_HIGH = safe_pitch

    i = 0
    N = len(tokens)
    while i < N:
        t = tokens[i]

        # 메타 토큰
        if t.startswith("TSig_"):
            try:
                _, a, b = t.split("_")
                num, den = int(a), int(b)
                beat_len = 60.0 / tempo
                bar_sec = (4/den) * num * beat_len
            except Exception:
                pass
            i += 1; continue

        if t.startswith("TEMPO_"):
            try:
                tempo = int(t.split("_")[1])
                beat_len = 60.0 / tempo
                bar_sec = (4/den) * num * beat_len
            except Exception:
                pass
            i += 1; continue

        if t.startswith("KEY_") or t.startswith("CHORD_"):
            i += 1; continue  # 조성/화음 토큰은 MIDI에 직접 반영하지 않음

        # 구조 토큰
        if t == "BAR":
            cur_bar += 1
            pos = 0
            i += 1; continue

        if t.startswith("POS_"):
            try:
                pos = int(t.split("_")[1])
                pos = max(0, min(ts_div-1, pos))
            except Exception:
                pos = max(0, min(ts_div-1, pos))
            i += 1; continue

        # 음표 이벤트
        if t.startswith("NOTE_ON_"):
            # NOTE_ON_p
            try:
                pitch = int(t.split("_")[2])
            except Exception:
                i += 1; continue
            i += 1

            # DUR_d (기본 4), VEL_v (기본 80)
            dur = 4
            vel = 80
            if i < N and tokens[i].startswith("DUR_"):
                try: dur = int(tokens[i].split("_")[1])
                except Exception: pass
                i += 1
            if i < N and (tokens[i].startswith("VEL_") or tokens[i].startswith("VELOCITY_")):
                try: vel = int(tokens[i].split("_")[1])
                except Exception: pass
                i += 1

            pitch = max(SAFE_LOW, min(SAFE_HIGH, pitch))
            start = (cur_bar * bar_sec) + (pos / ts_div) * bar_sec
            end   = start + (dur / ts_div) * bar_sec
            if end <= start:
                end = start + (1 / ts_div) * bar_sec  # 최소 1틱

            inst.notes.append(pm.Note(velocity=vel, pitch=pitch, start=start, end=end))
            continue

        # 알 수 없는 토큰은 무시
        i += 1

    pm_out.instruments.append(inst)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    pm_out.write(str(out_path))
    return out_path


In [None]:
# ====== 0) 준비 ======
import math, torch
import torch.nn as nn
from torch.amp import GradScaler, autocast  # ✅ AMP 최신 API

VOCAB_SIZE = len(VOCAB)

# ====== EMA 유틸 ======
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {k: p.detach().clone() for k, p in model.state_dict().items()}
        self.backup = None
    @torch.no_grad()
    def update(self, model):
        for k, p in model.state_dict().items():
            self.shadow[k].mul_(self.decay).add_(p.detach(), alpha=1 - self.decay)
    def store(self, model):
        self.backup = {k: p.detach().clone() for k, p in model.state_dict().items()}
    def copy_to(self, model):
        model.load_state_dict(self.shadow, strict=False)
    def restore(self, model):
        if self.backup is not None:
            model.load_state_dict(self.backup, strict=False)
            self.backup = None

ema = EMA(model, decay=0.999)

# ====== 1) 손실, 옵티마이저, 스케줄러, 체크포인트 ======
criterion = nn.CrossEntropyLoss(
    ignore_index=VOCAB["[PAD]"],
    label_smoothing=0.05,    # ✅ 0.05로 미세조정
)

opt = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    betas=(0.9, 0.95),
    weight_decay=0.03,       # ✅ 과규제 완화
)

EPOCHS = 50
warmup_steps = 1000
total_steps  = max(1, len(train_dl) * EPOCHS)

def lr_lambda(step):
    if step < warmup_steps:
        return step / max(1, warmup_steps)
    progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
    progress = min(1.0, max(0.0, progress))
    return 0.5 * (1 + math.cos(math.pi * progress))

sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)

use_cuda_amp = (DEVICE.type == 'cuda')
scaler = GradScaler('cuda' if use_cuda_amp else 'cpu')

# ====== 2) 학습/검증 함수 (ACC=2) ======
ACC = 2  # ✅ 유효 배치 x2

def run_epoch(dl, train=True, grad_clip=1.0):
    model.train(train)
    total, n = 0.0, 0
    if train:
        opt.zero_grad(set_to_none=True)
    for step, (x, y) in enumerate(dl, 1):
        x, y = x.to(DEVICE), y.to(DEVICE)
        with torch.set_grad_enabled(train):
            with autocast('cuda' if use_cuda_amp else 'cpu'):
                logits = model(x)
                loss = criterion(logits.view(-1, VOCAB_SIZE), y.view(-1))
                if train and ACC > 1:
                    loss = loss / ACC
        if train:
            scaler.scale(loss).backward()
            if step % ACC == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                scaler.step(opt)
                scaler.update()
                sched.step()
                opt.zero_grad(set_to_none=True)
                ema.update(model)  # ✅ step 뒤 EMA 업데이트
        total += loss.item() * (ACC if train and ACC > 1 else 1.0)
        n += 1
    return total / max(1, n)

# ====== 3) 학습 루프 + EMA 검증 + 베스트 저장 ======
best_val = float('inf')
ckpt_path = SPLIT_ROOT / "minigpt_best.pt"

for ep in range(1, EPOCHS+1):
    tr = run_epoch(train_dl, train=True)

    # ✅ EMA 가중치로 검증
    ema.store(model); ema.copy_to(model)
    vl = run_epoch(val_dl, train=False)
    ema.restore(model)

    if vl < best_val:
        best_val = vl
        torch.save(model.state_dict(), str(ckpt_path))
    print(f"[{ep}/{EPOCHS}] train {tr:.4f} | val {vl:.4f} | best {best_val:.4f} | lr {sched.get_last_lr()[0]:.6f}")

# ====== 4) 베스트 가중치 로드(안전) ======
_ = model.load_state_dict(torch.load(str(ckpt_path), map_location=DEVICE))

# ====== 5) 기본 generate (윈도우 자동 적용: pos 임베딩 길이에 맞춰 crop) ======
@torch.no_grad()
def generate(prompt_tokens, max_new=700, top_k=50, top_p=0.95, temp=0.9):
    model.eval()
    unk_id = VOCAB.get("[UNK]", None)
    if unk_id is None:
        raise ValueError("Vocab must contain [UNK] token.")
    pos_len = getattr(getattr(model, 'pos', None), 'shape', [1, 100000, 0])[1]
    ids = torch.tensor([[VOCAB.get(t, unk_id) for t in prompt_tokens]], device=DEVICE)
    for _ in range(max_new):
        ids_win = ids[:, -pos_len:] if ids.size(1) > pos_len else ids
        logits = model(ids_win)[:, -1, :] / max(temp, 1e-6)
        probs = torch.softmax(logits, dim=-1)[0]
        if top_k > 0:
            topk = torch.topk(probs, top_k)
            mask = torch.ones_like(probs, dtype=torch.bool); mask[topk.indices] = False
            probs = probs.masked_fill(mask, 0)
        if top_p < 1.0:
            sprob, sidx = torch.sort(probs, descending=True)
            keep = torch.cumsum(sprob, dim=-1) <= top_p
            keep[0] = True
            mask = torch.ones_like(probs, dtype=torch.bool); mask[sidx[keep]] = False
            probs = probs.masked_fill(mask, 0)
        probs = probs / probs.sum()
        nxt = torch.multinomial(probs, 1)
        ids = torch.cat([ids, nxt.view(1,1)], dim=1)
        if nxt.item() == VOCAB["[EOS]"]:
            break
    return [IVOCAB[i.item()] for i in ids[0]]

# ====== 6) 길게 생성: 청크 스티칭(모델 수정 없음) ======
def stitch_generate(prompt_tokens, total_new=512, chunk_new=700, context=480,
                    top_k=50, top_p=0.95, temp=0.9, stop_on_eos=False):
    all_tokens = list(prompt_tokens)
    made = 0
    while made < total_new:
        this_new = min(chunk_new, total_new - made)
        cur_prompt = all_tokens[-context:] if len(all_tokens) > context else all_tokens
        chunk = generate(cur_prompt, max_new=this_new, top_k=top_k, top_p=top_p, temp=temp)
        new_part = chunk[len(cur_prompt):] if len(chunk) > len(cur_prompt) else []
        if stop_on_eos and ("[EOS]" in new_part):
            eos_idx = new_part.index("[EOS]"); all_tokens += new_part[:eos_idx]; break
        all_tokens += new_part
        made += len(new_part)
        if len(new_part) == 0:
            break
    return all_tokens

# ====== 7) 프롬프트 설정 & 길게 생성 & 저장 ======
prompt = ["[BOS]","COMPOSER_Mozart","PERIOD_Middle","GENRE_Sonata","KEY_Cmin",
          "TSig_4_4","TEMPO_112","BAR","POS_0","BAR","POS_0","BAR","POS_0"]

tokens_long = stitch_generate(
    prompt_tokens=prompt,
    total_new=512,     # 길이는 유지
    chunk_new=700,
    context=480,
    top_k=50, top_p=0.95, temp=0.9,
    stop_on_eos=False
)

out_mid_long = SPLIT_ROOT / "sample_beethoven_long_v1.mid"
detokenize_to_midi(tokens_long, out_mid_long)
print("Saved →", out_mid_long)

# ====== 8) (선택) 퍼플렉서티로 상태 확인 ======
print("Best val loss:", best_val, "| approx PPL:", math.exp(best_val))


[1/50] train 6.3653 | val 6.3803 | best 6.3803 | lr 0.000003
[2/50] train 6.3194 | val 6.3728 | best 6.3728 | lr 0.000006
[3/50] train 6.2248 | val 6.3799 | best 6.3728 | lr 0.000009
[4/50] train 6.0734 | val 6.3722 | best 6.3722 | lr 0.000012
[5/50] train 5.8805 | val 6.3636 | best 6.3636 | lr 0.000015
[6/50] train 5.6667 | val 6.3607 | best 6.3607 | lr 0.000018
[7/50] train 5.4336 | val 6.3488 | best 6.3488 | lr 0.000021
[8/50] train 5.2518 | val 6.3348 | best 6.3348 | lr 0.000024
[9/50] train 5.0532 | val 6.3199 | best 6.3199 | lr 0.000027
[10/50] train 4.8896 | val 6.3068 | best 6.3068 | lr 0.000030
[11/50] train 4.7120 | val 6.2826 | best 6.2826 | lr 0.000033
[12/50] train 4.5868 | val 6.2588 | best 6.2588 | lr 0.000036
[13/50] train 4.4825 | val 6.2335 | best 6.2335 | lr 0.000039
[14/50] train 4.3387 | val 6.2103 | best 6.2103 | lr 0.000042
[15/50] train 4.1987 | val 6.1779 | best 6.1779 | lr 0.000045
[16/50] train 4.0673 | val 6.1568 | best 6.1568 | lr 0.000048
[17/50] train 3.9

In [None]:
# ====== 0) 준비 ======
import math, torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import GradScaler, autocast  # AMP 최신 API
from torch.optim.swa_utils import AveragedModel  # SWA

VOCAB_SIZE = len(VOCAB)
PAD_ID = VOCAB["[PAD]"]
UNK_ID = VOCAB.get("[UNK]", None)
assert UNK_ID is not None, "[UNK] 토큰이 vocab에 필요합니다."

# ----- (선택) 토큰 드롭아웃: VEL_/DUR_ 소량 마스킹 -----
DROP_PROB = 0.05  # 3~7% 권장
VEL_IDS = {tid for tok, tid in VOCAB.items() if tok.startswith("VEL_")}
DUR_IDS = {tid for tok, tid in VOCAB.items() if tok.startswith("DUR_")}
DROP_SET = VEL_IDS | DUR_IDS

def token_dropout(batch_ids, drop_prob=DROP_PROB):
    # batch_ids: LongTensor [B, T]
    if drop_prob <= 0 or not DROP_SET:
        return batch_ids
    dev = batch_ids.device
    drop_ids = torch.tensor(list(DROP_SET), device=dev)
    mask = torch.rand_like(batch_ids.float()) < drop_prob
    sel = mask & torch.isin(batch_ids, drop_ids)
    return batch_ids.masked_fill(sel, PAD_ID)

# ====== EMA 유틸 ======
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {k: p.detach().clone() for k, p in model.state_dict().items()}
        self.backup = None
    @torch.no_grad()
    def update(self, model):
        for k, p in model.state_dict().items():
            self.shadow[k].mul_(self.decay).add_(p.detach(), alpha=1 - self.decay)
    def store(self, model):
        self.backup = {k: p.detach().clone() for k, p in model.state_dict().items()}
    def copy_to(self, model):
        model.load_state_dict(self.shadow, strict=False)
    def restore(self, model):
        if self.backup is not None:
            model.load_state_dict(self.backup, strict=False)
            self.backup = None

ema = EMA(model, decay=0.999)

# ====== 1) 손실, 옵티마이저, 스케줄러 ======
criterion = nn.CrossEntropyLoss(
    ignore_index=PAD_ID,
    label_smoothing=0.05,
)

opt = torch.optim.AdamW(
    model.parameters(),
    lr=3e-4,
    betas=(0.9, 0.95),
    weight_decay=0.03,
)

EPOCHS = 50
warmup_steps = 1000
total_steps  = max(1, len(train_dl) * EPOCHS)

def lr_lambda(step):
    if step < warmup_steps:
        return step / max(1, warmup_steps)
    progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
    progress = min(1.0, max(0.0, progress))
    return 0.5 * (1 + math.cos(math.pi * progress))

sched = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)

use_cuda_amp = (DEVICE.type == 'cuda')
scaler = GradScaler('cuda' if use_cuda_amp else 'cpu')

# ====== R-Drop 설정 ======
kl_factor = 1.0  # 0.5~2.0 사이 탐색 추천

def sym_kl(logits1, logits2, mask=None):
    # logits: [B, T, V]; mask: [B, T] (1=valid, 0=ignore)
    p = F.log_softmax(logits1, dim=-1)
    q = F.log_softmax(logits2, dim=-1)
    p_exp = p.exp(); q_exp = q.exp()
    kl = (p_exp * (p - q)).sum(-1) + (q_exp * (q - p)).sum(-1)  # [B, T]
    if mask is not None:
        kl = kl * mask
        denom = mask.sum().clamp_min(1)
        return kl.sum() / denom
    return kl.mean()

# ====== 2) 학습/검증 함수 (ACC=2 + R-Drop + TokenDropout) ======
ACC = 2  # 유효 배치 x2

def run_epoch(dl, train=True, grad_clip=1.0):
    model.train(train)
    total, n = 0.0, 0
    if train:
        opt.zero_grad(set_to_none=True)
    for step, (x, y) in enumerate(dl, 1):
        x, y = x.to(DEVICE), y.to(DEVICE)
        if train:
            x = token_dropout(x)  # 입력 노이즈(소량)로 강건성 ↑

        with torch.set_grad_enabled(train):
            if train:
                with autocast('cuda' if use_cuda_amp else 'cpu'):
                    # R-Drop: dropout 활성 상태에서 두 번 forward
                    logits1 = model(x)
                    logits2 = model(x)
                    ce1 = criterion(logits1.view(-1, VOCAB_SIZE), y.view(-1))
                    ce2 = criterion(logits2.view(-1, VOCAB_SIZE), y.view(-1))
                    y_mask = (y != PAD_ID).float()
                    kl = sym_kl(logits1, logits2, y_mask)
                    loss = 0.5*(ce1+ce2) + kl_factor*kl
                    if ACC > 1:
                        loss = loss / ACC
            else:
                with autocast('cuda' if use_cuda_amp else 'cpu'):
                    logits = model(x)
                    loss = criterion(logits.view(-1, VOCAB_SIZE), y.view(-1))

        if train:
            scaler.scale(loss).backward()
            if step % ACC == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                scaler.step(opt)
                scaler.update()
                sched.step()
                opt.zero_grad(set_to_none=True)
                ema.update(model)

        total += loss.item() * (ACC if train and ACC > 1 else 1.0)
        n += 1
    return total / max(1, n)

# ====== 3) SWA(막판 평균) + EMA 검증 + 베스트 저장 ======
best_val = float('inf')
ckpt_path = SPLIT_ROOT / "minigpt_best.pt"

swa_start_epoch = max(5, EPOCHS - 5)  # 마지막 5에폭 평균
swa_model = AveragedModel(model)

for ep in range(1, EPOCHS+1):
    tr = run_epoch(train_dl, train=True)

    # SWA 평균 누적(에폭 단위)
    if ep >= swa_start_epoch:
        swa_model.update_parameters(model)

    # EMA 가중치로 검증
    ema.store(model); ema.copy_to(model)
    vl = run_epoch(val_dl, train=False)
    ema.restore(model)

    if vl < best_val:
        best_val = vl
        torch.save(model.state_dict(), str(ckpt_path))
    print(f"[{ep}/{EPOCHS}] train {tr:.4f} | val {vl:.4f} | best {best_val:.4f} | lr {sched.get_last_lr()[0]:.6f}")

# ====== 4) 베스트 가중치 로드(안전) + (선택) SWA 평가/저장 ======
_ = model.load_state_dict(torch.load(str(ckpt_path), map_location=DEVICE))

# SWA 가중치로도 한 번 평가해보고, 더 좋으면 SWA로 저장
try:
    ema.store(model)
    model.load_state_dict(swa_model.state_dict(), strict=False)
    vl_swa = run_epoch(val_dl, train=False)
    print(f"[SWA] val {vl_swa:.4f}")
    if vl_swa < best_val:
        torch.save(model.state_dict(), str(SPLIT_ROOT / "minigpt_best_swa.pt"))
        print("SWA checkpoint saved (better than EMA-best).")
    ema.restore(model)
except Exception as e:
    print("SWA eval skipped:", e)

# ====== 5) 기본 generate (윈도우 자동 crop + 경량 반복 페널티/온도 스케줄) ======
def temp_schedule(step, t_max, t0=0.9, t1=0.8):
    a = min(1.0, max(0.0, step / max(1, t_max)))
    return t0*(1-a) + t1*a

@torch.no_grad()
def generate(prompt_tokens, max_new=700, top_k=50, top_p=0.95, temp=0.9, rep_penalty=1.05):
    model.eval()
    pos_len = getattr(getattr(model, 'pos', None), 'shape', [1, 100000, 0])[1]
    ids = torch.tensor([[VOCAB.get(t, UNK_ID) for t in prompt_tokens]], device=DEVICE)
    for step in range(max_new):
        ids_win = ids[:, -pos_len:] if ids.size(1) > pos_len else ids
        cur_temp = temp_schedule(step, max_new, t0=temp, t1=max(0.6, temp-0.1))
        logits = model(ids_win)[:, -1, :] / max(cur_temp, 1e-6)
        probs = torch.softmax(logits, dim=-1)[0]

        # 미세 반복 페널티(직전 토큰만 약하게)
        last_tok = ids[0, -1]
        probs[last_tok] = probs[last_tok] / rep_penalty

        if top_k > 0:
            topk = torch.topk(probs, top_k)
            mask = torch.ones_like(probs, dtype=torch.bool); mask[topk.indices] = False
            probs = probs.masked_fill(mask, 0)
        if top_p < 1.0:
            sprob, sidx = torch.sort(probs, descending=True)
            keep = torch.cumsum(sprob, dim=-1) <= top_p
            keep[0] = True
            mask = torch.ones_like(probs, dtype=torch.bool); mask[sidx[keep]] = False
            probs = probs.masked_fill(mask, 0)
        probs = probs / probs.sum()

        nxt = torch.multinomial(probs, 1)
        ids = torch.cat([ids, nxt.view(1,1)], dim=1)
        if nxt.item() == VOCAB["[EOS]"]:
            break
    return [IVOCAB[i.item()] for i in ids[0]]

# ====== 6) 길게 생성: 청크 스티칭(모델 수정 없음) ======
def stitch_generate(prompt_tokens, total_new=512, chunk_new=700, context=480,
                    top_k=50, top_p=0.95, temp=0.9, stop_on_eos=False):
    all_tokens = list(prompt_tokens); made = 0
    while made < total_new:
        this_new = min(chunk_new, total_new - made)
        cur_prompt = all_tokens[-context:] if len(all_tokens) > context else all_tokens
        chunk = generate(cur_prompt, max_new=this_new, top_k=top_k, top_p=top_p, temp=temp)
        new_part = chunk[len(cur_prompt):] if len(chunk) > len(cur_prompt) else []
        if stop_on_eos and ("[EOS]" in new_part):
            eos_idx = new_part.index("[EOS]"); all_tokens += new_part[:eos_idx]; break
        all_tokens += new_part; made += len(new_part)
        if len(new_part) == 0: break
    return all_tokens

# ====== 7) 프롬프트 설정 & 길게 생성 & 저장 ======
prompt = ["[BOS]","COMPOSER_Mozart","PERIOD_Middle","GENRE_Sonata","KEY_Cmin",
          "TSig_4_4","TEMPO_112","BAR","POS_0","BAR","POS_0","BAR","POS_0"]

tokens_long = stitch_generate(
    prompt_tokens=prompt,
    total_new=512,   # 길이는 유지
    chunk_new=700,
    context=480,
    top_k=50, top_p=0.95, temp=0.9,
    stop_on_eos=False
)

out_mid_long = SPLIT_ROOT / "sample_mozart_long_v2.mid"
detokenize_to_midi(tokens_long, out_mid_long)
print("Saved →", out_mid_long)

# ====== 8) (선택) 퍼플렉서티로 상태 확인 ======
print("Best val loss (EMA):", best_val, "| approx PPL:", math.exp(best_val))


[1/50] train 3.1018 | val 2.9482 | best 2.9482 | lr 0.000003
[2/50] train 3.0835 | val 2.9442 | best 2.9442 | lr 0.000006
[3/50] train 3.0633 | val 2.9568 | best 2.9442 | lr 0.000009
[4/50] train 3.0389 | val 2.9537 | best 2.9442 | lr 0.000012
[5/50] train 3.0170 | val 2.9465 | best 2.9442 | lr 0.000015
[6/50] train 3.0054 | val 2.9490 | best 2.9442 | lr 0.000018
[7/50] train 2.9993 | val 2.9420 | best 2.9420 | lr 0.000021
[8/50] train 2.9986 | val 2.9555 | best 2.9420 | lr 0.000024
[9/50] train 2.9865 | val 2.9406 | best 2.9406 | lr 0.000027
[10/50] train 2.9873 | val 2.9527 | best 2.9406 | lr 0.000030
[11/50] train 2.9921 | val 2.9548 | best 2.9406 | lr 0.000033
[12/50] train 2.9947 | val 2.9597 | best 2.9406 | lr 0.000036
[13/50] train 2.9844 | val 2.9508 | best 2.9406 | lr 0.000039
[14/50] train 2.9849 | val 2.9557 | best 2.9406 | lr 0.000042
[15/50] train 2.9740 | val 2.9500 | best 2.9406 | lr 0.000045
[16/50] train 2.9781 | val 2.9574 | best 2.9406 | lr 0.000048
[17/50] train 2.9

In [None]:
!pip install pyFluidSynth==1.3.4


---
