## 1. 准备工作

### 1.1 环境准备


In [None]:
#关闭安装的输出
%%capture 
!pip install transformers
!pip install datasets
!pip install evaluate
!pip install torchinfo


### 1.2 函数库与gpu使用

In [None]:
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
from datasets import load_dataset
from datasets import load_from_disk
from transformers import BertTokenizer, BertModel, AutoTokenizer
from transformers import TrainingArguments, Trainer
from datasets import load_metric
from transformers.trainer_utils import EvalPrediction
import evaluate

# 使用gpu进行训练
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
# gpu型号
!nvidia-smi


Using cuda device
Tue Mar 14 08:46:50 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   72C    P0    32W /  70W |   6147MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------

### 1.3 goole drive连接

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## 2. 数据预处理

### 2.1 数据集导入

In [None]:

class Dataset(torch.utils.data.Dataset):
  def __init__(self, split, filepath):
    self.dataset = load_dataset(path=filepath,split=split)

  def __len__(self):
    return len(self.dataset)
  
  def __getitem__(self, i):
    text = self.dataset[i]['text']
    label = self.dataset[i]['label']

    return text, label


train_data = Dataset(filepath = 'seamew/ChnSentiCorp',split = 'train')
val_data = Dataset(filepath = 'seamew/ChnSentiCorp',split = 'validation')



### 2.2 分词和编码

In [None]:
mengzi_token = BertTokenizer.from_pretrained("Langboat/mengzi-bert-base-fin")

bert_bc_token = BertTokenizer.from_pretrained(
    pretrained_model_name_or_path='bert-base-chinese',
    cache_dir=None,
    force_download=False
)

#获取字典
#token_dict = tokenizer.get_vocab()
#type(token_dict), len(token_dict), '月光' in token_dict
#添加新词
#tokenizer.add_token(new_tokens=['月光','希望'])
#添加新符号
#tokenizer.add_special_tokens({'eos_token':'[EOS]'})

### 2.3 定义批处理函数

In [None]:
def bc_collate_fn(data):
  sents = [i[0] for i in data]
  labels = [i[1] for i in data]
  #编码
  data = bert_bc_token.batch_encode_plus(batch_text_or_text_pairs=sents,
                                 truncation=True,
                                 padding='max_length',
                                 max_length=512,
                                 return_tensors='pt',
                                 return_length=True)
  #input_ids:编码之后的数字
  #attention_mask:补零的位置三0，其他位置是1
  input_ids = data['input_ids'].to(device)
  attention_mask = data['attention_mask'].to(device)
  token_type_ids = data['token_type_ids'].to(device)
  labels = torch.Tensor(labels).to(device)

  return input_ids, attention_mask, token_type_ids, labels


def mengzi_collate_fn(data):
  sents = [i[0] for i in data]
  labels = [i[1] for i in data]
  #编码
  data = mengzi_token.batch_encode_plus(batch_text_or_text_pairs=sents,
                                 truncation=True,
                                 padding='max_length',
                                 max_length=512,
                                 return_tensors='pt',
                                 return_length=True)
  
  #input_ids:编码之后的数字
  #attention_mask:补零的位置三0，其他位置是1
  input_ids = data['input_ids'].to(device)
  attention_mask = data['attention_mask'].to(device)
  token_type_ids = data['token_type_ids'].to(device)
  labels = torch.Tensor(labels).to(device)

  return input_ids, attention_mask, token_type_ids, labels

#导入数据，这一步合并到trainner中了
bc_loader = DataLoader(dataset=train_data,
                    batch_size=64,
                    collate_fn=bc_collate_fn,
                    shuffle=True,
                    drop_last=True)

mengzi_loader = DataLoader(dataset=train_data,
                    batch_size=64,
                    collate_fn=mengzi_collate_fn,
                    shuffle=True,
                    drop_last=True)

test_loader = DataLoader(dataset=val_data,
                           batch_size=64,
                           collate_fn=mengzi_collate_fn,
                           shuffle=True,
                           drop_last=True)

# for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(mengzi_loader):
#   break
# print(len(mengzi_loader))
# input_ids.shape, attention_mask.shape, token_type_ids.shape, labels.shape

## 3. 模型


### 3.1 加载预训练模型

In [None]:
#加载预训练模型
bert_bc_pretrained = BertModel.from_pretrained("bert-base-chinese").to(device)
mengzi_pretrained = BertModel.from_pretrained("Langboat/mengzi-bert-base-fin").to(device)

#不使用finetuning，直接冻结预训练模型的参数
for param in bert_bc_pretrained.parameters():
  param.requires_grad_(False)

#模型试算
#out = bert_bc_pretrained(input_ids=input_ids,
#                 attention_mask=attention_mask,
#                 token_type_ids=token_type_ids)

#out.last_hidden_state.shape

print('param_num: ' + str(sum([i.nelement() for i in mengzi_pretrained.parameters()]) / 10000))

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at Langboat/mengzi-bert-base-fin were not used when initializing BertModel: ['cls.predictions.bias', 'sop.cls.weight',

param_num: 10226.7648


### 3.2 定义下游任务模型

In [None]:
class Model(nn.Module):
  def __init__(self):
    super().__init__()
    self.fc = nn.Linear(768,2)
  
  def forward(self, input_ids, attention_mask, token_type_ids):
    with torch.no_grad():
      out = bert_bc_pretrained(input_ids=input_ids,
                 attention_mask=attention_mask,
                 token_type_ids=token_type_ids)
      
    out = self.fc(out.last_hidden_state[:, 0])

    out = out.softmax(dim=1)

    return out

model = Model().to(device)

## 4. 训练下游任务模型

In [None]:
# #定义优化器、损失函数、评价指标
# optimizer = AdamW(model.parameters(), lr=5e-4)
# loss_fn = nn.CrossEntropyLoss()
# metric = load_metric('accuracy')

# #初始化训练参数
# args = TrainingArguments(output_dir='./output_dir',
#                          overwrite_output_dir = False,
#                          evaluation_strategy='epoch',
#                          num_train_epochs = 10,
#                          learning_rate = 1e-4, #优化器默认为AdamW
#                          adam_beta1 = 0.9,
#                          adam_beta2 = 0.999,
#                          adam_epsilon = 1e-8,
#                          weight_decay = 1e-2, #各层的权重衰减
#                          max_grad_norm = 1.0, #梯度裁剪
#                          per_device_eval_batch_size = 64,
#                          per_device_train_batch_size = 64,
#                          lr_scheduler_type = 'linear',
#                          save_strategy = 'epoch',
#                          no_cuda = False,
#                          seed = 1024,
#                          data_seed = 1024,
#                          load_best_model_at_end = False,
#                          metric_for_best_model = 'loss',
#                          greater_is_better = False
#                          )

# #初始化训练器                         
# trainer = Trainer(
#     model = model,
#     args = args,
#     data_collator = mengzi_collate_fn, #构建batch
#     train_dataset = train_data,
#     eval_dataset = val_data,
#     compute_metrics = metric,
#     tokenizer = mengzi_token
#     #callbacks = 
#     #optimizers = 
# )

# trainer.train()


PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
***** Running training *****
  Num examples = 9600
  Num Epochs = 10
  Instantaneous batch size per device = 64
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 1500
  Number of trainable parameters = 1538


RuntimeError: ignored

In [None]:
trainer.evaluate()

#trainer.save_model(output_dir='./output_dir')

## 5. 测试模型效果

In [None]:
# def test(loader_test, model, loss_fn):
#   size = len(loader_test.dataset)
#   num_batches = len(loader_test)
#   model.eval()
#   correct = 0
#   total = 0
#   for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
#     with torch.no_grad():
#       out = model(input_ids=input_ids,
#                         attention_mask=attention_mask,
#                         token_type_ids=token_type_ids)
    
#     out = out.argmax(dim=1)
#     correct += (out == labels).sum().item()
#     total += len(labels)
#   print("Test Accuracy:{accuracy:.3f}".format(accuracy=correct / total))



# test()