In [1]:
import os
import os.path as osp
import numpy as np
import json
import torch
import re
from tqdm import tqdm
from icecream import ic
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("../models/chinese-roberta-wwm-ext")
crime_pattern = re.compile(r'已构成(.*?)罪')

In [3]:
def preprocess_data(data_path, has_label=False):
    query_path = osp.join(data_path, 'query.json')
    all_candidates_path = osp.join(data_path, 'candidates')
    if has_label:
        label_path = osp.join(data_path, 'label_top30_dict.json')
        label = json.load(open(label_path))
        labels = []
    edges, inputs, query_ridxs, node_graph_ids, edge_graph_ids, candidate_ridxs = [
    ], [], [], [], [], []
    with open(query_path) as f:
        query_lines = f.readlines()
    for query_line in tqdm(query_lines):
        query_line = query_line.strip()
        query_dict = json.loads(query_line)
        input_str = "。".join(query_dict['crime']) + \
            "。"+query_dict['q']
        tokenized_inputs = tokenizer(input_str, return_tensors="pt")
        query_idx = len(inputs)
        inputs.append(tokenized_inputs)
        node_graph_ids.append(len(query_ridxs))
        query_ridxs.append(query_dict['ridx'])
        query_ridx = str(query_dict['ridx'])
        candidates_path = osp.join(all_candidates_path, query_ridx)
        for candidate in os.listdir(candidates_path):
            candidate_ridx = candidate[:-5]
            candidate_path = osp.join(candidates_path, candidate)
            candidate_dict = json.load(open(candidate_path))
            all_text = ''.join(candidate_dict.values())
            crime_name = crime_pattern.search(all_text)
            if crime_name is None:
                crime_name = ''
            else:
                crime_name = crime_name.group(1) + '罪'
            candidate_text = candidate_dict['ajjbqk']
            if 'ajName' in candidate_dict:
                candidate_text = candidate_dict['ajName'] + '。' + candidate_text
            candidate_text = '。'.join(
                [crime_name, candidate_text])
            tokenized_candidate = tokenizer(
                candidate_text, return_tensors="pt")
            candidate_idx = len(inputs)
            inputs.append(tokenized_candidate)
            node_graph_ids.append(node_graph_ids[-1])
            edge_graph_ids.append(node_graph_ids[-1])
            edges.append([query_idx, candidate_idx])
            candidate_ridxs.append(int(candidate_ridx))
            if has_label:
                if candidate_ridx in label[query_ridx]:
                    labels.append(label[query_ridx][candidate_ridx])
                else:
                    labels.append(0)
    if has_label:
        return inputs, edges, query_ridxs, node_graph_ids, edge_graph_ids, candidate_ridxs, labels
    return inputs, edges, query_ridxs, node_graph_ids, edge_graph_ids, candidate_ridxs


train_path = '../data/origin/train'
train_inputs, train_edges, train_query_ridxs, train_node_graph_ids, train_edge_graph_ids, train_candidate_ridxs, train_labels = preprocess_data(
    train_path, has_label=True)


  0%|          | 0/197 [00:00<?, ?it/s]


UnboundLocalError: local variable 'candidate_text' referenced before assignment

In [5]:
def save(processed_path, edges, inputs, query_ridxs, node_graph_ids, edge_graph_ids, candidate_ridxs, labels=None):
    torch.save(inputs, osp.join(processed_path, 'inputs.pt'))
    torch.save(edges, osp.join(processed_path, 'edges.pt'))
    torch.save(query_ridxs, osp.join(processed_path, 'query_ridxs.pt'))
    torch.save(candidate_ridxs, osp.join(processed_path, 'candidate_ridxs.pt'))
    torch.save(node_graph_ids, osp.join(processed_path, 'node_graph_ids.pt'))
    torch.save(edge_graph_ids, osp.join(processed_path, 'edge_graph_ids.pt'))
    if labels is not None:
        torch.save(labels, osp.join(processed_path, 'labels.pt'))

In [6]:
ic(len(train_inputs), len(train_edges), len(train_query_ridxs), len(train_node_graph_ids), len(train_edge_graph_ids), len(train_candidate_ridxs), len(train_labels))

ic| len(train_inputs): 19915
    len(train_edges): 19718
    len(train_query_ridxs): 197
    len(train_node_graph_ids): 19915
    len(train_edge_graph_ids): 19718
    len(train_candidate_ridxs): 19718
    len(train_labels): 19718


(19915, 19718, 197, 19915, 19718, 19718, 19718)

In [7]:
train_processed_path = '../data/origin/train/processed'
save(train_processed_path, train_edges, train_inputs, train_query_ridxs,
     train_node_graph_ids, train_edge_graph_ids, train_candidate_ridxs, train_labels)

In [8]:
test_path = '../data/origin/test'
test_inputs, test_edges, test_query_ridxs, test_node_graph_ids, test_edge_graph_ids, test_candidate_ridxs = preprocess_data(
    test_path, has_label=False)
test_processed_path = '../data/origin/test/processed'
save(test_processed_path, test_edges, test_inputs, test_query_ridxs,
     test_node_graph_ids, test_edge_graph_ids, test_candidate_ridxs)

100%|██████████| 40/40 [00:09<00:00,  4.18it/s]


In [9]:
val_query_ridxs = train_query_ridxs[:17]
edge_split_idx = (np.array(train_edge_graph_ids) < 17).sum()
node_split_idx = (np.array(train_node_graph_ids) < 17).sum()
val_inputs = train_inputs[:node_split_idx]
val_edges = train_edges[:edge_split_idx]
val_labels = train_labels[:edge_split_idx]
val_candidate_ridxs = train_candidate_ridxs[:edge_split_idx]
val_node_graph_ids = train_node_graph_ids[:node_split_idx]
val_edge_graph_ids = train_edge_graph_ids[:edge_split_idx]
val_processed_path = '../data/origin/val/processed'
save(val_processed_path, val_edges, val_inputs, val_query_ridxs,
     val_node_graph_ids, val_edge_graph_ids, val_candidate_ridxs, val_labels)

In [10]:
train_dev_query_ridxs = train_query_ridxs[17:]
train_dev_inputs = train_inputs[node_split_idx:]
train_dev_edges = np.array(train_edges[edge_split_idx:]) - node_split_idx
train_dev_edges = train_dev_edges.tolist()
train_dev_labels = train_labels[edge_split_idx:]
train_dev_candidate_ridxs = train_candidate_ridxs[edge_split_idx:]
train_dev_node_graph_ids = np.array(train_node_graph_ids[node_split_idx:]) - 17
train_dev_edge_graph_ids = np.array(train_edge_graph_ids[edge_split_idx:]) - 17
train_dev_node_graph_idx = train_dev_node_graph_ids.tolist()
train_dev_edge_graph_idx = train_dev_edge_graph_ids.tolist()
train_dev_processed_path = '../data/origin/train_dev/processed'
save(train_dev_processed_path, train_dev_edges, train_dev_inputs,
     train_dev_query_ridxs, train_dev_node_graph_ids, train_dev_edge_graph_ids,
     train_dev_candidate_ridxs, train_dev_labels)
