In [1]:
import glob
import pandas as pd
import numpy as np
import json
import pickle
import string
import copy
from collections import defaultdict
from collections import Counter

import spacy
import networkx as nx
model_dir = '/Users/talhindi/miniconda3/lib/python3.7/site-packages/en_core_web_sm/en_core_web_sm-2.1.0'

## Reading Data

In [2]:
train_test_split = pd.read_csv('../data/SG2017/train-test-split.csv', sep=';')

In [3]:
essays_txt_prg_list = []
for file in sorted(glob.glob("../data/SG2017/*.txt")):
    essay = open(file).readlines()
    essays_txt_prg_list.append(essay)

essay_txt_str = []
for essay in essays_txt_prg_list:
    essay_txt_str.append(''.join(essay))
    
essays_ann = []
for file in sorted(glob.glob("../data/SG2017/*.ann")):
    essay = open(file).readlines()
    essays_ann.append(essay)

In [4]:
essays_segments = []

for essay in essays_ann:    
    segments = []
    
    for line in essay:
        if line[0] == 'T':
            _, label_s_e, text = line.rstrip().split('\t')
            label, start, end = label_s_e.split()
            segments.append((label, int(start), int(end), text))
            
    segments.sort(key = lambda element : element[1])
    essays_segments.append(segments)

## Labels

In [15]:
def get_labels(essay_spacy, segments):
    '''O = 0, Arg-B = 1, Arg-I = 2'''
    
    doc_len = len(essay_spacy)
    
    labels = []
    tokens = []
    arg_seg_starts = [start for arg_type, start, end, text in segments]
    
    for token in essay_spacy:
        arg_I_token = False

        if token.idx in arg_seg_starts:
            labels.append('Arg-B')
#             labels.append(1.0)
            tokens.append(token.text)
            assert token.text in segments[arg_seg_starts.index(token.idx)][-1]
        else:
            for _, start, end, _ in segments:
                if token.idx > start and token.idx+len(token) <= end:
                    labels.append('Arg-I')
#                     labels.append(2.0)
                    tokens.append(token.text)
                    arg_I_token = True
            if not arg_I_token:
                labels.append('O')
#                 labels.append(0.0)
                tokens.append(token.text)

    assert len(labels) == doc_len
    return tokens, labels

In [8]:
set([arg_type for arg_type, start, end, text in essays_segments[0]])

{'Claim', 'MajorClaim', 'Premise'}

## Spacy

In [12]:
nlp = spacy.load(model_dir)

essay_spacy = []
for essay in essay_txt_str:
    essay_spacy.append(nlp(essay))

In [13]:
essay_prg_spacy = []
for essay_prgs in essays_txt_prg_list:
    prg_list_tok = []
    for prg in essay_prgs:
        prg_list_tok.append(nlp(prg))

    essay_prg_spacy.append(prg_list_tok)

In [16]:
# counting labels from each type
# without new lines
token_labels = []
train_BIO = defaultdict(int)
test_BIO = defaultdict(int)

for doc, segments, group in zip(essay_spacy, essays_segments, train_test_split.SET):
    tokens, labels = get_labels(doc, segments)
    
    if group == "TRAIN":
        for label in  labels:
            train_BIO[label] += 1
    else:
        for label in  labels:
            test_BIO[label] += 1


train_BIO,test_BIO

(defaultdict(int, {'O': 39617, 'Arg-B': 4823, 'Arg-I': 75312}),
 defaultdict(int, {'O': 9801, 'Arg-B': 1266, 'Arg-I': 18748}))

## Exporting Tokenized and Labeled Essays

In [21]:
# claim + premise merged
train_file  = open('../data/SG2017_tok/train.txt', 'w')
test_file  = open('../data/SG2017_tok/test.txt', 'w')

# train_dep_file  = open('../data/SG2017_tokenized/dep/train.txt', 'w')
# test_dep_file  = open('../data/SG2017_tokenized/dep/test.txt', 'w')

for essay_id, (doc, segments, group) in enumerate(zip(essay_spacy, essays_segments, train_test_split.SET)):
    
    tokens, labels = get_labels(doc, segments)
    labeled_token_id = 0
    
    if essay_id+1 < 10:
        essay_3digit_id = '00'+str(essay_id+1)
    elif essay_id+1 < 100:
        essay_3digit_id = '0'+str(essay_id+1)
    else:
        essay_3digit_id = str(essay_id+1)
    
    if group == "TRAIN":
        with open('../data/SG2017_tok/train/essay{}.tsv'.format(essay_3digit_id), 'w') as file:
            file.write('sentence_id\ttoken_id\ttoken\tlabel\n')
            for sent_id, sent in enumerate(doc.sents):
                for token_id, token in enumerate(sent):
                    assert token.text == tokens[labeled_token_id]
                    file.write('{}\t{}\t{}\t{}\n'.format(sent_id, token_id, token.text.replace('\n','_NEW_LINE_'), labels[labeled_token_id]))
                    
                    if '\n' not in token.text:
                        train_file.write('{} {}\n'.format(token.text, labels[labeled_token_id]))
#                         train_dep_file.write('{}_{} {}\n'.format(token.text, token.dep_, labels[labeled_token_id]))
                    
                    labeled_token_id += 1
                train_file.write('\n')
#                 train_dep_file.write('\n')
                
    else:
        with open('../data/SG2017_tok/test/essay{}.tsv'.format(essay_3digit_id), 'w') as file:
            file.write('sentence_id\ttoken_id\ttoken\tlabel\n')
            for sent_id, sent in enumerate(doc.sents):
                for token_id, token in enumerate(sent):
                    assert token.text == tokens[labeled_token_id]
                    file.write('{}\t{}\t{}\t{}\n'.format(sent_id, token_id, token.text.replace('\n','_NEW_LINE_'), labels[labeled_token_id]))
                    
                    if '\n' not in token.text:
                        test_file.write('{} {}\n'.format(token.text, labels[labeled_token_id]))
#                         test_dep_file.write('{}_{} {}\n'.format(token.text, token.dep_, labels[labeled_token_id]))
                    
                    labeled_token_id += 1
                test_file.write('\n')  
#                 test_dep_file.write('\n')
                

In [None]:
# claim + premise merged with dep relation
train_dep_file  = open('../data/SG2017_tok_dep/train.txt', 'w')
test_dep_file  = open('../data/SG2017_tok_dep/test.txt', 'w')

for essay_id, (doc, segments, group) in enumerate(zip(essay_spacy, essays_segments, train_test_split.SET)):
    
    tokens, labels = get_labels(doc, segments)
    labeled_token_id = 0
    
    if essay_id+1 < 10:
        essay_3digit_id = '00'+str(essay_id+1)
    elif essay_id+1 < 100:
        essay_3digit_id = '0'+str(essay_id+1)
    else:
        essay_3digit_id = str(essay_id+1)
    
    if group == "TRAIN":
        with open('../data/SG2017_tok_dep/train/essay{}.tsv'.format(essay_3digit_id), 'w') as file:
            file.write('sentence_id\ttoken_id\ttoken\tlabel\n')
            for sent_id, sent in enumerate(doc.sents):
                for token_id, token in enumerate(sent):
                    assert token.text == tokens[labeled_token_id]
                    file.write('{}\t{}\t{}\t{}\n'.format(sent_id, token_id, token.text.replace('\n','_NEW_LINE_'), labels[labeled_token_id]))
                    
                    if '\n' not in token.text:
                        train_dep_file.write('{}_{} {}\n'.format(token.text, token.dep_, labels[labeled_token_id]))
                    
                    labeled_token_id += 1
                    
                train_dep_file.write('\n')
                
    else:
        with open('../data/SG2017_tok_dep/test/essay{}.tsv'.format(essay_3digit_id), 'w') as file:
            file.write('sentence_id\ttoken_id\ttoken\tlabel\n')
            for sent_id, sent in enumerate(doc.sents):
                for token_id, token in enumerate(sent):
                    assert token.text == tokens[labeled_token_id]
                    file.write('{}\t{}\t{}\t{}\n'.format(sent_id, token_id, token.text.replace('\n','_NEW_LINE_'), labels[labeled_token_id]))
                    
                    if '\n' not in token.text:
                        test_dep_file.write('{}_{} {}\n'.format(token.text, token.dep_, labels[labeled_token_id]))
                    
                    labeled_token_id += 1
                    
                test_dep_file.write('\n')
                

# Seperating Claim and Premise

In [24]:
def get_labels_claim_premise(essay_spacy, segments, mode='claim'):
    '''labels are: B-claim, I-claim, B-premise, I-premise, O'''
    '''modes are: claim, premise, all'''
    
    assert mode in ['claim', 'premise', 'all']
    
    mode_labels= {'claim':['Claim','MajorClaim'],
                  'premise':'Premise',
                  'all': ['Claim', 'MajorClaim', 'Premise']}
    
    arg_type_to_tag = {'MajorClaim': 'claim',
                      'Claim': 'claim',
                      'Premise':'premise'}
    
    doc_len = len(essay_spacy)
    
    labels = []
    tokens = []
    
    arg_seg_starts = [start for arg_type, start, end, text in segments if arg_type in mode_labels[mode]]
    arg_seg_arg_type = [arg_type for arg_type, start, end, text in segments if arg_type in mode_labels[mode]]
    arg_seg_texts = [text for arg_type, start, end, text in segments if arg_type in mode_labels[mode]]
    
    for token in essay_spacy:
        arg_I_token = False

        if token.idx in arg_seg_starts:
            labels.append('B-' + arg_type_to_tag[arg_seg_arg_type[arg_seg_starts.index(token.idx)]])
            tokens.append(token.text)
            assert token.text in arg_seg_texts[arg_seg_starts.index(token.idx)]
        
        else:
            for arg_type, start, end, _ in segments:
                if arg_type in mode_labels[mode]:
                    if token.idx > start and token.idx+len(token) <= end:
                        labels.append('I-'+arg_type_to_tag[arg_type])
                        tokens.append(token.text)
                        arg_I_token = True
            
            if not arg_I_token:
                if mode == 'claim':
                    labels.append('O-claim')
                elif mode == 'premise':
                    labels.append('O-premise')
                else:
                    labels.append('O')
                tokens.append(token.text)

    
    assert len(labels) == doc_len
    return tokens, labels

## exporting files

In [26]:
# claim only
train_file  = open('../data/SG2017_claim/train.txt', 'w')
test_file  = open('../data/SG2017_claim/test.txt', 'w')


for essay_id, (doc, segments, group) in enumerate(zip(essay_spacy, essays_segments, train_test_split.SET)):
    
    tokens, labels = get_labels_claim_premise(doc, segments, mode='claim')
    labeled_token_id = 0
    
    if essay_id+1 < 10:
        essay_3digit_id = '00'+str(essay_id+1)
    elif essay_id+1 < 100:
        essay_3digit_id = '0'+str(essay_id+1)
    else:
        essay_3digit_id = str(essay_id+1)
    
    if group == "TRAIN":
        with open('../data/SG2017_claim/train/essay{}.tsv'.format(essay_3digit_id), 'w') as file:
            file.write('sentence_id\ttoken_id\ttoken\tlabel\n')
            for sent_id, sent in enumerate(doc.sents):
                for token_id, token in enumerate(sent):
                    assert token.text == tokens[labeled_token_id]
                    file.write('{}\t{}\t{}\t{}\n'.format(sent_id, token_id, token.text.replace('\n','_NEW_LINE_'), labels[labeled_token_id]))
                    
                    if '\n' not in token.text:
                        train_file.write('{} {}\n'.format(token.text, labels[labeled_token_id]))
                    
                    labeled_token_id += 1
                    
                train_file.write('\n')
                
    else:
        with open('../data/SG2017_claim/test/essay{}.tsv'.format(essay_3digit_id), 'w') as file:
            file.write('sentence_id\ttoken_id\ttoken\tlabel\n')
            for sent_id, sent in enumerate(doc.sents):
                for token_id, token in enumerate(sent):
                    assert token.text == tokens[labeled_token_id]
                    file.write('{}\t{}\t{}\t{}\n'.format(sent_id, token_id, token.text.replace('\n','_NEW_LINE_'), labels[labeled_token_id]))
                    
                    if '\n' not in token.text:
                        test_file.write('{} {}\n'.format(token.text, labels[labeled_token_id]))
                    
                    labeled_token_id += 1
                    
                test_file.write('\n')  
                

In [27]:
# premise only
train_file  = open('../data/SG2017_premise/train.txt', 'w')
test_file  = open('../data/SG2017_premise/test.txt', 'w')


for essay_id, (doc, segments, group) in enumerate(zip(essay_spacy, essays_segments, train_test_split.SET)):
    
    tokens, labels = get_labels_claim_premise(doc, segments, mode='premise')
    labeled_token_id = 0
    
    if essay_id+1 < 10:
        essay_3digit_id = '00'+str(essay_id+1)
    elif essay_id+1 < 100:
        essay_3digit_id = '0'+str(essay_id+1)
    else:
        essay_3digit_id = str(essay_id+1)
    
    if group == "TRAIN":
        with open('../data/SG2017_premise/train/essay{}.tsv'.format(essay_3digit_id), 'w') as file:
            file.write('sentence_id\ttoken_id\ttoken\tlabel\n')
            for sent_id, sent in enumerate(doc.sents):
                for token_id, token in enumerate(sent):
                    assert token.text == tokens[labeled_token_id]
                    file.write('{}\t{}\t{}\t{}\n'.format(sent_id, token_id, token.text.replace('\n','_NEW_LINE_'), labels[labeled_token_id]))
                    
                    if '\n' not in token.text:
                        train_file.write('{} {}\n'.format(token.text, labels[labeled_token_id]))
                    
                    labeled_token_id += 1
                    
                train_file.write('\n')
                
    else:
        with open('../data/SG2017_premise/test/essay{}.tsv'.format(essay_3digit_id), 'w') as file:
            file.write('sentence_id\ttoken_id\ttoken\tlabel\n')
            for sent_id, sent in enumerate(doc.sents):
                for token_id, token in enumerate(sent):
                    assert token.text == tokens[labeled_token_id]
                    file.write('{}\t{}\t{}\t{}\n'.format(sent_id, token_id, token.text.replace('\n','_NEW_LINE_'), labels[labeled_token_id]))
                    
                    if '\n' not in token.text:
                        test_file.write('{} {}\n'.format(token.text, labels[labeled_token_id]))
                    
                    labeled_token_id += 1
                    
                test_file.write('\n')  
                

In [25]:
# claim and premise separated
train_file  = open('../data/SG2017_claim_premise/train.txt', 'w')
test_file  = open('../data/SG2017_claim_premise/test.txt', 'w')


for essay_id, (doc, segments, group) in enumerate(zip(essay_spacy, essays_segments, train_test_split.SET)):
    
    tokens, labels = get_labels_claim_premise(doc, segments, mode='all')
    labeled_token_id = 0
    
    if essay_id+1 < 10:
        essay_3digit_id = '00'+str(essay_id+1)
    elif essay_id+1 < 100:
        essay_3digit_id = '0'+str(essay_id+1)
    else:
        essay_3digit_id = str(essay_id+1)
    
    if group == "TRAIN":
        with open('../data/SG2017_claim_premise/train/essay{}.tsv'.format(essay_3digit_id), 'w') as file:
            file.write('sentence_id\ttoken_id\ttoken\tlabel\n')
            for sent_id, sent in enumerate(doc.sents):
                for token_id, token in enumerate(sent):
                    assert token.text == tokens[labeled_token_id]
                    file.write('{}\t{}\t{}\t{}\n'.format(sent_id, token_id, token.text.replace('\n','_NEW_LINE_'), labels[labeled_token_id]))
                    
                    if '\n' not in token.text:
                        train_file.write('{} {}\n'.format(token.text, labels[labeled_token_id]))
                    
                    labeled_token_id += 1
                    
                train_file.write('\n')
                
    else:
        with open('../data/SG2017_claim_premise/test/essay{}.tsv'.format(essay_3digit_id), 'w') as file:
            file.write('sentence_id\ttoken_id\ttoken\tlabel\n')
            for sent_id, sent in enumerate(doc.sents):
                for token_id, token in enumerate(sent):
                    assert token.text == tokens[labeled_token_id]
                    file.write('{}\t{}\t{}\t{}\n'.format(sent_id, token_id, token.text.replace('\n','_NEW_LINE_'), labels[labeled_token_id]))
                    
                    if '\n' not in token.text:
                        test_file.write('{} {}\n'.format(token.text, labels[labeled_token_id]))
                    
                    labeled_token_id += 1
                    
                test_file.write('\n')  
                