In [1]:
import os
import time
from typing import List
import numpy as np
import pandas as pd
import pymorphy2
import matplotlib.pyplot as plt
%matplotlib inline
from nltk.tokenize import word_tokenize, sent_tokenize
import nltk
nltk.download("stopwords")
from nltk.corpus import stopwords
from pymystem3 import Mystem
from string import punctuation
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from catboost import CatBoost, CatBoostRegressor, Pool

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/aleksandr.khvorov/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [2]:
mystem = Mystem() 
russian_stopwords = stopwords.words("russian")
morph = pymorphy2.MorphAnalyzer()

In [8]:
def preprocess(text):
    tokens = mystem.lemmatize(text.lower())
    tokens = [token for token in tokens if token not in russian_stopwords\
              and token != " " \
              and token.strip() not in punctuation]
    
    return " ".join(tokens)

def get_data():
    data = []
    q_ids = []
    with open("train_qa.csv") as f:
        for line in list(f)[1:]:
            tokens = line.strip().split('","')
            data.append(tuple(map(lambda x: list(filter(lambda w: w != " " and w.strip() not in punctuation, 
                                            word_tokenize(x.strip(' "')))), tokens[2:])))
            q_ids.append(int(tokens[1]))
    return data, q_ids
            
data, q_ids = get_data()

In [9]:
len(data)

50364

In [10]:
# data_train, data_test, _, _ = train_test_split(data, [0] * len(data), train_size=0.7)

In [11]:
from collections import Counter
Counter(sorted([len(x[2]) for x in data if len(x[2]) < 10]))
# np.mean(sorted([len(x[2]) for x in data if len(x[2]) < 10]))

Counter({1: 8223,
         2: 11170,
         3: 10785,
         4: 7057,
         5: 4495,
         6: 2939,
         7: 2131,
         8: 1389,
         9: 764})

In [None]:
# simple solution

In [21]:
def avg_position_of_question(text, question):
    ind_sum = 0
    ind_count = 0
    inds = []
    for word in question:
#         print(word)
        ind = 0
        while word in text[ind:]:
#             print(ind, word)
            ind = text.index(word, ind)
            inds.append(ind)
            ind_sum += ind
            ind_count += 1
            ind += 1
#     return ind_sum // ind_count
    if len(inds) == 0:
        return 0
    return inds[len(inds) // 2]

def window_around_question_with_width(text, question, width=3):
    position = avg_position_of_question(text, question)
    start_width = len(question) // 2
#     print(text)
#     print(question)
#     print("Position:", position, "Start width:", start_width)
    for i in range(1, 4 * width):
        first = text[max(0, position - start_width - i) : max(0, position - start_width)]
        second = text[min(position + start_width, len(text)) : min(position + start_width + i, len(text))]
        sent = first + second
#         print(start_width, i)
#         print("First", first)
#         print(second)
        ans = []
        for w in sent:
            if w not in question and w not in punctuation:
                ans.append(w)
#         print(ans)
#         print()
        if len(ans) >= width:
            return ans
        
# window_around_question_with_width(data[0][0], data[0][1])
for i in range(5):
#     print(data[i][0])
#     print(data[i][1])
    print(data[i][2])
    print(window_around_question_with_width(data[i][0], data[i][1]))
    print()

['в', 'Древнем', 'Египте']
['в', 'Древнем', 'Египте']

['COSTAR']
['Телескоп', 'имеет', 'модульную']

['теория', 'дрейфа', 'материков']
['Альфреда', 'маргинальной', 'науки', 'и']

['изделиям', 'из', 'монолитных', 'камней']
['различных', 'фракций', 'для']

['оральные', 'и', 'назальные']
['встречаются', 'в', 'слабым', 'гласным']



In [22]:
def process_file():
    data = []
    w = open("ans.txt", 'w')
    with open("dataset_281937_1.txt") as f:
        for line in list(f)[1:]:
            tokens = line.strip().split('\t')
            text, q = map(lambda x: word_tokenize(x.strip(' "')), tokens[2:])
            ans = window_around_question_with_width(text, q) or []
            w.write(tokens[1] + "\t" + " ".join(ans) + "\n")
            
process_file()

In [12]:
texts, qs, anss = {}, {}, {}
norm_words = {}

In [13]:
def update_dicts(data, q_ids):
    for d, q_id in zip(data, q_ids):
        if len(d) == 3:
            text, q, ans = d
        else:
            text, q = d
        texts[q_id] = text
        qs[q_id] = q
        for word in text + q:
            if word not in norm_words:
                w = morph.parse('стали')[0].normal_form
                norm_words[word] = w
        if len(d) == 3:
            anss[q_id] = ans

In [14]:
update_dicts(data, q_ids)

In [15]:
class Point:
    def __init__(self, q_id, sample_ind, sample_len):
        self.q_id = q_id
        self.sample_ind = sample_ind
        self.sample_len = sample_len
        self._answer = None
        self._norm_text = None
        self._norm_q = None
        self._norm_ans = None
    
    def text(self):
        return texts[self.q_id]
    
    def question(self):
        return qs[self.q_id]
        
    def answer(self):
        if self._answer is None:
            self._answer = self.text()[self.sample_ind : self.sample_ind + self.sample_len]
        return self._answer
    
    def norm_text(self):
        if self._norm_text is None:
            self._norm_text = [norm_words[w] for w in self.text()]
        return self._norm_text
    
    def norm_q(self):
        if self._norm_q is None:
            self._norm_q = [norm_words[w] for w in self.question()]
        return self._norm_q
    
    def norm_ans(self):
        if self._norm_ans is None:
            self._norm_ans = [norm_words[w] for w in self.answer()]
        return self._norm_ans
        

def samples_from_example(text, question, answer, q_id, ranges=[2, 3]):
    samples = []
    ans = []
    true_ind = 0
    true_ans_len = len(answer)
    for i in range(len(text) - true_ans_len):
        if text[i : i + true_ans_len] == answer:
            true_ind = i
            break
    for ans_len in ranges:
        for i in range(len(text) - ans_len):
            samples.append(Point(q_id, i, ans_len))
            
            def target1():
                return min(abs(true_ind - i), true_ans_len) / max(ans_len, true_ans_len)
            
            def target2():
                l = max(ans_len, true_ans_len)
                intersect_len = max(0, (abs(true_ind - i) - l))
                return intersect_len / l
            
            def target3():
                return len(set(answer) & set(text[i : i + ans_len])) / true_ans_len
            
            def target4():
                ans = text[i : i + ans_len]
                precision = len([for w in answer if w in ans]) / ans_len
                recall = len([for w in ans if w in answer]) / true_ans_len
                return 2 * precision * recall / (precision + recall)
            
            ans.append(target4())
    return samples, ans

def create_dataset(data, q_ids, ranges=[2, 3]):
    X, y = [], []
    new_q_ids = []
    for p, q_id in zip(data, q_ids):
        samples, ans = samples_from_example(p[0], p[1], p[2], q_id, ranges)
        X += samples
        y += ans
        new_q_ids += [q_id] * len(ans)
    return X, y, new_q_ids

In [16]:
# text, question, sample_ind, sample_len

def intersected_words_num(point):
    intersect_words_num = len(set(point.question()) & set(point.answer()))
    return [intersect_words_num, intersect_words_num / len(point.question()), intersect_words_num / point.sample_len]

def intersected_norm_words_num(point):
    intersect_words_num = len(set(point.norm_q()) & set(point.norm_ans()))
    return [intersect_words_num, intersect_words_num / len(point.question()), intersect_words_num / point.sample_len]
    
def stop_words_in_answer_ratio(point):
    stop_num = len([w for w in point.norm_ans() if w in russian_stopwords])
    return [stop_num, stop_num / point.sample_len]

def query_len(point):
    return [len(point.question())]

def answer_len(point):
    return [point.sample_len]

def position_dist(p):
    q_pos = avg_position_of_question(p.norm_text(), p.norm_q())
    return [abs(q_pos - p.sample_ind)]

def answer_punctuation(p):
    n = len([w for w in p.answer() if w in punctuation])
    return [n, n / p.sample_len]

In [17]:
def features(points):
    features_list = [intersected_words_num, intersected_norm_words_num, stop_words_in_answer_ratio,
                     query_len, answer_len, position_dist, answer_punctuation]
    data = []
    for p in points:
        x = []
        for f in features_list:
            x += f(p)
        data.append(x)
    return data

In [18]:
data_train, data_test, q_id_tr, q_id_te = train_test_split(data, q_ids, train_size=0.2)
ranges = [3]
points_tr, y_tr, q_id_tr = create_dataset(data_train, q_id_tr, ranges=ranges)
# points_te, y_te, q_id_te = create_dataset(data_test, q_id_te, ranges=ranges)

In [19]:
# print(data[0])
# pss, yss, qss = create_dataset([data[0]], [q_ids[0]], ranges=[3])
# for i in range(len(pss)):
#     print(pss[i].answer())
#     print(yss[i])
#     print()

In [23]:
X_tr = features(points_tr)
# X_te = features(points_te)

In [None]:
len(X_tr)

In [24]:
train = Pool(
    data=X_tr,
    label=y_tr,
    group_id=q_id_tr
)

# test = Pool(
#     data=X_te,
#     label=y_te,
#     group_id=q_id_te
# )

In [34]:
model = CatBoost({'loss_function': 'PairLogit:max_pairs=1000', 'iterations': 200, 'metric_period': 25, 'random_seed': 0})
model.fit(train, verbose=1)

0:	learn: 0.6904709	total: 542ms	remaining: 1m 47s
25:	learn: 0.6572281	total: 15.4s	remaining: 1m 42s
50:	learn: 0.6501086	total: 31.6s	remaining: 1m 32s
75:	learn: 0.6482070	total: 47.2s	remaining: 1m 16s
100:	learn: 0.6473885	total: 1m 4s	remaining: 1m 2s
125:	learn: 0.6470456	total: 1m 19s	remaining: 46.9s
150:	learn: 0.6467218	total: 1m 36s	remaining: 31.3s
175:	learn: 0.6465335	total: 1m 52s	remaining: 15.3s
199:	learn: 0.6463876	total: 2m 7s	remaining: 0us


<catboost.core.CatBoost at 0x1a93c1d8d0>

In [26]:
model_reg = CatBoostRegressor(iterations=200, metric_period=25, random_seed=0)
model_reg.fit(train, verbose=1)

0:	learn: 0.1502644	total: 121ms	remaining: 24s
25:	learn: 0.1490274	total: 1.79s	remaining: 12s
50:	learn: 0.1487333	total: 3.3s	remaining: 9.64s
75:	learn: 0.1486537	total: 4.72s	remaining: 7.71s
100:	learn: 0.1486254	total: 6.14s	remaining: 6.02s
125:	learn: 0.1486081	total: 7.53s	remaining: 4.42s
150:	learn: 0.1485968	total: 8.9s	remaining: 2.89s
175:	learn: 0.1485853	total: 10.3s	remaining: 1.41s
199:	learn: 0.1485757	total: 11.7s	remaining: 0us


<catboost.core.CatBoostRegressor at 0x1a2df82a90>

In [27]:
model_reg.get_feature_importance(train)

array([32.06656718,  8.4165228 , 42.65196945,  0.        ,  3.34250813,
        0.        ,  0.        ,  0.        ,  3.02541846,  0.        ,
       10.49701398,  0.        ,  0.        ])

In [28]:
def read_eval_file():
    data = []
    q_ids = []
    with open("dataset_281937_1.txt") as f:
        for line in list(f)[1:]:
            tokens = line.strip().split('\t')
            text, q = map(lambda x: word_tokenize(x.strip(' "')), tokens[2:])
            data.append((text, q, ['привет']))
            q_id = int(tokens[1])
            q_ids.append(q_id)
    return data, q_ids
            
eval_data, eval_q_ids = read_eval_file()
update_dicts(eval_data, eval_q_ids)

In [29]:
points_eval, y_eval, q_id_eval = create_dataset(eval_data, eval_q_ids, ranges=ranges)
X_eval = features(points_eval)
eval_pool = Pool(data=X_eval, group_id=q_id_eval)

In [30]:
y_pred = model_reg.predict(eval_pool)

In [31]:
def ans_from_groups(y_pred, q_id_eval):
    q_id_ans = {}
    for i, (y, q_id) in enumerate(zip(y_pred, q_id_eval)):
        if q_id not in q_id_ans:
            q_id_ans[q_id] = i
        if y > y_pred[q_id_ans[q_id]]:
            q_id_ans[q_id] = i
    return q_id_ans

def write_ans(points, q_id_ans, eval_q_ids):
    w = open("ans.txt", 'w')
    for q_id in eval_q_ids:
        p = points[q_id_ans[q_id]]
        ans = p.answer()
        w.write(str(q_id) + "\t" + " ".join(ans) + "\n")
        
write_ans(points_eval, ans_from_groups(y_pred, q_id_eval), eval_q_ids)