In [1]:
import torch
import math
import preprocess
import perturbation
from torch.nn import MSELoss
import torch.nn.functional as F
from tqdm.auto import tqdm

In [2]:
# 设置任务和模型
task_name = "SST-2" # ["CoLA", "SST-2", "MRPC", "STS-B", "QQP", "MNLI-m", "MNLI-mm", "QNLI", "RTE", "WNLI"]
model_name = "models/bert-base-uncased" # [bert-base-uncased, roberta-base]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
# 设置超参数
# 训练参数
SEED = 42
BATCH_SIZE = 32
EPOCH = 3
# 对抗训练参数
Adv_epsilon = 1e-3 # 初始化扰动
Adv_init_type = "zero" # ["zero","rand","randn"]
lambda_s = 1
beta = 0
sampling_times = 10
sampling_epsilon = 0
sampling_step = 3e-5

In [4]:
model, dataloader, metric= preprocess.preprocess(task_name, model_name, BATCH_SIZE, SEED)
train_dataloader,eval_dataloader,test_dataloader = dataloader

Reusing dataset glue (/home/ailab/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


------------------ load dataset ------------------
[Notice]: loading dataset...


  0%|          | 0/3 [00:00<?, ?it/s]

[Notice]: dataset sst2 loaded.
--------------------------------------------------
------------ load tokenizer and model ------------
[Notice]: loading tokenizer and model...


Some weights of the model checkpoint at models/bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkp

[Notice]: tokenizer and model loaded.
--------------------------------------------------
------------- tokenize the dataset -------------
[Notice]: tokenizing the dataset...
[Notice]: the dataset is tokenized.
--------------------------------------------------
----------------- make dataloader ------------------
[Notice]: making dataloader...
[Notice]: make dataloader is done.
--------------------------------------------------


In [5]:
def ls(P, Q):
    task_type = "classification" if task_name != "STS-B" else "regression"
    if(task_type == "classification"):
        return F.kl_div(P.softmax(dim=-1).log(), Q.softmax(dim=-1), reduction='batchmean') + F.kl_div(Q.softmax(dim=-1).log(), P.softmax(dim=-1), reduction='batchmean')
    elif(task_type == "regression"):
        return MSELoss(P, Q, reduction="sum")

In [6]:
def SGLD(z_data,z_grad,z_size,step,epsilon):
    noise = perturbation.init_delta(z_size, epsilon=epsilon, init_type="randn").to(device)
    z_data = z_data - step * z_grad + math.sqrt(2 * step) * noise
    return z_data

In [7]:
# Training
print("*"*20, "Training", "*"*20)  # 训练任务
print("TASK:", task_name)
print("MODEL:", model_name)
print("DEVICE:", device)
print("="*16, "General Training", "="*16)  # 常规训练参数
print("EPOCH_NUM:", EPOCH)
print("BATCH_SIZE:", BATCH_SIZE)
print("="*14, "MAT Training", "="*14)  # MAT训练参数
print("Adversarial_Training_type:", "MAT")
print("lambda_s:", lambda_s)
print("beta:", beta)
print("*"*50)
model.to(device)
progress_bar = tqdm(range(EPOCH * len(train_dataloader)))
eval_metric_list = []
for i in range(EPOCH):
    print("-"*20, "EPOCH:", i, "-"*20)
    print("Training...", end='')
    model.train()

    for batch in train_dataloader:
        for t in batch:
            batch[t] = batch[t].to(device)

        # [begin] MAT Training
        # 1.init delta & inputs
        ## 1.1 获得batch的word_embedding
        if "bert-" in model_name:
            word_embedding = model.bert.embeddings.word_embeddings(batch["input_ids"])
        elif "roberta-" in model_name:
            word_embedding = model.roberta.embeddings.word_embeddings(batch["input_ids"])
        
        ## 1.2 初始化[抽样扰动delta_k]和[分布扰动mean_delta]
        delta_k = perturbation.init_delta(word_embedding.size(), epsilon=Adv_epsilon, init_type=Adv_init_type).to(device)
        delta_k.requires_grad = True
        mean_delta = delta_k  # 初始化delta的分布均值mean_delta

        ## 1.3 初始化模型输入inputs
        if "bert-" in model_name:  # bert模型输入inputs: "attention_mask","labels","token_type_ids",「inputs_embeds」和「input_ids」参数二选一
            inputs = {"attention_mask": batch["attention_mask"],"labels": batch["labels"], "token_type_ids": batch["token_type_ids"]}
        elif "roberta-" in model_name:  # roberta模型输入inputs: "attention_mask","labels",「inputs_embeds」和「input_ids」参数二选一
            inputs = {"attention_mask": batch["attention_mask"], "labels": batch["labels"]}

        ## 1.4 备份模型参数
        back_parameters = model.parameters()
        mean_theta = model.parameters()

        # 2.stochastic gradient langevin dynamics sampling
        ## 2.1 sampling perturbation (delta)
        for k in range(sampling_times):
            ### 构造带有扰动的输入
            inputs["inputs_embeds"] = delta_k + word_embedding.detach()
            ### 前向传播
            loss_adv = ls(model(**inputs).logits, model(**batch).logits)
            ### 反向传播
            loss_adv.backward()
            ### SGLD采样
            delta_k.data = SGLD(delta_k.data, delta_k.grad.data, delta_k.size(), sampling_step, sampling_epsilon)
            delta_k.requires_grad = True
            ### 更新扰动的分布均值
            mean_delta = beta * mean_delta + (1 - beta) * delta_k
        
        ## 2.2 sampling model parameters (theta)
        for k in range(sampling_times):
            ### 清空模型参数的梯度
            for p in model.parameters():
                p.grad.zero_()
            ### 构造带有扰动的输入
            inputs["inputs_embeds"] = mean_delta.detach() + word_embedding.detach()
            ### 前向传播
            loss_sum = model(**batch).loss + lambda_s * ls(model(**inputs).logits, model(**batch).logits)
            ### 反向传播
            loss_sum.backward()
            ### SGLD采样
            for p in model.parameters():
                p.data = SGLD(p.data, p.grad.data, p.size(), sampling_step, sampling_epsilon)
            ### 更新模型参数的分布均值
            for p,q in zip(mean_theta, model.parameters()):
                p.data = beta * p.data + (1 - beta) * q.data

        # 3.update model parameters
        for p,q in zip(model.parameters(), back_parameters): # 还原模型参数为上一次迭代
            p.data = q.data
        for p,q in zip(model.parameters(), mean_theta): # 更新这次迭代的模型参数
            p.data = beta * p.data + (1 - beta) * q.data
        # [end] MAT Training
        progress_bar.update(1)

    print("\rEvaling...", end='')
    model.eval()
    for batch in eval_dataloader:
        for t in batch:
            batch[t] = batch[t].to(device)
        with torch.no_grad():
            outputs = model(**batch)
        predictions = outputs.logits.argmax(dim=-1) if task_name != "STS-B" else outputs.logits.squeeze()
        metric.add_batch(predictions=predictions, references=batch["labels"])
    score = metric.compute()
    eval_metric_list.append(score)
    print("\rMetric:", score)
    print("-"*50)

# Best score in eval
score_list = []
for m in eval_metric_list:
    score_list.append(list(m.values())[0])
print("*"*19, "Best Score", "*"*19)
print("EPOCH:", score_list.index(max(score_list)))
print("Metric:", eval_metric_list[score_list.index(max(score_list))])
print("*"*50)


******************** Training ********************
TASK: SST-2
MODEL: models/bert-base-uncased
DEVICE: cuda:0
EPOCH_NUM: 3
BATCH_SIZE: 32
Adversarial_Training_type: MAT
lambda_s: 1
beta: 0
**************************************************


  0%|          | 0/6315 [00:00<?, ?it/s]

-------------------- EPOCH: 0 --------------------
Metric: {'accuracy': 0.8394495412844036}
--------------------------------------------------
-------------------- EPOCH: 1 --------------------
Metric: {'accuracy': 0.8795871559633027}
--------------------------------------------------
-------------------- EPOCH: 2 --------------------
Metric: {'accuracy': 0.8784403669724771}
--------------------------------------------------
******************* Best Score *******************
EPOCH: 1
Metric: {'accuracy': 0.8795871559633027}
**************************************************
