In [None]:
import torch
import preprocess
import perturbation
from torch.optim import AdamW
from transformers import get_scheduler
from torch.nn import MSELoss
import torch.nn.functional as F
from tqdm.auto import tqdm

In [None]:
# 设置任务和模型
task_name = "SST-2" # ["CoLA", "SST-2", "MRPC", "STS-B", "QQP", "MNLI-m", "MNLI-mm", "QNLI", "RTE", "WNLI"]
model_name = "models/bert-base-uncased" # [bert-base-uncased, roberta-base]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# 设置超参数
# 训练参数
SEED = 42
BATCH_SIZE = 32
EPOCH = 10
LEARNING_RATE = 3e-5
# 对抗训练参数
Adv_step = 3
Adv_epsilon = 1e-2
Adv_max_norm = 2e-2  # 0表示不限制扰动大小
Adv_init_type = "zero" # ["zero","rand","randn"]
lambda_s = 1
mu = 1
beta = 0.999

In [None]:
model, dataloader, metric= preprocess.preprocess(task_name, model_name, BATCH_SIZE, SEED)
train_dataloader,eval_dataloader,test_dataloader = dataloader

In [None]:
# 设置优化器
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
# lr_scheduler_name = ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=EPOCH * len(train_dataloader))

In [None]:
def ls(P, Q):
    task_type = "classification" if task_name != "STS-B" else "regression"
    if(task_type == "classification"):
        return F.kl_div(P.softmax(dim=-1).log(), Q.softmax(dim=-1), reduction='batchmean') + F.kl_div(Q.softmax(dim=-1).log(), P.softmax(dim=-1), reduction='batchmean')
    elif(task_type == "regression"):
        return MSELoss(P, Q, reduction="sum")

In [None]:
# Training
print("*"*20, "Training", "*"*20)  # 训练任务
print("TASK:", task_name)
print("MODEL:", model_name)
print("DEVICE:", device)
print("="*16, "General Training", "="*16) # 常规训练参数
print("EPOCH_NUM:", EPOCH)
print("BATCH_SIZE:", BATCH_SIZE)
print("LEARNING_RATE:", LEARNING_RATE)
print("="*14, "SMART Training", "="*14)  # SMART训练参数
print("Adversarial_Training_type:","SMART")
print("Adversarial_step:", Adv_step)
print("Adversarial_epsilon:", Adv_epsilon)
print("Adversarial_init_type", Adv_init_type)
print("Adversarial_max_norm:", Adv_max_norm)
print("lambda_s:", lambda_s)
print("mu:", mu)
print("beta:", beta)
print("*"*50)
model.to(device)
Teacher_parameters = model.state_dict() # 将预训练模型参数保存为Teacher_parameters
progress_bar = tqdm(range(EPOCH * len(train_dataloader)))
eval_metric_list = []
for i in range(EPOCH):
    print("-"*20, "EPOCH:", i, "-"*20)
    print("Training...", end='')
    model.train()
    for batch in train_dataloader:
        batch = {key: batch[key].to(device) for key in batch}
        
        # [begin] SMART Training
        # 1.compute L-loss 
        L_outputs=model(**batch)
        L_loss = L_outputs.loss

        # 2.compute R-loss like PGD
        ## 2.1 init delta & inputs
        if "bert-" in model_name:
            word_embedding = model.bert.embeddings.word_embeddings(batch["input_ids"])
        elif "roberta-" in model_name:
            word_embedding = model.roberta.embeddings.word_embeddings(batch["input_ids"])
        ### 初始化扰动delta
        delta = perturbation.init_delta(word_embedding.size(),epsilon=Adv_epsilon,init_type=Adv_init_type)
        delta.requires_grad = True
        ### 初始化inputs
        if "bert-" in model_name: # bert模型输入inputs: "attention_mask","labels","token_type_ids",「inputs_embeds」和「input_ids」参数二选一
            inputs = {"attention_mask" : batch["attention_mask"], "labels" : batch["labels"], "token_type_ids" : batch["token_type_ids"]}
        elif "roberta-" in model_name: # roberta模型输入inputs: "attention_mask","labels",「inputs_embeds」和「input_ids」参数二选一
            inputs = {"attention_mask" : batch["attention_mask"], "labels" : batch["labels"]}

        ## 2.2 updata delta
        for step in range(Adv_step):
            inputs["inputs_embeds"] = delta + word_embedding.detach() # |使用添加扰动的「inputs_embeds」 ↓
            outputs = model(**inputs)                                 # |.detach()截断word_embedding的生成过程 ↓
            loss = outputs.loss                                       # |且不可使用.detach_()，因为后续我们还想更新模型word_embeddings的参数
            loss.backward()
            delta = perturbation.update_delta(delta, Adv_epsilon, Adv_max_norm).detach()
            delta.requires_grad = True
        
        ## 2.3 compute loss use the last delta
        inputs["inputs_embeds"] = delta + word_embedding # 因为需要更新模型word_embeddings的参数，此处不加.detach()
        R_outputs = model(**inputs)
        R_loss = ls(R_outputs.logits , L_outputs.logits)

        # 3.compute D-loss
        back_parameters = model.state_dict() # 备份模型当前参数
        for name, p in model.named_parameters(): # 用Teacher_parameters替换当前参数
            if name != "classifier.weight" or name != "classifier.bias": # 分类器的参数不替换
                p.data = Teacher_parameters[name]
        D_outputs = model(**batch)
        D_loss = ls(D_outputs.logits, L_outputs.logits)
        for name, p in model.named_parameters(): # 还原模型参数，用于下一步的更新操作
            if name != "classifier.weight" or name != "classifier.bias":
                p.data = back_parameters[name]
        
        # 4.optimize model parameters
        loss = L_loss + lambda_s * R_loss + mu * D_loss # 累加损失
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        # 5.use momentum updata Teacher_parameters
        for name, p in model.named_parameters():
            if name != "classifier.weight" or name != "classifier.bias":
                Teacher_parameters[name] = beta * Teacher_parameters[name] + (1-beta) * p.data # θ_t = (1-β)*θ_s + β*θ_(t-1)
        
        progress_bar.update(1)
        # [end] SMART Training

    print("\rEvaling...", end='')
    model.eval()
    for batch in eval_dataloader:
        batch = {key: batch[key].to(device) for key in batch}
        with torch.no_grad():
            outputs = model(**batch)
        predictions = outputs.logits.argmax(dim=-1) if task_name != "STS-B" else outputs.logits.squeeze()
        metric.add_batch(predictions=predictions, references=batch["labels"])
    score = metric.compute()
    eval_metric_list.append(score)
    print("\rMetric:", score)
    print("-"*50)

# Best score in eval
score_list = []
for m in eval_metric_list:
    score_list.append(list(m.values())[0])
print("*"*19, "Best Score", "*"*19)
print("EPOCH:", score_list.index(max(score_list)))
print("Metric:", eval_metric_list[score_list.index(max(score_list))])
print("*"*50)