In [1]:
import os
import pandas as pd
import json
import numpy as np

from pprint import pprint
from sklearn.model_selection import train_test_split

In [2]:
def preprocess_wanli (json_file_name, data_split):
    
    save_file_path, _ = os.path.split(json_file_name)
    
    print (f"save_file_path : {save_file_path}")
    
    with open(json_file_name, "r") as jfile:    
        all_data = [json.loads(line) for line in jfile]    
    
    def create_label(gold_class):
        label_dict = {'entailment' : 1, 'contradiction': 0, 'neutral': 0}    
        return int(label_dict[gold_class])
    
    data_df = pd.DataFrame.from_dict(all_data, orient="columns")
    
    print (data_df.columns)
    
    data_df["label"] = data_df["gold"].apply(create_label)
    
    print (data_df.shape)
    print (data_df["label"].value_counts())
    # print (data_df)
    print()
    
    save_cols = ["premise", "hypothesis", "label"]
    
    if data_split == "train":
        
        col_names = data_df.columns
        
        X_col = [col for col in col_names if col != "label"]
        y_col = ["label"]
        
        X = data_df[X_col]
        y = data_df[y_col]
        
        X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=111, test_size=0.10, shuffle=True, stratify=y)

        print ()
        print ("Train Data :", X_train.shape, y_train.shape)
        print ("Valid Data :", X_valid.shape, y_valid.shape)

        train_df = pd.concat((X_train, y_train), axis=1)
        
        processed_file_path = os.path.join(save_file_path, f"processed_{data_split}.tsv")
        train_df.to_csv(processed_file_path, sep="\t", header=None, index=None)
        
        train_df = train_df[save_cols]
        
        final_file_path = os.path.join(save_file_path, f"{data_split}.tsv")
        train_df.to_csv(final_file_path, sep="\t", header=None, index=None)
                
        ### Valid Data ###
        valid_df = pd.concat((X_valid, y_valid), axis=1)
        
        processed_file_path = os.path.join(save_file_path, f"processed_valid.tsv")
        valid_df.to_csv(processed_file_path, sep="\t", header=None, index=None)
        
        valid_df = valid_df[save_cols]
        
        final_file_path = os.path.join(save_file_path, f"valid.tsv")
        valid_df.to_csv(final_file_path, sep="\t", header=None, index=None)
        
        print (f"Train DF Shape : {train_df.shape}")
        print (f"Train Label Count : {train_df['label'].value_counts()}")
        
        print ()
        print (f"Valid DF Shape : {valid_df.shape}")
        print (f"Valid Label Count : {valid_df['label'].value_counts()}")
        
    
    elif data_split == "test":
        
        processed_file_path = os.path.join(save_file_path, f"processed_{data_split}.tsv")
        data_df.to_csv(processed_file_path, sep="\t", header=None, index=None)
        
        data_df = data_df[save_cols]
        
        final_file_path = os.path.join(save_file_path, f"{data_split}.tsv")
        data_df.to_csv(final_file_path, sep="\t", header=None, index=None)
                
        print ()
        print (f"Test DF Shape : {data_df.shape}")
        print (f"Test Label Count : {data_df['label'].value_counts()}")

            
preprocess_wanli(json_file_name = "./../../data/train_data/je_con_prop/wanli/train.jsonl", data_split="train")
preprocess_wanli(json_file_name = "./../../data/train_data/je_con_prop/wanli/test.jsonl", data_split="test")


save_file_path : ./../../data/train_data/je_con_prop/wanli
Index(['id', 'premise', 'hypothesis', 'gold', 'genre', 'pairID'], dtype='object')
(102885, 7)
0    64374
1    38511
Name: label, dtype: int64


Train Data : (92596, 6) (92596, 1)
Valid Data : (10289, 6) (10289, 1)
Train DF Shape : (92596, 3)
Train Label Count : 0    57936
1    34660
Name: label, dtype: int64

Valid DF Shape : (10289, 3)
Valid Label Count : 0    6438
1    3851
Name: label, dtype: int64
save_file_path : ./../../data/train_data/je_con_prop/wanli
Index(['id', 'premise', 'hypothesis', 'gold', 'genre', 'pairID'], dtype='object')
(5000, 7)
0    3142
1    1858
Name: label, dtype: int64


Test DF Shape : (5000, 3)
Test Label Count : 0    3142
1    1858
Name: label, dtype: int64
