In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
class Args():
  def __init__(self):
    self.summary_method = "none"
    self.model_dir = "fusing_finetune_ckpt.pkl"
    self.log_dir = "fusing_finetune_log.pkl"
    self.use_pretrain = True
    self.log_step = 200
    self.toy = True
    self.toy_size = 80000
    self.batch_size = 16
    self.num_neg = 4
    self.max_len = 200
    self.lr = 1e-5
    self.steps = 50000
    self.clip = 1.0
    self.dist_func = "cosin"
    self.local_rank = 0
    self.gpu_ids = 0

args = Args()

In [3]:
import os
os.chdir("/home/newdisk/yangkai/contras_sum")
import torch
import tqdm
from finetune_data import read_all_sequences
import os
from transformers import RobertaTokenizer, RobertaConfig, RobertaForSequenceClassification
from torch.utils.data import TensorDataset, DataLoader

In [15]:
train_seqs, test_seqs= read_all_sequences(args)

 29%|██▊       | 65/227 [00:00<00:00, 606.29it/s]Read all sequences begin=====
100%|██████████| 227/227 [00:00<00:00, 751.49it/s]


In [16]:
import tqdm
train_sum_seqs = [(x, summarizer.sum_text(x),y) for x,y in tqdm.tqdm_notebook(train_seqs)]
test_sum_seqs = [(x, summarizer.sum_text(x),y) for x,y in tqdm.tqdm_notebook(test_seqs)]

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=80000.0), HTML(value='')))


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=9165.0), HTML(value='')))




In [21]:
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
def make_fusing_dataset(args,seqs):
    x_ids = []
    x_sum_ids = []
    seq_num = len(seqs)
    labels = torch.LongTensor(seq_num)
    for i, (x,x_sum,y) in enumerate(seqs):
        x_ids.append(x)
        x_sum_ids.append(x_sum)
        labels[i] = y

    x_ids = tokenizer(x_ids, padding = 'max_length', max_length = args.max_len, truncation = True, return_tensors="pt")["input_ids"]
    x_sum_ids = tokenizer(x_sum_ids, padding = 'max_length', max_length = 50, truncation = True, return_tensors="pt")["input_ids"]
    dataset = TensorDataset(x_ids, x_sum_ids, labels)

    return dataset

train_dataset = make_fusing_dataset(args,train_sum_seqs)
test_dataset = make_fusing_dataset(args,test_sum_seqs)


In [4]:
dataset = torch.load("dataset/processed_finetune_fusing_dataset.pkl")
train_dataset = dataset["train"]
test_dataset = dataset["test"]

In [5]:
train_loader = DataLoader(train_dataset,batch_size = args.batch_size, shuffle=True, drop_last = True, num_workers = 3)
test_loader = DataLoader(test_dataset, batch_size = args.batch_size, shuffle = False, drop_last = True, num_workers = 3)

In [6]:
from copy import deepcopy
config = RobertaConfig.from_pretrained("roberta-base")
config.num_labels = 2
model = RobertaForSequenceClassification.from_pretrained("roberta-base",config=config)
if args.use_pretrain:
    pretained_weight = torch.load("checkpoint.pkl", map_location='cpu')
    for key in pretained_weight:
        pretained_weight[key] = pretained_weight[key].cpu()
    model_weight = model.state_dict()
    for key in pretained_weight:
        if "pooler" in key:
            continue
        new_key = key.replace("module.base_model","roberta")
        model_weight[new_key] = deepcopy(pretained_weight[key])

    model.load_state_dict(model_weight)
sum_model = deepcopy(model)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.weight', 'classifie

In [10]:
device1 = torch.device("cuda:1")
device2 = torch.device("cuda:2")
model = model.to(device1)
sum_model = sum_model.to(device2)

In [11]:
from opt import OpenAIAdam
optimizer1 = OpenAIAdam(model.parameters(),
                                  lr=args.lr,
                                  schedule='warmup_linear',
                                  warmup=0.002,
                                  t_total=args.steps,
                                  b1=0.9,
                                  b2=0.999,
                                  e=1e-08,
                                  l2=0.01,
                                  vector_l2=True,
                                  max_grad_norm=args.clip)
optimizer2 = OpenAIAdam(sum_model.parameters(),
                                  lr=args.lr,
                                  schedule='warmup_linear',
                                  warmup=0.002,
                                  t_total=args.steps,
                                  b1=0.9,
                                  b2=0.999,
                                  e=1e-08,
                                  l2=0.01,
                                  vector_l2=True,
                                  max_grad_norm=args.clip)

In [None]:
critirion = torch.nn.CrossEntropyLoss()

step = 0
bar = tqdm.tqdm(total=args.steps)
bar.update(0)
loss_list = []
best_acc = 0
logs = []

while(step < args.steps):
    for batch in train_loader:
        optimizer1.zero_grad()
        optimizer2.zero_grad()
        x_ids, x_sum_ids, labels = batch[0].to(device1), batch[1].to(device2), batch[2].to(device1)
        logits1 = model(x_ids,labels = labels)[1]
        logits2 = sum_model(x_sum_ids, labels = labels.to(device2))[1]
        logits = logits1.to(device2) + logits2
        # print(loss.item())

        loss =critirion(logits,labels.to(device2))
        
        loss.backward()
        loss_list.append(loss.item())

        optimizer1.step()
        optimizer2.step()
        step += 1
        if (step % 10 == 0):
            bar.update(10)
        
        if (step % args.log_step == 0):
            
            print("step: ",step)
            print("loss: ",sum(loss_list)/step)
            log = {"step":step, "loss":sum(loss_list)/step}
            log = evaluate_model(model,test_loader,log)
            logs.append(log)
            torch.save(logs, args.log_dir)
            
            if (log["acc"] > best_acc):
                best_acc = log["acc"]
                torch.save(model.state_dict(),args.model_dir)
            model.train()

In [12]:
def evaluate_model(model, test_loader, log):
    print("Evaluation Start======")
    model.eval()
    TP, TN, FN, FP = 0, 0, 0, 0
    
    with torch.no_grad():
        for batch in test_loader:
            x_ids, x_sum_ids, labels = batch[0].to(device1), batch[1].to(device2), batch[2].to(device1)
            logits1 = model(x_ids,labels = labels)[1]
            logits2 = sum_model(x_sum_ids, labels = labels.to(device2))[1]
            logits = logits1.to(device2) + logits2
            # print(logits)

            prediction = torch.argmax(logits, dim = 1)
            TP += ((prediction == 1) & (labels.to(device2) == 1)).sum().item()
            # TN    predict 和 label 同时为0
            TN += ((prediction == 0) & (labels.to(device2) == 0)).sum().item()
            # FN    predict 0 label 1
            FN += ((prediction == 0) & (labels.to(device2) == 1)).sum().item()
            # FP    predict 1 label 0
            FP += ((prediction == 1) & (labels.to(device2) == 0)).sum().item()

    p = TP / (TP + FP)
    r = TP / (TP + FN)
    F1 = 2 * r * p / (r + p)
    acc = (TP + TN) / (TP + TN + FP + FN)
    print("recall: ",r)
    print("precision: ",p)
    print("F1: ",F1)
    print("Acc: ",acc)

    log["recall"] = r
    log["precision"] = p
    log["F1"] = F1
    log["acc"] = acc

    return log