# LoRA Rank Ablation Study

对 `diagnosis_generator`（Qwen3-1.7B）做 LoRA rank 消融实验：
- 对比 rank = 8, 16, 32, 64, 128
- alpha = rank × 2（保持比值 2.0 不变）
- 其余超参固定（lr=1e-4, batch=2, grad_acc=8, epochs=2）

评估指标：
1. Training loss 曲线
2. JSON 格式正确率
3. 字段完整性（results/recommendations/recomm_short）
4. 概率合理性（3 个疾病概率之和 ≈ 1.0）
5. 推理速度
6. 训练显存占用

**注意**: 每个 rank 训练完后需要 **重启 Runtime** 释放显存，再训练下一个。
建议按 rank 从小到大依次训练。

In [None]:
# 0. Setup
import os
repo_dir = '/content/Intel_Health'
if not os.path.exists(repo_dir):
    !git clone https://github.com/DemonRain7/Intel_Health.git {repo_dir}
else:
    !git -C {repo_dir} pull
%cd {repo_dir}

# Colab 已预装 torch，不要重装，否则会循环导入报错
!pip -q install "transformers>=4.46" datasets peft accelerate bitsandbytes sentencepiece loguru

try:
    from google.colab import userdata
    os.environ["HF_TOKEN"] = userdata.get("HF_TOKEN")
    print("HF_TOKEN loaded from Colab Secrets")
except (ImportError, Exception):
    from huggingface_hub import login
    login()

In [None]:
# 1. Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

DRIVE_ROOT = '/content/drive/MyDrive/Code_Project/IntelHealth'
SFT_DATA_DIR = f'{DRIVE_ROOT}/datasets/agent_sft/diagnosis_generator'
ABLATION_OUTPUT_ROOT = f'{DRIVE_ROOT}/models/adapters/ablation'
ABLATION_MERGED_ROOT = f'{DRIVE_ROOT}/models/merged/ablation'

os.makedirs(ABLATION_OUTPUT_ROOT, exist_ok=True)
os.makedirs(ABLATION_MERGED_ROOT, exist_ok=True)

# 检查训练数据
data_files = [f for f in os.listdir(SFT_DATA_DIR) if f.endswith('.jsonl')]
print(f'SFT data dir: {SFT_DATA_DIR}')
print(f'Data files: {data_files}')
for f in data_files:
    path = os.path.join(SFT_DATA_DIR, f)
    with open(path) as fp:
        count = sum(1 for _ in fp)
    print(f'  {f}: {count} samples')

## Part A: Training（每个 rank 跑一次，跑完重启 Runtime）

**修改下面的 `CURRENT_RANK` 后运行此 cell 和下一个 cell。**

训练顺序建议：8 → 16 → 32 → 64 → 128（每个跑完后重启 Runtime 再跑下一个）

In [None]:
# 2. 选择当前要训练的 rank
# ============================
# 每次只训练一个 rank，训练完重启 Runtime，修改此值再跑
CURRENT_RANK = 64  # <-- 修改这里: 8, 16, 32, 64, 128
# ============================

CURRENT_ALPHA = CURRENT_RANK * 2
MODEL_NAME = 'Qwen/Qwen3-1.7B'
OUTPUT_DIR = f'{ABLATION_OUTPUT_ROOT}/rank{CURRENT_RANK}'

# 检查是否已经训练过
if os.path.isdir(OUTPUT_DIR) and os.listdir(OUTPUT_DIR):
    print(f'WARNING: {OUTPUT_DIR} already exists and is not empty!')
    print('Contents:', os.listdir(OUTPUT_DIR))
    print('If you want to retrain, delete the directory first.')
else:
    os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f'\n=== Ablation Config ===')
print(f'Rank:       {CURRENT_RANK}')
print(f'Alpha:      {CURRENT_ALPHA}')
print(f'Alpha/Rank: {CURRENT_ALPHA / CURRENT_RANK}')
print(f'Model:      {MODEL_NAME}')
print(f'Data:       {SFT_DATA_DIR}')
print(f'Output:     {OUTPUT_DIR}')

In [None]:
# 3. 训练
import subprocess, shlex, sys, time as _time, os

cmd = [
    'python', 'training/supervised_finetuning.py',
    '--model_name_or_path',          MODEL_NAME,
    '--tokenizer_name_or_path',      MODEL_NAME,
    '--train_file_dir',              SFT_DATA_DIR,
    '--output_dir',                  OUTPUT_DIR,
    '--template_name',               'qwen',
    '--do_train',
    '--fp16',
    '--gradient_checkpointing',
    '--per_device_train_batch_size', '2',
    '--gradient_accumulation_steps',  '8',
    '--num_train_epochs',            '2',
    '--learning_rate',               '1e-4',
    '--lora_rank',                   str(CURRENT_RANK),
    '--lora_alpha',                  str(CURRENT_ALPHA),
    '--lora_dropout',                '0.05',
    '--model_max_length',            '256',
    '--logging_steps',               '5',
    '--save_strategy',               'epoch',
]

print(f'Training rank={CURRENT_RANK}, alpha={CURRENT_ALPHA}...')
print(f'Command: {" ".join(shlex.quote(x) for x in cmd)}')
print('=' * 60)

# ????????????????? dapt/jsonl
train_jsonl_files = sorted([f for f in os.listdir(SFT_DATA_DIR) if f.endswith('.jsonl')])
print(f'Training JSONL files ({len(train_jsonl_files)}): {train_jsonl_files}')
if any('dapt' in f.lower() for f in train_jsonl_files):
    print('[WARN] ??? dapt ??????? rank ablation ????')

_t0 = _time.time()
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
for line in proc.stdout:
    print(line, end='', flush=True)
ret = proc.wait()
_elapsed = _time.time() - _t0

if ret != 0:
    raise RuntimeError(f'Training failed with exit code {ret}')
print(f'\nTraining rank={CURRENT_RANK} complete in {_elapsed/60:.1f} min!')
print(f'Adapter saved to: {OUTPUT_DIR}')
child_gpu_stats_path = f'{OUTPUT_DIR}/gpu_stats.json'
if os.path.exists(child_gpu_stats_path):
    print(f'Child GPU stats found: {child_gpu_stats_path}')
else:
    print(f'[WARN] Child GPU stats not found: {child_gpu_stats_path}')

In [None]:
# 4. ????? GPU ????????????
import json, os

child_stats_path = f'{OUTPUT_DIR}/gpu_stats.json'
gpu_stats = {
    'rank': CURRENT_RANK,
    'alpha': CURRENT_ALPHA,
    'training_time_min': round(_elapsed / 60, 1),
}

if os.path.exists(child_stats_path):
    with open(child_stats_path) as f:
        child_stats = json.load(f)
    gpu_stats.update({
        'gpu_name': child_stats.get('gpu_name'),
        'total_gb': child_stats.get('total_gb'),
        'peak_allocated_gb': child_stats.get('peak_allocated_gb'),
        'peak_reserved_gb': child_stats.get('peak_reserved_gb'),
        'stats_source': 'training/supervised_finetuning.py',
    })
else:
    print(f'[WARN] ??? GPU ?????: {child_stats_path}')
    gpu_stats['stats_source'] = 'fallback_no_child_stats'

stats_path = f'{OUTPUT_DIR}/ablation_gpu_stats.json'
with open(stats_path, 'w') as f:
    json.dump(gpu_stats, f, indent=2)
print(f'GPU stats saved to {stats_path}')
print(json.dumps(gpu_stats, indent=2))

print(f'\n>>> ?????? Runtime ? Restart runtime????? CURRENT_RANK ?????')
print(f'>>> ?? 5 ? rank ???????? Part B ? merge + ???')

---

## Part B: Merge + Evaluate（全部 rank 训练完后运行）

确保 5 个 rank（8/16/32/64/128）都训练完毕，adapter 保存在 Google Drive 上。

以下步骤：
1. 逐个 merge adapter 到 base model
2. 对每个 merged 模型做推理评估
3. 生成对比报告和可视化

In [None]:
# 5. 检查哪些 rank 已训练完毕
import os, json

DRIVE_ROOT = '/content/drive/MyDrive/Code_Project/IntelHealth'
ABLATION_OUTPUT_ROOT = f'{DRIVE_ROOT}/models/adapters/ablation'
ABLATION_MERGED_ROOT = f'{DRIVE_ROOT}/models/merged/ablation'

RANKS = [8, 16, 32, 64, 128]
MODEL_NAME = 'Qwen/Qwen3-1.7B'

print('Checking trained adapters...')
trained_ranks = []
for rank in RANKS:
    adapter_dir = f'{ABLATION_OUTPUT_ROOT}/rank{rank}'
    if os.path.isdir(adapter_dir):
        has_adapter = any(f.startswith('adapter') for f in os.listdir(adapter_dir))
        status = 'READY' if has_adapter else 'NO ADAPTER'
        if has_adapter:
            trained_ranks.append(rank)
    else:
        status = 'NOT FOUND'
    print(f'  rank={rank:>3}: {status}  ({adapter_dir})')

print(f'\nReady for merge: {trained_ranks}')
if len(trained_ranks) < len(RANKS):
    missing = set(RANKS) - set(trained_ranks)
    print(f'WARNING: Missing ranks: {missing}. Go back to Part A to train them.')

In [None]:
# 6. Merge adapters??? rank?
import os, shutil, subprocess

FORCE_REMERGE = True  # True: ????? merge??????? merged ??

for rank in trained_ranks:
    adapter_dir = f'{ABLATION_OUTPUT_ROOT}/rank{rank}'
    merged_dir = f'{ABLATION_MERGED_ROOT}/rank{rank}'

    if FORCE_REMERGE and os.path.isdir(merged_dir):
        shutil.rmtree(merged_dir)

    # ????? merge
    if os.path.isdir(merged_dir) and any(
        f.endswith('.safetensors') or f.endswith('.bin')
        for f in os.listdir(merged_dir)
    ):
        print(f'rank={rank}: already merged, skipping')
        continue

    print(f'\nMerging rank={rank}...')
    cmd = [
        'python', 'training/merge_peft_adapter.py',
        '--base_model', MODEL_NAME,
        '--lora_model', adapter_dir,
        '--output_dir', merged_dir,
    ]
    ret = subprocess.run(cmd, capture_output=True, text=True)
    if ret.returncode != 0:
        print(f'  ERROR: {ret.stderr[-500:]}')
    else:
        print(f'  Merged to {merged_dir}')

print('\nMerge complete!')

In [None]:
# 7. ???????????
import json, random

SFT_DATA_DIR = f'{DRIVE_ROOT}/datasets/agent_sft/diagnosis_generator'

# ????? diagnosis_sft.jsonl????? diagnosis_dapt.jsonl
preferred_file = os.path.join(SFT_DATA_DIR, 'diagnosis_sft.jsonl')
if os.path.exists(preferred_file):
    eval_files = [preferred_file]
else:
    eval_files = [
        os.path.join(SFT_DATA_DIR, f)
        for f in sorted(os.listdir(SFT_DATA_DIR))
        if f.endswith('.jsonl')
    ]

print(f'Eval JSONL files ({len(eval_files)}): {[os.path.basename(f) for f in eval_files]}')
if any('dapt' in os.path.basename(f).lower() for f in eval_files):
    print('[WARN] ?????? dapt ??????????? JSON ?????')

# ???????? 20 ?????
all_samples = []
for fp in eval_files:
    with open(fp) as f:
        for line in f:
            line = line.strip()
            if line:
                all_samples.append(json.loads(line))

print(f'Total candidate samples: {len(all_samples)}')

# ??? 20 ???? user prompt ??????
random.seed(42)
test_samples = random.sample(all_samples, min(20, len(all_samples)))

test_inputs = []
for sample in test_samples:
    convs = sample.get('conversations', [])
    user_msg = next((c['value'] for c in convs if c['from'] == 'human'), '')
    if user_msg:
        test_inputs.append(user_msg)

print(f'Test inputs prepared: {len(test_inputs)}')
print(f'Sample input (truncated): {test_inputs[0][:200]}...')

In [None]:
# 8. ??????? rank ?????????
import gc, time, re, json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def evaluate_model(model_path, test_inputs, max_new_tokens=1024):
    """????????????????????"""
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(
        model_path, torch_dtype=torch.float16, device_map='cuda', trust_remote_code=True
    )
    model.eval()

    mem_loaded = torch.cuda.memory_allocated() / 1024**3

    eos_id = tokenizer.eos_token_id
    pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else eos_id

    results = []
    total_time = 0
    total_tokens = 0

    for user_input in test_inputs:
        messages = [
            {'role': 'system', 'content': '??????????????????JSON????JSON?????????'},
            {'role': 'user', 'content': user_input},
        ]
        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(text, return_tensors='pt')
        inputs = {k: v.to('cuda') for k, v in inputs.items()}
        input_len = inputs['input_ids'].shape[1]

        t0 = time.time()
        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                num_beams=1,
                eos_token_id=eos_id,
                pad_token_id=pad_id,
            )
        dt = time.time() - t0
        total_time += dt

        gen_tokens = output_ids.shape[1] - input_len
        total_tokens += gen_tokens

        output_text = tokenizer.decode(output_ids[0][input_len:], skip_special_tokens=True)

        # ?? <think> ??? markdown ??
        cleaned = re.sub(r'<think>[\s\S]*?</think>\s*', '', output_text)
        cleaned = re.sub(r'```(?:json)?|```', '', cleaned).strip()
        # ???? JSON
        json_match = re.search(r'\{[\s\S]*\}', cleaned)

        result = {
            'raw_output': output_text[:500],
            'gen_tokens': gen_tokens,
            'hit_max_new_tokens': gen_tokens >= max_new_tokens,
            'time_s': round(dt, 2),
            'json_valid': False,
            'has_results': False,
            'has_recommendations': False,
            'has_recomm_short': False,
            'results_count': 0,
            'prob_sum': 0.0,
        }

        if json_match:
            try:
                parsed = json.loads(json_match.group())
                result['json_valid'] = True

                if isinstance(parsed.get('results'), list) and len(parsed['results']) > 0:
                    result['has_results'] = True
                    result['results_count'] = len(parsed['results'])
                    result['prob_sum'] = sum(
                        r.get('probability', 0) for r in parsed['results']
                        if isinstance(r, dict)
                    )

                if isinstance(parsed.get('recommendations'), list) and len(parsed['recommendations']) > 0:
                    result['has_recommendations'] = True

                if isinstance(parsed.get('recomm_short'), list) and len(parsed['recomm_short']) > 0:
                    result['has_recomm_short'] = True

            except json.JSONDecodeError:
                pass

        results.append(result)

    mem_peak = torch.cuda.max_memory_allocated() / 1024**3

    # ??
    del model, tokenizer
    gc.collect()
    torch.cuda.empty_cache()

    # ??
    n = len(results)
    json_valid_rate = sum(1 for r in results if r['json_valid']) / n
    has_results_rate = sum(1 for r in results if r['has_results']) / n
    has_recs_rate = sum(1 for r in results if r['has_recommendations']) / n
    has_short_rate = sum(1 for r in results if r['has_recomm_short']) / n
    results_3_rate = sum(1 for r in results if r['results_count'] >= 3) / n
    hit_max_new_tokens_rate = sum(1 for r in results if r['hit_max_new_tokens']) / n

    prob_valid = [r for r in results if r['has_results'] and r['prob_sum'] > 0]
    prob_close_to_1 = sum(1 for r in prob_valid if abs(r['prob_sum'] - 1.0) <= 0.05) / max(len(prob_valid), 1)

    return {
        'json_valid_rate': round(json_valid_rate, 3),
        'has_results_rate': round(has_results_rate, 3),
        'results_3_rate': round(results_3_rate, 3),
        'has_recs_rate': round(has_recs_rate, 3),
        'has_short_rate': round(has_short_rate, 3),
        'prob_close_to_1_rate': round(prob_close_to_1, 3),
        'hit_max_new_tokens_rate': round(hit_max_new_tokens_rate, 3),
        'avg_time_s': round(total_time / n, 2),
        'avg_tokens': round(total_tokens / n, 1),
        'tok_per_s': round(total_tokens / max(total_time, 0.01), 1),
        'vram_loaded_gb': round(mem_loaded, 2),
        'vram_peak_gb': round(mem_peak, 2),
        'details': results,
    }

In [None]:
# 9. ??????? rank?
eval_results = {}

for rank in trained_ranks:
    merged_dir = f'{ABLATION_MERGED_ROOT}/rank{rank}'
    if not os.path.isdir(merged_dir):
        print(f'rank={rank}: merged model not found, skipping')
        continue

    print('\n' + '=' * 50)
    print(f'Evaluating rank={rank}')
    print('=' * 50)

    metrics = evaluate_model(merged_dir, test_inputs, max_new_tokens=1024)
    metrics['rank'] = rank
    metrics['alpha'] = rank * 2
    eval_results[rank] = metrics

    print(f'  JSON valid:   {metrics["json_valid_rate"]*100:.0f}%')
    print(f'  Has results:  {metrics["has_results_rate"]*100:.0f}%')
    print(f'  Results?3:    {metrics["results_3_rate"]*100:.0f}%')
    print(f'  Has recs:     {metrics["has_recs_rate"]*100:.0f}%')
    print(f'  Prob?1.0:     {metrics["prob_close_to_1_rate"]*100:.0f}%')
    print(f'  Hit max tok:  {metrics["hit_max_new_tokens_rate"]*100:.0f}%')
    print(f'  Avg time:     {metrics["avg_time_s"]}s ({metrics["tok_per_s"]} tok/s)')
    print(f'  VRAM peak:    {metrics["vram_peak_gb"]} GB')

print('\n' + '=' * 50)
print(f'Evaluation complete! ({len(eval_results)} ranks)')

In [None]:
# 10. 汇总表格
import pandas as pd

rows = []
for rank in sorted(eval_results.keys()):
    m = eval_results[rank]
    # 读取训练时的 GPU stats
    train_gpu = {}
    gpu_stats_path = f'{ABLATION_OUTPUT_ROOT}/rank{rank}/ablation_gpu_stats.json'
    if os.path.exists(gpu_stats_path):
        with open(gpu_stats_path) as f:
            train_gpu = json.load(f)

    # 读取 training loss
    final_loss = None
    state_path = f'{ABLATION_OUTPUT_ROOT}/rank{rank}/trainer_state.json'
    if os.path.exists(state_path):
        with open(state_path) as f:
            state = json.load(f)
        losses = [e['loss'] for e in state.get('log_history', []) if 'loss' in e]
        if losses:
            final_loss = losses[-1]

    rows.append({
        'Rank': rank,
        'Alpha': rank * 2,
        'Final Loss': round(final_loss, 4) if final_loss else 'N/A',
        'JSON Valid%': f'{m["json_valid_rate"]*100:.0f}%',
        'Results≥3%': f'{m["results_3_rate"]*100:.0f}%',
        'Has Recs%': f'{m["has_recs_rate"]*100:.0f}%',
        'Prob≈1.0%': f'{m["prob_close_to_1_rate"]*100:.0f}%',
        'Infer(s)': m['avg_time_s'],
        'Tok/s': m['tok_per_s'],
        'Train VRAM(GB)': train_gpu.get('peak_allocated_gb', 'N/A'),
        'Infer VRAM(GB)': m['vram_peak_gb'],
        'Train Time(min)': train_gpu.get('training_time_min', 'N/A'),
    })

df = pd.DataFrame(rows)
print('\n=== LoRA Rank Ablation Results ===')
display(df)

In [None]:
# 11. 可视化
import matplotlib.pyplot as plt
import numpy as np

ranks = sorted(eval_results.keys())
x = np.arange(len(ranks))
labels = [str(r) for r in ranks]

fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# (1) JSON Valid Rate
ax = axes[0, 0]
vals = [eval_results[r]['json_valid_rate'] * 100 for r in ranks]
bars = ax.bar(x, vals, color='#4CAF50')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.set_xlabel('LoRA Rank')
ax.set_ylabel('%')
ax.set_title('JSON Format Valid Rate')
ax.set_ylim(0, 105)
for bar, v in zip(bars, vals):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, f'{v:.0f}%', ha='center', fontsize=10)

# (2) Results ≥ 3 Rate
ax = axes[0, 1]
vals = [eval_results[r]['results_3_rate'] * 100 for r in ranks]
bars = ax.bar(x, vals, color='#2196F3')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.set_xlabel('LoRA Rank')
ax.set_ylabel('%')
ax.set_title('Results ≥ 3 Conditions Rate')
ax.set_ylim(0, 105)
for bar, v in zip(bars, vals):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, f'{v:.0f}%', ha='center', fontsize=10)

# (3) Prob ≈ 1.0 Rate
ax = axes[0, 2]
vals = [eval_results[r]['prob_close_to_1_rate'] * 100 for r in ranks]
bars = ax.bar(x, vals, color='#FF9800')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.set_xlabel('LoRA Rank')
ax.set_ylabel('%')
ax.set_title('Probability Sum ≈ 1.0 Rate (±0.05)')
ax.set_ylim(0, 105)
for bar, v in zip(bars, vals):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, f'{v:.0f}%', ha='center', fontsize=10)

# (4) Training Loss
ax = axes[1, 0]
for rank in ranks:
    state_path = f'{ABLATION_OUTPUT_ROOT}/rank{rank}/trainer_state.json'
    if os.path.exists(state_path):
        with open(state_path) as f:
            state = json.load(f)
        logs = [e for e in state.get('log_history', []) if 'loss' in e]
        if logs:
            steps = [e['step'] for e in logs]
            losses = [e['loss'] for e in logs]
            ax.plot(steps, losses, linewidth=1.5, label=f'rank={rank}')
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.set_title('Training Loss Curves')
ax.legend()
ax.grid(True, alpha=0.3)

# (5) Training VRAM
ax = axes[1, 1]
train_vrams = []
for rank in ranks:
    stats_path = f'{ABLATION_OUTPUT_ROOT}/rank{rank}/ablation_gpu_stats.json'
    if os.path.exists(stats_path):
        with open(stats_path) as f:
            stats = json.load(f)
        train_vrams.append(stats.get('peak_allocated_gb', 0))
    else:
        train_vrams.append(0)
bars = ax.bar(x, train_vrams, color='#9C27B0')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.set_xlabel('LoRA Rank')
ax.set_ylabel('GB')
ax.set_title('Training Peak VRAM (GB)')
for bar, v in zip(bars, train_vrams):
    if v > 0:
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1, f'{v:.1f}', ha='center', fontsize=10)

# (6) Inference Speed
ax = axes[1, 2]
vals = [eval_results[r]['tok_per_s'] for r in ranks]
bars = ax.bar(x, vals, color='#F44336')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.set_xlabel('LoRA Rank')
ax.set_ylabel('Tokens/s')
ax.set_title('Inference Throughput (tok/s)')
for bar, v in zip(bars, vals):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, f'{v:.0f}', ha='center', fontsize=10)

plt.suptitle('LoRA Rank Ablation Study — diagnosis_generator (Qwen3-1.7B)', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig('ablation_rank_results.png', dpi=150, bbox_inches='tight')
plt.show()
print('Chart saved to ablation_rank_results.png')

In [None]:
# 12. 保存完整报告到 Google Drive
import json, shutil

report_dir = f'{DRIVE_ROOT}/docs/ablation'
os.makedirs(report_dir, exist_ok=True)

# JSON 原始数据
report_data = {}
for rank, m in eval_results.items():
    report_data[rank] = {k: v for k, v in m.items() if k != 'details'}

json_path = f'{report_dir}/ablation_results.json'
with open(json_path, 'w') as f:
    json.dump(report_data, f, indent=2, ensure_ascii=False)
print(f'JSON saved to {json_path}')

# Markdown 报告
md_lines = [
    '# LoRA Rank Ablation Results',
    '',
    '## 实验配置',
    f'- Agent: diagnosis_generator (Qwen3-1.7B)',
    f'- Data: ~800 SFT samples',
    f'- Epochs: 2, LR: 1e-4, Batch: 2×8=16',
    f'- Alpha = Rank × 2 (固定比值 2.0)',
    f'- 测试样本: {len(test_inputs)} 条',
    '',
    '## 结果对比',
    '',
    '| Rank | Alpha | Final Loss | JSON Valid | Results≥3 | Prob≈1.0 | Tok/s | Train VRAM(GB) |',
    '|------|-------|-----------|-----------|----------|---------|-------|---------------|',
]

for rank in sorted(eval_results.keys()):
    m = eval_results[rank]
    # Read final loss
    fl = 'N/A'
    sp = f'{ABLATION_OUTPUT_ROOT}/rank{rank}/trainer_state.json'
    if os.path.exists(sp):
        with open(sp) as f:
            st = json.load(f)
        ls = [e['loss'] for e in st.get('log_history', []) if 'loss' in e]
        if ls:
            fl = f'{ls[-1]:.4f}'
    # Read train VRAM
    tv = 'N/A'
    gp = f'{ABLATION_OUTPUT_ROOT}/rank{rank}/ablation_gpu_stats.json'
    if os.path.exists(gp):
        with open(gp) as f:
            gs = json.load(f)
        tv = f'{gs.get("peak_allocated_gb", "?")}'

    md_lines.append(
        f'| {rank} | {rank*2} | {fl} | '
        f'{m["json_valid_rate"]*100:.0f}% | '
        f'{m["results_3_rate"]*100:.0f}% | '
        f'{m["prob_close_to_1_rate"]*100:.0f}% | '
        f'{m["tok_per_s"]} | {tv} |'
    )

md_path = f'{report_dir}/ablation_results.md'
with open(md_path, 'w') as f:
    f.write('\n'.join(md_lines))
print(f'Markdown saved to {md_path}')


# Copy chart
if os.path.exists('ablation_rank_results.png'):
    shutil.copy('ablation_rank_results.png', f'{report_dir}/ablation_rank_results.png')
    print(f'Chart copied to {report_dir}/')

print('\nAll done! Results saved to Google Drive.')