In [1]:
import math
import json
import glob
import collections
import random
from pathlib import Path
import pandas as pd
import numpy as np
import os
import copy
from tqdm.auto import tqdm
import pickle
import gc
from sklearn.model_selection import StratifiedKFold,KFold,GroupKFold
import torch
# pip install prefetch_generator
from prefetch_generator import BackgroundGenerator

os.environ["TOKENIZERS_PARALLELISM"] = "true"
from  transformers import AdamW, AutoTokenizer,AutoModel
import torch.nn as nn
import torch.nn.functional as F
import warnings
warnings.filterwarnings("ignore")
def seed_everything(seed):
    """
    Seeds basic parameters for reproductibility of results
    Arguments:
        seed {int} -- Number of the seed
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
SEED=2020
seed_everything(SEED)

  from .autonotebook import tqdm as notebook_tqdm


# 加载数据

In [2]:
DATA_PATH = "../data/"
BERT_PATH = "sentence-transformers/all-MiniLM-L6-v2"
MODEL_PATH = "./save/recall/2023_recall_v1_add_text_nice_valid.pkl"
PROMPT_LEN = 512
WIKI_LEN = 512
MAX_LEN = 512
BATCH_SIZE = 128
DEVICE = 'cuda'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

In [3]:
prompt = pd.read_csv('./data/crawl_context.csv')
prompt.loc[0]

id                                                            0
prompt                            What is physical mathematics?
A                 The study of physically motivated mathematics
B                             The study of mathematical physics
C                 The study of mathematics in physical contexts
D                           The study of mathematical equations
E                          The study of mathematical operations
answer                                                        A
wiki_text     The subject of physical mathematics is concern...
page_id                                                32439784
page_title                                 Physical mathematics
stem_label                                                    M
Name: 0, dtype: object

In [4]:
prompt.loc[0,'wiki_text']

'The subject of physical mathematics is concerned with physically motivated mathematics and is considered by some as a subfield of mathematical physics.\nAccording to Margaret Osler the simple machines of Hero of Alexandria and the ray tracing of Alhazen did not refer to causality or forces. Accordingly these early expressions of kinematics and optics do not rise to the level of mathematical physics as practiced by Galileo and Newton.\nThe details of physical units and their manipulation were addressed by Alexander Macfarlane in Physical Arithmetic in 1885. The science of kinematics created a need for mathematical representation of motion and has found expression with complex numbers, quaternions, and linear algebra.\nAt Cambridge University the Mathematical Tripos tested students on their knowledge of "mixed mathematics". "... [N]ew books which appeared  in the mid-eighteenth century offered a systematic introduction to the fundamental  operations of the fluxional calculus and showed 

In [5]:
prompt.loc[0,'A']

'The study of physically motivated mathematics'

In [8]:
prompt = pd.read_csv('./data/crawl_context.csv')
print(prompt.shape)
prompt['prompt_answer'] = prompt.apply(lambda row: ' '.join(str(row[field]) for field in ['prompt', 'A', 'B', 'C', 'D', 'E']), axis=1)
wiki = []
with open('./data/wiki_data.json', 'r',encoding='utf8') as f:
    lines = f.readlines()
    for line in lines:
        wiki.append(json.loads(line))
wiki = pd.DataFrame(wiki)
wiki['title_text'] = wiki.apply(lambda row : ' '.join(str(row[field]) for field in ['title','content']),axis=1)
wiki = wiki[['page_id', 'title_text']]
wiki.drop_duplicates(inplace=True)
wiki = wiki.reset_index(drop=True)
prompt.drop_duplicates(inplace=True)
prompt = prompt.reset_index(drop=True)
wiki['page_id'] = wiki['page_id'].apply(lambda x : int(x))
prompt = pd.merge(prompt, wiki, on='page_id')
prompt = prompt[['id','page_id','prompt_answer','wiki_text','title_text']]

(49284, 12)


In [9]:
prompt.loc[0,'title_text']

'Physical mathematics The subject of physical mathematics is concerned with physically motivated mathematics and is considered by some as a subfield of mathematical physics.\nAccording to Margaret Osler the simple machines of Hero of Alexandria and the ray tracing of Alhazen did not refer to causality or forces. Accordingly these early expressions of kinematics and optics do not rise to the level of mathematical physics as practiced by Galileo and Newton.\nThe details of physical units and their manipulation were addressed by Alexander Macfarlane in Physical Arithmetic in 1885. The science of kinematics created a need for mathematical representation of motion and has found expression with complex numbers, quaternions, and linear algebra.\nAt Cambridge University the Mathematical Tripos tested students on their knowledge of "mixed mathematics". "... [N]ew books which appeared  in the mid-eighteenth century offered a systematic introduction to the fundamental  operations of the fluxional

In [10]:
prompt['prompt_answer'].apply(lambda x: len(x) if str(x)!='nan' else 0 ).describe([0.98,0.99])

count    49284.000000
mean       224.731089
std        117.116009
min         36.000000
50%        202.000000
98%        524.000000
99%        600.000000
max       2129.000000
Name: prompt_answer, dtype: float64

In [11]:
prompt['title_text'].apply(lambda x: len(x) if str(x)!='nan' else 0 ).describe([0.7,0.98,0.99])

count     49284.000000
mean      10037.972770
std       13972.651941
min           5.000000
50%        5239.000000
70%        9664.300000
98%       50666.000000
99%       66664.000000
max      229856.000000
Name: title_text, dtype: float64

In [12]:
class LLMRecallDataSet(torch.utils.data.Dataset):
    def __init__(self, data):
        self.tokenizer = AutoTokenizer.from_pretrained(BERT_PATH, use_fast=True)
        self.query = []
        self.answer = []
        print('加载数据集中')
        for i in tqdm(range(len(data))):
            query = data.loc[i, 'prompt_answer']
            answer = data.loc[i, 'title_text']
            query_id = self.tokenizer.encode(query, add_special_tokens=False)
            answer_id = self.tokenizer.encode(answer, add_special_tokens=False)
            import pdb
            pdb.set_trace()
            if len(query_id) > 510:
                query_id = [101] + query_id[:510] + [102]
            else:
                query_id = [101] + query_id + [102]
            if len(answer_id) > 510:
                answer_id = [101] + answer_id[:510] + [102]
            else:
                answer_id = [101] + answer_id + [102]
            self.query.append(query_id)
            self.answer.append(answer_id)
    def __len__(self):
        return len(self.query) 
    
    def __getitem__(self,index):
        return self.query[index], self.answer[index]
    
    def collate_fn(self, batch):
        def sequence_padding(inputs, length=None, padding=0):
            """
            Numpy函数，将序列padding到同一长度
            """
            if length is None:
                length = max([len(x) for x in inputs])

            pad_width = [(0, 0) for _ in np.shape(inputs[0])]
            outputs = []
            for x in inputs:
                x = x[:length]
                pad_width[0] = (0, length - len(x))
                x = np.pad(x, pad_width, 'constant', constant_values=padding)
                outputs.append(x)

            return np.array(outputs, dtype='int64')
        batch_query, batch_answer = [], []
        
        for item in batch:
            query, answer = item
            batch_query.append(query)
            batch_answer.append(answer)
        batch_query = torch.tensor(sequence_padding(batch_query), dtype=torch.long)
        batch_answer = torch.tensor(sequence_padding(batch_answer), dtype=torch.long)
        
        return batch_query, batch_answer

        
class DataLoaderX(torch.utils.data.DataLoader):
    '''
        replace DataLoader with PrefetchDataLoader
    '''
    def __iter__(self):
        return BackgroundGenerator(super().__iter__())  

    
def get_loader(prompt,batch_size,train_mode=True,num_workers=4):
    ds_df = LLMRecallDataSet(prompt)
    loader = DataLoaderX(ds_df, batch_size=batch_size if train_mode else batch_size//2, shuffle=train_mode, num_workers=num_workers,pin_memory=True,
                                         collate_fn=ds_df.collate_fn, drop_last=train_mode)
    loader.num = len(ds_df)
    return loader

def debug_loader(prompt, batch_size):
    loader=get_loader(prompt,batch_size,train_mode=True,num_workers=0)
    for token_ids,labels in loader:
        print(token_ids)
        print(labels)
        break
    return loader

# define recall model

In [8]:
class RecallModel(nn.Module):
    def __init__(self):
        super(RecallModel, self).__init__()
        self.bert_model = AutoModel.from_pretrained(BERT_PATH)
    
    def mask_mean(self, x, mask=None):
        if mask != None:
            mask_x = x * (mask.unsqueeze(-1))
            x_sum = torch.sum(mask_x, dim=1)
            re_x = torch.div(x_sum, torch.sum(mask, dim=1).unsqueeze(-1))
        else:
            x_sum = torch.sum(x, dim=1)
            re_x = torch.div(x_sum, x.size()[1])
        return re_x
    
    def forward(self,input_ids):
        attention_mask = input_ids > 0
        out = self.bert_model(input_ids, attention_mask=attention_mask).last_hidden_state
        x = out[:,0,:]
        return x

def debug_label():
    loader=get_loader(prompt,batch_size=2,train_mode=True,num_workers=0)
    model= RecallModel()
    print('models paramters:', sum(p.numel() for p in model.parameters()))
    for token_ids,labels in loader:
        # print(token_ids)
        # print(labels)
        prob=model(token_ids)
        print(prob)
        break
    

In [10]:
def SimCSE_loss(topic_pred,content_pred,tau=0.05):
    similarities = F.cosine_similarity(topic_pred.unsqueeze(1), content_pred.unsqueeze(0), dim=2) # B,B
    y_true = torch.arange(0,topic_pred.size(0)).to(DEVICE)
    # similarities = similarities - torch.eye(pred.shape[0]) * 1e12
    similarities = similarities / tau
    loss=F.cross_entropy(similarities, y_true)
    return torch.mean(loss)
from torch.cuda.amp import autocast, GradScaler
def trainer(train_dataloader,val_dataloader,model, epochs, model_save_path,
            accumulation_steps=1, early_stop_epochs=5, device='cpu'):
    ########早停
    no_improve_epochs = 0

    ########优化器 学习率
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    crf_p=[n for n, p in param_optimizer if str(n).find('crf')!=-1]
    print(crf_p)
    optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) and n not in crf_p], 'weight_decay': 0.8},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)  and n not in crf_p], 'weight_decay': 0.0},
            {'params': [p for n, p in param_optimizer if n in crf_p], 'lr': 2e-3, 'weight_decay': 0.8},
            ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5, eps=1e-8)
    
    scaler = GradScaler()
    criterion = nn.BCEWithLogitsLoss()

    train_len = len(train_dataloader)

    best_score = 100000000
    losses = []
    for epoch in range(1, epochs + 1):
        model.train()
        bar = tqdm(train_dataloader)
        for i, inputs in enumerate(bar):
            with autocast():
                topic_inputs,content_inputs = (_.to(device) for _ in inputs)
                # print(topic_inputs.size())
                # print(content_inputs.size())
                topic_pred = model(topic_inputs)
                content_pred = model(content_inputs)
                # print(topic_pred.size())
                # print(content_pred.size())
                loss = SimCSE_loss(topic_pred,content_pred)
            scaler.scale(loss).backward()
            losses.append(loss.item())
            if (i + 1) % accumulation_steps == 0 or (i + 1) == train_len:
                scaler.step(optimizer)
                optimizer.zero_grad()
                scaler.update()
            bar.set_postfix(loss_mean=np.array(losses).mean(), epoch=epoch)
        if epoch % 20 == 0:
            torch.save(model.state_dict(), f'./save/recall/recall_epoch{epoch}.bin')
    return losses

In [11]:
train = prompt.loc[:47000].reset_index(drop=True)
val = prompt.loc[47000:].reset_index(drop=True)
train.to_csv('./data/recall_train.csv',index=False)
val.to_csv('./data/recall_val.csv',index=False)

train_loader=get_loader(train,
                      batch_size=BATCH_SIZE,
                      train_mode=True,
                      num_workers=2)
val_loader=get_loader(val, batch_size=BATCH_SIZE,
                     train_mode=False,
                     num_workers=2)
model= RecallModel().to(DEVICE)
trainer(train_loader,val_loader,model, 
            epochs=100, 
            model_save_path = './save/recall/recall.bin',
            accumulation_steps=1,
            early_stop_epochs=5, device=DEVICE)

加载数据集中


  0%|          | 0/47001 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (676 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 47001/47001 [05:47<00:00, 135.08it/s]


加载数据集中


  0%|          | 0/2284 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1052 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 2284/2284 [00:13<00:00, 166.98it/s]


[]


100%|██████████| 367/367 [01:43<00:00,  3.54it/s, epoch=1, loss_mean=0.219]
100%|██████████| 367/367 [01:42<00:00,  3.58it/s, epoch=2, loss_mean=0.154]
100%|██████████| 367/367 [01:42<00:00,  3.59it/s, epoch=3, loss_mean=0.125]
100%|██████████| 367/367 [01:42<00:00,  3.58it/s, epoch=4, loss_mean=0.107]
100%|██████████| 367/367 [01:42<00:00,  3.57it/s, epoch=5, loss_mean=0.0951]
100%|██████████| 367/367 [01:42<00:00,  3.58it/s, epoch=6, loss_mean=0.0859]
100%|██████████| 367/367 [01:42<00:00,  3.58it/s, epoch=7, loss_mean=0.0788]
100%|██████████| 367/367 [01:42<00:00,  3.58it/s, epoch=8, loss_mean=0.0731]
100%|██████████| 367/367 [01:42<00:00,  3.57it/s, epoch=9, loss_mean=0.0684]
100%|██████████| 367/367 [01:42<00:00,  3.58it/s, epoch=10, loss_mean=0.0644]
100%|██████████| 367/367 [01:42<00:00,  3.58it/s, epoch=11, loss_mean=0.0609]
100%|██████████| 367/367 [01:42<00:00,  3.58it/s, epoch=12, loss_mean=0.0579]
100%|██████████| 367/367 [01:42<00:00,  3.58it/s, epoch=13, loss_mean=0.0552]

[1.9519906044006348,
 1.8216842412948608,
 1.792602777481079,
 1.6925252676010132,
 1.3921329975128174,
 1.309242606163025,
 1.0707662105560303,
 0.9850286245346069,
 0.7344101667404175,
 0.7770193815231323,
 0.7651268243789673,
 0.7516265511512756,
 0.5482092499732971,
 0.5104630589485168,
 0.5680720210075378,
 0.5906234383583069,
 0.6495450139045715,
 0.46312934160232544,
 0.529392421245575,
 0.39064404368400574,
 0.3021797239780426,
 0.3385663330554962,
 0.34318941831588745,
 0.38456010818481445,
 0.4530014395713806,
 0.4653345048427582,
 0.4079946279525757,
 0.37095069885253906,
 0.3217395842075348,
 0.2944546639919281,
 0.43370068073272705,
 0.3861745297908783,
 0.4166041910648346,
 0.37937766313552856,
 0.2808482348918915,
 0.33334943652153015,
 0.4106549024581909,
 0.36930418014526367,
 0.23981483280658722,
 0.24083928763866425,
 0.5208274126052856,
 0.41099750995635986,
 0.33048388361930847,
 0.3470776677131653,
 0.4313044846057892,
 0.32513928413391113,
 0.30208033323287964,
 