## 1. Import Modules and Data
BERT can be fine-tined on Stanford Sentiment Treebank-2(SST2) dataset for text classification task. More info about SST2 can be found [here](https://huggingface.co/datasets/stanfordnlp/sst2).

In [3]:
import torch
from transformers import BertTokenizer, AdamW
from data import load_data
from modules.optim_schedule import ScheduledOptim
from modules.bert import BERTTextClassifier
import config 
# load sst-2
tokenizer, train_dataloader, valid_dataloader = load_data(
    name="sst2",
    loading_ratio=0.1,  # 加载10%的数据
    num_proc=4,  # 使用4个进程进行处理
    splits=["train", "validation"]  # 加载训练集和验证集
)

# 查看训练集的一个batch
for batch in train_dataloader:
    input_ids, attention_mask, labels = batch
    print(input_ids.shape, attention_mask.shape, labels.shape)
    break


torch.Size([16, 128]) torch.Size([16, 128]) torch.Size([16])


## 2. Build Model and Load from Pre-trained

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load pretrained model
model = BERTTextClassifier.from_pretrained(
    config.pretrained_path,  # 配置文件中的预训练模型目录
    num_frozen_layers=0,  # 冻结层
)
model = model.to(device)

print(model)


Number of trainable parameters: 108.50M
BERTTextClassifier(
  (embedding): BERTEmbedding(
    (token): TokenEmbedding(30522, 768, padding_idx=0)
    (position): PositionalEmbedding()
    (segment): SegmentEmbedding(2, 768, padding_idx=0)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): ModuleList(
    (0): TransformerBlock(
      (self_attn): MultiheadAttention(
        (attention): ScaledDotProductAttention()
        (w_q): Linear(in_features=768, out_features=768, bias=True)
        (w_k): Linear(in_features=768, out_features=768, bias=True)
        (w_v): Linear(in_features=768, out_features=768, bias=True)
        (w_concat): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_1): LayerNorm()
      (ffn): PositionwiseFeedForward(
        (linear1): Linear(in_features=768, out_features=3072, bias=True)
        (linear2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout):

## 3. Train Model

### 3.1 Optimizer and Scheduler
ScheduledOptim class is a wrapper for an optimizer that implements a learning rate scheduling strategy inspired by the Transformer paper (Attention Is All You Need). It adjusts the learning rate using a warm-up and decay mechanism to stabilize training. Check modules/optim_schedule.py for deep understanding.


### 3.2 Train Loop

In [5]:
import torch
from transformers import BertTokenizer
from torch.optim import AdamW

from data import load_data
from modules.optim_schedule import ScheduledOptim
from modules.bert import BERTTextClassifier
import config  

# finetune
pad_idx = 0 # 填充的标记
def split_batch(batch):  
    tokens, attention_mask, labels = batch 
    attention_mask = (tokens != pad_idx).long()  
    input_ids = tokens.to(device)
    attention_mask = attention_mask.to(device)
    labels = labels.to(device)
    return input_ids, attention_mask, labels

from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score
@torch.no_grad()
def evaluate(model, clf_criterion, dataloader):
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []

    for batch in tqdm(dataloader, desc="Evaluating"):
        input_ids, attention_mask, gt = split_batch(batch)

        clf_logits = model(input_ids)  # get classfier logits
        
        # print(f"evaluate_clf_logits: {clf_logits}")

        loss = clf_criterion(clf_logits, gt)
        total_loss += loss.item()

        preds = torch.argmax(clf_logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(gt.cpu().numpy())
        # get predictions and GT
        # print(f"Predictions: {preds.cpu().numpy()}")
        # print(f"Ground Truth: {gt.cpu().numpy()}")

    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_preds)

    return avg_loss, accuracy


In [6]:
def train(epoch, model, scheduled_optimizer, dataloader):
    model.train()
    total_loss = 0
    scheduled_optimizer.zero_grad()  

    for batch in tqdm(dataloader, desc=f"Training Epoch {epoch}"):
        input_ids, attention_mask, labels = split_batch(batch)
        
        # single sentence, requiring no segment info
        segment_info = torch.zeros(input_ids.size(0), input_ids.size(1))

        input_ids = input_ids.to(device).long()  
        attention_mask = attention_mask.to(device).long()  
        segment_info = segment_info.to(device).long()  
        
        
        loss, logits = model(input_ids, segment_info=segment_info, labels=labels)
        
        loss.backward()
        # grad-clip
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # update step and lr
        scheduled_optimizer.step_and_update_lr()  # 使用 ScheduledOptim 的 step_and_update_lr

        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    return avg_loss

In [None]:
def training_loop(model, train_dataloader, valid_dataloader, optimizer, criterion, num_epochs, warmup_steps):

    scheduled_optimizer = ScheduledOptim(
        optimizer=optimizer,
        d_model=768,
        n_warmup_steps=warmup_steps
    )

    for epoch in range(num_epochs):
        # train
        avg_train_loss = train(
            epoch + 1, model, scheduled_optimizer, train_dataloader
        )

        # evaluate
        avg_valid_loss, avg_acc = evaluate(model, criterion, valid_dataloader)

        print(
            f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {avg_train_loss:.4f},",
            f"Validation Loss: {avg_valid_loss:.4f}, Accuracy: {avg_acc * 100:.2f}",
        )

        # save checkpoint
        checkpoint_path = config.checkpoint_dir / f"bert_clf_{epoch + 1}.pth"
        torch.save(
            {
                "epoch": epoch + 1,
                "model": model.state_dict(),
                "optimizer": scheduled_optimizer._optimizer.state_dict(),
                "scheduler": scheduled_optimizer.n_current_steps,  # 这里保存当前的步数
            },
            checkpoint_path,
        )


# AdamW optimizer
optimizer = AdamW(model.parameters(), lr=config.FinetuningConfig.lr, eps=1e-8)

training_loop(model, train_dataloader, valid_dataloader, optimizer,torch.nn.CrossEntropyLoss(), config.FinetuningConfig.n_epoch, config.FinetuningConfig.warmup_steps)
