In [None]:
def read_log(path):
    ori, adv = [], []
    with open(path) as f:
        lines = f.readlines()
        for line in lines:
            if line.startswith('ORI: '):
                ori.append(line[5:-1])
            elif line.startswith("ADV: "):
                adv.append(line[5:-1])
    print(path)
    print(f"ORI{len(ori)}, ADV{len(adv)}")
    return ori, adv

def collect(paths):
    ori, adv = [], []
    for path in paths:
        _o, _a = read_log(path)
        ori += _o
        adv += _a
    return ori, adv

In [None]:
from utils import Metrics
use = Metrics.USE(0)

In [None]:
import datasets
import pandas
ori, adv = collect(['/home/phantivia/lab/PandRlib/advLog/BertOnBert-SST2-Layer10-Adv5-Dw0.5-USE-FULL-515-0.out',
                    '/home/phantivia/lab/PandRlib/advLog/BertOnBert-SST2-Layer10-Adv5-Dw0.5-USE-FULL-515-1.out'])

units = []
ori_datasets = datasets.Dataset.load_from_disk('/home/phantivia/datasets/sst2_train')
from tqdm import tqdm
i = 0
for od in tqdm(ori_datasets):
    if i == len(ori): break
    _ori_sentence = od['sentence'].lower()
    if _ori_sentence == ori[i]:
        u = use(ori[i], adv[i])
        ori_unit = {
            'label':od['label'],
            'sentence':od['sentence'],
            'use':u,
            'type':'ori',
        }
        adv_unit = {
            'label':od['label'],
            'sentence':adv[i],
            'use':u,
            'type':'adv',
        }
        units.append(ori_unit)
        units.append(adv_unit)
        i += 1

df = pandas.DataFrame(units)

In [None]:
seed = 114514
_bar = 0.9
bar = 0.7
ori_df = df[df['type'] == 'ori']
adv_df = df[df['type'] == 'adv']
adv_df = adv_df[adv_df['use'] > bar]
adv_df  =adv_df[adv_df['use'] < _bar]
udf = pandas.concat([ori_df, adv_df]).sample(frac=1, random_state=seed)
print(len(udf))

In [None]:
train_ds = datasets.Dataset.from_pandas(udf[['label', 'sentence']])
train_ds.save_to_disk('/home/phantivia/datasets/sst2-adv-bert')

In [None]:

from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset,load_metric, load_from_disk

task = "sst2"
num_labels = 2

train_dataset = train_ds
valid_dataset = load_from_disk('/home/phantivia/datasets/sst2-valid')
model_checkpoint = 'bert-base-uncased'


tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

task_valid_keys = {
    "sst2":'validation',
    'ag_news':'test',
    'mnli':'validation_matched',
}
task_to_keys = {  
            "sst2": ("sentence", None),
            "ag_news": ("text", None),
            "mnli": ("premise", "hypothesis"),
        }

sentence1_key, sentence2_key = task_to_keys[task]

def preprocess_function(examples):
    if sentence2_key is None:
        return tokenizer(examples[sentence1_key], truncation=True)
    return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)

encoded_train_dataset = train_dataset.map(preprocess_function, batched=True)
encoded_valid_dataset = valid_dataset.map(preprocess_function, batched=True)

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels = 2)

In [None]:
def compute_metrics(eval_pred):
    
    predictions, labels = eval_pred[0], eval_pred[-1]
    predictions = predictions.argmax(axis = 1)
    return {"accuracy": (predictions == labels).mean()}

batch_size = 128
args = TrainingArguments(
    task,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    weight_decay=0.05,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    seed = seed,
)



trainer = Trainer(
    model,
    args,
    train_dataset=encoded_train_dataset,
    eval_dataset=encoded_valid_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.evaluate()

In [None]:
trainer.train()

In [None]:
import torch
torch.save(model, '/home/phantivia/models/bert-base-uncased-adv-90.48-89.9.model')