# Fine-tuning on the reference dataset (Transformer)
本 notebook 使用 Transformer 替换 ResNet CNN，对 30 种菌株的 Raman 光谱进行微调演示，与原始 notebook 步骤一致，仅修改模型实现。

In [1]:
from time import time
t00 = time()
import numpy as np

In [2]:
import torch
print(torch.__version__)                 # 应显示 cu121 / cu122 等后缀
print(torch.cuda.is_available())         # True
print(torch.cuda.get_device_name(0))     # 'NVIDIA GeForce RTX 5090'

2.7.1+cu128
True
NVIDIA GeForce RTX 5090


## Loading data

In [3]:
X_fn = './data/X_finetune.npy'
y_fn = './data/y_finetune.npy'
X = np.load(X_fn)
y = np.load(y_fn)
print(X.shape, y.shape)

(3000, 1000) (3000,)


## Loading pre-trained Transformer

In [4]:
from transformer import SpectraTransformer
import os, torch

In [5]:
# Transformer parameters
input_dim = 1000
n_classes = 30
d_model = 128
nhead = 4
num_layers = 4
os.environ['CUDA_VISIBLE_DEVICES'] = '{}'.format(0)
cuda = torch.cuda.is_available()

model = SpectraTransformer(input_dim=input_dim, d_model=d_model, nhead=nhead,
    num_layers=num_layers, n_classes=n_classes)
if cuda: model.cuda()
# 如有已训练 checkpoint，可在此载入
ckpt_path = './pretrained_transformer_model.ckpt'
if os.path.exists(ckpt_path):
    model.load_state_dict(torch.load(ckpt_path, map_location=lambda s,l: s), strict=False)



## Fine-tuning

In [6]:
from datasets import spectral_dataloader
from training import run_epoch
from torch import optim

# Train/val split
p_val = 0.1
n_val = int(len(y) * p_val)
idxs = np.random.permutation(len(y))
idx_val, idx_tr = idxs[:n_val], idxs[n_val:]

epochs = 60  # 更换为 ~30 以达到论文精度
batch_size = 32
optimizer = optim.Adam(model.parameters(), lr=5e-4, betas=(0.5, 0.999))

dl_tr = spectral_dataloader(X, y, idxs=idx_tr, batch_size=batch_size, shuffle=True)
dl_val = spectral_dataloader(X, y, idxs=idx_val, batch_size=batch_size, shuffle=False)

best_val = 0
print('Starting fine-tuning!')
for epoch in range(epochs):
    acc_tr, _ = run_epoch(epoch, model, dl_tr, cuda, training=True, optimizer=optimizer)
    acc_val, _ = run_epoch(epoch, model, dl_val, cuda, training=False)
    print(f'Epoch {epoch+1}: Train {acc_tr:.2f}%  Val {acc_val:.2f}%')
    if acc_val > best_val:
        best_val = acc_val
        torch.save(model.state_dict(), 'finetuned_transformer_model.ckpt')

print(f'Finished in {time() - t00:.2f}s, best val acc={best_val:.2f}%')

Starting fine-tuning!
Epoch 1: Train 3.37%  Val 3.33%
Epoch 2: Train 3.56%  Val 1.67%
Epoch 3: Train 4.22%  Val 5.00%
Epoch 4: Train 6.33%  Val 4.67%
Epoch 5: Train 11.04%  Val 15.33%
Epoch 6: Train 18.56%  Val 15.00%
Epoch 7: Train 24.26%  Val 24.00%
Epoch 8: Train 28.41%  Val 23.33%
Epoch 9: Train 32.93%  Val 35.67%
Epoch 10: Train 39.96%  Val 44.67%
Epoch 11: Train 46.63%  Val 30.33%
Epoch 12: Train 49.41%  Val 52.67%
Epoch 13: Train 53.81%  Val 56.00%
Epoch 14: Train 56.07%  Val 55.00%
Epoch 15: Train 60.74%  Val 64.33%
Epoch 16: Train 62.52%  Val 61.33%
Epoch 17: Train 63.33%  Val 58.67%
Epoch 18: Train 65.56%  Val 58.00%
Epoch 19: Train 67.33%  Val 67.33%
Epoch 20: Train 68.41%  Val 67.00%
Epoch 21: Train 69.85%  Val 63.00%
Epoch 22: Train 72.19%  Val 71.33%
Epoch 23: Train 73.63%  Val 71.67%
Epoch 24: Train 73.78%  Val 76.33%
Epoch 25: Train 74.59%  Val 72.33%
Epoch 26: Train 75.00%  Val 75.33%
Epoch 27: Train 76.85%  Val 69.00%
Epoch 28: Train 78.00%  Val 71.33%
Epoch 29: Train