In [1]:
import os
import os.path as osp
import numpy as np
import json
import torch
import re
from tqdm import tqdm
from icecream import ic

In [2]:
np.random.seed(42)

In [3]:
train_path = '../data/train'
inputs = osp.join(train_path, 'candidates')
train_label_path = osp.join(train_path, 'label_top30_dict.json')
train_query_path = osp.join(train_path, 'query.json')
train_candidates_path = osp.join(train_path, 'candidates')

In [4]:
train_candidates = [int(x) for x in os.listdir(inputs)]
num_candidates = len(train_candidates)
num_candidates

237

In [5]:
train_label = json.load(open(train_label_path))

In [12]:
import typing_extensions
from transformers import AutoModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("../models/chinese-roberta-wwm-ext", )
model = AutoModel.from_pretrained('../models/Lawformer')
inputs = tokenizer("任某提起诉讼，请求判令解除婚姻关系并对夫妻共同财产进行分割。" , return_tensors="pt")

outputs = model(**inputs)

Some weights of the model checkpoint at ../models/Lawformer were not used when initializing LongformerModel: ['lm_head.dense.bias', 'lm_head.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.decoder.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing LongformerModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LongformerModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LongformerModel were not initialized from the model checkpoint at ../models/Lawformer and are newly initialized: ['longformer.pooler.dense.weight', 'longformer.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to

In [7]:
inputs

{'input_ids': tensor([[ 101,  818, 3378, 2990, 6629, 6401, 6390, 8024, 6435, 3724, 1161,  808,
         6237, 7370, 2042, 2012, 1068, 5143, 2400, 2190, 1923, 1988, 1066, 1398,
         6568,  772, 6822, 6121, 1146, 1200,  511,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]])}

In [8]:
outputs['pooler_output'].shape

torch.Size([1, 768])

In [9]:
with open(train_query_path) as f:
    train_query_lines = f.readlines()
train_query = []
for train_query_line in train_query_lines:
    train_query_line = train_query_line.strip()
    train_query_dict = json.loads(train_query_line)
    input_str = "[SEP]".join(train_query_dict['crime']) + \
        "[SEP]"+train_query_dict['q']
    tokenized_inputs = tokenizer(input_str, return_tensors="pt")
    train_query.append(
        {"ridx": int(train_query_dict['ridx']), "inputs": tokenized_inputs})
train_query = sorted(train_query, key=lambda x: x['ridx'])
train_query_ridx = [x['ridx'] for x in train_query]
ridx_to_idx = {ridx: idx for idx, ridx in enumerate(train_query_ridx)}
torch.save(train_query, osp.join(train_path, 'query.pt'))
torch.save(ridx_to_idx, osp.join(train_path, 'query_ridx_to_idx.pt'))


In [10]:
example_string = "，其行为已构成开设赌场罪，属情节严重。"
crime_pattern = re.compile(r'已构成(.*?)罪')
crime_match = crime_pattern.search(example_string)
crime_match.group(1)

'开设赌场'

In [23]:
candidates_zoo = []
for ridx in os.listdir(train_candidates_path):
    candidates_path = osp.join(train_candidates_path, ridx)
    for candidate in os.listdir(candidates_path):
        ridx = int(candidate[:-5])
        candidate_path = osp.join(candidates_path, candidate)
        candidate_dict = json.load(open(candidate_path))
        all_text = ''.join(candidate_dict.values())
        crime_name = crime_pattern.search(all_text)
        if crime_name is None:
            crime_name = ''
        else:
            crime_name = crime_name.group(1) + '罪'
        candidate_text = '[SEP]'.join([crime_name, candidate_dict['ajjbqk']])
        tokenized_candidate = tokenizer(candidate_text, return_tensors="pt")
        candidates_zoo.append({'ridx': ridx, 'inputs': tokenized_candidate})
candidates_zoo = sorted(candidates_zoo, key=lambda x: x['ridx'])
candidates_ridx = [x['ridx'] for x in candidates_zoo]
candidates_ridx_to_idx = {ridx: idx for idx,
                          ridx in enumerate(candidates_ridx)}
torch.save(candidates_zoo, osp.join(train_path, 'candidates.pt'))
torch.save(candidates_ridx_to_idx, osp.join(
    train_path, 'candidates_ridx_to_idx.pt'))

In [24]:
all_inputs = train_query + candidates_zoo
all_inputs = sorted(all_inputs, key=lambda x: x['ridx'])
all_inputs_ridx = [x['ridx'] for x in all_inputs]
all_inputs_ridx_to_idx = {ridx: idx for idx, ridx in enumerate(all_inputs_ridx)}

In [30]:
unique_ridx, ridx_count = np.unique(all_inputs_ridx, return_counts=True)

In [37]:
dup_ridx = unique_ridx[np.where(ridx_count > 1)[0]]
print(dup_ridx)

[    6     8    13 ... 99892 99918 99995]


In [None]:
all_inputs[11:13]

In [40]:
print(len(candidates_ridx))
print(np.unique(candidates_ridx).shape)

23718
(19959,)


In [41]:
train_label

{'5156': {'38633': 2,
  '38632': 2,
  '32518': 3,
  '36655': 3,
  '501': 3,
  '32377': 3,
  '4348': 3,
  '17848': 2,
  '27033': 3,
  '24364': 2,
  '28530': 3,
  '21312': 3,
  '31607': 3,
  '7859': 3,
  '34099': 3,
  '11977': 2,
  '1970': 3,
  '11940': 3,
  '42565': 3,
  '12976': 2,
  '28331': 2,
  '33175': 2,
  '18097': 3,
  '39991': 3,
  '38445': 2,
  '24091': 2,
  '14776': 3,
  '39608': 3,
  '20875': 3,
  '28626': 2},
 '4891': {'24048': 1,
  '412': 2,
  '30682': 0,
  '30491': 2,
  '2091': 3,
  '8281': 0,
  '20587': 3,
  '483': 3,
  '8883': 2,
  '43366': 3,
  '37626': 2,
  '43270': 2,
  '33780': 2,
  '39700': 2,
  '24091': 2,
  '7786': 2,
  '4069': 3,
  '34210': 3,
  '206': 3,
  '31106': 3,
  '19303': 2,
  '38091': 1,
  '28626': 2,
  '1217': 2,
  '20875': 2,
  '8573': 2,
  '3990': 3,
  '4697': 3,
  '40810': 2,
  '34186': 2},
 '5187': {'43487': 0,
  '22069': 1,
  '41975': 0,
  '14624': 0,
  '26190': 1,
  '13008': 0,
  '16336': 1,
  '23097': 1,
  '33306': 1,
  '42909': 0,
  '33100': 1,


In [11]:
edges, inputs, labels, query_ridxs, node_graph_ids, edge_graph_ids, candidate_ridxs = [], [], [], [], [], [], []
train_label = json.load(open(train_label_path))
with open(train_query_path) as f:
    train_query_lines = f.readlines()
for train_query_line in tqdm(train_query_lines):
    train_query_line = train_query_line.strip()
    train_query_dict = json.loads(train_query_line)
    input_str = "[SEP]".join(train_query_dict['crime']) + \
        "[SEP]"+train_query_dict['q']
    tokenized_inputs = tokenizer(input_str, return_tensors="pt")
    query_idx = len(inputs)
    inputs.append(tokenized_inputs)
    node_graph_ids.append(len(query_ridxs))
    query_ridxs.append(train_query_dict['ridx'])
    query_ridx = str(train_query_dict['ridx'])
    candidates_path = osp.join(train_candidates_path, query_ridx)
    for candidate in os.listdir(candidates_path):
        candidate_ridx = candidate[:-5]
        candidate_path = osp.join(candidates_path, candidate)
        candidate_dict = json.load(open(candidate_path))
        all_text = ''.join(candidate_dict.values())
        crime_name = crime_pattern.search(all_text)
        if crime_name is None:
            crime_name = ''
        else:
            crime_name = crime_name.group(1) + '罪'
        candidate_text = '[SEP]'.join([crime_name, candidate_dict['ajjbqk']])
        tokenized_candidate = tokenizer(candidate_text, return_tensors="pt")
        candidate_idx = len(inputs)
        inputs.append(tokenized_candidate)
        node_graph_ids.append(node_graph_ids[-1])
        edge_graph_ids.append(node_graph_ids[-1])
        edges.append([query_idx, candidate_idx])
        candidate_ridxs.append(int(candidate_ridx))
        if candidate_ridx in train_label[query_ridx]:
            labels.append(train_label[query_ridx][candidate_ridx])
        else:
            labels.append(0)

 20%|█▉        | 39/197 [00:40<02:42,  1.03s/it]


KeyboardInterrupt: 

In [None]:
ic(len(edges), len(inputs), len(labels), len(query_ridxs), len(candidate_ridxs), len(node_graph_ids), len(edge_graph_ids))

ic| len(edges): 19718
    len(inputs): 19915
    len(labels): 19718
    len(query_ridxs): 197
    len(candidate_ridxs): 19718
    len(node_graph_ids): 19915
    len(edge_graph_ids): 19718


(19718, 19915, 19718, 197, 19718, 19915, 19718)

In [None]:
torch.save(inputs, osp.join(train_path, 'inputs.pt'))
torch.save(edges, osp.join(train_path, 'edges.pt'))
torch.save(labels, osp.join(train_path, 'labels.pt'))
torch.save(query_ridxs, osp.join(train_path, 'query_ridxs.pt'))
torch.save(candidate_ridxs, osp.join(train_path, 'candidate_ridxs.pt'))
torch.save(node_graph_ids, osp.join(train_path, 'node_graph_ids.pt'))
torch.save(edge_graph_ids, osp.join(train_path, 'edge_graph_ids.pt'))

In [None]:
val_path = '../data/val/processed'
val_query_ridxs = query_ridxs[:17]
edge_split_idx = (np.array(edge_graph_ids) < 17).sum()
node_split_idx = (np.array(node_graph_ids) < 17).sum()
val_inputs = inputs[:node_split_idx]
val_edges = edges[:edge_split_idx]
val_labels = labels[:edge_split_idx]
val_candidate_ridxs = candidate_ridxs[:edge_split_idx]
val_node_graph_ids = node_graph_ids[:node_split_idx]
val_edge_graph_ids = edge_graph_ids[:edge_split_idx]
torch.save(val_inputs, osp.join(val_path, 'inputs.pt'))
torch.save(val_edges, osp.join(val_path, 'edges.pt'))
torch.save(val_labels, osp.join(val_path, 'labels.pt'))
torch.save(val_query_ridxs, osp.join(val_path, 'query_ridxs.pt'))
torch.save(val_candidate_ridxs, osp.join(val_path, 'candidate_ridxs.pt'))
torch.save(val_node_graph_ids, osp.join(val_path, 'node_graph_ids.pt'))
torch.save(val_edge_graph_ids, osp.join(val_path, 'edge_graph_ids.pt'))

In [None]:
train_dev_path = '../data/train_dev/processed'
train_dev_query_ridxs = query_ridxs[17:]
train_dev_inputs = inputs[node_split_idx:]
train_dev_edges = edges[edge_split_idx:]
train_dev_labels = labels[edge_split_idx:]
train_dev_candidate_ridxs = candidate_ridxs[edge_split_idx:]
train_dev_node_graph_ids = node_graph_ids[node_split_idx:]
train_dev_edge_graph_ids = edge_graph_ids[edge_split_idx:]
torch.save(train_dev_inputs, osp.join(train_dev_path, 'inputs.pt'))
torch.save(train_dev_edges, osp.join(train_dev_path, 'edges.pt'))
torch.save(train_dev_labels, osp.join(train_dev_path, 'labels.pt'))
torch.save(train_dev_query_ridxs, osp.join(train_dev_path, 'query_ridxs.pt'))
torch.save(train_dev_candidate_ridxs, osp.join(train_dev_path, 'candidate_ridxs.pt'))
torch.save(train_dev_node_graph_ids, osp.join(train_dev_path, 'node_graph_ids.pt'))
torch.save(train_dev_edge_graph_ids, osp.join(train_dev_path, 'edge_graph_ids.pt'))

In [None]:
test_path = '../data/test'
test_query_path = osp.join(test_path, 'query.json')
test_candidates_path = osp.join(test_path, 'candidates')

In [81]:
test_edges, test_inputs, test_query_ridxs, test_node_graph_ids, test_edge_graph_ids, test_candidate_ridxs = [
], [], [], [], [], []
with open(test_query_path) as f:
    test_query_lines = f.readlines()
for test_query_line in tqdm(test_query_lines):
    test_query_line = test_query_line.strip()
    test_query_dict = json.loads(test_query_line)
    input_str = "[SEP]".join(test_query_dict['crime']) + \
        "[SEP]"+test_query_dict['q']
    tokenized_inputs = tokenizer(input_str, return_tensors="pt")
    query_idx = len(inputs)
    test_inputs.append(tokenized_inputs)
    test_node_graph_ids.append(len(node_graph_ids))
    test_query_ridxs.append(test_query_dict['ridx'])
    query_ridx = str(test_query_dict['ridx'])
    candidates_path = osp.join(test_candidates_path, query_ridx)
    for candidate in os.listdir(candidates_path):
        candidate_ridx = candidate[:-5]
        candidate_path = osp.join(candidates_path, candidate)
        candidate_dict = json.load(open(candidate_path))
        all_text = ''.join(candidate_dict.values())
        crime_name = crime_pattern.search(all_text)
        if crime_name is None:
            crime_name = ''
        else:
            crime_name = crime_name.group(1) + '罪'
        candidate_text = '[SEP]'.join([crime_name, candidate_dict['ajjbqk']])
        tokenized_candidate = tokenizer(candidate_text, return_tensors="pt")
        candidate_idx = len(inputs)
        test_inputs.append(tokenized_candidate)
        test_node_graph_ids.append(node_graph_ids[-1])
        test_edge_graph_ids.append(node_graph_ids[-1])
        test_edges.append([query_idx, candidate_idx])
        test_candidate_ridxs.append(int(candidate_ridx))
test_processed_path = '../data/test/processed'
torch.save(test_inputs, osp.join(test_processed_path, 'inputs.pt'))
torch.save(test_edges, osp.join(test_processed_path, 'edges.pt'))
torch.save(test_query_ridxs, osp.join(test_processed_path, 'query_ridxs.pt'))
torch.save(test_candidate_ridxs, osp.join(test_processed_path, 'candidate_ridxs.pt'))
torch.save(test_node_graph_ids, osp.join(test_processed_path, 'node_graph_ids.pt'))
torch.save(test_edge_graph_ids, osp.join(test_processed_path, 'edge_graph_ids.pt'))

100%|██████████| 40/40 [00:09<00:00,  4.19it/s]
