# Experiment: Modification using synonyms for LV changes evaluation

## Imports

In [1]:
import logging
import pickle
import spacy
from essay_evaluation.classifier import Classifier
from essay_evaluation.corpus import read_csv
from essay_evaluation.word_substitution import WordSubstitution
from essay_evaluation.pipeline import Pipeline
from essay_evaluation.legacy.contractions import expand_contractions
from collections import defaultdict
from tqdm.notebook import tqdm

Importing and building WordSubstitution functionality, this might take a while the first time due to the dependencies...
Please be patient!


## Setup: Paths and data read

In [2]:
dataset_path = '/home/simon/Downloads/flip_new.csv'
texts, levels = read_csv(dataset_path)
test_size = 98
model_regressor_path = '/home/simon/Downloads/model_reg.pkl'
with open(model_regressor_path, 'rb') as fh:
    model_regressor = pickle.load(fh)


# look for repetition in the worst X sentences
repetition_sentence_number = 5

### Our sample sentence:
Let's use a text with a lot of repetition. (That's the only thing we can revise at the moment.)

In [3]:
text = """This is nice. This is also nice. Here comes another nice sentence. How about a third one? Let's not stop 
here! One last sentence."""

## Experiment Overview
### Steps to revise an essay and give feedback
<ol>
    <li>Rank sentences based on grade impact</li>
    <li>Find candidates for substitution
        <ol>
            <li>Filter stopwords</li>
            <li>Lemmatize, lowecase and add indecies of token in the essay in a dict to count repetitions</li>
            <li>Identify repetitions, filter tokens that are coreferences</li>
        </ol>
    </li>
    <li>Generate revised essays
        <ol>
            <li>Mark the remaining repetition words on the sentence with the "čš" symbol on both sides</li>
            <li>Retrieve synonyms for the marked words</li>
            <li>Substitute top 5 synonyms on each word on a round robin fashion</li>
        </ol>
    </li> 
    <li> Re-grade the revised essay</li>
    <li>Give feedback based on which revision technique worked best (e.g. repetition removal) (Not part of this 
    notebook!)</li> 
</ol>

#### 1. Rank sentences based on grade impact
This helps us to
- find out where to do revisions
- where to look for mistakes

In [4]:
pipline_sentence_ranking = Pipeline().lexical_variation_taaled().neuralcoref().get_pipe()
pipline_sentence_ranking.add_pipe(Classifier(model_regressor, 'score_regression'), name=Classifier.name + 'reg', last=True)

def remove_sentences_one_by_one(doc):
    sents = list(doc.sents)
    for i, _ in enumerate(sents):
        one_removed = sents[:i] + sents[i+1:]
        yield ' '.join([s.text for s in one_removed])

def rank_sentences(doc):
    """
    Returns a list of sentences which have a bad influence on the grade sorted by their influence.
    score_diff = lower: means the sentence has a bad impact on the grade (revision necessary)
                 higher: means the sentence has a good impact on the grade (no revision necessary)
    :return: 
    """
    if not doc.has_extension('score_regression'):
        raise Exception("document has no regression score")
    alt_texts = remove_sentences_one_by_one(doc)
    result = []
    for sent_i, alt_text in enumerate(tqdm(alt_texts, total=len(list(doc.sents)))):
        doc_alt = pipline_sentence_ranking(alt_text)
        score_diff = doc._.score_regression - doc_alt._.score_regression 
        result.append({
            'index': sent_i,
            'score_diff': score_diff
        })
    return sorted(result, key=lambda sent: sent['score_diff'])

We use following sample input:
`This is nice. This is also nice. Here comes another nice sentence. How about a third one? Let's not stop here! One last sentence.`

The output gives us a sentence ranking. A lower `score_diff` means that the sentence has a bad impact on the grade and 
should be revised: 

In [5]:
doc = pipline_sentence_ranking(text)
ranked_sentences = rank_sentences(doc)
ranked_sentences

HBox(children=(FloatProgress(value=0.0, max=6.0), HTML(value='')))




[{'index': 2, 'score_diff': -1.13},
 {'index': 0, 'score_diff': -1.1099999999999999},
 {'index': 1, 'score_diff': -1.0899999999999999},
 {'index': 5, 'score_diff': 0.06000000000000005},
 {'index': 3, 'score_diff': 0.25},
 {'index': 4, 'score_diff': 0.40000000000000013}]

#### 2. Find candidates for substitution

In [6]:
def get_word_counts(sentences):
    """
    returns a list of substitution candidates. 
    This does not return single tokens as these are aggregated and maybe occure in different places 
    :param sentences: 
    :return: 
    """
    # we might want to do the coref checking here!
    result = defaultdict(list)
    for s in sentences:
        for t in s:
            if t.tag_ in ['NN', 'NNS', 'NNP', 'NNPS', 'VB', 'VBD', 'VBG', 'VBN', 'VBP', 'VBZ', 'JJ', 'JJR', 'JJS', 'RB',
                          'RBR', 'RBS'] and not t._.in_coref:
                result[t.lemma_].append(t)        
            
    return sorted([i for i in result.items() if len(i[1]) > 1 ], key=lambda item: len(item[1]), reverse=True)
    # return Counter(tokens).most_common() - we don't use a Counter here so we can keep track of the original Token 
    # objects

This function returns us a list of tuples: (lemma, [list of tokens]). Thanks to the token object, we can find out the 
token's position inside the document.
We look for repetitions inside the worst X sentences (e.g. 5, defined as `repetition_sentence_number`)

In [7]:
wc_bad_sentences = get_word_counts([list(doc.sents)[rs['index']] for rs in ranked_sentences[:repetition_sentence_number]]) 
wc_bad_sentences

([('nice', [nice, nice, nice]),
  ('be', [is, is]),
  ('here', [Here, here]),
  ('sentence', [sentence, sentence])],
 [('nice', [nice, nice, nice]),
  ('sentence', [sentence, sentence]),
  ('be', [is, is])])

#### 3. Revise the essay
We decided to replace the detected repetitions inside the whole essay. We do not use synonysm twice. So each repetition 
will be replaced with another synonym.

In [8]:
nlp = spacy.load("en_core_web_sm")
ws = WordSubstitution(nlp)



def substitute_repetitions(sentences, substitution_indices):
    used_synonyms = defaultdict(set)
    result = []
    for sent in sentences:
        lst_sent = list(sent)
        revised_sent = []
        for t in sent:
            if t in substitution_indices:
                i = t.i - sent.start
                sentence_with_i_removed = lst_sent[:i] + ['čš' + str(t) + 'čš'] +  lst_sent[i+1:]
                #print(' '.join([str(t) for t in sentence_with_i_removed]))
                try:
                    synomyms = ws(' '.join([str(t) for t in sentence_with_i_removed]))
                    # do not replace the same word with the same synonym twice
                    # todo: maybe allow to use the same synonym if there was a different substitution in between
                    for synonym in synomyms['rankedSynonyms']:
                        if synonym in used_synonyms[t.lemma_]:
                            continue
                        
                        used_synonyms[t.lemma_].add(synonym)
                        revised_sent.append(synonym)
                        break
                except ValueError:
                    logging.warning("word substitution failed - " + ' '.join([str(t) for t in sentence_with_i_removed]))
                    revised_sent.append(str(t))
            else:
                revised_sent.append(str(t))
        result.append(' '.join(revised_sent))
    return result

def get_revised_essay(doc, word_counts):
    # remove first occurence of each repetition
    for word in word_counts:
        word[1].pop()
    
    substitution_indices = set([word for candidate_list in word_counts for word in candidate_list[1]])
    
    return " ".join(substitute_repetitions(doc.sents, substitution_indices))


#### 4. Regrade the revised essay

In [9]:
# only in the worse 3 sentences
#revision = get_revised_essay(doc, wc_bad_sentences)
#revision

#doc_alt = pipline_sentence_ranking(revision)

#print("before", doc._.score_regression)
#print("after", doc_alt._.score_regression)


### Run this on the whole FLIP-English testset

In [10]:
result = []
for essay_i, text in enumerate(tqdm(texts[:test_size], total=len(texts[:test_size]))): # 
    doc = pipline_sentence_ranking(expand_contractions(text))
    ranked_sentences = rank_sentences(doc)
    wc_bad_sentences = get_word_counts([list(doc.sents)[rs['index']] for rs in ranked_sentences[:repetition_sentence_number]]) 
    text_revised = get_revised_essay(doc, wc_bad_sentences)
    doc_revised = pipline_sentence_ranking(text_revised)
    
    result.append({
        'doc': doc,
        'ranked_sentences': ranked_sentences,
        'text_revised': text_revised,
        'doc_revised': doc_revised
        #'Text': doc_alt.text,
    })
    

HBox(children=(FloatProgress(value=0.0, max=98.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))








































































































HBox(children=(FloatProgress(value=0.0, max=18.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=14.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=31.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=82.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=70.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=18.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=17.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=27.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=19.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=42.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=17.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=14.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=77.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=18.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=34.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=91.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=18.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=18.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=29.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=24.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=14.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=19.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=29.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=17.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=17.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=17.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=26.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=28.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=7.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=14.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=33.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=19.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=11.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=13.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=17.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=19.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=17.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=115.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=65.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=17.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=92.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=27.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=23.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=17.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=17.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=8.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=12.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=66.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=22.0), HTML(value='')))

In [14]:
debug_data = []
for essay_i, element in enumerate(result):
    item = {
        'index': essay_i,
        'score_before_revision': element['doc']._.score_regression,
        'score_after_revision': element['doc_revised']._.score_regression,
        'gold_label': levels[essay_i],
        'text': element['doc'].text,
        'revision': element['doc'].text,
        'ranked_sentences': element['ranked_sentences']
    }
    for key, val in element['doc_revised']._.features.items():
        item['REVISION_' + key] = val
    for key, val in element['doc']._.features.items():
        item['ORIGINAL_' + key] = val
    debug_data.append(item)
import csv
with open('/home/simon/Downloads/result_revision_debug.csv', 'w') as fh:
    writer = csv.DictWriter(fh, fieldnames=debug_data[0].keys())
    writer.writeheader()
    writer.writerows(debug_data)