In [14]:
import os
import re
import json
import pandas as pd

from collections import defaultdict
from tqdm.notebook import tqdm

from allennlp.predictors.predictor import Predictor
import allennlp_models.tagging
# POS

def get_predictor(model):
    if model == 'pos':
        url = "https://storage.googleapis.com/allennlp-public-models/elmo-constituency-parser-2020.02.10.tar.gz"
    elif model == 'ner':
        url = "https://storage.googleapis.com/allennlp-public-models/ner-elmo.2021-02-12.tar.gz"
    
    return Predictor.from_path(url)

In [15]:
def show_dialogue_and_summaries(dialogsum_df, summaries, i):
    for dlg in dialogsum_df.loc[i]['dialogue'].split('\n'):
        print(dlg)

    print('\n')
    for _, (summ, idx) in enumerate(summaries):
        if idx == i:
            print(summ)
            break

def prepare_dataframe():
    data_path = os.path.join('..', 'data')
    dialogsum_path = os.path.join(data_path, 'dialogsum', 'DialogSum_Data')
    test_path = os.path.join(dialogsum_path, 'dialogsum.test.jsonl')
    dialogsum_test_df = pd.read_json(test_path, lines=True)
    dialogsum_test_df = dialogsum_test_df.rename(columns={"summary1": "summary"})
    dialogsum_test_df['split'] = 'test'
    
    
    dev_path = os.path.join(dialogsum_path, 'dialogsum.dev.jsonl')
    dialogsum_dev_df = pd.read_json(dev_path, lines=True)
    dialogsum_dev_df['split'] = 'val'

    train_path = os.path.join(dialogsum_path, 'dialogsum.train.jsonl')
    dialogsum_train_df = pd.read_json(train_path, lines=True)
    dialogsum_train_df['split'] = 'train'
    
    dialogsum_df = pd.concat([dialogsum_train_df, dialogsum_dev_df, dialogsum_test_df])
    dialogsum_df.reset_index(inplace=True)
    
    return dialogsum_df

def get_summs_w_2_persons(dialogsum_df):
    summaries = []
    two_person_dlg_f = dialogsum_df[dialogsum_df['dialogue'].str.contains('#Person3#') == False]
    indices = dialogsum_df.index[dialogsum_df['dialogue'].str.contains('#Person3#') == False].tolist()
    summaries = [(dialogsum_df.loc[idx]['summary'], idx) for idx in indices]
    
    return summaries, two_person_dlg_f

def split_sentence(pred):
    sents = []
    sent = []
    for word, token in zip(pred['tokens'], pred['pos_tags']):
        if token == '.':
            sent[-1] += word
            sents.append(sent)
            sent = []
        else:
            sent.append(word)
            
    return [" ".join(s) for s in sents]

def get_sentences(pred):
    root = pred['hierplane_tree']['root']
    sents = []
    for child in pred['hierplane_tree']['root']['children']:
        if child['nodeType'] == 'S':
            sents.append(child)    
    
    
    return sents if sents else [root]

def split_summaries(summaries, predictor):
    summaries_split = {}
    for summ, i in tqdm(summaries):
        out = predictor.predict(summ)
        summaries_split[i] = split_sentence(out)

    return summaries_split

def get_split_summaries(summ_split_path, summaries, predictor, force_rerun=False):
    if os.path.exists(summ_split_path) and not force_rerun:
        with open(summ_split_path, 'r') as json_file:
            splits = json.load(json_file)
            summaries_split = {int(k): splits[k] for k in splits}
    else:
        summaries_split = split_summaries(summaries, predictor)
        with open(summ_split_path, 'w') as json_file:
            json.dump(summaries_split, json_file)
    
    return summaries_split

def cleanup_summs(summaries_split):
    for k in tqdm(summaries_split):
        for i in range(len(summaries_split[k])):
            summaries_split[k][i] = summaries_split[k][i].replace('#Person1#', 'XYZ1')
            summaries_split[k][i] = summaries_split[k][i].replace('# Person1#', 'XYZ1')
            summaries_split[k][i] = summaries_split[k][i].replace('# Person1 #', 'XYZ1')
            summaries_split[k][i] = summaries_split[k][i].replace('#Person1 #', 'XYZ1')

            summaries_split[k][i] = summaries_split[k][i].replace('#Person2#', 'XYZ2')
            summaries_split[k][i] = summaries_split[k][i].replace('# Person2#', 'XYZ2')
            summaries_split[k][i] = summaries_split[k][i].replace('# Person2 #', 'XYZ2')
            summaries_split[k][i] = summaries_split[k][i].replace('#Person2 #', 'XYZ2')

            summaries_split[k][i] = summaries_split[k][i].replace(" 'll", "'ll")
            summaries_split[k][i] = summaries_split[k][i].replace(" 's", "'s")
            summaries_split[k][i] = summaries_split[k][i].replace(" n't", "n't")
            summaries_split[k][i] = summaries_split[k][i].replace(" - ", "-")
            summaries_split[k][i] = summaries_split[k][i].replace(" ,", ",")
            summaries_split[k][i] = summaries_split[k][i].replace(" 've'", "'ve'")
            summaries_split[k][i] = summaries_split[k][i].replace("..", ".")
            summaries_split[k][i] = summaries_split[k][i].replace('. ', ' ')
            
def tag_summs(summaries_split, predictor):
    tagged_summs = {}
    curr = None

    for k in tqdm(summaries_split):
        curr = k
        split = summaries_split[k]
        tagged_summs[k] = []
        for summ in split:
            ptree = predict_and_get_tree(summ, predictor)
            subs = get_subsentences(ptree)
            #subs_2 = []
            #for sub in subs:
            #    subs_2 += get_subsentences(sub) 

            clauses = get_clauses(subs)
            tagged_summs[k].extend(clauses)
    
    return tagged_summs

def correct_tags(tagged_summs, summaries_split):
    count = 0
    for k in tagged_summs:
        if len(summaries_split[k]) > len(tagged_summs[k]):
            count += 1
            tagged_summs[k] = summaries_split[k][::]
        
def cleanup_tags(tagged_summs):
    for k in tqdm(tagged_summs):
        i = 0
        while i < len(tagged_summs[k]):
            if tagged_summs[k][i] == '.':
                del tagged_summs[k][i]
            else:
                if isinstance(tagged_summs[k][i], list):
                    tagged_summs[k][i] = ' '.join(tagged_summs[k][i])
                tagged_summs[k][i] = tagged_summs[k][i].replace(" 'll", "'ll")
                tagged_summs[k][i] = tagged_summs[k][i].replace(" 's", "'s")
                tagged_summs[k][i] = tagged_summs[k][i].replace(" ,", ",")
                tagged_summs[k][i] = tagged_summs[k][i].replace(" '", "'")
                tagged_summs[k][i] = tagged_summs[k][i].replace(" - ", "-")
                tagged_summs[k][i] = tagged_summs[k][i].replace(". ", " ")
                if len(tagged_summs[k][i]) > 0:
                    if tagged_summs[k][i][-1] != '.':
                        tagged_summs[k][i] = tagged_summs[k][i] + '.'
                tagged_summs[k][i] = tagged_summs[k][i].replace("Person1", "XYZ1")
                tagged_summs[k][i] = tagged_summs[k][i].replace("Person2", "XYZ2")
                i += 1
                
    # hard coded corrections
    tagged_summs[259][0] = 'Alice wants to apply for a scholarship offered by the American Minority Students Scholarship Association since she is eligible for it that she is Asian American, a student in junior year and has GPA 3.92.'
    tagged_summs[259].pop(1)
    tagged_summs[298][1] = 'They both play bridge'
    tagged_summs[2016] = ["Edward Smith wants to book a flight to New York on July 21st but it isn't available.", "he takes another flight on July 22nd."]
    tagged_summs[744] = ["XYZ1 sends a necklace to Mom on Mother's Day"] + tagged_summs[744]
    tagged_summs[4446][1] = 'XYZ1 thinks working overtime is not always pleasant.'
    tagged_summs[4250][1] = 'Jason comforts her.'
    tagged_summs[3363][2] = "there'll be more collections of his works."
    tagged_summs[6934] = ['XYZ1 is going to buy bicycle A5, FOB Qingdao from Mr Smith.', 'they agree on 3.5%.']
    tagged_summs[8760][1] = "They 've got meat, utensils and paper plates, and are going to buy some buns and ketchup."
    tagged_summs[10203][2] = "they've made a room reservation."
    tagged_summs[11213][1] = "They've been dating for three years."
    tagged_summs[11832][0] = "Say forgets to take Melber's book and suggest they pick it up after the show."
    
            
def get_tagged_summs(tagged_summs_path, summaries_split, predictor, force_rerun=False):
    if os.path.exists(tagged_summs_path) and not force_rerun:
        with open(tagged_summs_path, 'r') as json_file:
            json_data = json.load(json_file)
            tagged_summs = {int(k): json_data[k] for k in json_data}
    else:
        tagged_summs = tag_summs(summaries_split, predictor)
        with open(tagged_summs_path, 'w') as json_file:
            json.dump(tagged_summs, json_file)

    
    correct_tags(tagged_summs, summaries_split)
    cleanup_tags(tagged_summs)
    remove_dots_and_empty(tagged_summs)
    
    return tagged_summs

In [16]:
from difflib import SequenceMatcher
from nltk import Tree
from nltk.tree.parented import ParentedTree

preposition_dependent = set([
    'if', 'though', 'before', 'although', 'beside', 'besides', 'despite', 'during',
    'unless', 'until', 'via', 'vs', 'upon', 'unlike', 'like', 'with', 'within', 'without', 'because'
])

noun_tags = set(['NP', 'NN', 'NNP'])
verb_tags = set(['VP', 'VBP'])

def similar(a, b):
    return SequenceMatcher(None, a, b).ratio()


def get_clauses(subsentences):
    clauses = []
    for subsent in subsentences:
        clause = []
        i = 0
        while i < len(subsent):
            child = subsent[i]
            if child[0] == 'and':
                clause.extend(child.leaves())
                i += 1
                child = subsent[i]
                clause.extend(child.leaves())
            elif child.label() == 'SBAR':
                if child[0].label() == 'IN' and child[0][0].lower() not in preposition_dependent:
                    clauses.append(clause)
                    clause = []
                    clauses.append(child.leaves()[1:])
                else:
                    clause.extend(child.leaves())

            elif child.right_sibling():
                if child.label() == 'PP' and child.right_sibling().label() in noun_tags:
                    clause.extend(child.leaves())
                    
                if child.label() in noun_tags:
                    clause.extend(child.leaves())
                    if child.right_sibling().label() != 'SBAR':
                        clause.extend(child.right_sibling().leaves())
                    if child.right_sibling().right_sibling():                        
                        if child.right_sibling().label() == 'ADVP' :
                            clause.extend(child.right_sibling().right_sibling().leaves())
                        if child.right_sibling().label() == 'JJ' and child.right_sibling().right_sibling().label() == 'PP':
                            clause.extend(child.right_sibling().right_sibling().leaves())
                        elif child.right_sibling().right_sibling().label() == 'VP':
                            clause.extend(child.right_sibling().right_sibling().leaves())

            i += 1
        if clause:
            clauses.append(clause)

    return clauses

def get_subsentences(tree):
    subs = []
    tmp = []
    for i, child in enumerate(tree):
        if child.label() != 'CC':
            tmp.append(child)
        if child.label() == 'S':
            if tmp:
                subs.append(tmp)
                tmp = []
            subs.append(child)

    return subs if subs else [tree]

def contains_title(subject):
    lower_string = [string.lower() for string in subject]
    titles = ['mr', 'mrs', 'ms', 'mister', 'miss', 'misses', 'dr', 'doctor']
    for title in titles:
        if title in lower_string:
            return True
        
    return False
    
def get_subject(tree):
    output = []
    for child in tree:
        if child.label() in noun_tags:
            output = child.leaves()
            break
        elif child.label() == 'S' or child.label() == 'SBAR':
            output = get_subject(child)
            if output:
                break
        elif child.label() == 'VP':
            if child.left_sibling():
                output = child.left_sibling().leaves()
            else:
                output = get_subject(child)
            break
    
    if len(output) >= 2:
        if 'and' in output:
            output = output
        elif contains_title(output):

            output = [output[1]]
        else:
            output = [output[0]]
        
        return output
     
    return output

def predict_and_get_tree(summ, predictor=None):
    if predictor is None:
        predictor = get_predictor('pos')
    pred = predictor.predict(summ)
    t = Tree.fromstring(pred['trees'])
    ptree = ParentedTree.convert(t)

    return ptree

def has_top_level_NP(tree):
    noun_tags = set(['NP', 'NN', 'NNP'])
    for child in tree:
        if child.label() == 'S' or child.label() == 'SBAR':
            return has_top_level_NP(child)
        if child.label() == 'VP' or child.label() in noun_tags:
            return True
        
    return False

def get_names(pred):
    names = []
    for word, tag in zip(pred['words'], pred['tags']):
        if 'PER' in tag:
            names.append(word.lower())
    
    return names

def get_all_names(tagged_summ):
    names = []
    for k in tqdm(tagged_summs):
        for sent in tagged_summs[k]:
            pred = predictor.predict(sent)
            names += get_names(pred)
    
    names_set = set(names)
    names_set.add('xyz1')
    names_set.add('xyz2')
    return names_set

def get_summs_w_they(tagged_summs):
    with_they = []
    for k in tagged_summs:
        for i in range(len(tagged_summs[k])):
            if 'they' in tagged_summs[k][i]:
                with_they.append((k,i))
    
    return with_they

def get_none_and_theyNP(with_they, tagged_summs, predictor):
    they_NP = []
    is_none = []
    for k, i in tqdm(with_they):
        pair = (k, i)
        summ = tagged_summs[k][i]
        pred = predictor.predict(summ)
        t = Tree.fromstring(pred['trees'])
        ptree = ParentedTree.convert(t)
        subject = get_subject(ptree)
        if subject:
            if 'they' in subject.lower():
                they_NP.append(pair)
        else:
            is_none.append(pair)
            
    return they_NP, is_none

def distinct_prepositions(summaries, predictor):
    preps = set()
    for k in tqdm(summaries):
        split = summaries[k]
        for sent in split:
            out = predictor.predict(sentence=sent)
            t = Tree.fromstring(out['trees'])
            ptree = ParentedTree.convert(t)
            for subtree in ptree.subtrees(filter=lambda x: x.label() == 'IN'):
                preps.add(subtree[0])
    return preps

def remove_dots_and_empty(tagged_summs):
    for k in tagged_summs:
        i = 0
        while i < len(tagged_summs[k]):
            if tagged_summs[k][i] == '.' or tagged_summs[k][i] == '':
                del tagged_summs[k][i]
            else:
                i += 1

In [17]:
dialogsum_df = prepare_dataframe()
summaries, dialogsum_df = get_summs_w_2_persons(dialogsum_df)

In [20]:
predictor = get_predictor('pos')

  "AllenNLP Tango is an experimental API and parts of it might change or disappear "
2022-06-26 22:32:58,104 - INFO - allennlp.common.plugins - Plugin allennlp_models available
2022-06-26 22:32:58,363 - INFO - allennlp.common.file_utils - cache of https://storage.googleapis.com/allennlp-public-models/elmo-constituency-parser-2020.02.10.tar.gz is up-to-date
2022-06-26 22:32:58,364 - INFO - allennlp.models.archival - loading archive file https://storage.googleapis.com/allennlp-public-models/elmo-constituency-parser-2020.02.10.tar.gz from cache at /home/tnguyen/.allennlp/cache/653d0c5a1fb85ac98e84e332fa2a2c0596d9c86a2f38189886d65a422dabe1e9.8cfb67d64c5824347f7328a0f84e46d2e74f9d9bb1aba6441b313d5aaccdea4d
2022-06-26 22:32:58,366 - INFO - allennlp.models.archival - extracting archive file /home/tnguyen/.allennlp/cache/653d0c5a1fb85ac98e84e332fa2a2c0596d9c86a2f38189886d65a422dabe1e9.8cfb67d64c5824347f7328a0f84e46d2e74f9d9bb1aba6441b313d5aaccdea4d to temp dir /tmp/tmpizb9ta_v
2022-06-26 22:33

2022-06-26 22:33:14,005 - INFO - allennlp.common.file_utils - cache of https://allennlp.s3.amazonaws.com/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_options.json is up-to-date
2022-06-26 22:33:17,236 - INFO - allennlp.common.file_utils - cache of https://allennlp.s3.amazonaws.com/models/elmo/2x4096_512_2048cnn_2xhighway/elmo_2x4096_512_2048cnn_2xhighway_weights.hdf5 is up-to-date
2022-06-26 22:33:18,386 - INFO - allennlp.common.params - model.span_extractor.type = bidirectional_endpoint
2022-06-26 22:33:18,387 - INFO - allennlp.common.params - model.span_extractor.input_dim = 500
2022-06-26 22:33:18,388 - INFO - allennlp.common.params - model.span_extractor.forward_combination = y-x
2022-06-26 22:33:18,388 - INFO - allennlp.common.params - model.span_extractor.backward_combination = x-y
2022-06-26 22:33:18,388 - INFO - allennlp.common.params - model.span_extractor.num_width_embeddings = None
2022-06-26 22:33:18,389 - INFO - allennlp.common.params - model.

2022-06-26 22:33:18,448 - INFO - allennlp.nn.initializers -    text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_0.state_projection.weight
2022-06-26 22:33:18,449 - INFO - allennlp.nn.initializers -    text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_1.input_linearity.weight
2022-06-26 22:33:18,449 - INFO - allennlp.nn.initializers -    text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_1.state_linearity.bias
2022-06-26 22:33:18,449 - INFO - allennlp.nn.initializers -    text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_1.state_linearity.weight
2022-06-26 22:33:18,450 - INFO - allennlp.nn.initializers -    text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._elmo_lstm.forward_layer_1.state_projection.weight
2022-06-26 22:33:18,450 - INFO - allennlp.nn.initializers -    text_field_embedder.token_embedder_elmo._elmo._elmo_lstm._token_embedder._char_embeddin

In [21]:
summ_split_path = 'summaries_split.json'
summaries_split = get_split_summaries(summ_split_path, summaries, predictor, force_rerun=True)
cleanup_summs(summaries_split)

tagged_summs_path = 'tagged_summs.json'
tagged_summs = get_tagged_summs(tagged_summs_path, summaries_split, predictor, force_rerun=True)

  0%|          | 0/13324 [00:00<?, ?it/s]



  0%|          | 0/13324 [00:00<?, ?it/s]

  0%|          | 0/13324 [00:00<?, ?it/s]

  0%|          | 0/13324 [00:00<?, ?it/s]

In [1896]:
tagged_summs_path = 'tagged_summs.json'
tagged_summs = get_tagged_summs(tagged_summs_path, summaries_split, predictor, force_rerun=True)

  0%|          | 0/13324 [00:00<?, ?it/s]

  0%|          | 0/13324 [00:00<?, ?it/s]

In [1970]:
tagged_summs_path = 'tagged_summs.json'
tagged_summs = get_tagged_summs(tagged_summs_path, summaries_split, predictor)

  0%|          | 0/13324 [00:00<?, ?it/s]

In [1957]:
tagged_summs[40]

['XYZ1 made a bargain to buy a new dress.',
 'XYZ2 watched TV, read a boring book, and took a shower at home.']

In [1922]:
summaries_split[259]

['Alice wants to apply for a scholarship offered by the American Minority Students Scholarship Association since she is eligible for it that she is Asian American, a student in junior year and has GPA 3.',
 '92.',
 'To get the scholarship, Alice must write an essay on the topic -- The Place of Ethnic Minorities in a Democratic Society.',
 'XYZ1 is helping her write a letter of recommendation, read her essay, and give some suggestions.']

In [1011]:
#summ = summaries_split[258][0]
NP_first = True
for k in tqdm(range(11213, 13325)):
    if k not in tagged_summs:
        continue
    for i in range(len(tagged_summs[k])):
        summ = tagged_summs[k][i]
        pred = predictor.predict(summ)
        t = Tree.fromstring(pred['trees'])
        ptree = ParentedTree.convert(t)
        NP_first = has_top_level_NP(ptree)
        if (not NP_first):
            print(ptree)
            break
              
    if (not NP_first):
        print(k, i)
        break

  0%|          | 0/2112 [00:00<?, ?it/s]

In [883]:
tagged_summs[165][2] = ' '.join(clauses[0]) + '.'

In [1260]:
i = 11213
show_dialogue_and_summaries(dialogsum_df, summaries, i)

#Person1#: Hey Kevin, what are you doing here? Don't you usually spend Tuesday nights at home studying?
#Person2#: I needed to get out of the house. My parents just went ballistic over something my older sister told them.
#Person1#: What did she tell them? Is she dropping out of college?
#Person2#: Nothing that serious. She finally told them that she moved out of the dormitory a few months ago and has been living with her boyfriend.
#Person1#: And your parents took it badly?
#Person2#: That's putting it mildly. My father started shouting at my sister and my mother just glared at her.
#Person1#: Ouch, that sounds bad. What did your sister do?
#Person2#: She started arguing back to my dad that how much she loves her boyfriend, how they're in love and it's not hurting anybody, and so on. My dad said she's too young to do this, and that she should move out right away.
#Person1#: How long has your sister been with her boyfriend?
#Person2#: Three years. They've been dating since freshman yea

In [1279]:
summ = tagged_summs[11213][0]
ptree = predict_and_get_tree(summ, predictor)
subject = get_subject(ptree)
print(subject)

Kevin


In [1314]:
count = 0
two_person_dlg_df = dialogsum_df[dialogsum_df['dialogue'].str.contains('#Person3#') == False]
indices = dialogsum_df.index[dialogsum_df['dialogue'].str.contains('#Person3#') == False].tolist()

In [1323]:
dialogsum_df = prepare_dataframe()
summaries, dialogsum_df = get_summs_w_2_persons(dialogsum_df)

In [1343]:
def find_exceptions(tagged_summs):
    predictor = get_predictor('pos')
    they_NP
    for k in tqdm(tagged_summs):

{(11213, 1)}

In [1467]:
show_dialogue_and_summaries(dialogsum_df, summaries, 11832)

#Person1#: They should be a great show. Let's go in.
#Person2#: Sure. Say, did you bring my book?
#Person1#: Oh, I completely forgot it.
#Person2#: You forgot? But you promised. I needed to study for the test. Oh, I knew I never should have lent it to you.
#Person1#: Calm down, Melber. After the show, we can drive by my house and pick it up.
#Person2#: It's pretty far out of the way. But I guess we'll have to.
#Person1#: Don't worry. I'll treat you to an ice cream to make it up to you.
#Person2#: OK.


Say forgets to take Melber's book and suggest they pick it up after the show.


In [2098]:
def assign_exceptions(k, i):
    p_1 = (True, 'Person1')
    p_2 = (True, 'Person2')
    exceptions = {
        (0, 0): p_1,
        (7, 4): p_1,
        (950, 0): p_1,
        (11213, 1): p_2,
        (7233, 1): p_1,
        (11832, 0): p_1,
    }
    is_exception, key = exceptions.get((k, i), (False, None))
    
    return is_exception, key

def assign_labels(tagged_summs, names_set, dialogsum_df, labeled_summs, none_subjects, predictor=None, start_idx=0, end_idx=13460):
    #none_subjects = []
    #labeled_summs = {}
    if predictor is None:
        predictor = get_predictor('pos')

    for k in tqdm(range(start_idx, end_idx)):
        curr[0] = k
        if k not in tagged_summs:
            continue
        labeled_summs[k] = {'Person1': [], 'Person2': []}
        prev = None
        for i in range(len(tagged_summs[k])):
            summ = tagged_summs[k][i]
            
            is_exception, key = assign_exceptions(k, i)
            if is_exception:
                labeled_summs[k][key].append(summ)
                continue
            
            tree = predict_and_get_tree(summ, predictor)
            
            subject = get_subject(tree)
            if subject:
                subject = ' '.join(subject).lower()
                for suffix in ["'s", "'ll", "'ve", "'"]:
                    subject = subject.replace(suffix, "")

                if 'and' in subject:
                    target_keys = []
                    subjects = subject.split('and')
                    for subj in subjects:
                        target_keys += define_speaker(subj.strip(), names_set, dialogsum_df, k, prev)
                else:
                    target_keys = define_speaker(subject.lower(), names_set, dialogsum_df, k, prev)
                
                if not target_keys:
                    none_subjects.append((k, i))
                    continue
                
                for key in target_keys:
                    labeled_summs[k][key].append(summ)
            else:
                none_subjects.append((k, i))
            prev = key
                
    return labeled_summs, none_subjects

def define_speaker(subject, names_set, dialogsum_df, idx, prev):
    if subject == 'xyz1':
        keys = ['Person1']
    elif subject == 'xyz2':
        keys = ['Person2']
    elif subject in names_set:
        keys = search_speaker(subject, dialogsum_df.loc[idx]['dialogue'])
    else:
        keys = pronoun_distinction(subject, prev)
    
    return keys

def search_speaker(name, dialogue):
    person = None
    for utt in dialogue.split('\n'):
        if name in utt.lower():
            person = get_speaker(utt)
            break
        else:
            for word in utt.split(' '):
                if similar(word.lower(), name) >= 0.75:
                    person = get_speaker(utt)
                    return [person]
                    
    
    return [person] if person else []


def get_speaker(utt):
    intro_sents = ["i am", "i'm", "name is", "name's", "this is", "that is", "that's"]
    introduces = False
    for sent in intro_sents:
        if sent in utt:
            introduces = True
            break
    speaker = utt.split(' ')[0]

    person = 'Person2' if '1' in speaker else 'Person1'
    if introduces:
        person = 'Person1' if '1' in speaker else 'Person2'

    return person

def pronoun_distinction(pronoun, prev_label):
    singular = set(['he', 'she', 'his', 'her', 'him'])
    if pronoun in singular:
        return [prev_label] if prev_label else []
    
    return ['Person1', 'Person2']
    

In [2097]:
i = curr[0]
tagged_summs_copy = {i: tagged_summs[i]}
tagged_summs_copy

{6934: ['XYZ1 is going to buy bicycle A5, FOB Qingdao from Mr Smith.',
  'they agree on 3.5%.']}

In [2096]:
tagged_summs[6934] = ['XYZ1 is going to buy bicycle A5, FOB Qingdao from Mr Smith.', 'they agree on 3.5%.']

In [2065]:
none_subjects = []
labeled_summs = {}

In [2079]:
none_subjects_test = []
labeled_summs_test = {}

In [2101]:
curr = [0]
label_assignments, errors = assign_labels(tagged_summs, names_set, dialogsum_df, labeled_summs, none_subjects, predictor, start_idx=6934)
#label_assignments_test, errors_test = assign_labels(tagged_summs_copy, names_set, dialogsum_df, labeled_summs_test, none_subjects_test, predictor, start_idx=i, end_idx=i+1)

  0%|          | 0/6526 [00:00<?, ?it/s]

In [2105]:
len(none_subjects)

41

In [2107]:
with open('labels.json', 'w') as json_file:
    json.dump(label_assignments, json_file)

In [2111]:
count = 0
for k in label_assignments:
    if len(label_assignments[k]['Person1']) > 0 or len(label_assignments[k]['Person2']) > 0:
        count += 1

print(count)

13283


In [133]:
with open('labels.json', 'r') as json_file:
    labels = json.load(json_file)
    labels = {int(key): labels[key] for key in labels}

In [134]:
labels[max(labels)]

{'Person1': ['Frank invites Besty to the party to celebrate his big promotion.'],
 'Person2': []}

In [135]:
for key in labels:
    for person in labels[key]:
        for i in range(len(labels[key][person])):
            labels[key][person][i] = labels[key][person][i].replace('XYZ1', '#Person1#')
            labels[key][person][i] = labels[key][person][i].replace('XYZ2', '#Person2#')

In [136]:
empties = []
for key in labels:
    for person in labels[key]:
        if len(labels[key][person]) == 0:
            empties.append(key)
            labels[key][person].append(dialogsum_df.loc[key]['summary'] + '#TOREMOVE#')

In [137]:
labels[12]

{'Person1': ["#Person2# suggests #Person1# use a spam filter to reject Bean's pornographic stuff.#TOREMOVE#"],
 'Person2': ["#Person2# suggests #Person1# use a spam filter to reject Bean's pornographic stuff."]}

In [138]:
len(empties)

5615

In [139]:
faulties = set()
for k in empties:
    if labels[k]['Person1'] == labels[k]['Person2']:
        faulties.add(k)

faulties = list(faulties)
iter_faults = iter(faulties)
print(len(faulties))

41


In [200]:
faulties.index(index)

32

In [213]:
index = next(iter_faults)
print(labels[index]['Person1'])
print(dialogsum_df.loc[index]['summary'])

StopIteration: 

In [216]:
for key in labels:
    for person in labels[key]:
        if '#TOREMOVE#' in labels[key][person][0]:
            labels[key][person][0] = labels[key][person][0].replace('#TOREMOVE#', '')
        for i in range(len(labels[key][person])):
            if labels[key][person][i][-1] != '.':
                labels[key][person][i] += '.'

In [218]:
labels[max(labels)]

{'Person1': ['Frank invites Besty to the party to celebrate his big promotion.'],
 'Person2': ["Frank invites Besty to the party to celebrate his big promotion. Besty couldn't wait for the party."]}

In [219]:
with open('labels_corrected.json', 'w') as json_file:
    json.dump(labels, json_file)

In [4]:
with open('labels_corrected.json', 'r') as json_file:
    data = json.load(json_file)
    labels = {int(k): data[k] for k in data}

In [13]:
import torch
from sklearn import preprocessing
le = preprocessing.LabelEncoder()
p_1 = le.fit_transform(labels[12]['Person1'])
torch.Tensor(p_1)

tensor([0.])

In [None]:
import torch
for k in labels:
    p_1 = labels[k]['Person1']
    p_2 = labels[k]['Person2']

In [240]:
dialogsum_new_labels = {
    'test': {}, 'val': {} , 'train': {}
}

for key in labels:
    split = dialogsum_df.loc[key]['split']
    dialogsum_new_labels[split][key] = labels[key]
    

In [264]:
dialogsum_new_labels['test'][12960]

{'Person1': ['Ms. Dawson helps #Person1# to write a memo to inform every employee that they have to change the communication method and should not use Instant Messaging anymore.'],
 'Person2': ['Ms Dawson helps #Person1# to write a memo to inform every employee that they have to change the communication method and should not use Instant Messaging anymore.']}

In [243]:
def write_labels_to_json(split, data):
    with open(f'new_dialogsum_labels_{split}.json', 'w') as json_file:
        json.dump(data, json_file)

In [250]:
for split in dialogsum_new_labels:
    data = dialogsum_new_labels[split]
    print(split, len(data))
    write_labels_to_json(split, data)

test 496
val 495
train 12333


In [253]:
total = len(labels)
print('total', total)

test_percentage = 496 / total * 100
print('test percentage', test_percentage, '%')

val_percentage = 495 / total * 100
print('val percentage', val_percentage, '%')

train_percentage = 12333 / total * 100
print('train percentage', train_percentage, '%')

total 13324
test percentage 3.7226058240768536 %
val percentage 3.7151005703992794 %
train percentage 92.56229360552388 %


In [254]:
data_path = os.path.join('..', 'data')
dialogsum_path = os.path.join(data_path, 'dialogsum', 'DialogSum_Data')
test_path = os.path.join(dialogsum_path, 'dialogsum.test.jsonl')
dialogsum_test_df = pd.read_json(test_path, lines=True)
dialogsum_test_df = dialogsum_test_df.rename(columns={"summary1": "summary"})
dialogsum_test_df['split'] = 'test'


dev_path = os.path.join(dialogsum_path, 'dialogsum.dev.jsonl')
dialogsum_dev_df = pd.read_json(dev_path, lines=True)
dialogsum_dev_df['split'] = 'val'

train_path = os.path.join(dialogsum_path, 'dialogsum.train.jsonl')
dialogsum_train_df = pd.read_json(train_path, lines=True)
dialogsum_train_df['split'] = 'train'

In [263]:
def get_percentage(amount, total):
    return amount / total * 100

len_test = len(dialogsum_test_df)
len_val = len(dialogsum_dev_df)
len_train = len(dialogsum_train_df)
total = len_test + len_val + len_train

print('total', total)
print('test', len_test)
print('test percentage', get_percentage(len_test, total))

print('val', len_val)
print('val percentage', get_percentage(len_val, total))

print('train', len_train)
print('train percentage', get_percentage(len_train, total))


total 13460
test 500
test percentage 3.7147102526002973
val 500
val percentage 3.7147102526002973
train 12460
train percentage 92.5705794947994
