In [1]:
import os
import os.path as osp
import numpy as np
import json
import torch
import re
from tqdm import tqdm
from icecream import ic
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("../models/chinese-roberta-wwm-ext")
crime_pattern = re.compile(r'已构成(.*?)罪')
from summary import unique_sentences, split_sentence, score_sentences, get_text_after_last_startword

In [2]:
def stat_list(l):
    return {
        'mean': np.mean(l),
        'std': np.std(l),
        'max': np.max(l),
        'min': np.min(l),
        'len': len(l),
    }

In [3]:
query_sent_len = []
candidate_sent_len = []
def preprocess_data(data_path, has_label=False):
    query_path = osp.join(data_path, 'query.json')
    all_candidates_path = osp.join(data_path, 'candidates')
    if has_label:
        label_path = osp.join(data_path, 'label_top30_dict.json')
        label = json.load(open(label_path))
        labels = []
    edges, inputs, query_ridxs, node_graph_ids, edge_graph_ids, candidate_ridxs = [
    ], [], [], [], [], []
    with open(query_path) as f:
        query_lines = f.readlines()
    node_idx = 0
    for graph_idx, query_line in enumerate(tqdm(query_lines[:])):
        sentences_zoo = []
        sentences_ptr = [0]
        query_line = query_line.strip()
        query_dict = json.loads(query_line)
        query_text = "。".join(query_dict['crime']) + \
            "。"+get_text_after_last_startword(query_dict['q'])
        query_sentences = list(split_sentence(query_text))
        unique_query_sentences = unique_sentences(query_sentences)
        sentences_zoo += unique_query_sentences
        sentences_ptr.append(len(sentences_zoo))
        query_sent_len.append(len(unique_query_sentences))
        node_graph_ids.append(graph_idx)
        query_ridxs.append(query_dict['ridx'])
        query_ridx = str(query_dict['ridx'])
        query_idx = node_idx
        node_idx += 1
        candidates_path = osp.join(all_candidates_path, query_ridx)
        for candidate in os.listdir(candidates_path):
            candidate_ridx = candidate[:-5]
            candidate_path = osp.join(candidates_path, candidate)
            candidate_dict = json.load(open(candidate_path))
            all_text = ''.join(candidate_dict.values())
            crime_name = crime_pattern.search(all_text)
            if crime_name is None:
                crime_name = ''
            else:
                crime_name = crime_name.group(1) + '罪'
            candidate_text = '。'.join(
                [crime_name, get_text_after_last_startword(candidate_dict['ajjbqk'])])
            candidate_sentences = list(split_sentence(candidate_text))
            unique_candidate_sentences = unique_sentences(
                candidate_sentences)
            sentences_zoo += unique_candidate_sentences
            sentences_ptr.append(len(sentences_zoo))
            candidate_sent_len.append(len(unique_candidate_sentences))
            candidate_idx = node_idx
            node_idx += 1
            node_graph_ids.append(graph_idx)
            edge_graph_ids.append(graph_idx)
            edges.append([query_idx, candidate_idx])
            candidate_ridxs.append(int(candidate_ridx))
            if has_label:
                if candidate_ridx in label[query_ridx]:
                    labels.append(label[query_ridx][candidate_ridx])
                else:
                    labels.append(0)
        sentence_scores = score_sentences(sentences_zoo)
        for l, r in zip(sentences_ptr[:-1], sentences_ptr[1:]):
            sentences = sentences_zoo[l:r]
            sentences_score = sentence_scores[l:r]
            sentences = [sentence for _, sentence in sorted(
                zip(sentences_score, sentences), key=lambda x: x[0], reverse=True)]
            text = ''.join(sentences)
            tokenized_text = tokenizer(
                text, return_tensors="pt")
            inputs.append(tokenized_text)
    if has_label:
        return inputs, edges, query_ridxs, node_graph_ids, edge_graph_ids, candidate_ridxs, labels
    return inputs, edges, query_ridxs, node_graph_ids, edge_graph_ids, candidate_ridxs


train_path = '../data/summary/train'
train_inputs, train_edges, train_query_ridxs, train_node_graph_ids, train_edge_graph_ids, train_candidate_ridxs, train_labels = preprocess_data(
    train_path, has_label=True)
ic(len(train_inputs), len(train_edges), len(train_query_ridxs), len(train_node_graph_ids),
   len(train_edge_graph_ids), len(train_candidate_ridxs), len(train_labels))


  0%|          | 0/197 [00:00<?, ?it/s]Building prefix dict from the default dictionary ...
Loading model from cache /tmp/jieba.cache
Loading model cost 0.578 seconds.
Prefix dict has been built successfully.
100%|██████████| 197/197 [07:33<00:00,  2.30s/it]
ic| len(train_inputs): 19915
    len(train_edges): 19718
    len(train_query_ridxs): 197
    len(train_node_graph_ids): 19915
    len(train_edge_graph_ids): 19718
    len(train_candidate_ridxs): 19718
    len(train_labels): 19718


(19915, 19718, 197, 19915, 19718, 19718, 19718)

In [4]:
ic(stat_list(query_sent_len), stat_list(candidate_sent_len))

ic| stat_list(query_sent_len): {'len': 10, 'max': 14, 'mean': 6.9, 'min': 4, 'std': 3.014962686336267}
    stat_list(candidate_sent_len): {'len': 1003,
                                    'max': 768,
                                    'mean': 84.7308075772682,
                                    'min': 2,
                                    'std': 105.39177879196771}


({'mean': 6.9, 'std': 3.014962686336267, 'max': 14, 'min': 4, 'len': 10},
 {'mean': 84.7308075772682,
  'std': 105.39177879196771,
  'max': 768,
  'min': 2,
  'len': 1003})

In [4]:
def save(processed_path, edges, inputs, query_ridxs, node_graph_ids, edge_graph_ids, candidate_ridxs, labels=None):
    torch.save(inputs, osp.join(processed_path, 'inputs.pt'))
    torch.save(edges, osp.join(processed_path, 'edges.pt'))
    torch.save(query_ridxs, osp.join(processed_path, 'query_ridxs.pt'))
    torch.save(candidate_ridxs, osp.join(processed_path, 'candidate_ridxs.pt'))
    torch.save(node_graph_ids, osp.join(processed_path, 'node_graph_ids.pt'))
    torch.save(edge_graph_ids, osp.join(processed_path, 'edge_graph_ids.pt'))
    if labels is not None:
        torch.save(labels, osp.join(processed_path, 'labels.pt'))
train_processed_path = '../data/summary/train/processed'
save(train_processed_path, train_edges, train_inputs, train_query_ridxs,
     train_node_graph_ids, train_edge_graph_ids, train_candidate_ridxs, train_labels)

In [5]:
test_path = '../data/summary/test'
test_inputs, test_edges, test_query_ridxs, test_node_graph_ids, test_edge_graph_ids, test_candidate_ridxs = preprocess_data(
    test_path, has_label=False)
test_processed_path = '../data/summary/test/processed'
save(test_processed_path, test_edges, test_inputs, test_query_ridxs,
     test_node_graph_ids, test_edge_graph_ids, test_candidate_ridxs)

100%|██████████| 40/40 [00:30<00:00,  1.31it/s]


In [7]:
val_query_ridxs = train_query_ridxs[:17]
edge_split_idx = (np.array(train_edge_graph_ids) < 17).sum()
node_split_idx = (np.array(train_node_graph_ids) < 17).sum()
val_inputs = train_inputs[:node_split_idx]
val_edges = train_edges[:edge_split_idx]
val_labels = train_labels[:edge_split_idx]
val_candidate_ridxs = train_candidate_ridxs[:edge_split_idx]
val_node_graph_ids = train_node_graph_ids[:node_split_idx]
val_edge_graph_ids = train_edge_graph_ids[:edge_split_idx]
val_processed_path = '../data/summary/val/processed'
os.makedirs(val_processed_path, exist_ok=True)
save(val_processed_path, val_edges, val_inputs, val_query_ridxs,
     val_node_graph_ids, val_edge_graph_ids, val_candidate_ridxs, val_labels)

In [8]:
train_dev_query_ridxs = train_query_ridxs[17:]
train_dev_inputs = train_inputs[node_split_idx:]
train_dev_edges = np.array(train_edges[edge_split_idx:]) - node_split_idx
train_dev_edges = train_dev_edges.tolist()
train_dev_labels = train_labels[edge_split_idx:]
train_dev_candidate_ridxs = train_candidate_ridxs[edge_split_idx:]
train_dev_node_graph_ids = np.array(train_node_graph_ids[node_split_idx:]) - 17
train_dev_edge_graph_ids = np.array(train_edge_graph_ids[edge_split_idx:]) - 17
train_dev_node_graph_idx = train_dev_node_graph_ids.tolist()
train_dev_edge_graph_idx = train_dev_edge_graph_ids.tolist()
train_dev_processed_path = '../data/summary/train_dev/processed'
os.makedirs(train_dev_processed_path, exist_ok=True)
save(train_dev_processed_path, train_dev_edges, train_dev_inputs,
     train_dev_query_ridxs, train_dev_node_graph_ids, train_dev_edge_graph_ids,
     train_dev_candidate_ridxs, train_dev_labels)