## Частичное обучение

In [19]:
def read_file(path, word):
    
    trainList = list()
    targetList = list()
    textList = list()
    filename = path + word + '.txt'
    
    with open(filename, 'r', encoding='utf-8', newline='') as f:
        reader = csv.reader(f, delimiter='\t', quoting=csv.QUOTE_NONE)
        for index, row in enumerate(reader):
            if len(row) > 1:
                targetList.append(row[0])
                trainList.append(row[1])
            else:
                try:
                    textList.append(row[0])
                except:
                    print('СМОТРИ:', row, index)
                    pass
    return trainList, targetList, textList

In [2]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline

def learn_clf(model, trainList, targetList):
    
    text_clf = Pipeline([('tdidfvect', TfidfVectorizer(ngram_range=(1,2))),
                         ('clf', model),
                        ])
    
    text_clf.fit(trainList, targetList)
    
    return text_clf

In [3]:
import numpy as np

def semi_learn(clf, predicted, trainList, targetList, textList):

    startIndex = len(trainList)

    for index, result in enumerate(predicted):
        maximum = max(result)
        if maximum >= 0.9:
            label = np.argmax(result)
            targetList.append(clf.classes_[label])
            trainList.append(textList[index])

    endIndex = len(trainList)

    lenBefore = len(textList)

    for index in range(startIndex, endIndex):
        sentence = trainList[index]
        if sentence in textList:
            textList.remove(sentence)

    lenAfter = len(textList)
    
    count = lenBefore - lenAfter
    
    return trainList, targetList, textList, count

In [4]:
import csv

def write_learn_result(path, word, trainList, targetList):
    
    outputName = path + word + '.csv'
    
    with open(outputName, 'w', encoding='utf-8', newline='') as myfile:
        wr = csv.writer(myfile, quoting=csv.QUOTE_NONE, escapechar='\\')

        for index in range(0, len(trainList) - 1):
            line = targetList[index] + '\t' + trainList[index]
            wr.writerow([line])

In [21]:
import os
from collections import Counter

def result_func(dataset, mode, wordList):

    path = 'Input\\marked txt\\' + dataset + '(' + mode + ')\\'

    for word in wordList:

        bool_break = False
        trainList, targetList, textList = read_file(path, word)

        print(word, len(targetList), Counter(targetList))

        count = -1
        while (count != 0 and len(textList) > 0):
            try:
                clf = learn_clf(model, trainList, targetList)
                predicted = clf.predict_proba(textList)
                trainList, targetList, textList, count = semi_learn(clf, predicted, trainList, targetList, textList)
            except Exception as e:
                print('Error', word)

                if hasattr(e, 'message'):
                    print(e.message)
                else:
                    print(e)

                bool_break = True
                break

        if bool_break == False:

            savepath = 'Input\\full txt\\' + dataset + '(' + mode + ')\\' + str(model.__class__.__name__) + '\\'        
            if not os.path.exists(savepath):
                os.makedirs(savepath)

            write_learn_result(savepath, word, trainList, targetList)
            print('-----')

    print('Finish')

# Начало работы

In [17]:
# wordList = [
#     'балка',
#     'вид',
#     'винт',
#     'горн',
#     'губа',
#     'жаба',
#     'клетка',
#     'крыло',
#     'купюра', 
#     'курица',
#     'лавка', 
#     'лайка', 
#     'лев', 
#     'лира', 
#     'мина', 
#     'мишень',
#     'обед', 
#     'оклад', 
#     'опушка', 
#     'полис', 
#     'пост', 
#     'поток', 
#     'проказа', 
#     'пропасть', 
#     'проспект', 
#     'пытка',
#     'рысь',
#     'среда',
#     'хвост',
#     'штамп',
# ]

wordList = [
    'акция',
    'баба',
    'байка',
    'бум',
    'бычок',
    'вал',
    'газ',
    'гвоздика',
    'гипербола', 
    'град',
    'гусеница', 
    'дождь', 
    'домино', 
    'забой', 
    'икра', 
    'кабачок',
    'капот', 
    'карьер', 
    'кличка', 
    'ключ', 
    'кок', 
    'кольцо', 
    'концерт', 
    'котелок', 
    'крона', 
    'круп',
    'кулак',
    'лейка',
    'лук',
    'мандарин',
    'ножка', 
    'опора', 
    'патрон', 
    'печать', 
    'пол',
    'полоз', 
    'почерк', 
    'пробка', 
    'рак', 
    'рок', 
    'свет', 
    'секрет', 
    'скат', 
    'слог', 
    'стан',
    'стопка',
    'таз',
    'такса',
    'тюрьма',
    'шах',
    'шашка'
] 

dataset = 'bts-rnc'
mode = 'test'

In [22]:
# from sklearn.naive_bayes import MultinomialNB
# model = MultinomialNB(alpha=2.0, fit_prior=True)

from sklearn.neighbors import KNeighborsClassifier
model = KNeighborsClassifier(weights='uniform')

# from sklearn.naive_bayes import BernoulliNB
# model = BernoulliNB()

# from sklearn.ensemble import GradientBoostingClassifier
# model = GradientBoostingClassifier()

result_func(dataset=dataset, mode=mode, wordList=wordList)

акция 134 Counter({'29853': 69, '28176': 38, '15738': 27})
-----
баба 111 Counter({'29090': 42, '38405': 24, '21130': 23, '1': 13, '0': 9})
-----
байка 79 Counter({'16141': 75, '39858': 4})
-----
бум 56 Counter({'41843': 33, '32940': 19, '18362': 4})
-----
бычок 66 Counter({'27009': 34, '0': 28, '26270': 4})
-----
вал 109 Counter({'33648': 34, '0': 27, '41024': 25, '1': 23})
-----
газ 61 Counter({'23414': 28, '17569': 21, '16756': 12})
-----
гвоздика 75 Counter({'31219': 35, '0': 22, '26662': 18})
-----
гипербола 38 Counter({'1': 24, '0': 14})
-----
град 82 Counter({'13861': 30, '29527': 21, '0': 16, '35134': 15})
-----
гусеница 43 Counter({'19345': 24, '21860': 19})
-----
дождь 56 Counter({'40011': 35, '22422': 21})
-----
домино 91 Counter({'38622': 37, '13675': 30, '1': 24})
-----
забой 44 Counter({'36412': 29, '15050': 15})
-----
икра 74 Counter({'19490': 27, '28149': 24, '15898': 23})
-----
кабачок 60 Counter({'0': 30, '37286': 30})
-----
капот 64 Counter({'15899': 35, '13120': 29}