In [1]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Author: RenzeLou
# use this script to preprocess all .tsv file to .json format and generate .source_vocab for each split(train/dev/test)
# we don't do lowering because we use bert-base-cased

import json,csv
import numpy as np
import pandas as pd
import re

In [2]:
train_file = []
dev_file = []
test_file = []
labels = []

In [3]:
with open("./test.tsv","r",encoding="utf-8") as test:
    t=csv.reader(test,delimiter='\t')
    for k,line in enumerate(t):
        print(list(line))
#         assert len(line) == 4
        if k==10:
            break

['index', 'sentence1', 'sentence2']
['0', 'Maude and Dora had seen the trains rushing across the prairie, with long, rolling puffs of black smoke streaming back from the engine. Their roars and their wild, clear whistles could be heard from far away. Horses ran away when they came in sight.', 'Horses ran away when Maude and Dora came in sight.']
['1', 'Maude and Dora had seen the trains rushing across the prairie, with long, rolling puffs of black smoke streaming back from the engine. Their roars and their wild, clear whistles could be heard from far away. Horses ran away when they came in sight.', 'Horses ran away when the trains came in sight.']
['2', 'Maude and Dora had seen the trains rushing across the prairie, with long, rolling puffs of black smoke streaming back from the engine. Their roars and their wild, clear whistles could be heard from far away. Horses ran away when they came in sight.', 'Horses ran away when the puffs came in sight.']
['3', 'Maude and Dora had seen the tr

In [4]:
with open("./train.tsv","r",encoding="utf-8") as train:
    t=csv.reader(train,delimiter='\t')
    for k,line in enumerate(t):
        print(list(line))
#         assert len(line) == 4
        if k==10:
            break

['index', 'sentence1', 'sentence2', 'label']
['0', 'I stuck a pin through a carrot. When I pulled the pin out, it had a hole.', 'The carrot had a hole.', '1']
['1', "John couldn't see the stage with Billy in front of him because he is so short.", 'John is so short.', '1']
['2', 'The police arrested all of the gang members. They were trying to stop the drug trade in the neighborhood.', 'The police were trying to stop the drug trade in the neighborhood.', '1']
['3', "Steve follows Fred's example in everything. He influences him hugely.", 'Steve influences him hugely.', '0']
['4', 'When Tatyana reached the cabin, her mother was sleeping. She was careful not to disturb her, undressing and climbing back into her berth.', 'mother was careful not to disturb her, undressing and climbing back into her berth.', '0']
['5', 'George got free tickets to the play, but he gave them to Eric, because he was particularly eager to see it.', 'George was particularly eager to see it.', '0']
['6', 'John was 

In [5]:
with open("./train.tsv","r",encoding="utf-8") as train:
    t=csv.reader(train,delimiter='\t')
    for i,line in enumerate(t):
        if i==0:
            print("skip first line")
            continue
#         print(list(line))
        try:
            assert len(line) == 4
        except:
            print("error")
            print(line)
            continue
        lb=line[3]
        if lb not in labels:
            labels.append(lb)
        label = labels.index(lb)
        instance = dict()
        instance["index"],instance['seq1'],instance['seq2'],instance["label"]=int(line[0]),line[1],line[2],label
        train_file.append(instance)

skip first line


In [6]:
with open("./dev.tsv","r",encoding="utf-8") as dev:
    t=csv.reader(dev,delimiter='\t')
    for i,line in enumerate(t):
#         print(list(line))
        if i==0:
            print("skip first line")
            continue
#         print(list(line))
        try:
            assert len(line) == 4
        except:
            print("error")
            print(line)
            continue
        lb=line[3]
        if lb not in labels:
            labels.append(lb)
        label = labels.index(lb)
        instance = dict()
        instance["index"],instance['seq1'],instance['seq2'],instance["label"]=int(line[0]),line[1],line[2],label
        dev_file.append(instance)

skip first line


In [7]:
with open("./test.tsv","r",encoding="utf-8") as test:
    t=csv.reader(test,delimiter='\t')
    for i,line in enumerate(t):
#         print(list(line))
        if i==0:
            print("skip first line")
            continue
#         print(list(line))
        try:
            assert len(line) == 3
        except:
            print("error")
            print(line)
            continue
#         lb=line[3]
#         if lb not in labels:
#             labels.append(lb)
#         label = labels.index(lb)
        instance = dict()
        instance["index"],instance['seq1'],instance['seq2'],instance["label"]=int(line[0]),line[1],line[2],None
        test_file.append(instance)

skip first line


In [8]:
def simple_filt(seq:str):
    if seq is None:
        return None
    tokens = seq.split()
    tokens = [t for t in tokens if t != ""]
    
    return " ".join(tokens)

In [9]:
def write_to_json(list_files:list,file_names:list,labels:list):
    ''' write cached object to json file(train/dev/test)
    and also generate data statistic information
    
    assert the order of input list is train/dev/test
    '''
    def seq_len(seq:str):
        try:
            return len(seq.split(" "))
        except:
            if seq is None:
                return 0
            else:
                raise KeyError
                
    max_lens = []
    for i,list_file in enumerate(list_files):
        max_len = 0
        new_list = []
        for item in list_file:
            item['seq1'],item['seq2'] = simple_filt(item["seq1"]),simple_filt(item['seq2'])
            max_len_two_seq=max(seq_len(item['seq1']),seq_len(item['seq2']))
            max_len = max(max_len_two_seq,max_len)
            new_list.append(item)
        with open(file_names[i],"w") as f:
            json.dump(new_list,f)
        
        print("successfully dump %s file" %file_names[i])
        max_lens.append(max_len)
    
    data_info = dict()
    train_info = dict()
    dev_info = dict()
    test_info = dict()
    idx2lb = dict()
    for i,lb in enumerate(labels):
        idx2lb[i] = lb
    print("index to label:",idx2lb)
    
    train_info["instance_num"],train_info["instance_max_len"] = len(list_files[0]),max_lens[0]
    dev_info["instance_num"],dev_info["instance_max_len"] = len(list_files[1]),max_lens[1]
    test_info["instance_num"],test_info["instance_max_len"] = len(list_files[2]),max_lens[2]
    
    
    data_info["labels"] = idx2lb
    data_info["train"] = train_info
    data_info["dev"] = dev_info
    data_info["test"] = test_info
    
    data_info_name = "./data_info.json"
    with open(data_info_name,"w") as f:
        json.dump(data_info,f)
    print("successfully dump %s file" %data_info_name )

In [10]:
labels

['1', '0']

In [11]:
write_to_json([train_file,dev_file,test_file],["train.json","dev.json","test.json"],labels)

successfully dump train.json file
successfully dump dev.json file
successfully dump test.json file
index to label: {0: '1', 1: '0'}
successfully dump ./data_info.json file
