In [6]:
import torch
import torch.nn as nn
import os
import numpy as np
import random
import json
import jsonlines
import csv
import re
import time
import argparse
import sys
import sklearn
import traceback

from torch.utils import data
from tqdm import tqdm
# from apex import amp
from scipy.special import softmax

from ditto_light.ditto import evaluate, DittoModel
from ditto_light.exceptions import ModelNotFoundError
from ditto_light.dataset import DittoDataset
from ditto_light.summarize import Summarizer
from ditto_light.knowledge import *

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package stopwords to /home/cxz/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [10]:
import pandas as pd

# 假设 ent1 和 ent2 是 DataFrame 数据
# 创建示例 DataFrame
ent1 = pd.DataFrame({'attr1': [1, 2, 3], 'attr2': ['A', 'B', 'C']})
ent2 = pd.DataFrame({'attr1': [4, 5, 6], 'attr2': ['D', 'E', 'F']})

content = ''
for ent in [ent1, ent2]:
    if isinstance(ent, pd.DataFrame):
        for index, row in ent.iterrows():
            for attr in ent.columns:
                content += 'COL %s VAL %s ' % (attr, row[attr])
            content += '\t'
    else:
        if isinstance(ent, str):
            content += ent
        else:
            for attr in ent.keys():
                content += 'COL %s VAL %s ' % (attr, ent[attr])
        content += '\t'

content += '0'


In [14]:
ent1, ent2

(   attr1 attr2
 0      1     A
 1      2     B
 2      3     C,
    attr1 attr2
 0      4     D
 1      5     E
 2      6     F)

In [13]:
content

'COL attr1 VAL 1 COL attr2 VAL A \tCOL attr1 VAL 2 COL attr2 VAL B \tCOL attr1 VAL 3 COL attr2 VAL C \tCOL attr1 VAL 4 COL attr2 VAL D \tCOL attr1 VAL 5 COL attr2 VAL E \tCOL attr1 VAL 6 COL attr2 VAL F \t0'

In [7]:
def set_seed(seed: int):
    """
    Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch``
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:

# 支持多种数据输入
# txt文件
# jsonl文件
# string -pairs
# csv 文件

# 序列化 一对 数据条目
def to_str(ent1, ent2):

    # jsonl :{'title': '  "GoPro Headstrap Plus Quickclip"@en Quickclip | Sportsman\'s Warehouse"@en'}
    # content :'COL title VAL   "GoPro Headstrap Plus Quickclip"@en Quickclip | Sportsman\'s Warehouse"@en '
    content = ''
    for ent in [ent1, ent2]:
        if isinstance(ent, str):
            content += ent
        else:
            for attr in ent.keys():
                content += 'COL %s VAL %s ' % (attr, ent[attr])
        content += '\t'

    content += '0'
    # 每个属性之间用空格隔开，两个数据条目之间用制表符隔开。
    # 向content中添加字符"0"，表示序列化后的字符串末尾

    new_ent1, new_ent2, _ = content.split('\t')

    return new_ent1 + '\t' + new_ent2 + '\t0'

In [9]:
input_path = '1.txt'
input_path += '.jsonl'

print('remove', input_path[:-6])

remove 1.txt


In [None]:
# 的句子对应用于MRPC模型进行分类,#返回预测标签啊对应的得分
def classify(sentence_pairs, model,
             lm='distilbert',
             max_len=256,
             threshold=None): 
    """Apply the MRPC model.

    Args:
        sentence_pairs (list of str): the sequence pairs
        model (MultiTaskNet): the model in pytorch
        max_len (int, optional): the max sequence length
        threshold (float, optional): the threshold of the 0's class

    Returns:
        list of float: the scores of the pairs
    """
    inputs = sentence_pairs
    # print('max_len =', max_len)
    dataset = DittoDataset(inputs,
                           max_len=max_len,
                           lm=lm)
    # dataset = DittoDataset(inputs,
    #                        max_len=max_len) lm = roberta
    # print(dataset[0])
    iterator = data.DataLoader(dataset=dataset,
                               batch_size=len(dataset), # 一批次？
                               shuffle=False,
                               num_workers=0,
                               collate_fn=DittoDataset.pad)
    # prediction
    all_probs = []
    all_logits = []
    with torch.no_grad():
        # print('Classification')
        for i, batch in enumerate(iterator):
            x, _ = batch
            logits = model(x)
            probs = logits.softmax(dim=1)[:, 1]
            all_probs += probs.cpu().numpy().tolist()
            all_logits += logits.cpu().numpy().tolist()

# 默认为0.5
    if threshold is None:
        threshold = 0.5

    pred = [1 if p > threshold else 0 for p in all_probs]
    return pred, all_logits

In [None]:
#　模型预测，并将预测结果写入　输出文件
def predict(input_rows, output_path,
            model,
            batch_size=1024,
            lm='distilbert',
            threshold=None):

    pairs = []
# 处理数据批次的预测结果
#　直接写入　使用writer(写入了output-json文件)
    def process_batch(pairs, writer):
        predictions, logits = classify(pairs, model, lm=lm,
                                       threshold=threshold)
        # logits是模型对输入句子对的预测结果，表示两个句子是同义词和不是同义词的概率得分
        #   通常是一个包含两个值的数组
        
        # scores是通过对logits进行softmax操作得到的类别概率分布

        scores = softmax(logits, axis=1)
        for pair, pred, score in zip(pairs, predictions, scores):
            output = {'left': pair[0], 'right': pair[1],
                'match': pred,
                'match_confidence': score[int(pred)]}
            writer.write(output)
    
    # 将结果写入jsonl文件
    rows = input_rows
    with jsonlines.open(output_path, mode='w') as writer: # writer向文件写入
        pairs = []
        for row in rows:
            pairs.append(to_str(row[0], row[1]))
            
            # 满足batch_size 进行批量处理
            if len(pairs) == batch_size:
                # predict batch_size 批量处理
                process_batch(pairs, writer)
                pairs.clear()

        if len(pairs) > 0:
            process_batch(pairs, writer)
       
    # run_time = time.time() - start_time
    # run_tag = '%s_lm=%s_dk=%s_su=%s' % (config['name'], lm, str(dk_injector != None), str(summarizer != None))
    # os.system('echo %s %f >> log.txt' % (run_tag, run_time))
    # # 产生一个log.txt文件 将运行标签和运行时间写入到log.txt文件

In [None]:

# 加载模型
# 默认使用GPU 
def load_model(task, save_pth, lm, use_gpu=True):
    
    # load models
    checkpoint = os.path.join(save_pth, task, 'model.pt')
    if not os.path.exists(checkpoint):
        raise ModelNotFoundError(checkpoint)

    if use_gpu:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    else:
        device = 'cpu'

    model = DittoModel(device=device, lm=lm)

    saved_state = torch.load(checkpoint, map_location=lambda storage, loc: storage)
    model.load_state_dict(saved_state['model'])
    model = model.to(device)

    return model

In [78]:
import jsonlines
import csv
import os
import pandas as pd

In [14]:
def combine_dataframes(ent1, ent2):
    rows_1 = []
    rows_2 = []
    rows = []

    content = ''
    for _, row in ent1.iterrows():
        # 读取属性
        
        for attr,val in zip(row.index, row.values):
            content += 'COL %s VAL %s ' % (str(attr), str(val))
        
        content += '\t'
        rows_1.append(content)
        content = ''

    for _, row in ent2.iterrows():
        # 读取属性
        
        for attr,val in zip(row.index, row.values):
            content += 'COL %s VAL %s ' % (str(attr), str(val))
        
        content += '\t'
        rows_2.append(content)
        content = ''
        
    for row_l in rows_1:
        left = row_l.split('\t')[0]
        for row_r in rows_2:
            right = row_r.split('\t')[0]
            rows.append(left + '\t ' + right)

    return rows

def combine_string(rows_l, rows_r):

    rows = []
        
    for left in rows_l:
        # left = row_l.split('\t')[0]
        for right in rows_r:
            # right = row_r.split('\t')[0]
            rows.append(left + '\t ' + right)

    return rows



# 处理多样化的输入数据：
# 
def get_input(content,type='string'):
    # 判断使用那种输入方式？
    if type =='file':
        # txt or json
        input_path = content
        if len(content)==2 and '.csv' in content[0] and '.csv' in content[0]:
            # 传入了2个csv文件
            path_1 = content[0]
            df1 = pd.read_csv(path_1)
            path_2 = content[1]
            df2 = pd.read_csv(path_2)
            rows = combine_dataframes(df1,df2)
            return rows

        if '.txt' in input_path:
            with jsonlines.open(input_path + '.jsonl', mode='w') as writer:
                for line in open(input_path):
                    writer.write(line.split('\t')[:2])
            input_path += '.jsonl'

        # batch processing
        # start_time = time.time()
        with jsonlines.open(input_path) as reader: # writer向文件写入
            rows = []
            for _, row in enumerate(reader):
                # rows.append(to_str(row[0], row[1], summarizer, max_len))
                rows.append(row)
        # 删除文件
        if '.txt' in input_path:
            # cmd  = 'rm {}'.format(input_path[:-6]) 
            # print('remove', input_path[:-6])
            cmd  = 'rm {}'.format(input_path)
            print('remove', input_path)
            os.system(cmd)
            
        return rows
        
    # csv
    elif type =='dataframe':
        # 得到了两个df
        
        # df1 = content[0]
        # df2 = content[1]
        # rows = combine_dataframes(df1,df2)
        
        # return rows
        pass
    elif type =='string':
        # 得到了两个 string
        
        rows_l = content[0].split('\n')
        rows_r = content[1].split('\n')
        
        rows = combine_string(rows_l, rows_r)
        return rows
    else:
        print('TYPE ERROR')
        return None
        
    
    
    
    # # jsonl 文件
    
    # # csv表格
    
    # # string 直接输入
    
    # pass

In [None]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--task", type=str, default='Structured/Beer')
    parser.add_argument("--input_path", type=str, default='input/candidates.jsonl')
    parser.add_argument("--output_path", type=str, default='output/match_candidates.jsonl')
    parser.add_argument("--lm", type=str, default='distilbert')
    parser.add_argument("--use_gpu", dest="use_gpu", action="store_true")

    parser.add_argument("--checkpoint_path", type=str, default='checkpoints/')
    parser.add_argument("--dk", type=str, default=None)
    parser.add_argument("--summarize", dest="summarize", action="store_true")
    parser.add_argument("--max_len", type=int, default=256)
    hp = parser.parse_args()

    # load the models
    set_seed(123)

    model = load_model(hp.task, hp.checkpoint_path,
                       hp.lm, hp.use_gpu)


    # 来自交互 gradio
    # content, type = XXXX
    
    # input_rows = pass
    input_rows = get_input(content, type)
    
    predict(input_rows, hp.output_path, model,
            summarizer=summarizer,
            max_len=hp.max_len,
            lm=hp.lm,
            dk_injector=dk_injector,
            threshold=threshold)

In [17]:
# path = 'candidate.txt'
# content = ''
# with open(path, 'r') as file:
#     for line in file:
#         content += line
#         # print(line)
        
# content1 = content
# content2 = content

# len(get_input([content1, content2], 'string')) # 8*8


64

In [9]:
content.split('\n')[0].split('\t')[:2]

['COL title VAL secure transaction processing in firm real-time database systems binto george , jayant r. haritsa sigmod conference COL authors VAL  COL venue VAL  COL year VAL 1997.0 ',
 'COL title VAL secure buffering in firm real-time database systems 2000 COL authors VAL binto george , jayant r. haritsa COL venue VAL the vldb journal -- the international journal on very large data bases COL year VAL  ']

In [17]:
# content = 'candidate.jsonl'
# rows = get_input(content, 'file')

# content = 'candidate.txt'
# rows_1 = get_input(content, 'file')
# rows_1[0]

# content = ['can_a.csv', 'can_b.csv']
# rows_1 = get_input(content, 'file')
# rows_1[0]

In [23]:
# import pandas as pd

# # 假设 ent1 和 ent2 是 DataFrame 数据
# # 创建示例 DataFrame
# ent1 = pd.DataFrame({'attr1': [1, 2, 3], 'attr2': ['A', 'B', 'C']})
# ent2 = pd.DataFrame({'attr1': [4, 5, 6], 'attr2': ['D', 'E', 'F']})


In [75]:
# def get_dataframe(ent1, ent2):
#     rows_1 = []
#     rows_2 = []
#     rows = []

#     content = ''
#     for _, row in ent1.iterrows():
#         # 读取属性
        
#         for attr,val in zip(row.index, row.values):
#             print(attr,val)
#             content += 'COL %s VAL %s ' % (str(attr), str(val))
        
#         content += '\t'
#         rows_1.append(content)
#         content = ''

#     for _, row in ent2.iterrows():
#         # 读取属性
        
#         for attr,val in zip(row.index, row.values):
#             content += 'COL %s VAL %s ' % (str(attr), str(val))
        
#         content += '\t'
#         rows_2.append(content)
#         content = ''
        
#     for row_l in rows_1:
#         left = row_l.split('\t')[0]
#         for row_r in rows_2:
#             right = row_r.split('\t')[0]
#             rows.append(left + '\t ' + right)

#     return rows


# get_dataframe(ent1, ent2)

In [None]:

# content = ''
# for ent in [ent1, ent2]:
#     if isinstance(ent, pd.DataFrame):
#         for index, row in ent.iterrows():
#             for attr in ent.columns:
#                 content += 'COL %s VAL %s ' % (attr, row[attr])
#             content += '\t'
#     else:
#         if isinstance(ent, str):
#             content += ent
#         else:
#             for attr in ent.keys():
#                 content += 'COL %s VAL %s ' % (attr, ent[attr])
#         content += '\t'

# content += '0'