## Task 2 [2p*]

Your task is to train the embeddings for Simple Wikipedia titles, using gensim library. As the example below shows, training is really simple:

```python
from gensim.test.utils import common_texts
from gensim.models import Word2Vec
model = Word2Vec(sentences=common_texts, vector_size=100, window=5, min_count=1, workers=4)
model.save("word2vec.model")
```
*sentences* can be a list of list of tokens, you can also use *gensim.models.word2vec.LineSentence(source)* to create restartable iterator from file. At first, use [this file](https://drive.google.com/file/d/1H0ChgZjcbW7x3Gy_9RK0CoduP5M8WscP/view?usp=drive_link) containing such pairs of titles, that one article links to another.

We say that two titles are *related* if they both contain a word (or a word bigram) which is not very popular (it occurs only in several titles). Make this definition more precise, and create the corpora which contains pairs of related titles. Make a mixture of the original corpora, and the new one, then train title vectors again.

Compare these two approaches using similar code to the code from Task 1.

In [None]:
from gensim.models.word2vec import LineSentence
from gensim.test.utils import datapath

sentences = LineSentence("simple.wiki.links.txt")


In [None]:
from gensim.models import Word2Vec

model_basic = Word2Vec(sentences=sentences, vector_size=100, window=5, min_count=1, workers=6)
model_basic.save("model_basic_links_only.w2v")


KeyboardInterrupt: 

___

In [5]:
with open("simple.wiki.links.txt", "r", encoding="utf-8") as f:
    original_pairs = [line.strip().lower().split() for line in f]

# ensure only pairs
original_pairs = [pair for pair in original_pairs if len(pair) == 2]

def split_title(title):
    return title.lower().replace('_', ' ').split()

sentences = [split_title(a) + split_title(b) for a, b in original_pairs]

In [None]:
from collections import Counter
import itertools


title_tokens = [split_title(t) for pair in original_pairs for t in pair]
word_counter = Counter(itertools.chain.from_iterable(title_tokens))

def get_bigrams(tokens):
    return list(zip(tokens, tokens[1:]))

bigram_counter = Counter(itertools.chain.from_iterable([get_bigrams(t) for t in title_tokens]))


In [None]:
RARE_THRESHOLD = 10
rare_words = set(w for w, c in word_counter.items() if c <= RARE_THRESHOLD)
rare_bigrams = set(b for b, c in bigram_counter.items() if c <= RARE_THRESHOLD)


In [None]:
from itertools import combinations

title_to_tokens = {t: split_title(t) for t in set(itertools.chain.from_iterable(original_pairs))}
title_list = list(title_to_tokens.keys())

# Build index of titles by rare features
word_index = {}
for title, tokens in title_to_tokens.items():
    features = set(tokens).intersection(rare_words)
    for f in features:
        word_index.setdefault(f, set()).add(title)

bigram_index = {}
for title, tokens in title_to_tokens.items():
    features = set(get_bigrams(tokens)).intersection(rare_bigrams)
    for b in features:
        bigram_index.setdefault(b, set()).add(title)

related_pairs = set()
for index in (word_index, bigram_index):
    for titles in index.values():
        for a, b in combinations(titles, 2):
            related_pairs.add((a, b))


In [None]:
combined_pairs = original_pairs + list(related_pairs)
combined_sentences = [split_title(a) + split_title(b) for a, b in combined_pairs]


In [None]:
model_mixed = Word2Vec(sentences=combined_sentences, vector_size=100, window=5, min_count=1, workers=4)
model_mixed.save("model_mixed_links_plus_related.w2v")


___

In [None]:
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt

def visualize(model, title_tokens, num=100):
    words = list(title_tokens.keys())[:num]
    vectors = [model.wv[' '.join(split_title(w))] for w in words if ' '.join(split_title(w)) in model.wv]
    reduced = PCA(n_components=2).fit_transform(vectors)

    plt.figure(figsize=(10, 10))
    for i, word in enumerate(words):
        plt.scatter(reduced[i, 0], reduced[i, 1])
        plt.text(reduced[i, 0], reduced[i, 1], word, fontsize=9)
    plt.title("PCA of title vectors")
    plt.grid(True)
    plt.show()
