In [None]:
pip install pymystem3 ruwordnet pyaspeller

In [None]:
!ruwordnet download

In [None]:
import re
import numpy as np
import pandas as pd

import pymystem3

from ruwordnet import RuWordNet

from pyaspeller import YandexSpeller

from tqdm.notebook import tqdm
tqdm.pandas()

from sys import setrecursionlimit
setrecursionlimit(1000)

In [None]:
class Spellchecker():

  speller: object
  spell: object

  def __init__(self, **kwargs):
    pass

  def fix_sent(self, s, w):
    try:
      return ' '.join(self.spell(s+' '+w).split()[len(s.split()):]).lower()
    except:
      return w

  def fix_tuned(self, s, w):
    word = self.spell(w)
    sent = ' '.join(self.spell(s+' '+w).split()[len(s.split()):])
    w_vars = next(self.speller.spell(w), {'s':[]})['s']
    s_vars = [x for x in self.speller.spell(s+' '+w) if x['word']==w]
    if s_vars:
      s_vars = s_vars[-1]['s']

    if word!=w or sent!=w:
      if w in w_vars+s_vars:
        return w
      elif sent!=w:
        return sent
      else:
        return word
    else:
      return w

class YaSpeller(Spellchecker):
  def __init__(self, max_requests=100):
    self.speller = YandexSpeller(lang='ru', ignore_capitalization=True, max_requests=max_requests)
    self.spell = self.speller.spelled

In [None]:
class PredictabilityDataProcessing:
  def __init__(self):
    self.wn = RuWordNet()
    self.morph = pymystem3.Mystem(disambiguation=True)
    self.speller = YaSpeller()
    self.preprocess_cache = dict()
    self.sem_cache = dict()
    self.rel_cache = dict()
    self.spell_cache = dict()
  
  def analyse_grammar(self, data):
    items = data.replace(to_replace=np.nan, value='')['item']
    texts = '\n'.join([(i.strip()+' '+a).strip() if a else (i.strip()+' DUMMY').strip() for i, a in zip(items, data.answer)])+'\n'
    
    analyses = np.array(self.morph.analyze(texts))
    n_indexes = np.where((analyses == {'text': '\n'}) | (analyses == {'text': '-\n'}))[0]
    words = [analyses[i-1] for i in n_indexes]

    print('Collecting lemmas...')
    lemma = [word['analysis'][0]['lex'] if ('analysis' in word and word['analysis']) else '' for word in words]

    print('Collecting POS tags...')
    pos = [word['analysis'][0]['gr'].split('=')[0].split(',')[0] if ('analysis' in word and word['analysis']) else '' for word in words]
    
    print('Collecting grammatical features...')
    gram = []
    for word in words:
      if ('analysis' in word and word['analysis']):
        try:
          gram.append(word['analysis'][0]['gr'].split('=')[0].split(',')[1]+','+word['analysis'][0]['gr'].split('=')[1])
        except:
          gram.append(word['analysis'][0]['gr'].split('=')[1])
      else:
        gram.append('')
    
    gram = ['|'.join( [g.split('(')[0]+x.strip(')') for x in g.split('(')[1].split('|')] ) if '(' in g else g for g in gram]
    
    return {'lemma': lemma, 'pos': pos, 'gram': gram}
  
  def analyse_sem(self, lemma):
    sem = []
    
    for lem in tqdm(lemma):
      if lem:
        if lem in self.sem_cache:
          sem.append(self.sem_cache[lem])
        else:
          try:
            sem.append({chain.split('>')[0].lower(): chain.lower() for chain 
                        in self.__printer__(self.__get_hypernyms_chains__(lem)) if chain})
            self.sem_cache[lem] = sem[-1]
          except:
            sem.append('')
            self.sem_cache[lem] = ''
      else:
        sem.append('')
    
    return [s if s else '' for s in sem]

  def analyse_rel(self, lemma):
    rel = []
    
    for lem in tqdm(lemma):
      if lem:
        if lem in self.rel_cache:
          rel.append(self.rel_cache[lem])
        else:
          rel.append(self.__other_sem__(lem))
          self.rel_cache[lem] = rel[-1]
      else:
        rel.append('')
    
    return rel

  def fix_spelling(self, line):
    if line.rel:
      return line.answer
    else:
      if int(line['number'])>1:
        s = line['item'].strip()
      else:
        s = ''
      if isinstance(line.answer, str):
        if s+' '+line.answer in self.spell_cache:
          return self.spell_cache[s+' '+line.answer]
        else:
          sp = self.speller.fix_tuned(s, line.answer)
          self.spell_cache[s+' '+line.answer] = sp
          return sp
      else:
        return ''


  def process_data(self, data_to_process, process_answers=True, process_grammar=True, 
                   process_sem=True, process_rel=True, process_spelling=True):
    
    data = data_to_process.copy()

    if process_answers:
      print('Preprocessing data...')
      if 'raw_answer' not in data:
        data['raw_answer'] = data['answer']
      answers = data.progress_apply(self.preprocessing, axis=1)
      data['answer'] = answers
    
    if process_grammar:
      print('Analysing grammar...')
      grammar = self.analyse_grammar(data)
      data['lemma'] = grammar['lemma']
      data['pos'] = grammar['pos']
      data['gram'] = grammar['gram']

    if process_sem:
      print('Collecting semantic chains...')
      sems = self.analyse_sem(data['lemma'])
      data['sems'] = sems
    
    if process_rel:
      print('Collecting other semantic relations...')
      rel = self.analyse_rel(data['lemma'])
      data['rel'] = rel

    if process_spelling:
      print('Respelling...')
      spellings = data.progress_apply(self.fix_spelling, axis=1)
      data['processed_answer'] = data['answer']
      data['answer'] = spellings
      data_same = data[data.answer==data.processed_answer]
      data_corrected = data[data.answer != data.processed_answer]
      if not data_corrected.empty:
        print('Reanalysing corrected data...')
        data_corrected = self.process_data(data_corrected, process_answers=False, process_spelling=False)
        return pd.concat([data_same, data_corrected])
    
    return data

  def __get_hypernyms_chain__(self, synset):
    hypernyms = {hyper:{} for hyper in synset.hypernyms}
    if hypernyms:
      for hyper in hypernyms:
        hypernyms[hyper] = self.__get_hypernyms_chain__(hyper)
    return hypernyms
  
  def __printer__(self, razbor):
    razbor_p = []
    if not razbor:
        return ['']
    else:
        for x in razbor:
            razbor_p += [x.title+'>'+stroka for stroka in self.__printer__(razbor[x])]
    return razbor_p
  
  def __get_hypernyms_chains__(self, word):
    try:
      synsets = {homo.synset:{} for homo in self.wn[word]}
    except KeyError:
      return ''
    for synset in synsets:
      synsets[synset] = self.__get_hypernyms_chain__(synset)
    return synsets
  
  def __other_sem__(self, word):
    try:
      synsets = {homo.synset:{} for homo in self.wn[word]}
    except KeyError:
      return ''
    for synset in synsets:
      synsets[synset] = {
          'antonyms':[x.title.lower() for x in synset.antonyms],
          'domains':[x.title.lower() for x in synset.domains],
          'domain_items':[x.title.lower() for x in synset.domain_items],
          'meronyms':[x.title.lower() for x in synset.meronyms],
          'holonyms':[x.title.lower() for x in synset.holonyms],
          'premises':[x.title.lower() for x in synset.premises],
          'conclusions':[x.title.lower() for x in synset.conclusions],
          'causes':[x.title.lower() for x in synset.causes],
          'effects':[x.title.lower() for x in synset.effects],
          'related':[x.title.lower() for x in synset.related]
          }
    return {k.title.lower(): v for k, v in synsets.items()}
  
  def preprocessing(self, line):

    if isinstance(line['raw_answer'], str):

      if int(line['number'])>1:
          stimul = str(line['item'])+' '+str(line.raw_answer).lower()
      else:
          stimul = str(line.raw_answer).lower()

      if stimul in self.preprocess_cache:
        return self.preprocess_cache[stimul]

      if isinstance(line['item'], str):
        c = re.sub('[^а-яА-ЯёЁ\-\s]', '', line['item'])
        i = c.lower().split()
      else:
        c = ''
        i = ['введите', 'первое', 'слово']

      s = re.sub('[^а-яА-ЯёЁ\-\s]', '', line['raw_answer'])
      s = re.sub('\s+', ' ', s)
      s = s.strip().lower()

      if s == c:
        return ''

      if ' ' in s:
        if not sum([re.sub('ё', 'е', x)!=re.sub('ё', 'е', y) for x, y in zip(s.split(), i)]):
          s = ' '.join(s.split()[len(i):])

        if ' ' in s:
          ns = self.speller.fix_tuned(c, ' '.join(s.split()[:2]))
          if ' ' in ns:
            s = s.split()[0]
          else:
            s = ns

      self.preprocess_cache[stimul] = s.strip()

      return s.strip()
    
    else:
      return ''

In [None]:
class AccuracyTables():
  
    def __init__(self, original):

        self.original = original
    
    def process_data(self, data):

        table = pd.merge(left=data, right=self.original, how='left', on=['original', 'number'],
               suffixes=('_given', '_true'))

        table = table.replace(np.nan, '')

        table[['lex_accuracy', 'lemma_accuracy', 'pos_accuracy', 'sem_accuracy', 'gram_accuracy']] = table.progress_apply(self.get_accuracies, axis=1, result_type='expand')

        table = table.drop(columns=['item_true'])

        return table

    def get_accuracies(self, line):

        if line.answer_given == line.answer_true:
            lex = 1
        else:
            lex = 0

        if line.pos_given==line.pos_true:
            if line.pos_given!='':
                pos=1
            else:
                pos=np.nan
        else:
            pos=0

        if line.lemma_given==line.lemma_true:
            if line.lemma_given!='':
                lemma=1
            else:
                lemma=np.nan
        else:
            lemma=0

        if line.sems_true and line.sems_given:
            sems_given = [set([sem for sem in s.split('>') if sem]) for s in line.sems_given.values()]
            sems_true = [set([sem for sem in s.split('>') if sem]) for s in line.sems_true.values()]

            sem_inter = []
            for g in sems_given:
                for t in sems_true:
                    sem_inter.append(len(g&t)/len(t))
            sem = max(sem_inter)
        else:
            sem = np.nan

        if line.gram_true and line.sems_given:
            gram_given = [set([gram for gram in g.split(',') if gram]) for g in line.gram_given.split('|')]
            gram_true = [set([gram for gram in g.split(',') if gram]) for g in line.gram_true.split('|')]

            gram_inter = []
            for g in gram_given:
                for t in gram_true:
                    gram_inter.append(len(g&t)/len(t))
                gram = max(gram_inter)
        else:
            gram = np.nan

        return lex, lemma, pos, sem, gram