<a href="https://colab.research.google.com/github/ArnyWu/DeepGenerativeModels_HW5/blob/main/HW9.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets transformers peft bitsandbytes scikit-learn seaborn matplotlib torch accelerate

Collecting bitsandbytes
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl (59.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.48.2


In [None]:
!pip uninstall bitsandbytes
!pip install -U bitsandbytes

Found existing installation: bitsandbytes 0.48.2
Uninstalling bitsandbytes-0.48.2:
  Would remove:
    /usr/local/lib/python3.12/dist-packages/bitsandbytes-0.48.2.dist-info/*
    /usr/local/lib/python3.12/dist-packages/bitsandbytes/*
Proceed (Y/n)? y
  Successfully uninstalled bitsandbytes-0.48.2
Collecting bitsandbytes
  Using cached bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Using cached bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl (59.4 MB)
Installing collected packages: bitsandbytes
Successfully installed bitsandbytes-0.48.2


In [None]:
import warnings
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModelForCausalLM,
    pipeline,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, TaskType
from sklearn.metrics import (
    f1_score,
    roc_auc_score,
    average_precision_score,
    confusion_matrix,
    ConfusionMatrixDisplay
)
from sklearn.preprocessing import label_binarize
import re # 用於解析

# 抑制不必要的警告
warnings.filterwarnings("ignore")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# --- 參數設定 ---
# 使用的模型：Qwen、Gemma
GEMMA_MODEL_NAME = "google/gemma-2-2b" # For fine-tuning (Sequence Classification)
GEMMA_CHAT_MODEL_NAME = "google/gemma-2-2b-it" # For zero/few-shot inference
NUM_LABELS = 3
LABELS = ["low_risk", "mid_risk", "high_risk"]
ID2LABEL = {0: "low_risk", 1: "mid_risk", 2: "high_risk"}
LABEL2ID = {"low_risk": 0, "mid_risk": 1, "high_risk": 2}
TRAIN_EPOCHS = 1 # 保持 1 個 epoch 以便快速執行

# 檢查 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if not torch.cuda.is_available():
    print("Warning: GPU not available. QLoRA fine-tuning requires a GPU.")

Using device: cuda


In [None]:
# --- 1. 資料預處理 ---
print("\n--- Step 1: Preprocessing ---")

# 載入資料集
dataset = load_dataset("dair-ai/emotion")
print("Dataset loaded:")
print(dataset)

# 定義風險映射函式
def map_emotion_to_risk(example):
    emotion = example['label']
    if emotion in [1, 2, 5]:  # joy, love, surprise
        example['risk_label'] = 0  # low_risk
    elif emotion in [3, 4]:  # anger, fear
        example['risk_label'] = 1  # mid_risk
    elif emotion == 0:  # sadness
        example['risk_label'] = 2  # high_risk
    else:
        example['risk_label'] = -1
    return example

dataset = dataset.map(map_emotion_to_risk)
dataset = dataset.filter(lambda x: x['risk_label'] != -1)
dataset = dataset.rename_column("risk_label", "labels")

print("Dataset after risk mapping:")
print(dataset['train'][0])

# 載入 Tokenizer (用於 QLoRA) - 使用 Gemma 模型名稱
tokenizer = AutoTokenizer.from_pretrained(GEMMA_MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Tokenize 函式
def tokenize_function(examples):
    return tokenizer(examples['text'], truncation=True, padding=False, max_length=512)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["text", "label"])
tokenized_datasets.set_format("torch")

# 準備完整的測試集 (用於所有評估)
test_dataset_full_tokenized = tokenized_datasets["test"]
original_test_set_full = dataset['test'] # 包含原始 text
y_true_test_full = original_test_set_full['labels'] # 統一的真實標籤

print(f"Full test set size: {len(original_test_set_full)}")

# --- 評估用的輔助函式 ---
def evaluate_model(y_true, y_pred, y_probs, title_prefix=""):
    print(f"\n--- {title_prefix} Evaluation ---")

    # F1 Score
    f1 = f1_score(y_true, y_pred, average='weighted')
    print(f"F1 Score (Weighted): {f1:.4f}")

    # Confusion Matrix
    try:
        cm = confusion_matrix(y_true, y_pred, labels=[0, 1, 2])
        print("Confusion Matrix:")
        print(cm)

        # 繪製 CM
        fig, ax = plt.subplots()
        disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=LABELS)
        disp.plot(ax=ax, cmap='Blues')
        plt.title(f"{title_prefix} Confusion Matrix") # 更新標題
        plt.xticks(rotation=45)
        plt.tight_layout()
        cm_filename = f"plot_cm_{title_prefix.lower().replace(' ', '_').replace('-', '_')}.png" # 更新檔案名稱
        plt.savefig(cm_filename)
        print(f"Saved confusion matrix to {cm_filename}")
        plt.close(fig)
    except Exception as e:
        print(f"Error plotting confusion matrix: {e}")

    # 需要機率的指標
    if y_probs is not None and len(np.unique(y_true)) > 1:
        try:
            # 標籤二元化
            y_true_bin = label_binarize(y_true, classes=[0, 1, 2])

            if y_true_bin.shape[1] == 3 and y_probs.shape[1] != 3:
                print(f"Warning: Probability shape mismatch. Adjusting.")
                temp_probs = np.zeros((y_probs.shape[0], 3))
                if y_probs.shape[1] < 3:
                     temp_probs[:, :y_probs.shape[1]] = y_probs
                y_probs = temp_probs

            # AUROC
            auroc = roc_auc_score(y_true_bin, y_probs, average='weighted', multi_class='ovr')
            print(f"AUROC (Weighted, OVR): {auroc:.4f}")

            # PR-AUC
            pr_auc = average_precision_score(y_true_bin, y_probs, average='weighted')
            print(f"PR-AUC (Weighted): {pr_auc:.4f}")

        except ValueError as e:
            print(f"Could not calculate AUROC/PR-AUC: {e}")
    else:
        print("Skipping AUROC/PR-AUC (no probabilities or only one class present).")


# --- 共用的輔助函式 ---
def parse_response(response):
    """解析模型的生成式回應"""
    response = response.lower()
    if "high_risk" in response or "(2)" in response:
        return 2
    elif "mid_risk" in response or "(1)" in response:
        return 1
    elif "low_risk" in response or "(0)" in response:
        return 0
    else:
        # 嘗試只找數字
        match = re.search(r'\b([012])\b', response)
        if match:
            return int(match.group(1))
        else:
            return 0 # 解析失敗時，預設為 low_risk


--- Step 1: Preprocessing ---
Dataset loaded:
DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 16000
    })
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 2000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 2000
    })
})
Dataset after risk mapping:
{'text': 'i didnt feel humiliated', 'label': 0, 'labels': 2}


Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Full test set size: 2000


In [None]:
# --- 2. Zero-shot 推論 (使用 Gemma-Chat) ---
print("\n--- Step 2: Zero-shot Inference (Gemma-Chat, Full 2000) ---") # 更新標題
print("WARNING: This step will run on all 2,000 test samples and will be EXTREMELY SLOW (potentially hours).")

def create_zero_shot_prompt(new_text):
    """建立 Zero-Shot (無範例) 的 Prompt"""
    prompt = "This is a text classification task. Classify the text into one of three risk categories: low_risk (0), mid_risk (1), or high_risk (2).\n\n"
    prompt += "===\n\n"
    prompt += f"Text: {new_text}\nRisk:"
    return prompt

try:
    # 載入 Gemma-Chat 模型
    chat_tokenizer = AutoTokenizer.from_pretrained(GEMMA_CHAT_MODEL_NAME, trust_remote_code=True) # 使用 Gemma Chat 模型名稱
    chat_model = AutoModelForCausalLM.from_pretrained(
        GEMMA_CHAT_MODEL_NAME, # 使用 Gemma Chat 模型名稱
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    )
    if chat_tokenizer.pad_token is None:
        chat_tokenizer.pad_token = chat_tokenizer.eos_token

    print(f"Zero-shot model loaded on device: {chat_model.device}")

    y_pred_zero_shot = []

    # 迭代完整的 2000 筆資料
    for i, test_sample in enumerate(original_test_set_full):
        if (i+1) % 50 == 0 or i == 0:
            print(f"Running Zero-shot sample {i+1}/{len(original_test_set_full)}...")

        prompt_text = create_zero_shot_prompt(test_sample['text'])

        # 格式化為 Chat (Gemma 的 Chat 格式)
        messages = [
            {"role": "user", "content": prompt_text}
        ]
        text = chat_tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        model_inputs = chat_tokenizer([text], return_tensors="pt").to(chat_model.device)

        generated_ids = chat_model.generate(
            model_inputs.input_ids,
            max_new_tokens=10 # 只需要標籤
        )
        # Gemma 的 generate 會包含 prompt，需要移除
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]
        response = chat_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

        y_pred_zero_shot.append(parse_response(response))

    # 評估 (無機率)
    evaluate_model(y_true_test_full, y_pred_zero_shot, y_probs=None, title_prefix="Zero-Shot (Gemma-Chat Full)") # 更新標題

    # 清理 VRAM，為下一步做準備
    del chat_model
    del chat_tokenizer
    torch.cuda.empty_cache()

except Exception as e:
    print(f"Error during Zero-shot inference: {e}. Skipping.")
    # 確保清理
    if 'chat_model' in locals(): del chat_model
    if 'chat_tokenizer' in locals(): del chat_tokenizer
    torch.cuda.empty_cache()


--- Step 2: Zero-shot Inference (Gemma-Chat, Full 2000) ---


tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

Zero-shot model loaded on device: cuda:0
Running Zero-shot sample 1/2000...
Running Zero-shot sample 50/2000...
Running Zero-shot sample 100/2000...
Running Zero-shot sample 150/2000...
Running Zero-shot sample 200/2000...
Running Zero-shot sample 250/2000...
Running Zero-shot sample 300/2000...
Running Zero-shot sample 350/2000...
Running Zero-shot sample 400/2000...
Running Zero-shot sample 450/2000...
Running Zero-shot sample 500/2000...
Running Zero-shot sample 550/2000...
Running Zero-shot sample 600/2000...
Running Zero-shot sample 650/2000...
Running Zero-shot sample 700/2000...
Running Zero-shot sample 750/2000...
Running Zero-shot sample 800/2000...
Running Zero-shot sample 850/2000...
Running Zero-shot sample 900/2000...
Running Zero-shot sample 950/2000...
Running Zero-shot sample 1000/2000...
Running Zero-shot sample 1050/2000...
Running Zero-shot sample 1100/2000...
Running Zero-shot sample 1150/2000...
Running Zero-shot sample 1200/2000...
Running Zero-shot sample 1250/20

In [None]:
# --- 3. Few-shot 推論 (使用 Gemma-Chat, 平衡範例) ---
print("\n--- Step 3: Few-shot Inference (Gemma-Chat, Full 2000) ---") # 更新標題
print("INFO: Using balanced 6-shot examples (one from each emotion).")
print("WARNING: This step will also be EXTREMELY SLOW (potentially hours).")

# 需要 concatenate_datasets 來組合
from datasets import concatenate_datasets

# 我們使用原始的文字資料集
original_train_set = dataset['train']

# 「平衡抽樣」
# 原始標籤: 0: sadness, 1: joy, 2: love, 3: anger, 4: fear, 5: surprise
print("Selecting 6 balanced few-shot examples...")
example_list = []
for i in range(6):
    # 從原始標籤 (label) 過濾，確保每種情緒都有
    example_list.append(
        original_train_set.filter(lambda x: x['label'] == i).select(range(1))
    )

# 將 6 個範例組合起來
few_shot_examples = concatenate_datasets(example_list)
print("Balanced examples selected:")
print(few_shot_examples['text'])
print(few_shot_examples['labels']) # 應該會顯示 [2, 0, 0, 1, 1, 0] (對應的 risk)

def create_few_shot_prompt(examples, new_text):
    """建立 Few-Shot (有範例) 的 Prompt"""
    prompt = "This is a text classification task. Classify the text into one of three risk categories: low_risk (0), mid_risk (1), or high_risk (2).\n\n"
    # 現在 'examples' 會有 6 筆
    for ex in examples:
        prompt += f"Text: {ex['text']}\nRisk: {ID2LABEL[ex['labels']]} ({ex['labels']})\n\n"
    prompt += "===\n\n"
    prompt += f"Text: {new_text}\nRisk:"
    return prompt

try:
    # 載入 Gemma-Chat 模型
    chat_tokenizer = AutoTokenizer.from_pretrained(GEMMA_CHAT_MODEL_NAME, trust_remote_code=True) # 使用 Gemma Chat 模型名稱
    chat_model = AutoModelForCausalLM.from_pretrained(
        GEMMA_CHAT_MODEL_NAME, # 使用 Gemma Chat 模型名稱
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    )
    if chat_tokenizer.pad_token is None:
        chat_tokenizer.pad_token = chat_tokenizer.eos_token

    print(f"Few-shot model loaded on device: {chat_model.device}")

    y_pred_few_shot = []

    # 迭代完整的 2000 筆資料
    for i, test_sample in enumerate(original_test_set_full):
        if (i+1) % 50 == 0 or i == 0:
            print(f"Running Few-shot sample {i+1}/{len(original_test_set_full)}...")

        prompt_text = create_few_shot_prompt(few_shot_examples, test_sample['text'])

        # 格式化為 Chat (Gemma 的 Chat 格式)
        messages = [
            {"role": "user", "content": prompt_text}
        ]
        text = chat_tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        model_inputs = chat_tokenizer([text], return_tensors="pt").to(chat_model.device)

        generated_ids = chat_model.generate(
            model_inputs.input_ids,
            max_new_tokens=10 # 只需要標籤
        )
        # Gemma 的 generate 會包含 prompt，需要移除
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]
        response = chat_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

        y_pred_few_shot.append(parse_response(response))

    # 評估 (無機率)
    evaluate_model(y_true_test_full, y_pred_few_shot, y_probs=None, title_prefix="Few-Shot (Gemma-Chat Full 6-Shot Balanced)") # 更新標題

except Exception as e:
    print(f"Error during Few-shot inference: {e}. Skipping.")
finally:
    # 清理 VRAM
    if 'chat_model' in locals(): del chat_model
    if 'chat_tokenizer' in locals(): del chat_tokenizer
    torch.cuda.empty_cache()


--- Step 3: Few-shot Inference (Gemma-Chat, Full 2000) ---
INFO: Using balanced 6-shot examples (one from each emotion).
Selecting 6 balanced few-shot examples...


Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

Balanced examples selected:
Column(['i didnt feel humiliated', 'i have been with petronas for years i feel that petronas has performed well and made a huge profit', 'i am ever feeling nostalgic about the fireplace i will know that it is still on the property', 'im grabbing a minute to post i feel greedy wrong', 'i feel as confused about life as a teenager or as jaded as a year old man'])
Column([2, 0, 0, 1, 1])


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Few-shot model loaded on device: cuda:0
Running Few-shot sample 1/2000...
Running Few-shot sample 50/2000...
Running Few-shot sample 100/2000...
Running Few-shot sample 150/2000...
Running Few-shot sample 200/2000...
Running Few-shot sample 250/2000...
Running Few-shot sample 300/2000...
Running Few-shot sample 350/2000...
Running Few-shot sample 400/2000...
Running Few-shot sample 450/2000...
Running Few-shot sample 500/2000...
Running Few-shot sample 550/2000...
Running Few-shot sample 600/2000...
Running Few-shot sample 650/2000...
Running Few-shot sample 700/2000...
Running Few-shot sample 750/2000...
Running Few-shot sample 800/2000...
Running Few-shot sample 850/2000...
Running Few-shot sample 900/2000...
Running Few-shot sample 950/2000...
Running Few-shot sample 1000/2000...
Running Few-shot sample 1050/2000...
Running Few-shot sample 1100/2000...
Running Few-shot sample 1150/2000...
Running Few-shot sample 1200/2000...
Running Few-shot sample 1250/2000...
Running Few-shot samp

In [None]:
# --- 4. LoRA / QLoRA Fine-tuning (使用 Gemma-Base) ---
print("\n--- Step 4: LoRA / QLoRA Fine-tuning (Gemma-Base) ---") # 更新標題
try:
    # QLoRA 設定 (4-bit 量化)
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    # 載入基底模型 (用於序列分類) - 使用 Gemma Base 模型名稱
    qlora_model = AutoModelForSequenceClassification.from_pretrained(
        GEMMA_MODEL_NAME, # 使用 Gemma Base 模型
        num_labels=NUM_LABELS,
        quantization_config=bnb_config,
        device_map="auto",
        id2label=ID2LABEL,
        label2id=LABEL2ID,
        trust_remote_code=True
    )

    qlora_model.config.pad_token_id = tokenizer.pad_token_id

    # LoRA 設定
    lora_config = LoraConfig(
        task_type=TaskType.SEQ_CLS,
        r=16,
        lora_alpha=32,
        lora_dropout=0.1,
        bias="none",
        # 鎖定 Gemma 的 attention 模組
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "gemma_act"], # 更新 target_modules
    )

    # 套用 PEFT (LoRA)
    peft_model = get_peft_model(qlora_model, lora_config)
    peft_model.print_trainable_parameters()

    # Data Collator
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    # 訓練參數
    training_args = TrainingArguments(
        output_dir="./gemma-lora-emotion-risk", # 更新輸出目錄
        learning_rate=2e-4, # LoRA 可用較高學習率
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=TRAIN_EPOCHS,
        weight_decay=0.01,
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        logging_steps=100,
        report_to="none", # 關閉 wandb
        bf16=True if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else False,
    )

    # 建立 Trainer
    trainer = Trainer(
        model=peft_model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    # --- 5. (隱含) & 6. 訓練與評估 ---
    print("Starting QLoRA fine-tuning...")
    trainer.train()
    print("Fine-tuning complete.")

    print("\n--- Step 6: Evaluating QLoRA Model ---") # 更新標題

    # 在測試集上預測
    predictions = trainer.predict(test_dataset_full_tokenized) # 使用 tokenized 測試集

    # 處理預測結果
    y_pred_lora = np.argmax(predictions.predictions, axis=1)
    y_probs_lora_logits = torch.from_numpy(predictions.predictions)
    y_probs_lora = torch.nn.functional.softmax(y_probs_lora_logits, dim=1).numpy()
    y_true_lora = predictions.label_ids # 這會等於 y_true_test_full

    # 評估
    evaluate_model(y_true_lora, y_pred_lora, y_probs_lora, title_prefix="QLoRA Fine-Tune (Gemma-Base)") # 更新標題


    # --- 7. 視覺化呈現 (使用 QLoRA 模型的結果) ---
    print("\n--- Step 7: Visualizing QLoRA Results ---") # 更新標題

    # 取得 high_risk (label 2) 的機率
    p_high_risk = y_probs_lora[:, 2]

    # 圖 1: 高風險走勢圖
    try:
        plt.figure(figsize=(15, 5))
        plt.plot(p_high_risk, alpha=0.7, label="P(high_risk)")
        plt.title("High Risk Probability (P(high_risk)) Trend (Gemma-Base QLoRA Model)") # 更新標題
        plt.xlabel("Test Sample Index")
        plt.ylabel("P(high_risk)")
        plt.grid(True, linestyle='--', alpha=0.5)
        plt.savefig("plot_high_risk_trend_gemma.png") # 更新檔案名稱
        print("Saved high-risk trend plot to plot_high_risk_trend_gemma.png")
        plt.close()
    except Exception as e:
        print(f"Error plotting trend: {e}")

    # 圖 2.1: 高風險滾動平均 (Rolling Window)
    try:
        # 使用 pandas 進行滾動平均
        rolling_avg = pd.Series(p_high_risk).rolling(window=50, min_periods=1).mean()
        plt.figure(figsize=(15, 5))
        plt.plot(rolling_avg, color='red')
        plt.title("Rolling Average (Window=50) of P(high_risk) (Gemma-Base QLoRA Model)") # 更新標題
        plt.xlabel("Test Sample Index")
        plt.ylabel("Rolling Avg P(high_risk)")
        plt.grid(True, linestyle='--', alpha=0.5)
        plt.savefig("plot_high_risk_rolling_avg_gemma.png") # 更新檔案名稱
        print("Saved high-risk rolling average plot to plot_high_risk_rolling_avg_gemma.png")
        plt.close()
    except Exception as e:
        print(f"Error plotting rolling average: {e}")

    # 圖 2.2: 高風險濃度熱圖 (1D Heatmap)
    try:
        plt.figure(figsize=(18, 2)) # 寬而短
        sns.heatmap(
            [p_high_risk],
            cmap="rocket",
            cbar=True,
            cbar_kws={'label': 'P(high_risk)', 'orientation': 'horizontal', 'pad': 0.3},
            xticklabels=False,
            yticklabels=False
        )
        plt.title("High Risk Probability Concentration (Gemma-Base QLoRA Model)") # 更新標題
        plt.xlabel("Test Sample Index")
        plt.tight_layout()
        plt.savefig("plot_high_risk_concentration_heatmap_gemma.png") # 更新檔案名稱
        print("Saved high-risk concentration heatmap to plot_high_risk_concentration_heatmap_gemma.png")
        plt.close()
    except Exception as e:
        print(f"Error plotting heatmap: {e}")

except Exception as e:
    print(f"Error during QLoRA fine-tuning or evaluation: {e}")
    print("This may be due to CUDA OOM or other resource constraints.")

print("\n--- Script Finished ---")


--- Step 4: LoRA / QLoRA Fine-tuning (Gemma-Base) ---


model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/481M [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Some weights of Gemma2ForSequenceClassification were not initialized from the model checkpoint at google/gemma-2-2b and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 20,773,632 || all params: 2,635,122,432 || trainable%: 0.7883
Starting QLoRA fine-tuning...


Epoch,Training Loss,Validation Loss
1,0.0664,0.06377


Fine-tuning complete.

--- Step 6: Evaluating QLoRA Model ---



--- QLoRA Fine-Tune (Gemma-Base) Evaluation ---
F1 Score (Weighted): 0.9695
Confusion Matrix:
[[901  15   4]
 [ 11 471  17]
 [  2  12 567]]
Saved confusion matrix to plot_cm_qlora_fine_tune_(gemma_base).png
AUROC (Weighted, OVR): 0.9988
PR-AUC (Weighted): 0.9972

--- Step 7: Visualizing QLoRA Results ---
Saved high-risk trend plot to plot_high_risk_trend_gemma.png
Saved high-risk rolling average plot to plot_high_risk_rolling_avg_gemma.png
Saved high-risk concentration heatmap to plot_high_risk_concentration_heatmap_gemma.png

--- Script Finished ---
