In [1]:
import sys
import glob
from termcolor import colored
import random

In [20]:
def wm_tool_eval(gold_sent_labels, wm_sent_labels):
    '''
    gold_sent_labels = list of sentences where each sentence is a list of gold labels of its tokens
    labels should be in 'B-claim', 'I-claim', 'O-claim'
    
    wm_sent_labels = list of sentences where each sentence is a list of prediction labels of its tokens (WM tool)
    labels should be in 'claim', 'no-claim'
    '''
    # total number of claim and no-claim segments in all sentences, with max of 1 no-claim segment per sentence
    claims_segment_count, claims_correct, claims_mistake = 0, 0, 0
    noclaims_segment_count, noclaims_correct, noclaims_mistake = 0, 0, 0
    
    # looping through all sentences
    for gold_sent, wm_tagged_sent in zip(gold_sent_labels, wm_sent_labels):
        no_claim_tokens, all_claim_tokens, single_claim_tokens = [], [], []
        
        # looping through token_labels for a single sentence
        for gold_token, wm_tagged_token in zip(gold_sent, wm_tagged_sent):
            
            if gold_token == 'O-claim':
                no_claim_tokens.append(wm_tagged_token)
            
            else:
                if gold_token == 'B-claim':
                    if len(single_claim_tokens) == 0:
                        single_claim_tokens.append(wm_tagged_token)
                    else:
                        all_claim_tokens.append(single_claim_tokens)
                        single_claim_tokens = []
                
                else: #gold_token == 'I-claim'
                    single_claim_tokens.append(wm_tagged_token)
        
        # adding the last claim in the sentence
        if len(single_claim_tokens) > 0:
            all_claim_tokens.append(single_claim_tokens)
                        
        
        # checking if all of the 'O-claim' tokens are labels as 'no-claim' by the WM tool
        if len(no_claim_tokens) > 0:
            noclaims_segment_count += 1
            if all([label == 'no-claim' for label in no_claim_tokens]):
                noclaims_correct += 1
            else:
                noclaims_mistake += 1
                
        # for each claim segment in the sentence
        #     checking that it has at least token tagged as 'claim' by the WM tool
        for claim in all_claim_tokens:
            claims_segment_count += 1
            if any([label == 'claim' for label in claim]):
                claims_correct += 1
            else:
                claims_mistake += 1
    
    # returning the total number of segments with the counts of correct and wrong tagged segments by the WM tool
    return noclaims_segment_count, noclaims_correct, noclaims_mistake, claims_segment_count, claims_correct, claims_mistake

In [21]:
wm_tool_eval([['O-claim', 'O-claim', 'O-claim', 'O-claim', 'O-claim', 'O-claim'], 
              ['B-claim', 'I-claim', 'I-claim', 'O-claim', 'B-claim', 'I-claim'],
              ['O-claim', 'B-claim', 'I-claim', 'I-claim', 'I-claim', 'O-claim']],
             
             [['no-claim', 'no-claim', 'no-claim', 'no-claim', 'no-claim', 'no-claim'], 
              ['no-claim', 'claim', 'no-claim', 'no-claim', 'no-claim', 'no-claim'], 
              ['no-claim', 'no-claim', 'claim', 'no-claim', 'no-claim', 'no-claim']])

(3, 3, 0, 3, 2, 1)

In [2]:
def read_wm_essays(essays_dir, header=False):
     # read files
    data = []
    for file in sorted(glob.glob(essays_dir+'*.tsv')):
        data.append(open(file).readlines())


    essays_sent_token_label, tokens, labels = [], [], []

    for essay_id, essay in enumerate(data):
        prev_sent_id = '0'
        essay_sents, sent_token, sent_label = [], [], []
        doc_tokens, doc_labels = [], []
        
        if header:
            essay = essay[1:]
        
        for i,line in enumerate(essay):
            if '_NEW_LINE_' not in line:
                sent_id, token_id, token, label = line.rstrip().split()

                if sent_id != prev_sent_id:
                    essay_sents.append((sent_token, sent_label))
                    sent_token, sent_label = [], []

                if len(token) < 25 and 'www' not in token:
                    doc_tokens.append(token)
                    doc_labels.append('{}-claim'.format(label.split('-')[0]))
                    sent_token.append(token)
                    sent_label.append('{}-claim'.format(label.split('-')[0]))

                prev_sent_id = sent_id
        
        essay_sents.append((sent_token, sent_label))
        essays_sent_token_label.append(essay_sents)
        tokens.append(doc_tokens)
        labels.append(doc_labels)

    essay_str, essay_str_sent = [], []
    for essay in essays_sent_token_label:
        
        sentences = []
        for essay_sent_tokens, essay_sent_labels in essay:
            sent = ' '.join(essay_sent_tokens)
    #         sent = sent.replace(" ' m", "'m")
    #         sent = sent.replace(" ' s", "'s")
    #         sent = sent.replace(" : ", ": ")
            sentences.append(sent)
        
        essay_str_sent.append(sentences)
        essay_str.append(' '.join(sentences))

    return {'essay_sent_token_label': essays_sent_token_label, 'tokens': tokens, 'labels': labels,
            'essay': essay_str, 'essay_sent': essay_str_sent}

In [3]:
wm1 = read_wm_essays('/Users/talhindi/Documents/data_wm/arg_clean_45_1/')
wm2 = read_wm_essays('/Users/talhindi/Documents/data_wm/arg_clean_45_2/')
wm_nr = read_wm_essays('/Users/talhindi/Documents/data_wm/wm_narrative/', header=True)
sg = read_wm_essays('/Users/talhindi/Documents/claim_detection/data/SG2017_claim/train/', header=True)

In [19]:
wm_sent = [sent for essay_sent in wm2['essay_sent'] for sent in essay_sent]
sg_sent = [sent for essay_sent in sg['essay_sent'] for sent in essay_sent]
wm_nr_sent = [sent for essay_sent in wm_nr['essay_sent'] for sent in essay_sent]

random.seed(2453)
random.shuffle(wm_sent)
random.shuffle(sg_sent)
random.shuffle(wm_nr_sent)

## SG vs WM

In [6]:
mixed_sent = []
for sent1, sent2 in zip(wm_sent, sg_sent[:len(wm_sent)]):
    if int(10*random.random()) % 2 == 0:
        mixed_sent.append('{}\t{}\t0\n'.format(sent1, sent2))
    else:
        mixed_sent.append('{}\t{}\t0\n'.format(sent2, sent1))

for sent1, sent2 in zip(wm_sent, sg_sent[len(wm_sent):2*len(wm_sent)]):
    if int(10*random.random()) % 2 == 0:
        mixed_sent.append('{}\t{}\t0\n'.format(sent1, sent2))
    else:
        mixed_sent.append('{}\t{}\t0\n'.format(sent2, sent1))

for sent1, sent2 in zip(wm_sent, sg_sent[2*len(wm_sent):3*len(wm_sent)]):
    if int(10*random.random()) % 2 == 0:
        mixed_sent.append('{}\t{}\t0\n'.format(sent1, sent2))
    else:
        mixed_sent.append('{}\t{}\t0\n'.format(sent2, sent1))

len(mixed_sent)

5586

In [34]:
wm_same_sent = []
for i in range(1,len(wm_sent)):
    wm_same_sent.append('{}\t{}\t1\n'.format(wm_sent[i-1], wm_sent[i]))
    if i-2 >= 0:
        wm_same_sent.append('{}\t{}\t1\n'.format(wm_sent[i-1], wm_sent[i]))

sg_same_sent = []
for i in range(1,len(sg_sent)-1):
    sg_same_sent.append('{}\t{}\t1\n'.format(sg_sent[i-1], sg_sent[i]))
    if i-2 >= 0:
        sg_same_sent.append('{}\t{}\t1\n'.format(sg_sent[i-1], sg_sent[i]))

random.seed(2453)
random.shuffle(wm_same_sent)
random.shuffle(sg_same_sent)

len(wm_same_sent), len(sg_same_sent)

(3721, 11273)

In [14]:
all_sent = mixed_sent[:5000] + wm_same_sent[:2500] + sg_same_sent[:2500]
random.seed(2453)
random.shuffle(all_sent)
print(len(all_sent))

with open('/Users/talhindi/Documents/claim_detection/data/sg_wm_sent.tsv', 'w') as writer:
    writer.writelines(all_sent)

10000

In [18]:
with open('/Users/talhindi/Documents/claim_detection/data/sg_wm_sent.tsv', 'w') as writer:
    writer.writelines(all_sent)

## WM-arg VS WM-nr

In [37]:
mixed_sent = []
for sent1, sent2 in zip(wm_nr_sent, wm_sent[:len(wm_nr_sent)]):
    if int(10*random.random()) % 2 == 0:
        mixed_sent.append('{}\t{}\t0\n'.format(sent1, sent2))
    else:
        mixed_sent.append('{}\t{}\t0\n'.format(sent2, sent1))

for sent1, sent2 in zip(wm_nr_sent[:len(wm_sent)-len(wm_nr_sent)], wm_sent[len(wm_nr_sent):]):
    if int(10*random.random()) % 2 == 0:
        mixed_sent.append('{}\t{}\t0\n'.format(sent1, sent2))
    else:
        mixed_sent.append('{}\t{}\t0\n'.format(sent2, sent1))

#shift by 5       
for sent1, sent2 in zip(wm_nr_sent, wm_sent[5:len(wm_nr_sent)+5]):
    if int(10*random.random()) % 2 == 0:
        mixed_sent.append('{}\t{}\t0\n'.format(sent1, sent2))
    else:
        mixed_sent.append('{}\t{}\t0\n'.format(sent2, sent1))

for sent1, sent2 in zip(wm_nr_sent[:len(wm_sent)-len(wm_nr_sent)], wm_sent[len(wm_nr_sent)+5:-5]):
    if int(10*random.random()) % 2 == 0:
        mixed_sent.append('{}\t{}\t0\n'.format(sent1, sent2))
    else:
        mixed_sent.append('{}\t{}\t0\n'.format(sent2, sent1))

#shift by 10
for sent1, sent2 in zip(wm_nr_sent, wm_sent[10:len(wm_nr_sent)+10]):
    if int(10*random.random()) % 2 == 0:
        mixed_sent.append('{}\t{}\t0\n'.format(sent1, sent2))
    else:
        mixed_sent.append('{}\t{}\t0\n'.format(sent2, sent1))

for sent1, sent2 in zip(wm_nr_sent[:len(wm_sent)-len(wm_nr_sent)], wm_sent[len(wm_nr_sent)+10:-10]):
    if int(10*random.random()) % 2 == 0:
        mixed_sent.append('{}\t{}\t0\n'.format(sent1, sent2))
    else:
        mixed_sent.append('{}\t{}\t0\n'.format(sent2, sent1))


len(mixed_sent)

5556

In [38]:
wmarg_same_sent = []
for i in range(1,len(wm_sent)):
    wmarg_same_sent.append('{}\t{}\t1\n'.format(wm_sent[i-1], wm_sent[i]))
    if i-2 >= 0:
        wmarg_same_sent.append('{}\t{}\t1\n'.format(wm_sent[i-1], wm_sent[i]))

wmnr_same_sent = []
for i in range(1,len(wm_nr_sent)-1):
    wmnr_same_sent.append('{}\t{}\t1\n'.format(wm_nr_sent[i-1], wm_nr_sent[i]))
    if i-2 >= 0:
        wmnr_same_sent.append('{}\t{}\t1\n'.format(wm_nr_sent[i-1], wm_nr_sent[i]))

random.seed(2453)
random.shuffle(wmarg_same_sent)
random.shuffle(wmnr_same_sent)

len(wmarg_same_sent), len(wmnr_same_sent)

(3721, 2659)

In [39]:
all_sent = mixed_sent[:5000] + wmarg_same_sent[:2500] + wmnr_same_sent[:2500]
random.seed(2453)
random.shuffle(all_sent)
print(len(all_sent))

with open('/Users/talhindi/Documents/claim_detection/data/wmarg_wmnr_sent.tsv', 'w') as writer:
    writer.writelines(all_sent)

10000
