In [None]:
import torch
import preprocess
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm

In [None]:
# 设置任务和模型
task_name = "CoLA" # ["CoLA", "SST-2", "MRPC", "STS-B", "QQP", "MNLI-m", "MNLI-mm", "QNLI", "RTE", "WNLI"]
model_name = "models/bert-base-uncased" # [bert-base-uncased, roberta-base]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# 设置超参数
# 训练参数
SEED = 42 # I don't understand why 42 , but everyone sets 42.
BATCH_SIZE = 32
EPOCH = 20
LEARNING_RATE = 1e-5
# 对抗训练参数
Adv_step = 3
Adv_epsilon = 1e-2
Adv_max_norm = 2e-2  # 0表示不限制扰动大小
lambda_s = 1
mu = 1
beta = 0.99

In [None]:
model, dataloader, metric= preprocess.preprocess(task_name, model_name, BATCH_SIZE, SEED)
train_dataloader,eval_dataloader,test_dataloader = dataloader

In [None]:
# 设置优化器
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
# lr_scheduler_name = ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]
lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=EPOCH * len(train_dataloader))

In [None]:
import torch.nn.functional as F
from torch.nn import MSELoss

def ls(P, Q, task_type):
    if(task_type == "classification"):
        return F.kl_div(P.softmax(dim=-1).log(), Q.softmax(dim=-1), reduction='sum') + F.kl_div(Q.softmax(dim=-1).log(), P.softmax(dim=-1), reduction='sum')
    elif(task_type == "regression"):
        return MSELoss(P, Q, reduction="sum")

In [None]:
# Training
print("*"*20, "Training", "*"*20)  # 训练任务
print("TASK:", task_name)
print("MODEL:", model_name)
print("DEVICE:", device)
print("-"*16, "General Training", "-"*16) # 常规参数
print("EPOCH_NUM:", EPOCH)
print("BATCH_SIZE:", BATCH_SIZE)
print("LEARNING_RATE:", LEARNING_RATE)
print("="*14, "Adversarial Training", "="*14)  # 对抗训练参数
print("Adversarial_Training_type:","SMART")
print("Adversarial_step:", Adv_step)
print("Adversarial_epsilon:", Adv_epsilon)
print("Adversarial_max_norm:", Adv_max_norm)
print("lambda_s:", lambda_s)
print("mu:", mu)
print("beta:", beta)
print("*"*50)
model.to(device)
old_parameters = model.parameters()
progress_bar = tqdm(range(EPOCH * len(train_dataloader)))
eval_metric_list = []
k=0
for i in range(EPOCH):
    print("-"*20, "EPOCH:", i, "-"*20)
    print("Training...", end='')
    model.train()
    for batch in train_dataloader:
        for t in batch:
            batch[t] = batch[t].to(device)
        
        k=k+1
        # [begin] Adversarial Training
        # 1.compute L-loss 
        L_outputs=model(**batch)
        L_loss = L_outputs.loss

        # 2.compute R-loss
        ## 2.1 init delta
        if "bert-base" in model_name:
            word_embedding_init = model.bert.embeddings.word_embeddings(batch["input_ids"])
        elif "roberta-base" in model_name:
            word_embedding_init = model.roberta.embeddings.word_embeddings(batch["input_ids"])
        delta = torch.zeros_like(word_embedding_init)
        adv_inputs = {"attention_mask" : batch["attention_mask"], "labels" : batch["labels"], "token_type_ids" : batch["token_type_ids"]}
        # 模型输入参数: "attention_mask","labels","token_type_ids","inputs_embeds",「inputs_embeds」和「input_ids」参数二选一

        ## 2.2 updata delta like PGD
        for step in range(Adv_step):
            delta.requires_grad = True
            word_embedding = word_embedding_init.clone().detach()
            adv_inputs["inputs_embeds"] = delta + word_embedding # 使用添加扰动的「inputs_embeds」
            outputs = model(**adv_inputs)
            loss = outputs.loss
            optimizer.zero_grad()
            loss.backward() 

            delta_grad = delta.grad.clone().detach() # delta的梯度
            delta_grad_norm = torch.norm(delta_grad.view(delta_grad.size(0), -1), dim=1).view(-1, 1, 1) # .view：摊平  .norm：求L2范数  .view：重新设置shape
            delta_grad_norm = torch.clamp(delta_grad_norm, min=1e-8)  # 设置最小值，避免下一步除数为0
            delta = (delta + Adv_epsilon * delta_grad / delta_grad_norm).detach()  # 新的扰动 δ = δ + ε * g/||g||

            if Adv_max_norm > 0:  # 限制扰动大小  Adv_max_norm = 0 则不限制
                delta_norm = torch.norm(delta.view(delta.size(0), -1).float(), p=2, dim=1).detach()
                exceed_mask = (delta_norm > Adv_max_norm).to(delta_norm)
                if(sum(exceed_mask) != 0):  # 存在超出限制的扰动大小
                    reweights = (Adv_max_norm / delta_norm * exceed_mask + (1-exceed_mask)).view(-1, 1, 1) # 缩减比例
                    delta = (delta * reweights).detach()  # 按比例缩减到norm-ball内
        ## 2.3 put the final delta into the model
        adv_inputs["inputs_embeds"] = delta + word_embedding_init.clone().detach()
        R_outputs = model(**adv_inputs)
        R_loss = ls(R_outputs.logits , L_outputs.logits , task_type = "classification" if task_name != "STS-B" else "regression")

        # 3.compute D-loss
        back_parameters = model.parameters()
        for p,q in zip(model.parameters(),old_parameters):
            p.data = q.data
        D_outputs = model(**batch)
        D_loss = ls(L_outputs.logits, D_outputs.logits, task_type = "classification" if task_name != "STS-B" else "regression")
        for p,q in zip(model.parameters(),back_parameters):
            p.data = q.data
        old_parameters = model.parameters()
        
        # 4.optimize model parameters
        loss = L_loss + lambda_s * R_loss + mu * D_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        for p,q in zip(model.parameters(), old_parameters):
            p.data =  (1-beta) * q.data + beta * p.data
        progress_bar.update(1)
        # [end] Adversarial Training

    print("\rEvaling...", end='')
    model.eval()
    for batch in eval_dataloader:
        for t in batch:
            batch[t] = batch[t].to(device)
        with torch.no_grad():
            outputs = model(**batch)
        predictions = outputs.logits.argmax(dim=-1) if task_name != "STS-B" else outputs.logits.squeeze()
        metric.add_batch(predictions=predictions, references=batch["labels"])
    score = metric.compute()
    eval_metric_list.append(score)
    print("\rMetric:", score)
    print("-"*50)

# Best score in eval
score_list = []
for m in eval_metric_list:
    score_list.append(list(m.values())[0])
print("*"*19, "Best Score", "*"*19)
print("EPOCH:", score_list.index(max(score_list)))
print("Metric:", eval_metric_list[score_list.index(max(score_list))])
print("*"*50)
