In [1]:
import sys

In [2]:
sys.path.append('../rover_based_aggregation/')

In [3]:
import utils.word_transition_network as wtn_module
from utils.rover import RoverVotingScheme
from utils.word_transition_network import *
from collections import Counter

In [4]:
import gensim

In [5]:
import yt.wrapper as yt
yt.config.set_proxy("hahn")

In [6]:
data_table = yt.read_table(
    "//home/voice/edvls/tickets/VA-442_ideal_testsets/assistant_ideal_annotations_2019-02-16__2019-02-25"
)
data_table = list(data_table)

In [7]:
AggregationResult = collections.namedtuple('AggregationResult', 'text confidence cost')

In [8]:
def aggregate_prod(raw_data):
    """
       aggregation from prod
    """
    cost = 2
    while cost < 5:
        cost += 1
        answers = [(x["text"], x["speech"]) for x in raw_data[:cost]]
        answers = Counter(answers)
        if answers.most_common(1)[0][1] >= 3:
            break

    texts = Counter()
    speechs = Counter()
    for text, speech in [(x["text"], x["speech"]) for x in raw_data[:cost]]:
        if speech != "BAD" and text:
            text = text.lower().replace('ё', 'е')
        else:
            text = ""
        speechs.update([speech])
        texts.update([text])
    
    
    text, text_rate = max(texts.items(), key=lambda x: (x[1], x[0] != ""))
    if text != "" and text_rate >= 2:
        conf = text_rate * 1.0 / sum(texts.values())
    else:
        text = None
        conf = 0
    common = speechs.most_common(2)
    speech, speech_rate = common[0]
    if speech == "BAD" and len(common) >= 2 and common[1][1] == speech_rate:
        speech = common[1][0]

    # conf = text_rate / sum(texts.values())
    return AggregationResult(text, conf, cost)

In [9]:
def evaluate_metrics(data, field, algorithm, treshhold=0, cluster_refernces=None, print_=True):
    errors = 0
    total_length = 0
    aggregated = 0
    total_items = 0
    correct = 0
    cost = 0
    not_aggregated_correct=0
    aggregated_correct=[]
    not_aggregated_incorrect=0
    aggregated_incorrect=[]
    agg_incorrect_hyps=[]
    
    for row in data:
        if row["mark"] != "TEST":
            continue
        total_items += 1
        
        hyp = algorithm(sorted(row[field], key=lambda x: x["submit_ts"]))
        cost += hyp.cost
        
        if (hyp.text is None) or (hyp.confidence < treshhold):
            if row['text'] in [obj['text'] for obj in row[field]]:
                not_aggregated_correct+=1
            else:
                not_aggregated_incorrect+=1
            continue
            
        hyp = hyp.text
        aggregated += 1
        _, e, l = calculate_wer(row["text"], hyp, cluster_refernces)
        errors += e
        if e == 0:
            correct += 1
        elif row['text'] in [obj['text'] for obj in row[field]]:
            agg_incorrect_hyps.append(hyp)
            aggregated_incorrect.append(row)
        
        total_length += l

    accuracy = correct / aggregated
    wer = errors / total_length
    aggregated_part = aggregated / total_items
    cost = cost / total_items
    if print_:
        print("Aggregated: {:.4%}\nWER: {:.4%}\nAccuracy: {:.4%}\nMean overlap: {:.4}".format(
            aggregated_part, wer, accuracy, cost
        ))
    return aggregated_part, wer, accuracy, cost,aggregated_incorrect,agg_incorrect_hyps

In [10]:
for i in data_table[0].keys():
    print(i)
col_name='toloka_assignments_repeat_11_selected_workers_with_pitch'

speech
toloka_assignments_repeat_8_selected_workers_with_chorus_and_pitch
audio
toloka_text
date
number_of_speakers
toloka_assignments
toloka_assignments_repeat_4_with_bend
toloka_assignments_repeat_6_with_chorus
_other
linguists_sugested_text
toloka_assignments_repeat_10
raw_text_linguists
toloka_assignments_repeat_9_selected_workers_with_chorus_and_pitch
check_in_yang_results
toloka_assignments_repeat_5_with_chorus_and_pitch
url
toloka_assignments_repeat_3_with_pitch
toloka_assignments_repeat_11_selected_workers_with_pitch
yang_assignments_repeat_1
mark
mds_key
toloka_number_of_speakers
text
toloka_assignments_repeat_2_with_pitch
toloka_assignments_repeat_7_with_chorus_and_pitch
toloka_speech
linguists_worker_id
toloka_assignments_repeat_1
linguists_comment


# Baseline: ~production method quality

In [11]:
aggregated_part, wer, accuracy, cost,aggregated_incorrect,agg_incorrect_hyps=\
    evaluate_metrics(data_table, col_name, aggregate_prod)

Aggregated: 67.9862%
WER: 7.2892%
Accuracy: 78.9873%
Mean overlap: 3.833


In [12]:
import numpy as np

The code below saves all errors (aggregated,but incorrect) of the production method in the following format: linguist answer, selected hypothesis, list of all annotations

In [13]:
np.random.seed(0)
with open('errors.txt','w+',encoding='utf8') as f:
    hyps_list=np.array(list(zip(aggregated_incorrect,agg_incorrect_hyps)))
    np.random.shuffle(hyps_list)
    for i,(row,hyp) in enumerate(hyps_list):
        print(i,file=f)
        print(row['text'],file=f)
        print('---',file=f)
        print(hyp,file=f)
        print('--',file=f)
        for annot in sorted(row[col_name],key=lambda x:x["submit_ts"]):
            print(annot['text'],file=f)
        print('',file=f)

In [14]:
from joblib import Parallel,delayed,effective_n_jobs
from functools import partial
from tqdm import tqdm_notebook
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry

def fix_misspellings(text):
    headers = {
        'Connection': 'close'
    }
    params={
        'text':text,
        'lang':'ru',
        'options':'512',
    }
    s = requests.Session()
    retries = Retry(total=10, backoff_factor=1, status_forcelist=[ 502, 503, 504 ])
    s.mount('http://', HTTPAdapter(max_retries=retries))
    try:
        resp=s.get('https://speller.yandex.net/services/spellservice.json/checkText',params=params,headers=headers)
        resp=resp.json()
    except ValueError:
        print(resp)
    if not resp:
        return text
    first_pos=0
    new_text=[]
    for fix in resp:
        new_text.append(text[first_pos:fix['pos']].strip())
        new_text.append(fix['s'][0].strip())
        first_pos=fix['pos']+fix['len']
    new_text.append(text[first_pos:].strip())
    return ' '.join(new_text).strip()

def process_entry(entry):
    new_entry={'text':entry['text'],'mark':entry['mark']}
    new_entry[col_name]=[]
    for elem in entry[col_name]:
        elem['text']=fix_misspellings(elem['text'].lower().replace('ё', 'е'))
        new_entry[col_name].append(elem)
    return new_entry

# Worker skills
Valya's idea: we can weigh answers based on worker skill, which can be estimated from the training part of the dataset by computing average value of some metric (accuracy, WER, etc.). Here I chose WER for simplicity

In [15]:
from collections import defaultdict
import numpy as np
def compute_worker_skills(data_table,column,metric):
    worker_stats=defaultdict(list)
    for entry in data_table:
        if entry['mark']!='TEST':
            for annot in entry[column]:
                if 'speech'!='BAD' and annot['text'] is not None:
                    worker_stats[annot['worker_id']].append(metric(entry['text'],annot['text']))
    skill_for_worker={worker_id:np.mean(stats) for worker_id,stats in worker_stats.items()}
    l=list(worker_stats.values())
    average_skill=np.mean(np.concatenate(l))
    return skill_for_worker,average_skill

In [16]:
def wer_metric(ref,hyp):
    r,e,l=calculate_wer(ref,hyp)
    return 1-r if r is not None else 1

# Fixes for the production method:
1. Use skills (defined above), divide by max for normalization (apparently works best)
2. Use more than 5 annotations (all of them, if we can); confidence threshold is adaptive to number of annotations
3. Expand question marks as "wildcard symbols": if a sentence without question marks matches a sentence with question marks, where each "?" gets replaced by 0-3 words, it gains 0.5 of a vote. We also remove "?"'s and add 0.1 to the scores of resulting sentences

In [17]:
import re

def aggregate_prod_fixed(raw_data,num_annot=10):
    for overlap in range(3,num_annot+1):
        fixed_text_speech=[elem['text'].lower().replace('ё', 'е') for elem in raw_data[:overlap]]
        worker_ids=[elem['worker_id'] for elem in raw_data[:overlap]]
        text_without_quest_mark=[elem for elem in fixed_text_speech if '?' not in elem]
        text_with_quest_mark=[elem for elem in fixed_text_speech if '?' in elem]
        texts = Counter()

        for text in text_without_quest_mark:
            texts[text]+=1
        for text_q in text_with_quest_mark:
            pattern=re.compile(text_q.replace('?','(\w+){0,3}?'))
            for text in text_without_quest_mark:
                if re.match(pattern,text):
                    texts[text]+=0.3
        for text_q in text_with_quest_mark:
            clean=' '.join(text_q.replace('?','').split())
            texts[clean]+=0.1
        text, text_rate = max(texts.items(), key=lambda x: x[1])
        if text != "" and text_rate >= 0.5*overlap:
            conf = text_rate * 1.0 / sum(texts.values())
            used = overlap
            break
    else:
        if text != "" and text_rate >= 0.375*overlap+0.1:
            conf = text_rate * 1.0 / sum(texts.values())
        else:
            text = None
            conf = 0
        used = num_annot
    return AggregationResult(text, conf, used)

def aggregate_prod_fixed_with_skills(raw_data,skills,num_annot=10):
    for overlap in range(3,num_annot+1):
        fixed_text_speech=[elem['text'].lower().replace('ё', 'е') for elem in raw_data[:overlap]]
        worker_ids=[elem['worker_id'] for elem in raw_data[:overlap]]

        per_worker,average=skills
        skills_for_workers=[per_worker.get(elem['worker_id'],average) for elem in raw_data[:overlap]]
        
        text_without_quest_mark=[(text,skill) for text,skill in zip(fixed_text_speech,skills_for_workers) if '?' not in text]
        text_with_quest_mark=[(text,skill) for text,skill in zip(fixed_text_speech,skills_for_workers) if '?' in text]
        texts = Counter()
        
        for text,skill in text_without_quest_mark:
            texts[text]+=1*skill
        for text_q,skill in text_with_quest_mark:
            pattern=re.compile(text_q.replace('?','(\w+){0,3}?'))
            for text,skill_other in text_without_quest_mark:
                if re.match(pattern,text):
                    texts[text]+=0.3*skill
        for text_q,skill in text_with_quest_mark:
            clean=' '.join(text_q.replace('?','').split())
            texts[clean]+=0.1*skill
        text, text_rate = max(texts.items(), key=lambda x: x[1])
        if text != "" and text_rate >= 0.5*overlap:
            conf = text_rate * 1.0 / sum(texts.values())
            used = overlap
            break
    else:
        if text != "" and text_rate >= 0.375*overlap:
            conf = text_rate * 1.0 / sum(texts.values())
        else:
            text = None
            conf = 0
        used = num_annot
    return AggregationResult(text, conf, used)

Small ablation study: we show impact of spelling correction, using skills, accounting for question marks, using all annotations.

In [18]:
aggregated_part, wer, accuracy, cost,aggregated_incorrect,agg_incorrect_hyps=\
    evaluate_metrics(data_table,
                     col_name,
                     aggregate_prod)

Aggregated: 67.9862%
WER: 7.2892%
Accuracy: 78.9873%
Mean overlap: 3.833


In [19]:
aggregated_part, wer, accuracy, cost,aggregated_incorrect,agg_incorrect_hyps=\
    evaluate_metrics(data_table,
                     col_name,
                     aggregate_prod_fixed)

Aggregated: 66.4372%
WER: 5.5237%
Accuracy: 82.8584%
Mean overlap: 5.691


^ This one actually yields the best results, which shows that spelling correction might be far from perfect

In [20]:
aggregated_part, wer, accuracy, cost,aggregated_incorrect,agg_incorrect_hyps=\
    evaluate_metrics(data_table,
                     col_name,
                     partial(aggregate_prod_fixed,num_annot=5))

Aggregated: 64.7447%
WER: 5.6927%
Accuracy: 82.4546%
Mean overlap: 3.855


Using skills actually gives slightly worse results:

In [21]:
skill_for_worker_wer,average_skill_wer=compute_worker_skills(data_table,col_name,wer_metric)

In [22]:
aggregated_part, wer, accuracy, cost,aggregated_incorrect,agg_incorrect_hyps=\
    evaluate_metrics(data_table,
                     col_name,
                     partial(aggregate_prod_fixed_with_skills,skills=(skill_for_worker_wer,average_skill_wer)))

Aggregated: 62.3924%
WER: 4.7223%
Accuracy: 85.4253%
Mean overlap: 6.006


In [23]:
aggregated_part, wer, accuracy, cost,aggregated_incorrect,agg_incorrect_hyps=\
    evaluate_metrics(data_table,
                     col_name,
                     partial(aggregate_prod_fixed_with_skills,skills=(skill_for_worker_wer,average_skill_wer),num_annot=5))

Aggregated: 58.9214%
WER: 4.7108%
Accuracy: 85.8325%
Mean overlap: 3.928


# Spelling correction
We try to fix misspellings and orthographic errors by feeding annotations from Toloka to Speller API (NB: it does not always work, which explains some errors). We also normalize texts before feeding them to Counter and not after (something that could be fixed in the production method)

In [24]:
data_table_with_spell_corr=Parallel(n_jobs=effective_n_jobs(),batch_size=10)(delayed(process_entry)(entry) for entry in tqdm_notebook(data_table))

HBox(children=(IntProgress(value=0, max=7022), HTML(value='')))




In [25]:
aggregated_part, wer, accuracy, cost,aggregated_incorrect,agg_incorrect_hyps=\
    evaluate_metrics(data_table_with_spell_corr,
                     col_name,
                     aggregate_prod_fixed)

Aggregated: 67.4412%
WER: 5.8910%
Accuracy: 81.4547%
Mean overlap: 5.591


To show the impact of transliteration, we use a very quick&dirty solution and map all words to Cyrillic characters. Nevertheless, it shows that a certain part of the metric can be explained and improved by consistency in writing (and perhaps rephrased worker/linguist task explanations)

In [26]:
import transliterate

In [27]:
def translit_to_cyr(entry):
    new_entry={'text':transliterate.translit(entry['text'],'ru'),'mark':entry['mark']}
    new_entry[col_name]=[]
    for elem in entry[col_name]:
        elem['text']=transliterate.translit(elem['text'],'ru')
        new_entry[col_name].append(elem)
    return new_entry
data_table_transl=Parallel(effective_n_jobs())(delayed(translit_to_cyr)(entry) for entry in tqdm_notebook(data_table))

HBox(children=(IntProgress(value=0, max=7022), HTML(value='')))




In [28]:
aggregated_part, wer, accuracy, cost,aggregated_incorrect,agg_incorrect_hyps=\
    evaluate_metrics(data_table_transl,
                     col_name,
                     partial(aggregate_prod_fixed,num_annot=5))

Aggregated: 65.0316%
WER: 5.6310%
Accuracy: 82.7525%
Mean overlap: 3.845


In [29]:
np.random.seed(0)
with open('errors_fixed_alg.txt','w+',encoding='utf8') as f:
    hyps_list=np.array(list(zip(aggregated_incorrect,agg_incorrect_hyps)))
    np.random.shuffle(hyps_list)
    for i,(row,hyp) in enumerate(hyps_list):
        print(i,file=f)
        print(row['text'],file=f)
        print('---',file=f)
        print(hyp,file=f)
        print('--',file=f)
        for annot in sorted(row[col_name],key=lambda x:x["submit_ts"]):
            print(annot['text'],file=f)
        print('',file=f)