# Knowledge Distillation with CLS Vector Matching
本 notebook 參考 `knowledge-distill.ipynb`，但將蒸餾目標從 soft label 機率改為教師模型的 [CLS] 向量。學生模型將學習模仿教師模型的 [CLS] 表示。

## 01. 環境安裝（Kaggle GPU, CUDA 12.1 相容）

In [1]:
# 安裝必要套件
!pip install -q torch torchvision torchaudio torchmetrics transformers datasets

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.5/207.5 MB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.1/21.1 MB[0m [31m84.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m13.6 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency

## 02. 套件匯入與全域參數

In [2]:
import os, json, random, math, time
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from sklearn.metrics import confusion_matrix
from tqdm import tqdm
import torch
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from torch import nn
from transformers import (
    AutoTokenizer, AutoModel, AutoModelForSequenceClassification, get_scheduler
)
from torchmetrics.classification import (
    MulticlassAccuracy, MulticlassF1Score, MulticlassPrecision, MulticlassRecall
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 03. 公用函式

In [3]:
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def get_metrics(num_classes, device):
    return {
        'acc': MulticlassAccuracy(num_classes=num_classes, average='micro').to(device),
        'f1_macro': MulticlassF1Score(num_classes=num_classes, average='macro').to(device),
        'f1_weighted': MulticlassF1Score(num_classes=num_classes, average='weighted').to(device),
        'prec_macro': MulticlassPrecision(num_classes=num_classes, average='macro').to(device),
        'recall_macro': MulticlassRecall(num_classes=num_classes, average='macro').to(device),
    }

@torch.no_grad()
def evaluate(model, dataloader, metrics):
    model.eval()
    for metric in metrics.values():
        metric.reset()
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        preds = torch.argmax(outputs.logits, dim=1)
        for name, metric in metrics.items():
            metric.update(preds, labels)
    return {name: metric.compute().item() for name, metric in metrics.items()}

## 04. 載入與準備資料，並產生教師 CLS 向量

In [4]:
df = pd.read_csv('/kaggle/input/taiwan-political-news-dataset/news_training_with_translations.csv')
df['title_content'] = df['title'] + df['content']
df['title_content_en'] = df['title_en'] + df['content_en']

# 載入 Teacher 模型 (取 [CLS] 向量)
teacher_model_path = 'launch/POLITICS'
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_path)
teacher_model = AutoModel.from_pretrained(teacher_model_path).to(device)
teacher_model.eval()

def get_teacher_cls(text):
    inputs = teacher_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
    with torch.no_grad():
        outputs = teacher_model(**inputs)
        cls_vec = outputs.last_hidden_state[:, 0, :]  # 取 [CLS]
    return cls_vec.squeeze().cpu().numpy()

# 對整份資料產生 teacher_cls
df['teacher_cls'] = df['title_content_en'].map(lambda x: get_teacher_cls(x).tolist())

tokenizer_config.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/672 [00:00<?, ?B/s]

2025-06-03 12:38:17.476725: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748954297.695736      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748954297.756688      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


pytorch_model.bin:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at launch/POLITICS and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

## 05. 自訂 Dataset 類別 (含 teacher_cls)

In [5]:
class DistillCLSDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=512):
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = max_len
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        enc = self.tokenizer(
            row['title_content'], max_length=self.max_length,
            padding='max_length', truncation=True, return_tensors='pt'
        )
        item = {key: val.squeeze() for key, val in enc.items()}
        item['label'] = int(row['label_encoded'])
        item['teacher_cls'] = torch.tensor(row['teacher_cls'], dtype=torch.float)
        return item

## 06. 訓練與蒸餾函式 (MSE loss for CLS)

## 07. 進行 5-Fold Cross-Validation (CLS Vector 蒸餾)

In [6]:
def train(model, dataloader, optimizer, scheduler, alpha=0.5, class_weights=None):
    model.train()
    total_loss = 0
    mse_loss_fn = nn.MSELoss()
    if class_weights is not None:
        class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)
        ce_loss_fn = nn.CrossEntropyLoss(weight=class_weights_tensor)
    else:
        ce_loss_fn = nn.CrossEntropyLoss()
    for batch in tqdm(dataloader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        teacher_cls = batch['teacher_cls'].to(device)
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
        student_cls = outputs.hidden_states[-1][:, 0, :]  # 取 [CLS]
        ce_loss = ce_loss_fn(outputs.logits, labels)
        mse_loss = mse_loss_fn(student_cls, teacher_cls)
        loss = alpha * mse_loss + (1 - alpha) * ce_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

In [7]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_scheduler
from torch.optim import AdamW
from torch.utils.data import DataLoader
from sklearn.model_selection import StratifiedKFold
from sklearn.utils.class_weight import compute_class_weight
import torch.nn as nn

# 超參數設定
tag = 'student_ckip_cls'
mname = 'ckiplab/bert-base-chinese'
use_content = True
max_len = 512 if use_content else 128
batch_size = 16
epochs = 5
lr = 2e-5
weight_decay = 0.01
warmup_ratio = 0.1
dropout = 0.1
patience = 2
num_classes = 3

tokenizer = AutoTokenizer.from_pretrained(mname)
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
fold_results = []

for fold, (train_idx, val_idx) in enumerate(kf.split(df, df['label_encoded'])):
    print(f'===== Fold {fold + 1} =====')
    train_df, val_df = df.iloc[train_idx], df.iloc[val_idx]
    labels_in_fold = train_df['label_encoded'].tolist()
    class_weights = compute_class_weight(class_weight='balanced', classes=list(range(num_classes)), y=labels_in_fold)
    print(f'Class weights for fold {fold + 1}: {class_weights}')
    train_dataset = DistillCLSDataset(train_df, tokenizer, max_len=max_len)
    val_dataset = DistillCLSDataset(val_df, tokenizer, max_len=max_len)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    model = AutoModelForSequenceClassification.from_pretrained(
        mname,
        num_labels=num_classes,
        hidden_dropout_prob=dropout,
        attention_probs_dropout_prob=dropout,
    ).to(device)
    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    total_steps = len(train_loader) * epochs
    warmup_steps = int(total_steps * warmup_ratio)
    scheduler = get_scheduler(
        'linear',
        optimizer=optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )
    metrics = get_metrics(num_classes=num_classes, device=device)
    best_score = None
    patience_counter = 0
    for epoch in range(epochs):
        print(f'Epoch {epoch + 1}')
        train_loss = train(model, train_loader, optimizer, scheduler, class_weights=class_weights)
        eval_result = evaluate(model, val_loader, metrics)
        print(f'Train loss: {train_loss:.4f}, Eval: {eval_result}')
        current_score = eval_result['f1_macro']
        if best_score is None or current_score > best_score:
            best_score = current_score
            patience_counter = 0
            model.save_pretrained(f'./{tag}_fold{fold}')
            tokenizer.save_pretrained(f'./{tag}_fold{fold}')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f'Early stopping at epoch {epoch + 1}')
                break
    fold_results.append(eval_result)

tokenizer_config.json:   0%|          | 0.00/174 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/701 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/110k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

===== Fold 1 =====
Class weights for fold 1: [1.5234657  0.70686767 1.07653061]


pytorch_model.bin:   0%|          | 0.00/409M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ckiplab/bert-base-chinese and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1


  0%|          | 0/159 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/409M [00:00<?, ?B/s]

100%|██████████| 159/159 [02:44<00:00,  1.04s/it]


Train loss: 0.7675, Eval: {'acc': 0.6403785347938538, 'f1_macro': 0.60112464427948, 'f1_weighted': 0.6284379363059998, 'prec_macro': 0.6356787085533142, 'recall_macro': 0.59917813539505}
Epoch 2


100%|██████████| 159/159 [02:44<00:00,  1.03s/it]


Train loss: 0.4648, Eval: {'acc': 0.7113564610481262, 'f1_macro': 0.7026959657669067, 'f1_weighted': 0.7147173881530762, 'prec_macro': 0.6979424953460693, 'recall_macro': 0.7185636758804321}
Epoch 3


100%|██████████| 159/159 [02:44<00:00,  1.04s/it]


Train loss: 0.3619, Eval: {'acc': 0.7350157499313354, 'f1_macro': 0.7178173065185547, 'f1_weighted': 0.7342883348464966, 'prec_macro': 0.7164022326469421, 'recall_macro': 0.7234266400337219}
Epoch 4


100%|██████████| 159/159 [02:44<00:00,  1.04s/it]


Train loss: 0.2841, Eval: {'acc': 0.7350157499313354, 'f1_macro': 0.723602294921875, 'f1_weighted': 0.7375081777572632, 'prec_macro': 0.7184553146362305, 'recall_macro': 0.7332620620727539}
Epoch 5


100%|██████████| 159/159 [02:44<00:00,  1.04s/it]


Train loss: 0.2214, Eval: {'acc': 0.7334384918212891, 'f1_macro': 0.7257581949234009, 'f1_weighted': 0.7371573448181152, 'prec_macro': 0.7214058041572571, 'recall_macro': 0.7431294918060303}
===== Fold 2 =====
Class weights for fold 2: [1.52406739 0.70714685 1.07558386]


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ckiplab/bert-base-chinese and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1


100%|██████████| 159/159 [02:44<00:00,  1.04s/it]


Train loss: 0.7132, Eval: {'acc': 0.5418641567230225, 'f1_macro': 0.5336098670959473, 'f1_weighted': 0.5000643134117126, 'prec_macro': 0.6543775796890259, 'recall_macro': 0.6412322521209717}
Epoch 2


100%|██████████| 159/159 [02:45<00:00,  1.04s/it]


Train loss: 0.4208, Eval: {'acc': 0.7677724957466125, 'f1_macro': 0.7659031748771667, 'f1_weighted': 0.7676311731338501, 'prec_macro': 0.7596344947814941, 'recall_macro': 0.7789400219917297}
Epoch 3


100%|██████████| 159/159 [02:44<00:00,  1.04s/it]


Train loss: 0.3103, Eval: {'acc': 0.7266982793807983, 'f1_macro': 0.726544976234436, 'f1_weighted': 0.726726770401001, 'prec_macro': 0.7352079153060913, 'recall_macro': 0.7695929408073425}
Epoch 4


100%|██████████| 159/159 [02:44<00:00,  1.04s/it]


Train loss: 0.2224, Eval: {'acc': 0.7551342844963074, 'f1_macro': 0.7497968673706055, 'f1_weighted': 0.7548509240150452, 'prec_macro': 0.7487032413482666, 'recall_macro': 0.7576009035110474}
Early stopping at epoch 4
===== Fold 3 =====
Class weights for fold 3: [1.52406739 0.70714685 1.07558386]


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ckiplab/bert-base-chinese and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1


100%|██████████| 159/159 [02:44<00:00,  1.04s/it]


Train loss: 0.7145, Eval: {'acc': 0.6793048977851868, 'f1_macro': 0.6313817501068115, 'f1_weighted': 0.6622200608253479, 'prec_macro': 0.7327060103416443, 'recall_macro': 0.6096360683441162}
Epoch 2


100%|██████████| 159/159 [02:44<00:00,  1.04s/it]


Train loss: 0.4199, Eval: {'acc': 0.7203791737556458, 'f1_macro': 0.7152518033981323, 'f1_weighted': 0.7180999517440796, 'prec_macro': 0.7353862524032593, 'recall_macro': 0.7269697785377502}
Epoch 3


100%|██████████| 159/159 [02:44<00:00,  1.04s/it]


Train loss: 0.3099, Eval: {'acc': 0.7677724957466125, 'f1_macro': 0.759808361530304, 'f1_weighted': 0.7685845494270325, 'prec_macro': 0.7544221878051758, 'recall_macro': 0.7670352458953857}
Epoch 4


100%|██████████| 159/159 [02:45<00:00,  1.04s/it]


Train loss: 0.2208, Eval: {'acc': 0.7677724957466125, 'f1_macro': 0.7593727111816406, 'f1_weighted': 0.768185555934906, 'prec_macro': 0.7551373243331909, 'recall_macro': 0.7661914825439453}
Epoch 5


100%|██████████| 159/159 [02:45<00:00,  1.04s/it]


Train loss: 0.1682, Eval: {'acc': 0.7551342844963074, 'f1_macro': 0.7476608157157898, 'f1_weighted': 0.7564740180969238, 'prec_macro': 0.7408839464187622, 'recall_macro': 0.7589603662490845}
Early stopping at epoch 5
===== Fold 4 =====
Class weights for fold 4: [1.52682339 0.70655509 1.07558386]


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ckiplab/bert-base-chinese and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1


100%|██████████| 159/159 [02:45<00:00,  1.04s/it]


Train loss: 0.7188, Eval: {'acc': 0.551342785358429, 'f1_macro': 0.5578203201293945, 'f1_weighted': 0.5430013537406921, 'prec_macro': 0.6019439697265625, 'recall_macro': 0.6187837719917297}
Epoch 2


100%|██████████| 159/159 [02:45<00:00,  1.04s/it]


Train loss: 0.3951, Eval: {'acc': 0.6413902044296265, 'f1_macro': 0.6399282813072205, 'f1_weighted': 0.6355011463165283, 'prec_macro': 0.6776048541069031, 'recall_macro': 0.6646516919136047}
Epoch 3


100%|██████████| 159/159 [02:45<00:00,  1.04s/it]


Train loss: 0.2933, Eval: {'acc': 0.7393364906311035, 'f1_macro': 0.7314118146896362, 'f1_weighted': 0.7401139736175537, 'prec_macro': 0.7270792126655579, 'recall_macro': 0.737011730670929}
Epoch 4


100%|██████████| 159/159 [02:45<00:00,  1.04s/it]


Train loss: 0.2096, Eval: {'acc': 0.7440758347511292, 'f1_macro': 0.7321245670318604, 'f1_weighted': 0.7438106536865234, 'prec_macro': 0.7330218553543091, 'recall_macro': 0.7334973812103271}
Epoch 5


100%|██████████| 159/159 [02:45<00:00,  1.04s/it]


Train loss: 0.1624, Eval: {'acc': 0.7503949403762817, 'f1_macro': 0.7403512001037598, 'f1_weighted': 0.7504347562789917, 'prec_macro': 0.7401707768440247, 'recall_macro': 0.7405363321304321}
===== Fold 5 =====
Class weights for fold 5: [1.52682339 0.70655509 1.07558386]


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ckiplab/bert-base-chinese and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1


100%|██████████| 159/159 [02:45<00:00,  1.04s/it]


Train loss: 0.7170, Eval: {'acc': 0.6919431090354919, 'f1_macro': 0.6879408359527588, 'f1_weighted': 0.6913821697235107, 'prec_macro': 0.6807147264480591, 'recall_macro': 0.7035756707191467}
Epoch 2


100%|██████████| 159/159 [02:45<00:00,  1.04s/it]


Train loss: 0.4145, Eval: {'acc': 0.7424960732460022, 'f1_macro': 0.7344478368759155, 'f1_weighted': 0.7425670623779297, 'prec_macro': 0.738959789276123, 'recall_macro': 0.7320330142974854}
Epoch 3


100%|██████████| 159/159 [02:45<00:00,  1.04s/it]


Train loss: 0.3094, Eval: {'acc': 0.7646129727363586, 'f1_macro': 0.7580678462982178, 'f1_weighted': 0.7653297185897827, 'prec_macro': 0.759589433670044, 'recall_macro': 0.7664188146591187}
Epoch 4


100%|██████████| 159/159 [02:45<00:00,  1.04s/it]


Train loss: 0.2267, Eval: {'acc': 0.7740916013717651, 'f1_macro': 0.7673786878585815, 'f1_weighted': 0.7736951112747192, 'prec_macro': 0.7754917144775391, 'recall_macro': 0.76812744140625}
Epoch 5


100%|██████████| 159/159 [02:45<00:00,  1.04s/it]


Train loss: 0.1817, Eval: {'acc': 0.7819905281066895, 'f1_macro': 0.7776935696601868, 'f1_weighted': 0.7828954458236694, 'prec_macro': 0.769996166229248, 'recall_macro': 0.7919963598251343}


## 08. 匯總結果

In [8]:
results_df = pd.DataFrame(fold_results)
print('平均指標：')
print(results_df.mean())
results_df.to_csv('student_ckip_cls_5fold_metrics.csv', index=False)

平均指標：
acc             0.755219
f1_macro        0.748252
f1_weighted     0.756362
prec_macro      0.744232
recall_macro    0.758445
dtype: float64


## 09. 混淆矩陣與預測儲存

In [9]:
import seaborn as sns
import matplotlib.pyplot as plt
for fold, (train_idx, val_idx) in enumerate(kf.split(df, df['label_encoded'])):
    print(f'===== Fold {fold+1} 評估與儲存預測 =====')
    val_df = df.iloc[val_idx].copy()
    val_dataset = DistillCLSDataset(val_df, tokenizer)
    val_loader = DataLoader(val_dataset, batch_size=16)
    model_path = f'./student_ckip_cls_fold{fold}'
    model = AutoModelForSequenceClassification.from_pretrained(model_path).to(device)
    metrics = get_metrics(num_classes=3, device=device)
    # 取得預測
    all_preds, all_labels = [], []
    model.eval()
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            preds = torch.argmax(outputs.logits, dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
    preds = torch.cat(all_preds)
    labels = torch.cat(all_labels)
    val_df['pred'] = preds.numpy()
    val_df['label_encoded'] = labels.numpy()
    val_df.to_csv(f'predictions_cls_fold{fold}.csv', index=False)
    # 混淆矩陣繪圖
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(5, 4))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['偏藍', '偏綠', '中立'], yticklabels=['偏藍', '偏綠', '中立'])
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title(f'Confusion Matrix - Fold {fold}')
    plt.tight_layout()
    plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
    plt.close()

===== Fold 1 評估與儲存預測 =====


  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')


===== Fold 2 評估與儲存預測 =====


  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')


===== Fold 3 評估與儲存預測 =====


  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')


===== Fold 4 評估與儲存預測 =====


  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')


===== Fold 5 評估與儲存預測 =====


  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  fig.canvas.draw()
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
  plt.savefig(f'confusion_matrix_cls_fold{fold}.png')
