In [27]:
import pandas as pd

def load_lcqmc():
    '''LCQMC文本匹配数据集
    '''
    train = pd.read_csv('https://mirror.coggle.club/dataset/LCQMC.train.data.zip', 
            sep='\t', names=['query1', 'query2', 'label'])

    valid = pd.read_csv('https://mirror.coggle.club/dataset/LCQMC.valid.data.zip', 
            sep='\t', names=['query1', 'query2', 'label'])

    test = pd.read_csv('https://mirror.coggle.club/dataset/LCQMC.test.data.zip', 
            sep='\t', names=['query1', 'query2', 'label'])

    return train, valid, test

In [44]:
from typing import Dict, List
import numpy as np
import torch
import torch.nn as nn
from loguru import logger
import torch.nn.functional as F
from scipy.stats import spearmanr
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import BertConfig, BertModel, BertTokenizer

In [29]:
from tqdm.notebook import tqdm
tqdm.pandas()

In [30]:
train,valid,test = load_lcqmc()

In [31]:
train.iloc[1]

query1    我手机丢了，我想换个手机
query2     我想买个新手机，求推荐
label                1
Name: 1, dtype: object

In [32]:
def load_data(df):
    return [(df.iloc[i]["query1"],df.iloc[i]["query2"],df.iloc[i]["label"]) for i in range(0,len(df))]

In [33]:
train_data = load_data(train)
valid_data = load_data(valid)
test_data = load_data(test)

In [66]:
import random
SAMPLES = 10000
BATCH_SIZE = 32
EPOCHS = 1
LR = 1e-5
DROPOUT = 0.3
MAXLEN = 64
POOLING = 'cls'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_data = random.sample(train_data, SAMPLES) 

In [35]:
class TrainDataset(Dataset):
    """训练数据集, 重写__getitem__和__len__方法"""
    def __init__(self, data: List):
        self.data = data
      
    def __len__(self):
        return len(self.data)
    
    def text_2_id(self, text: str):
        # 添加自身两次, 经过bert编码之后, 互为正样本
        return tokenizer([text, text], max_length=MAXLEN, truncation=True, padding='max_length', return_tensors='pt')
    
    def __getitem__(self, index: int):
        return self.text_2_id(self.data[index][0])
    
class TestDataset(Dataset):
    """测试数据集, 重写__getitem__和__len__方法"""
    def __init__(self, data: List):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def text_2_id(self, text: str):
        return tokenizer(text, max_length=MAXLEN, truncation=True, padding='max_length', return_tensors='pt')
    
    def __getitem__(self, index: int):        
        return self.text_2_id([self.data[index][0]]), self.text_2_id([self.data[index][1]]), int(self.data[index][2])

In [36]:
train_dataloader = DataLoader(TrainDataset(train_data), batch_size=BATCH_SIZE)
valid_dataloader = DataLoader(TestDataset(valid_data), batch_size=BATCH_SIZE)
test_dataloader = DataLoader(TestDataset(test_data), batch_size=BATCH_SIZE)

In [37]:
class SimcseModel(nn.Module):
    """Simcse无监督模型定义"""
    def __init__(self, pretrained_model, pooling):
        super(SimcseModel, self).__init__()
        config = BertConfig.from_pretrained(pretrained_model)       
        config.attention_probs_dropout_prob = DROPOUT   # 修改config的dropout系数
        config.hidden_dropout_prob = DROPOUT           
        self.bert = BertModel.from_pretrained(pretrained_model, config=config)
        self.pooling = pooling
        
    def forward(self, input_ids, attention_mask, token_type_ids):

        out = self.bert(input_ids, attention_mask, token_type_ids, output_hidden_states=True)

        if self.pooling == 'cls':
            return out.last_hidden_state[:, 0]  # [batch, 768]
        
        if self.pooling == 'pooler':
            return out.pooler_output            # [batch, 768]
        
        if self.pooling == 'last-avg':
            last = out.last_hidden_state.transpose(1, 2)    # [batch, 768, seqlen]
            return torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1)       # [batch, 768]
        
        if self.pooling == 'first-last-avg':
            first = out.hidden_states[1].transpose(1, 2)    # [batch, 768, seqlen]
            last = out.hidden_states[-1].transpose(1, 2)    # [batch, 768, seqlen]                   
            first_avg = torch.avg_pool1d(first, kernel_size=last.shape[-1]).squeeze(-1) # [batch, 768]
            last_avg = torch.avg_pool1d(last, kernel_size=last.shape[-1]).squeeze(-1)   # [batch, 768]
            avg = torch.cat((first_avg.unsqueeze(1), last_avg.unsqueeze(1)), dim=1)     # [batch, 2, 768]
            return torch.avg_pool1d(avg.transpose(1, 2), kernel_size=2).squeeze(-1)     # [batch, 768]

In [38]:
def simcse_unsup_loss(y_pred: 'tensor') -> 'tensor':
    """无监督的损失函数
    y_pred (tensor): bert的输出, [batch_size * 2, 768]
    
    """
    # 得到y_pred对应的label, [1, 0, 3, 2, ..., batch_size-1, batch_size-2]
    y_true = torch.arange(y_pred.shape[0], device=DEVICE)
    y_true = (y_true - y_true % 2 * 2) + 1
    # batch内两两计算相似度, 得到相似度矩阵(对角矩阵)
    sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
    # 将相似度矩阵对角线置为很小的值, 消除自身的影响
    sim = sim - torch.eye(y_pred.shape[0], device=DEVICE) * 1e12
    # 相似度矩阵除以温度系数
    sim = sim / 0.05
    # 计算相似度矩阵与y_true的交叉熵损失
    loss = F.cross_entropy(sim, y_true)
    return loss

In [39]:
def eval(model, dataloader) -> float:
    """模型评估函数 
    批量预测, batch结果拼接, 一次性求spearman相关度
    """
    model.eval()
    sim_tensor = torch.tensor([], device=DEVICE)
    label_array = np.array([])
    with torch.no_grad():
        for source, target, label in dataloader:
            # source        [batch, 1, seq_len] -> [batch, seq_len]
            source_input_ids = source.get('input_ids').squeeze(1).to(DEVICE)
            source_attention_mask = source.get('attention_mask').squeeze(1).to(DEVICE)
            source_token_type_ids = source.get('token_type_ids').squeeze(1).to(DEVICE)
            source_pred = model(source_input_ids, source_attention_mask, source_token_type_ids)
            # target        [batch, 1, seq_len] -> [batch, seq_len]
            target_input_ids = target.get('input_ids').squeeze(1).to(DEVICE)
            target_attention_mask = target.get('attention_mask').squeeze(1).to(DEVICE)
            target_token_type_ids = target.get('token_type_ids').squeeze(1).to(DEVICE)
            target_pred = model(target_input_ids, target_attention_mask, target_token_type_ids)
            # concat
            sim = F.cosine_similarity(source_pred, target_pred, dim=-1)
            sim_tensor = torch.cat((sim_tensor, sim), dim=0)            
            label_array = np.append(label_array, np.array(label))
    # corrcoef 
    return spearmanr(label_array, sim_tensor.cpu().numpy()).correlation

In [69]:
def train(model, train_dl, dev_dl, optimizer) -> None:
    """模型训练函数"""
    model.train()
    global best
    for batch_idx, source in enumerate(tqdm(train_dl), start=1):
        # 维度转换 [batch, 2, seq_len] -> [batch * 2, sql_len]
        real_batch_num = source.get('input_ids').shape[0]
        input_ids = source.get('input_ids').view(real_batch_num * 2, -1).to(DEVICE)
        attention_mask = source.get('attention_mask').view(real_batch_num * 2, -1).to(DEVICE)
        token_type_ids = source.get('token_type_ids').view(real_batch_num * 2, -1).to(DEVICE)
        
        out = model(input_ids, attention_mask, token_type_ids)        
        loss = simcse_unsup_loss(out)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx % 10 == 0:     
            logger.info(f'loss: {loss.item():.4f}')
            corrcoef = eval(model, dev_dl)
            model.train()
            if best < corrcoef:
                best = corrcoef
                torch.save(model.state_dict(), save_path)
                logger.info(f"higher corrcoef: {best:.4f} in batch: {batch_idx}, save model")

In [72]:
model_path = 'G:\\deep_learning\\models\\chinese_wwm_ext_pytorch'
save_path = 'G:\\deep_learning\\Coggle\\202301\\models\\simcse\\chinese_bert_wwm_ext\\simcse_unsup.pt'

In [None]:
logger.info(f'device: {DEVICE}, pooling: {POOLING}, model path: {model_path}')
tokenizer = BertTokenizer.from_pretrained(model_path)
# load model
assert POOLING in ['cls', 'pooler', 'last-avg', 'first-last-avg']
model = SimcseModel(pretrained_model=model_path, pooling=POOLING).to(DEVICE)  
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
# train
best = 0
for epoch in range(EPOCHS):
    logger.info(f'epoch: {epoch}')
    train(model, train_dataloader, valid_dataloader, optimizer)
logger.info(f'train is finished, best model is saved at {save_path}')
# eval
model.load_state_dict(torch.load(save_path))
valid_corrcoef = eval(model, valid_dataloader)
test_corrcoef = eval(model, test_dataloader)
logger.info(f'dev_corrcoef: {valid_corrcoef:.4f}')
logger.info(f'test_corrcoef: {test_corrcoef:.4f}')

2023-01-30 23:03:34.144 | INFO     | __main__:<module>:1 - device: cuda, pooling: cls, model path: G:\deep_learning\models\chinese_wwm_ext_pytorch
Some weights of the model checkpoint at G:\deep_learning\models\chinese_wwm_ext_pytorch were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceC

 86%|████████████████████████████████████████████████████████████████████▊           | 269/313 [11:01<00:23,  1.87it/s]2023-01-30 23:14:36.855 | INFO     | __main__:train:19 - loss: 0.0275
2023-01-30 23:15:01.358 | INFO     | __main__:train:25 - higher corrcoef: 0.6266 in batch: 270, save model
 89%|███████████████████████████████████████████████████████████████████████▎        | 279/313 [11:28<00:18,  1.80it/s]2023-01-30 23:15:03.902 | INFO     | __main__:train:19 - loss: 0.0474
 92%|█████████████████████████████████████████████████████████████████████████▊      | 289/313 [11:54<00:12,  1.86it/s]2023-01-30 23:15:29.839 | INFO     | __main__:train:19 - loss: 0.0469
2023-01-30 23:15:52.829 | INFO     | __main__:train:25 - higher corrcoef: 0.6273 in batch: 290, save model
 96%|████████████████████████████████████████████████████████████████████████████▍   | 299/313 [12:19<00:07,  1.88it/s]2023-01-30 23:15:55.274 | INFO     | __main__:train:19 - loss: 0.0816
 99%|█████████████████████████