In [None]:
!pip install pyonmttok fasttext

In [None]:
!git clone https://github.com/facebookresearch/fastText.git
!cd fastText && mkdir build && cd build && cmake .. && make && make install

In [None]:
!rm -f ru_tg_train.tar.gz
!wget https://www.dropbox.com/s/1ecl9orr2tagcgi/ru_tg_train.tar.gz
!rm -f ru_tg_train.json
!tar -xzvf ru_tg_train.tar.gz
!rm ru_tg_train.tar.gz

In [None]:
!rm -f ru_tg_test.tar.gz
!wget https://www.dropbox.com/s/gvfk6t4g7kxw9ae/ru_tg_test.tar.gz
!rm -f ru_tg_test.json
!tar -xzvf ru_tg_test.tar.gz
!rm ru_tg_test.tar.gz

In [None]:
!wget https://www.dropbox.com/s/amua7p1rt1dcvy0/ru_cat_train_raw_markup.tsv
!wget https://www.dropbox.com/s/xia50d1h28e87x4/ru_cat_test_raw_markup.tsv
!head -n 2 ru_cat_train_raw_markup.tsv

In [None]:
import pyonmttok
tokenizer = pyonmttok.Tokenizer("conservative", joiner_annotate=False)

def preprocess(text):
    text = str(text).strip().replace("\n", " ").replace("\xa0", " ").lower()
    tokens, _ = tokenizer.tokenize(text)
    text = " ".join(tokens)
    return text

In [None]:
import json
from collections import Counter
from sklearn.metrics import cohen_kappa_score

def normalize(text):
    return text.replace("\t", " ").replace("\n", " ").replace('"', '').replace("\xa0", " ")

def convert_to_ft(answers_file_name, original_json, output_file_name, min_votes=3, use_preprocess=True):
    with open(answers_file_name, "r") as r:
        header = tuple(next(r).strip().split("\t"))
        records = []
        for line in r:
            fields = line.strip().split("\t")
            assert len(fields) == len(header), fields
            records.append(dict(zip(header, fields)))

    # Filter honeypots out
    records = [r for r in records if not r["GOLDEN:res"]]

    # Normalize fields
    for r in records:
        r.pop("GOLDEN:res", None)
        r.pop("HINT:text", None)
        for key, value in r.items():
            new_key = key.split(":")[-1]
            r[new_key] = r.pop(key)

    # Restore original urls (to fix a bug)
    with open(original_json, "r") as r:
        data = json.load(r)
        title2url = {normalize(d["title"]): d["url"] for d in data}
        for r in records:
            title = normalize(r["title"])
            if title not in title2url:
                continue
            r["url"] = title2url[title]

    # Calc inter-annotator agreement
    annotator2labels = defaultdict(dict)
    unique_keys = list(set([r["url"] for r in records]))
    unique_workers = list(set([r["worker_id"] for r in records]))
    unique_res = list(set([r["res"] for r in records]))
    res2num = {res: i for i, res in enumerate(unique_res)}
    for r in records:
        annotator2labels[r["worker_id"]][r["url"]] = r["res"]
    worker2labels = {}
    for worker_id in unique_workers:
        worker_labels = []
        worker_res = annotator2labels[worker_id]
        for key in unique_keys:
            if key not in worker_res:
                worker_labels.append(-1)
                continue
            worker_labels.append(res2num[worker_res[key]])
        worker2labels[worker_id] = worker_labels
    scores = []
    for w1, labels1 in worker2labels.items():
        for w2, labels2 in worker2labels.items():
            if w1 == w2:
                continue
            fixed_labels1 = []
            fixed_labels2 = []
            for l1, l2 in zip(labels1, labels2):
                if l1 == -1 or l2 == -1:
                    continue
                fixed_labels1.append(l1)
                fixed_labels2.append(l2)
            if fixed_labels1 and fixed_labels2:
                score = cohen_kappa_score(fixed_labels1, fixed_labels2)
                if -1.0 <= score <= 1.0:
                    scores.append(score)
    print("Avg kappa score: {}".format(sum(scores)/len(scores)))

    results = defaultdict(list)
    for r in records:
        results[r["url"]].append(r["res"])

    data = {r["url"]: r for r in records}
    for url, res in results.items():
        res_count = Counter(res)
        if res_count.most_common(1)[0][1] < min_votes:
            data.pop(url)

    rub_cnt = Counter()
    for _, d in data.items():
        rub_cnt[d["res"]] += 1
    print(rub_cnt.most_common())

    with open(output_file_name, "w") as w:
        records = list(data.values())
        random.shuffle(records)
        for d in records:
            title = d["title"] if not use_preprocess else preprocess(d["title"])
            text = d["text"] if not use_preprocess else preprocess(d["text"])
            w.write("__label__{} {} {}\n".format(d["res"], title, text))

convert_to_ft("ru_cat_train_raw_markup.tsv", "ru_tg_train.json", "ru_cat_train_markup.txt", min_votes=2, use_preprocess=True)
convert_to_ft("ru_cat_test_raw_markup.tsv", "ru_tg_test.json", "ru_cat_test_markup.txt", min_votes=4, use_preprocess=True)

In [None]:
!cat ru_cat_train_markup.txt | wc -l
!cat ru_cat_test_markup.txt | wc -l

In [None]:
!rm -f lenta-ru-news.csv.gz
!wget https://github.com/yutkin/Lenta.Ru-News-Dataset/releases/download/v1.0/lenta-ru-news.csv.gz
!rm -f lenta-ru-news.csv
!gzip -d lenta-ru-news.csv.gz

In [None]:
import csv
import random
from collections import Counter

def parse_lenta(input_file, output_file, use_preprocess=True):
    parts = {
        "society": 0.02,
        "economy": 0.02,
        "sports": 0.02,
        "technology": 0.02,
        "science": 0.02,
        "other": 0.02,
        "entertainment": 0.02
    }
    topics_mapping = {
        "Экономика": "economy",
        "Спорт": "sports",
        "Силовые структуры": "society",
        "Бизнес": "economy",
        "Культпросвет": "entertainment",
        ("Наука и техника", "Игры"): "entertainment",
        ("Наука и техника", "Наука"): "science",
        ("Наука и техника", "Космос"): "science",
        ("Наука и техника", "Жизнь"): "science",
        ("Наука и техника", "История"): "science",
        ("Наука и техника", "Оружие"): "society",
        ("Наука и техника", "Гаджеты"): "technology",
        ("Наука и техника", "Софт"): "technology",
        ("Наука и техника", "Техника"): "technology",
        ("Мир", "Общество"): "society",
        ("Мир", "Политика"): "society",
        ("Мир", "Происшествия"): "society",
        ("Мир", "Конфликты"): "society",
        ("Мир", "Преступность"): "society",
        ("Россия", "Политика"): "society",
        ("Россия", "Общество"): "society",
        ("Россия", "Происшествия"): "society",
        ("Интернет и СМИ", "Мемы"): "technology",
        ("Интернет и СМИ", "Киберпреступность"): "technology",
        ("Интернет и СМИ", "Интернет"): "technology",
        ("Интернет и СМИ", "Вирусные ролики"): "technology",
        ("Ценности", "Стиль"): "other",
        ("Ценности", "Явления"): "other",
        ("Ценности", "Внешний вид"): "other",
        ("Ценности", "Движение"): "technology",
        ("Из жизни", "Происшествия"): "society",
        ("Путешествия", "Происшествия"): "society",
    }
    with open(input_file, "r") as r:
        next(r)
        reader = csv.reader(r, delimiter=',')
        records = []
        for row in reader:
            url, title, text, topic, tag = row
            topic = topic.strip()
            tag = tag.strip()
            true_topic = None
            if topic in topics_mapping:
                true_topic = topics_mapping[topic]
            elif (topic, tag) in topics_mapping:
                true_topic = topics_mapping[(topic, tag)]
            else:
                continue
            records.append({"url": url, "title": title, "text": text, "res": true_topic})
        print(len(records))
        rub_cnt = Counter()
        for d in records:
            rub_cnt[d["res"]] += 1
        print(rub_cnt.most_common())
        with open(output_file, "w") as w:
            for r in records:
                if random.random() > parts[r["res"]]:
                    continue
                title = preprocess(r["title"]) if use_preprocess else r["title"]
                text = preprocess(r["text"]) if use_preprocess else r["text"]
                w.write("__label__{} {} {}\n".format(r["res"], title, text))

parse_lenta("lenta-ru-news.csv", "lenta_markup.txt")
!cat lenta_markup.txt | wc -l

In [None]:
!rm -f ru_not_news.txt
!wget https://www.dropbox.com/s/wwptzqhgxvtjhbd/ru_not_news.txt

In [None]:
with open("ru_not_news.txt", "r") as r, open("ru_not_news_fixed.txt", "w") as w:
    for line in r:
        words = line.strip().split(" ")
        text = " ".join(words[1:])
        text = preprocess(text)
        w.write("__label__{} {}\n".format("not_news", text))

In [None]:
!wget https://www.dropbox.com/s/2nx97d8nzbzusee/ru_vectors_v2.bin

In [None]:
!wget https://raw.githubusercontent.com/facebookresearch/fastText/master/python/doc/examples/bin_to_vec.py
!python bin_to_vec.py ru_vectors_v2.bin > ru_vectors_v2.vec

In [None]:
!cat ru_cat_train_markup.txt > ru_cat_train_all.txt
!cat lenta_markup.txt >> ru_cat_train_all.txt
!cat ru_not_news_fixed.txt >> ru_cat_train_all.txt
!shuf ru_cat_train_all.txt > ru_cat_train_shuf.txt

In [None]:
import random
with open("ru_cat_train_shuf.txt", "r") as r, open("ru_cat_train_train.txt", "w") as train, open("ru_cat_train_val.txt", "w") as val:
    for line in r:
        if random.random() < 0.1:
            val.write(line)
        else:
            train.write(line)
!cat ru_cat_train_val.txt | wc -l

In [None]:
!fasttext supervised -input ru_cat_train_train.txt -pretrainedVectors ru_vectors_v2.vec -dim 50 -autotune-validation ru_cat_train_val.txt -output ru_cat -autotune-modelsize 10M

In [None]:
!fasttext test ru_cat.ftz ru_cat_test_markup.txt

In [None]:
import fasttext
model = fasttext.load_model("ru_cat.ftz")
true_labels = []
predicted_labels = []
errors = []
with open("ru_cat_test_markup.txt", "r") as r:
    for line in r:
        words = line.strip().split(" ")
        label = words[0][9:]
        true_labels.append(label)
        text = " ".join(words[1:])
        predicted_label = model.predict([text])[0][0][0][9:]
        if label != predicted_label:
            errors.append((label, predicted_label, text[:100]))
        predicted_labels.append(predicted_label)
for label, predicted_label, text in errors:
    print("T: {} P: {} | {}".format(label, predicted_label, text))