In [1]:
import math
import time
import torch
import preprocess
import perturbation
from torch.nn import MSELoss
import torch.nn.functional as F
from tqdm.auto import tqdm

In [2]:
# 设置任务和模型
task_name = "SST-2" # ["CoLA", "SST-2", "MRPC", "STS-B", "QQP", "MNLI-m", "MNLI-mm", "QNLI", "RTE", "WNLI"]
model_name = "models/bert-base-uncased" # [bert-base-uncased, roberta-base]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
# 设置超参数
# 训练参数
seed = 42
batch_size = 32
epochs = 2
# 对抗训练参数
adv_init_epsilon = 1e-2 # 初始化扰动
adv_init_type = "zero" # ["zero","rand","randn"]
sampling_times_theta = 5 # theta采样次数
sampling_times_delta = 3 # delta采样次数
sampling_noise_theta = 0 # 采样噪声
sampling_noise_delta = 0 # 采样噪声
sampling_step_theta = 3e-5 # theta采样步长
sampling_step_delta = 1e-2 # theta采样步长
lambda_s = 1 # lambda λ
beta = 0.1 # beta β

In [4]:
model, dataloader, metric= preprocess.preprocess(task_name, model_name, batch_size, seed)
train_dataloader,eval_dataloader,test_dataloader = dataloader

Reusing dataset glue (/home/ailab/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


---------------- load the dataset ----------------
[Notice]: loading dataset...


  0%|          | 0/3 [00:00<?, ?it/s]

[Notice]: dataset sst2 is loaded.
--------------------------------------------------
-------- load the tokenizer and the model --------
[Notice]: loading tokenizer and model...


Some weights of the model checkpoint at models/bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkp

[Notice]: tokenizer and model are loaded.
--------------------------------------------------
-------------- tokenize the dataset --------------
[Notice]: tokenizing the dataset...
[Notice]: the dataset is tokenized.
--------------------------------------------------
-------------- make the dataloader ---------------
[Notice]: making dataloader...
[Notice]: the dataloader is made.
--------------------------------------------------


In [5]:
def ls(P, Q):
    task_type = "classification" if task_name != "STS-B" else "regression"
    if(task_type == "classification"):
        return F.kl_div(P.softmax(dim=-1).log(), Q.softmax(dim=-1), reduction='batchmean') + F.kl_div(Q.softmax(dim=-1).log(), P.softmax(dim=-1), reduction='batchmean')
    elif(task_type == "regression"):
        return MSELoss(P, Q, reduction="sum")

In [6]:
def SGLD(z, grad, step, epsilon):
    noise = perturbation.init_delta(z.size(), epsilon=epsilon, init_type="randn")
    z = z - step * grad + math.sqrt(2 * step) * noise
    return z

In [7]:
train_start = time.time()
file = None # 设置日志文件

In [8]:
# Training
print(time.ctime(), file=file)
print("*"*20, "Training", "*"*20, file=file)  # 训练任务
print("TASK:", task_name, file=file)
print("MODEL:", model_name, file=file)
print("DEVICE:", device, file=file)
print("="*16, "General Training", "="*16, file=file)  # 常规训练参数
print("EPOCH_NUM:", epochs, file=file)
print("BATCH_SIZE:", batch_size, file=file)
print("="*18, "MAT Training", "="*18, file=file)  # MAT训练参数
print("Adversarial_Training_type:", "MAT", file=file)
print("Adversarial_init_epsilon:", adv_init_epsilon, file=file)
print("Adversarial_init_type:", adv_init_type, file=file)
print("Sampling_times_theta:", sampling_times_theta, file=file)
print("Sampling_times_delta:", sampling_times_delta, file=file)
print("Sampling_noise_theta:", sampling_noise_theta, file=file)
print("Sampling_noise_delta:", sampling_noise_delta, file=file)
print("Sampling_step_theta:", sampling_step_theta, file=file)
print("Sampling_step_delta:", sampling_step_delta, file=file)
print("lambda:", lambda_s, file=file)
print("beta:", beta, file=file)
print("*"*50, file=file)
model.to(device)
progress_bar = tqdm(range(epochs * len(train_dataloader)))
progress_bar.set_description("Training...")
eval_metric_list = []
for i in range(epochs):
    print("-"*20, "EPOCH:", i, "-"*20, file=file)
    print("Training...", end='', file=file)
    model.train()

    for batch in train_dataloader:
        batch = {key: batch[key].to(device) for key in batch}

        # [begin] MAT Training
        # 1.init delta & inputs
        ## 1.1 获得batch的word_embedding
        if "bert-" in model_name:
            word_embedding = model.bert.embeddings.word_embeddings(batch["input_ids"])
        elif "roberta-" in model_name:
            word_embedding = model.roberta.embeddings.word_embeddings(batch["input_ids"])
        
        ## 1.2 初始化[抽样扰动delta_k]和[分布扰动mean_delta]
        delta = perturbation.init_delta(word_embedding.size(), epsilon=adv_init_epsilon, init_type=adv_init_type)
        delta.requires_grad = True
        mean_delta = delta.detach().clone() # 初始化delta的分布均值mean_delta

        ## 1.3 初始化模型输入inputs
        if "bert-" in model_name:  # bert模型输入inputs: "attention_mask","labels","token_type_ids",「inputs_embeds」和「input_ids」参数二选一
            inputs = {"attention_mask": batch["attention_mask"],"labels": batch["labels"], "token_type_ids": batch["token_type_ids"]}
        elif "roberta-" in model_name:  # roberta模型输入inputs: "attention_mask","labels",「inputs_embeds」和「input_ids」参数二选一
            inputs = {"attention_mask": batch["attention_mask"], "labels": batch["labels"]}

        ## 1.4 备份模型参数
        back_parameters = model.state_dict()
        mean_theta = model.state_dict()

        # 2.stochastic gradient langevin dynamics sampling
        ## 2.1 sampling perturbation (delta)
        for k in range(sampling_times_delta):
            ### 构造带有扰动的输入
            inputs["inputs_embeds"] = delta + word_embedding.detach()
            ### 前向传播
            loss_adv = ls(model(**inputs).logits, model(**batch).logits)
            ### 反向传播
            loss_adv.backward()
            ### SGLD采样
            delta.data = SGLD(delta.data, - delta.grad, sampling_step_delta, sampling_noise_delta)
            delta.grad = None
            ### 更新扰动的分布均值
            mean_delta.data = beta * mean_delta.data + (1 - beta) * delta.data

        ## 2.2 sampling model parameters (theta)
        for k in range(sampling_times_theta):
            ### 清空模型参数的梯度
            for p in model.parameters():
                if p.grad!=None:
                    p.grad.zero_()
            ### 构造带有扰动的输入
            if "bert-" in model_name:
                word_embedding = model.bert.embeddings.word_embeddings(batch["input_ids"])
            elif "roberta-" in model_name:
                word_embedding = model.roberta.embeddings.word_embeddings(batch["input_ids"])
            inputs["inputs_embeds"] = mean_delta + word_embedding
            ### 前向传播
            loss_sum = model(**batch).loss + lambda_s * ls(model(**inputs).logits, model(**batch).logits)
            ### 反向传播
            loss_sum.backward()
            ### SGLD采样
            for p in model.parameters():
                p.data = SGLD(p.data, p.grad, sampling_step_theta, sampling_noise_theta)
            ### 更新模型参数的分布均值
            new_model_param = model.state_dict()
            for name in mean_theta:
                mean_theta[name] = beta * mean_theta[name] + (1 - beta) * new_model_param[name]

        # 3.update model parameters
        for name in back_parameters:
            back_parameters[name] = beta * back_parameters[name] + (1 - beta) * mean_theta[name]
        model.load_state_dict(back_parameters) # 更新这次迭代的模型参数
        # [end] MAT Training
        progress_bar.update(1)

    print("\rEvaling...", end='', file=file)
    model.eval()
    for batch in eval_dataloader:
        batch = {key: batch[key].to(device) for key in batch}
        with torch.no_grad():
            outputs = model(**batch)
        predictions = outputs.logits.argmax(dim=-1) if task_name != "STS-B" else outputs.logits.squeeze()
        metric.add_batch(predictions=predictions, references=batch["labels"])
    score = metric.compute()
    eval_metric_list.append(score)
    print("\rMetric:", score, file=file)
    print("-"*50, file=file)

# Best score in eval
score_list = []
for m in eval_metric_list:
    score_list.append(list(m.values())[0])
print("*"*19, "Best Score", "*"*19, file=file)
print("EPOCH:", score_list.index(max(score_list)), file=file)
print("Metric:", eval_metric_list[score_list.index(max(score_list))], file=file)
print("*"*50, file=file)

Tue May 17 14:29:51 2022
******************** Training ********************
TASK: SST-2
MODEL: models/bert-base-uncased
DEVICE: cuda:0
EPOCH_NUM: 2
BATCH_SIZE: 32
Adversarial_Training_type: MAT
Adversarial_init_epsilon: 0.01
Adversarial_init_type: zero
Sampling_times_theta: 5
Sampling_times_delta: 3
Sampling_noise_theta: 0
Sampling_noise_delta: 0
Sampling_step_theta: 3e-05
Sampling_step_delta: 0.01
lambda: 1
beta: 0.1
**************************************************


  0%|          | 0/4210 [00:00<?, ?it/s]

-------------------- EPOCH: 0 --------------------
Training...139747238954384
139747238954384
139747238954384


SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
delta.grad

In [None]:
file.close()