# Qwen3-TTS Marlene - Batch/Epoch 비교 테스트

**테스트 목표:** batch 8/16, epoch 10/50 비교

| 모델 | batch | epoch |
|------|-------|-------|
| A | 8 | 10 |
| B | 8 | 50 |
| C | 16 | 10 |
| D | 16 | 50 |

**방법:** batch8 epoch50, batch16 epoch50 두 번만 학습하고 중간 체크포인트 활용

## 셀 1: 환경 설정

In [None]:
!apt-get install -y sox
!pip install -q soundfile librosa tqdm huggingface_hub
!pip install flash-attn --no-build-isolation
!git clone https://github.com/QwenLM/Qwen3-TTS.git /content/Qwen3-TTS-repo
%cd /content/Qwen3-TTS-repo
!pip install -e .
print("환경 설정 완료!")

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# 패치
sft_path = "/content/Qwen3-TTS-repo/finetuning/sft_12hz.py"
with open(sft_path, 'r') as f:
    code = f.read()
code = code.replace('log_with="tensorboard"', 'log_with=None')
with open(sft_path, 'w') as f:
    f.write(code)
print("패치 완료!")

In [None]:
from huggingface_hub import snapshot_download
import os, json, librosa, soundfile as sf
from tqdm import tqdm

# 모델 다운로드
model_path = snapshot_download(
    "Qwen/Qwen3-TTS-12Hz-1.7B-Base", 
    local_dir="/content/qwen3_tts_model"
)

# 마를렌 오디오 24kHz 변환
AUDIO_DIR = "/content/drive/MyDrive/marlene_tts_data/audio"
OUTPUT_DIR = "/content/audio_24k_marlene"
os.makedirs(OUTPUT_DIR, exist_ok=True)

files = [f for f in os.listdir(AUDIO_DIR) if f.endswith('.wav')]
for f in tqdm(files):
    try:
        audio, _ = librosa.load(os.path.join(AUDIO_DIR, f), sr=24000)
        sf.write(os.path.join(OUTPUT_DIR, f), audio, 24000)
    except:
        pass

print(f"오디오 변환 완료: {len(os.listdir(OUTPUT_DIR))}개")

In [None]:
# JSONL 생성
with open("/content/drive/MyDrive/marlene_tts_data/marlene_finetune.jsonl", 'r', encoding='utf-8') as f:
    data = [json.loads(line) for line in f]

valid_files = set(os.listdir(OUTPUT_DIR))
REF_AUDIO = "/content/audio_24k_marlene/Marlene_airSupply_1_01.wav"

filtered = []
for item in data:
    filename = item['audio']
    if filename in valid_files:
        filtered.append({
            "audio": f"/content/audio_24k_marlene/{filename}",
            "text": item['text'],
            "ref_audio": REF_AUDIO
        })

with open("/content/marlene_24k.jsonl", 'w', encoding='utf-8') as f:
    for item in filtered:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')

print(f"JSONL 준비 완료: {len(filtered)}개 샘플")

In [None]:
# 토큰화
!python /content/Qwen3-TTS-repo/finetuning/prepare_data.py \
    --device cuda:0 \
    --tokenizer_model_path Qwen/Qwen3-TTS-Tokenizer-12Hz \
    --input_jsonl /content/marlene_24k.jsonl \
    --output_jsonl /content/marlene_tokenized.jsonl

---
# Batch 8, Epoch 50 학습
checkpoint-epoch-9 = epoch 10, checkpoint-epoch-49 = epoch 50

In [None]:
!rm -rf /content/drive/MyDrive/marlene_batch8_epoch50
!python /content/Qwen3-TTS-repo/finetuning/sft_12hz.py \
    --init_model_path /content/qwen3_tts_model \
    --train_jsonl /content/marlene_tokenized.jsonl \
    --output_model_path /content/drive/MyDrive/marlene_batch8_epoch50 \
    --batch_size 8 \
    --lr 1e-6 \
    --num_epochs 50 \
    --speaker_name marlene

---
# Batch 16, Epoch 50 학습
checkpoint-epoch-9 = epoch 10, checkpoint-epoch-49 = epoch 50

In [None]:
!rm -rf /content/drive/MyDrive/marlene_batch16_epoch50
!python /content/Qwen3-TTS-repo/finetuning/sft_12hz.py \
    --init_model_path /content/qwen3_tts_model \
    --train_jsonl /content/marlene_tokenized.jsonl \
    --output_model_path /content/drive/MyDrive/marlene_batch16_epoch50 \
    --batch_size 16 \
    --lr 1e-6 \
    --num_epochs 50 \
    --speaker_name marlene

---
# 비교 테스트 (4개 모델)

In [None]:
import torch
from IPython.display import Audio, display
from qwen_tts import Qwen3TTSModel

TEST_TEXTS = [
    "뭐야, 왜 이렇게 늦은 거야?",
    "드디어 해냈다! 이겼어!",
    "정말 고마워, 잊지 않을게.",
]

MODELS = [
    ("batch8-epoch10", "/content/drive/MyDrive/marlene_batch8_epoch50/checkpoint-epoch-9"),
    ("batch8-epoch50", "/content/drive/MyDrive/marlene_batch8_epoch50/checkpoint-epoch-49"),
    ("batch16-epoch10", "/content/drive/MyDrive/marlene_batch16_epoch50/checkpoint-epoch-9"),
    ("batch16-epoch50", "/content/drive/MyDrive/marlene_batch16_epoch50/checkpoint-epoch-49"),
]

for model_name, model_path in MODELS:
    print("=" * 50)
    print(model_name)
    print("=" * 50)
    
    tts = Qwen3TTSModel.from_pretrained(
        model_path,
        device_map="cuda:0",
        dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
    )
    
    for i, text in enumerate(TEST_TEXTS):
        print(f"\n[{i+1}] {text}")
        wavs, sr = tts.generate_custom_voice(text=text, speaker="marlene")
        display(Audio(wavs[0], rate=sr))
    
    del tts
    torch.cuda.empty_cache()
    print("\n")

---
# 결과 정리

| 모델 | batch | epoch | 결과 |
|------|-------|-------|------|
| A | 8 | 10 | |
| B | 8 | 50 | |
| C | 16 | 10 | |
| D | 16 | 50 | |