In [1]:
import os
import logging
from tqdm import tqdm, trange

import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers import BertConfig, AdamW, get_linear_schedule_with_warmup
from seqeval.metrics import precision_score, recall_score, f1_score

%cd ../

from bert_finetune_cls.utils import MODEL_CLASSES, compute_metrics, get_intent_labels

logger = logging.getLogger(__name__)

C:\Users\威威的小荔枝\Desktop\第五课_代码


In [2]:
# 计算评价指标

def compute_metrics(intent_preds, intent_labels):
    """
        计算metrics
    """
    assert len(intent_preds) == len(intent_labels)
    results = {}
    intent_result = get_intent_acc(intent_preds, intent_labels)

    results.update(intent_result)

    return results

def get_intent_acc(preds, labels):
    acc = (preds == labels).mean()
    return {
        "intent_acc": acc
    }

In [3]:



class Trainer(object):
    def __init__(self, args, train_dataset=None, dev_dataset=None, test_dataset=None):
        self.args = args
        self.train_dataset = train_dataset
        self.dev_dataset = dev_dataset
        self.test_dataset = test_dataset
        
        # 加载模型，标签名称到编号的映射 (label maps):
        self.intent_label_lst = get_intent_labels(args)
        # Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
        
        # 加载模型的config，model本身
        self.config_class, self.model_class, _ = MODEL_CLASSES[args.model_type]
        self.config = self.config_class.from_pretrained(args.model_name_or_path, finetuning_task=args.task)
        self.model = self.model_class.from_pretrained(args.model_name_or_path,
                                                      config=self.config,
                                                      args=args,
                                                      intent_label_lst=self.intent_label_lst,)
        
        # 将模型放到GPU，如果有的话
        # GPU or CPU
        self.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
        self.model.to(self.device)

    def train(self):
        
        # 加载训练数据
        train_sampler = RandomSampler(self.train_dataset)
        train_dataloader = DataLoader(self.train_dataset, sampler=train_sampler, batch_size=self.args.train_batch_size)
        
        # 计算训练的总的更新步数，用于learning rate的schedule (不是迭代步数)
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            self.args.num_train_epochs = self.args.max_steps // (len(train_dataloader) // self.args.gradient_accumulation_steps) + 1
        else:
            t_total = len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs
        
        # 打印一下参数，看看都有哪些
        for n, p in self.model.named_parameters():
            print(n)

        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.weight']  # bias和归一化操作中的参数是做weight decay; 
        optimizer_grouped_parameters = [
            {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': self.args.weight_decay},
            {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.args.learning_rate, eps=self.args.adam_epsilon)
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.args.warmup_steps, num_training_steps=t_total)

        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(self.train_dataset))
        logger.info("  Num Epochs = %d", self.args.num_train_epochs)
        logger.info("  Total train batch size = %d", self.args.train_batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)
        logger.info("  Logging steps = %d", self.args.logging_steps)   # 计算dev performance；
        logger.info("  Save steps = %d", self.args.save_steps)         # 保存model checkpoint；

        global_step = 0
        tr_loss = 0.0
        self.model.zero_grad()  # 清空梯度；

        train_iterator = trange(int(self.args.num_train_epochs), desc="Epoch")

        for _ in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            for step, batch in enumerate(epoch_iterator):
                self.model.train()
                batch = tuple(t.to(self.device) for t in batch)  # 将数据传到设备上面：GPU or CPU

                inputs = {'input_ids': batch[0],
                          'attention_mask': batch[1],
                          'intent_label_ids': batch[3],
                         }
                if self.args.model_type != 'distilbert':
                    inputs['token_type_ids'] = batch[2]
                outputs = self.model(**inputs)
                loss = outputs[0]

                if self.args.gradient_accumulation_steps > 1:
                    loss = loss / self.args.gradient_accumulation_steps

                loss.backward()   # 求梯度

                tr_loss += loss.item()
                if (step + 1) % self.args.gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.max_grad_norm)  # 

                    optimizer.step()
                    scheduler.step()  # Update learning rate schedule
                    self.model.zero_grad()  # 清空梯度；
                    global_step += 1

                    if self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0:
                        self.evaluate("dev")

                    if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
                        self.save_model()

                if 0 < self.args.max_steps < global_step:
                    epoch_iterator.close()
                    break

            if 0 < self.args.max_steps < global_step:
                train_iterator.close()
                break

        return global_step, tr_loss / global_step

    def evaluate(self, mode):
        if mode == 'test':
            dataset = self.test_dataset
        elif mode == 'dev':
            dataset = self.dev_dataset
        else:
            raise Exception("Only dev and test dataset available")

        eval_sampler = SequentialSampler(dataset)
        eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=self.args.eval_batch_size)

        # Eval!
        logger.info("***** Running evaluation on %s dataset *****", mode)
        logger.info("  Num examples = %d", len(dataset))
        logger.info("  Batch size = %d", self.args.eval_batch_size)
        eval_loss = 0.0
        nb_eval_steps = 0
        intent_preds = None
        out_intent_label_ids = None

        self.model.eval()

        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            batch = tuple(t.to(self.device) for t in batch)
            with torch.no_grad():
                inputs = {'input_ids': batch[0],
                          'attention_mask': batch[1],
                          'intent_label_ids': batch[3],
                         }
                if self.args.model_type != 'distilbert':
                    inputs['token_type_ids'] = batch[2]
                outputs = self.model(**inputs)
                tmp_eval_loss, intent_logits = outputs[:2]

                eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1

            # Intent prediction
            if intent_preds is None:
                intent_preds = intent_logits.detach().cpu().numpy()
                out_intent_label_ids = inputs['intent_label_ids'].detach().cpu().numpy()
            else:
                intent_preds = np.append(intent_preds, intent_logits.detach().cpu().numpy(), axis=0)
                out_intent_label_ids = np.append(
                    out_intent_label_ids, inputs['intent_label_ids'].detach().cpu().numpy(), axis=0)

            
        eval_loss = eval_loss / nb_eval_steps
        results = {
            "loss": eval_loss
        }

        # Intent result
        intent_preds = np.argmax(intent_preds, axis=1)

        total_result = compute_metrics(intent_preds, out_intent_label_ids)
        results.update(total_result)

        logger.info("***** Eval results *****")
        for key in sorted(results.keys()):
            logger.info("  %s = %s", key, str(results[key]))

        return results

    def save_model(self):
        # Save model checkpoint (Overwrite)
        if not os.path.exists(self.args.model_dir):
            os.makedirs(self.args.model_dir)
        model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
        model_to_save.save_pretrained(self.args.model_dir)

        # Save training arguments together with the trained model
        torch.save(self.args, os.path.join(self.args.model_dir, 'training_args.bin'))
        logger.info("Saving model checkpoint to %s", self.args.model_dir)

    def load_model(self):
        # Check whether model exists
        if not os.path.exists(self.args.model_dir):
            raise Exception("Model doesn't exists! Train first!")

        try:
            self.model = self.model_class.from_pretrained(self.args.model_dir,
                                                          args=self.args,
                                                          intent_label_lst=self.intent_label_lst)
            self.model.to(self.device)
            logger.info("***** Model Loaded *****")
        except:
            raise Exception("Some model files might be missing...")

### 举例查看

In [4]:
from transformers import BertConfig, DistilBertConfig, AlbertConfig
from transformers import BertTokenizer, DistilBertTokenizer, AlbertTokenizer

from bert_finetune_cls.model import ClsBERT
from bert_finetune_cls.utils import init_logger, load_tokenizer, get_intent_labels, set_seed
from bert_finetune_cls.data_loader import load_and_cache_examples

MODEL_CLASSES = {
    'bert': (BertConfig, ClsBERT, BertTokenizer),
}

MODEL_PATH_MAP = {
    'bert': './bert_finetune_cls/resources/uncased_L-2_H-128_A-2',
}

# 先构建参数
class Args():
    task =  None
    data_dir =  None
    intent_label_file =  None


args = Args()
args.seed = 1991
args.no_cuda = True
args.task = "atis"
args.data_dir = "./bert_finetune_cls/data"
args.intent_label_file = "intent_label.txt"
args.max_seq_len = 50
args.model_type = "bert"
args.model_dir = "bert_finetune_cls/experiments/outputs/clsbert_0"
args.model_name_or_path = MODEL_PATH_MAP[args.model_type]

args.train_batch_size = 8
args.eval_batch_size = 16
args.dropout_rate = 0.1

args.max_steps = 1000
args.num_train_epochs = 1
args.gradient_accumulation_steps = 1
args.weight_decay = 1e-5
args.learning_rate = 1e-5
args.adam_epsilon = 1e-8
args.max_grad_norm = 1.0
args.warmup_steps = 100

args.logging_steps = 100
args.save_steps = 200

# 设置随机种子
set_seed(args)

# 加载tokenizer
tokenizer = load_tokenizer(args)

# 加载数据集
train_dataset = load_and_cache_examples(args, tokenizer, mode="train")
dev_dataset = load_and_cache_examples(args, tokenizer, mode="dev")
test_dataset = load_and_cache_examples(args, tokenizer, mode="test")

# 加载trainer
trainer = Trainer(args, train_dataset, dev_dataset, test_dataset)


Some weights of the model checkpoint at ./bert_finetune_cls/resources/uncased_L-2_H-128_A-2 were not used when initializing ClsBERT: ['cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing ClsBERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ClsBERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ClsBERT were not initialized from the model checkpoint at ./bert_finetune_cls/r

In [5]:
# 训练
trainer.train()

bert.embeddings.word_embeddings.weight
bert.embeddings.position_embeddings.weight
bert.embeddings.token_type_embeddings.weight
bert.embeddings.LayerNorm.weight
bert.embeddings.LayerNorm.bias
bert.encoder.layer.0.attention.self.query.weight
bert.encoder.layer.0.attention.self.query.bias
bert.encoder.layer.0.attention.self.key.weight
bert.encoder.layer.0.attention.self.key.bias
bert.encoder.layer.0.attention.self.value.weight
bert.encoder.layer.0.attention.self.value.bias
bert.encoder.layer.0.attention.output.dense.weight
bert.encoder.layer.0.attention.output.dense.bias
bert.encoder.layer.0.attention.output.LayerNorm.weight
bert.encoder.layer.0.attention.output.LayerNorm.bias
bert.encoder.layer.0.intermediate.dense.weight
bert.encoder.layer.0.intermediate.dense.bias
bert.encoder.layer.0.output.dense.weight
bert.encoder.layer.0.output.dense.bias
bert.encoder.layer.0.output.LayerNorm.weight
bert.encoder.layer.0.output.LayerNorm.bias
bert.encoder.layer.1.attention.self.query.weight
bert.enc

Epoch:   0%|                                                                                                                                 | 0/2 [00:00<?, ?it/s]
Iteration:   0%|                                                                                                                           | 0/560 [00:00<?, ?it/s]
Iteration:   0%|▏                                                                                                                  | 1/560 [00:00<01:49,  5.13it/s]
Iteration:   1%|▌                                                                                                                  | 3/560 [00:00<01:31,  6.09it/s]
Iteration:   1%|█                                                                                                                  | 5/560 [00:00<01:19,  7.02it/s]
Iteration:   1%|█▍                                                                                                                 | 7/560 [00:00<01:09,  7.95it/s]
Iteration:   2%|

Iteration:  17%|██████████████████▉                                                                                               | 93/560 [00:08<00:44, 10.53it/s]
Iteration:  17%|███████████████████▎                                                                                              | 95/560 [00:08<00:43, 10.67it/s]
Iteration:  17%|███████████████████▋                                                                                              | 97/560 [00:08<00:43, 10.54it/s]
Iteration:  18%|████████████████████▏                                                                                             | 99/560 [00:09<00:43, 10.55it/s]

Evaluating:   0%|                                                                                                                           | 0/32 [00:00<?, ?it/s]

Evaluating:  19%|█████████████████████▌                                                                                             | 6/32 [00:00<00:00, 54.05it/s]

Evaluating:  

Iteration:  31%|███████████████████████████████████▌                                                                             | 176/560 [00:16<00:34, 11.11it/s]
Iteration:  32%|███████████████████████████████████▉                                                                             | 178/560 [00:17<00:34, 11.18it/s]
Iteration:  32%|████████████████████████████████████▎                                                                            | 180/560 [00:17<00:33, 11.35it/s]
Iteration:  32%|████████████████████████████████████▋                                                                            | 182/560 [00:17<00:33, 11.24it/s]
Iteration:  33%|█████████████████████████████████████▏                                                                           | 184/560 [00:17<00:32, 11.45it/s]
Iteration:  33%|█████████████████████████████████████▌                                                                           | 186/560 [00:17<00:32, 11.54it/s]
Iteration:  34%|

Iteration:  47%|████████████████████████████████████████████████████▊                                                            | 262/560 [00:25<00:27, 10.85it/s]
Iteration:  47%|█████████████████████████████████████████████████████▎                                                           | 264/560 [00:25<00:27, 10.84it/s]
Iteration:  48%|█████████████████████████████████████████████████████▋                                                           | 266/560 [00:25<00:27, 10.79it/s]
Iteration:  48%|██████████████████████████████████████████████████████                                                           | 268/560 [00:25<00:26, 10.92it/s]
Iteration:  48%|██████████████████████████████████████████████████████▍                                                          | 270/560 [00:26<00:26, 10.94it/s]
Iteration:  49%|██████████████████████████████████████████████████████▉                                                          | 272/560 [00:26<00:26, 10.85it/s]
Iteration:  49%|

Iteration:  60%|███████████████████████████████████████████████████████████████████▊                                             | 336/560 [00:33<00:22, 10.13it/s]
Iteration:  60%|████████████████████████████████████████████████████████████████████▏                                            | 338/560 [00:33<00:22, 10.01it/s]
Iteration:  61%|████████████████████████████████████████████████████████████████████▌                                            | 340/560 [00:33<00:21, 10.00it/s]
Iteration:  61%|█████████████████████████████████████████████████████████████████████                                            | 342/560 [00:33<00:21, 10.17it/s]
Iteration:  61%|█████████████████████████████████████████████████████████████████████▍                                           | 344/560 [00:33<00:21, 10.15it/s]
Iteration:  62%|█████████████████████████████████████████████████████████████████████▊                                           | 346/560 [00:34<00:21, 10.12it/s]
Iteration:  62%|

Iteration:  73%|██████████████████████████████████████████████████████████████████████████████████▉                              | 411/560 [00:41<00:17,  8.71it/s]
Iteration:  74%|███████████████████████████████████████████████████████████████████████████████████▎                             | 413/560 [00:41<00:16,  9.08it/s]
Iteration:  74%|███████████████████████████████████████████████████████████████████████████████████▌                             | 414/560 [00:41<00:15,  9.29it/s]
Iteration:  74%|███████████████████████████████████████████████████████████████████████████████████▋                             | 415/560 [00:41<00:15,  9.44it/s]
Iteration:  74%|███████████████████████████████████████████████████████████████████████████████████▉                             | 416/560 [00:41<00:15,  9.25it/s]
Iteration:  74%|████████████████████████████████████████████████████████████████████████████████████▏                            | 417/560 [00:41<00:15,  9.03it/s]
Iteration:  75%|

Iteration:  89%|████████████████████████████████████████████████████████████████████████████████████████████████████             | 496/560 [00:49<00:06, 10.41it/s]
Iteration:  89%|████████████████████████████████████████████████████████████████████████████████████████████████████▍            | 498/560 [00:49<00:06, 10.16it/s]

Evaluating:   0%|                                                                                                                           | 0/32 [00:00<?, ?it/s]

Evaluating:  16%|█████████████████▉                                                                                                 | 5/32 [00:00<00:00, 48.08it/s]

Evaluating:  31%|███████████████████████████████████▋                                                                              | 10/32 [00:00<00:00, 47.66it/s]

Evaluating:  50%|█████████████████████████████████████████████████████████                                                         | 16/32 [00:00<00:00, 48.69it/s]

Evaluating:

Iteration: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌| 558/560 [00:56<00:00,  9.40it/s]
Iteration: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊| 559/560 [00:56<00:00,  9.47it/s]
Epoch:  50%|████████████████████████████████████████████████████████████▌                                                            | 1/2 [00:56<00:56, 56.44s/it]
Iteration:   0%|                                                                                                                           | 0/560 [00:00<?, ?it/s]
Iteration:   0%|▏                                                                                                                  | 1/560 [00:00<00:57,  9.71it/s]
Iteration:   0%|▍                                                                                                                  | 2/560 [00:00<00:57,  9.74it/s]
Iteration:   1%|

Iteration:  12%|█████████████▍                                                                                                    | 66/560 [00:07<00:47, 10.39it/s]
Iteration:  12%|█████████████▊                                                                                                    | 68/560 [00:07<00:46, 10.53it/s]
Iteration:  12%|██████████████▎                                                                                                   | 70/560 [00:07<00:46, 10.63it/s]
Iteration:  13%|██████████████▋                                                                                                   | 72/560 [00:07<00:45, 10.65it/s]
Iteration:  13%|███████████████                                                                                                   | 74/560 [00:07<00:44, 10.85it/s]
Iteration:  14%|███████████████▍                                                                                                  | 76/560 [00:07<00:44, 10.88it/s]
Iteration:  14%|

Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 43.24it/s]
Iteration:  25%|████████████████████████████▎                                                                                    | 140/560 [00:15<02:18,  3.03it/s]
Iteration:  25%|████████████████████████████▍                                                                                    | 141/560 [00:15<01:51,  3.76it/s]
Iteration:  25%|████████████████████████████▋                                                                                    | 142/560 [00:15<01:33,  4.49it/s]
Iteration:  26%|████████████████████████████▊                                                                                    | 143/560 [00:15<01:23,  5.00it/s]
Iteration:  26%|█████████████████████████████                                                                                    | 144/560 [00:15<01:15,  5.52it/s]
Iteration:  26%|

Iteration:  40%|█████████████████████████████████████████████▍                                                                   | 225/560 [00:23<00:30, 11.05it/s]
Iteration:  41%|█████████████████████████████████████████████▊                                                                   | 227/560 [00:23<00:30, 10.80it/s]
Iteration:  41%|██████████████████████████████████████████████▏                                                                  | 229/560 [00:23<00:30, 10.89it/s]
Iteration:  41%|██████████████████████████████████████████████▌                                                                  | 231/560 [00:23<00:31, 10.49it/s]
Iteration:  42%|███████████████████████████████████████████████                                                                  | 233/560 [00:23<00:30, 10.67it/s]
Iteration:  42%|███████████████████████████████████████████████▍                                                                 | 235/560 [00:24<00:29, 10.85it/s]
Iteration:  42%|

Iteration:  52%|██████████████████████████████████████████████████████████▋                                                      | 291/560 [00:30<00:34,  7.83it/s]
Iteration:  52%|██████████████████████████████████████████████████████████▉                                                      | 292/560 [00:30<00:32,  8.25it/s]
Iteration:  52%|███████████████████████████████████████████████████████████                                                      | 293/560 [00:30<00:30,  8.64it/s]
Iteration:  53%|███████████████████████████████████████████████████████████▌                                                     | 295/560 [00:31<00:28,  9.22it/s]
Iteration:  53%|███████████████████████████████████████████████████████████▉                                                     | 297/560 [00:31<00:27,  9.40it/s]
Iteration:  53%|████████████████████████████████████████████████████████████▏                                                    | 298/560 [00:31<00:28,  9.26it/s]
Iteration:  54%|

Iteration:  65%|█████████████████████████████████████████████████████████████████████████▊                                       | 366/560 [00:38<00:23,  8.21it/s]
Iteration:  66%|██████████████████████████████████████████████████████████████████████████                                       | 367/560 [00:38<00:22,  8.57it/s]
Iteration:  66%|██████████████████████████████████████████████████████████████████████████▎                                      | 368/560 [00:38<00:22,  8.37it/s]
Iteration:  66%|██████████████████████████████████████████████████████████████████████████▋                                      | 370/560 [00:39<00:21,  8.96it/s]
Iteration:  66%|██████████████████████████████████████████████████████████████████████████▊                                      | 371/560 [00:39<00:20,  9.12it/s]
Iteration:  66%|███████████████████████████████████████████████████████████████████████████                                      | 372/560 [00:39<00:21,  8.82it/s]
Iteration:  67%|

Iteration:  78%|████████████████████████████████████████████████████████████████████████████████████████▏                        | 437/560 [00:45<00:14,  8.63it/s]
Iteration:  78%|████████████████████████████████████████████████████████████████████████████████████████▍                        | 438/560 [00:46<00:14,  8.18it/s]
Iteration:  78%|████████████████████████████████████████████████████████████████████████████████████████▌                        | 439/560 [00:46<00:14,  8.07it/s]

Evaluating:   0%|                                                                                                                           | 0/32 [00:00<?, ?it/s]

Evaluating:  16%|█████████████████▉                                                                                                 | 5/32 [00:00<00:00, 46.74it/s]

Evaluating:  28%|████████████████████████████████▎                                                                                  | 9/32 [00:00<00:00, 43.76it/s]

Evaluating: 

(1001, 2.1433121363718906)

In [7]:
# 评估

trainer.load_model()
trainer.evaluate("dev")


Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 64.26it/s]


{'loss': 1.774277739226818, 'intent_acc': 0.714}

In [8]:
trainer.evaluate("test")

Evaluating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 56/56 [00:00<00:00, 58.21it/s]


{'loss': 1.807569729430335, 'intent_acc': 0.7077267637178052}