In [None]:
import torch
import preprocess
from torch.optim import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm

In [None]:
# 设置任务和模型
task_name = "STS-B" # ["CoLA", "SST-2", "MRPC", "STS-B", "QQP", "MNLI-m", "MNLI-mm", "QNLI", "RTE", "WNLI"]
model_name = "models/bert-base-uncased" # [bert-base-cased, bert-base-uncased, roberta-base]
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# 设置超参数
# 训练参数
SEED = 0
BATCH_SIZE = 8
EPOCH = 10
LEARNING_RATE = 1e-5
# 对抗训练参数
Adv_step = 5
Adv_epsilon = 5e-2
Adv_max_norm = 0  # 0表示不限制扰动大小

In [None]:
model, dataloader, metric= preprocess.preprocess(task_name, model_name, BATCH_SIZE, SEED)
train_dataloader,eval_dataloader,test_dataloader = dataloader

In [None]:
# 设置优化器


In [None]:
# Training
print("*"*20, "Training", "*"*20)  # 训练参数
print("TASK:", task_name)
print("MODEL:", model_name)
print("DEVICE:", device)
print("-"*50)
print("EPOCH_NUM:", EPOCH)
print("BATCH_SIZE:", BATCH_SIZE)
print("LEARNING_RATE:", LEARNING_RATE)
print("="*14, "Adversarial Training", "="*14)  # 对抗训练参数
print("Adversarial_Training_type:","SMART")
print("Adversarial_step:", Adv_step)
print("Adversarial_epsilon:", Adv_epsilon)
print("Adversarial_max_norm:", Adv_max_norm)
print("*"*50)
model.to(device)
progress_bar = tqdm(range(EPOCH * len(train_dataloader)))
for i in range(EPOCH):
    print("-"*20, "EPOCH:", i, "-"*20)
    print("Training...")
    model.train()
    for batch in train_dataloader:
        for t in batch:
            batch[t] = batch[t].to(device)

        # [begin] Adversarial Training
        

        # [end] Adversarial Training

    print("Eval...")
    model.eval()
    for batch in eval_dataloader:
        for t in batch:
            batch[t] = batch[t].to(device)
        with torch.no_grad():
            outputs = model(**batch)
        predictions = outputs.logits.argmax(dim=-1) if task_name != "STS-B" else outputs.logits.squeeze()
        metric.add_batch(predictions=predictions, references=batch["labels"])
    print("Metric:", metric.compute())
    print("-"*50)
