In [None]:
# download wikipedia dump
! wget https://dumps.wikimedia.org/plwiki/20210801/plwiki-20210801-pages-articles-multistream.xml.bz2

# extract raw passages with wikiextractor
! pip install wikiextractor
! mkdir wiki/
! wikiextractor -o wiki/ -b 50M plwiki-20210801-pages-articles-multistream.xml.bz2

In [None]:
import os
import re
import pickle
from copy import copy

import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from sentence_transformers import SentenceTransformer

## Parse Wikipedia

In [None]:
def load_part(part_path):
    data = []
    with open(part_path) as f:
        for line in f:
            line = line.strip()
            if line.startswith('<doc '):
                m = re.match(r'<doc id="(.*)" url="(.*)" title="(.*)">', line)
                data.append({
                    'id': m.group(1),
                    'url': m.group(2),
                    'title': m.group(3),
                    'content': [],
                })
            elif line.startswith('</doc>'):
                continue
            else:
                data[-1]['content'].append(line)
    
    return data

In [None]:
def filter_content(data, min_tokens=10, max_tokens=250):
    out = []
    for article in data:
        filtered_article = copy(article)
        filtered_article['content'] = []
        
        for i, line in enumerate(article['content']):
            ls = line.split(' ')
            if i < 2:
                continue
            elif i == 2 or len(ls) >= min_tokens:
                filtered_article['content'].append(' '.join(ls[:max_tokens]))

        out.append(filtered_article)
    
    return out

In [None]:
def flat_content(data):
    out = []
    for article in data:
        for i, line in enumerate(article['content']):
            out.append({
                'id': article['id'],
                'url': article['url'],
                'title': article['title'],
                'idx': i,
                'text': line,
            })
    return out

In [None]:
def encode_texts(texts, multi=True, normalize=True):
    if multi:
        pool = encoder.start_multi_process_pool()
        emb = encoder.encode_multi_process(texts, pool)
        encoder.stop_multi_process_pool(pool)
    else:
        emb = encoder.encode(texts, convert_to_numpy=True)
    
    if normalize:
        emb = emb / np.sqrt(np.sum(emb**2, axis=1, keepdims=True))
    
    return emb

def encode_paragraph(paragraphs):
    passages = [r['title'] + ' | ' + r['text'] for r in paragraphs]
    embeddings = encode_texts(passages)
    
    out = []
    for row, emb in zip(paragraphs, embeddings):
        row = copy(row)
        row['emb'] = emb
        out.append(row)
        
    return out

In [9]:
def parse(part_path):
    print(part_path)
    
    parsed_articles = load_part(part_path)
    print(sum([len(a['content']) for a in parsed_articles]))

    filtered_articles = filter_content(parsed_articles)
    print(sum([len(a['content']) for a in filtered_articles]))
    
    paragraphs = flat_content(filtered_articles)
    print(len(paragraphs))
    
    encoded_paragraphs = encode_paragraph(paragraphs)
    print(len(encoded_paragraphs))
    
    with open(part_path + '.pkl', 'wb') as f:
        pickle.dump(encoded_paragraphs, f)


def parse_txt(part_path):
    print(part_path)
    
    paragraphs = []
    with open(part_path) as f:
        for i, line in enumerate(f):
            paragraphs.append({
                'id': part_path + '::' + str(i),
                'url': part_path + '::' + str(i),
                'title': 'Przys≈Çowie',
                'idx': 0,
                'text': line.strip(),
            })
    print(len(paragraphs))
    
    encoded_paragraphs = encode_paragraph(paragraphs)
    print(len(encoded_paragraphs))
    
    with open(part_path + '.pkl', 'wb') as f:
        pickle.dump(encoded_paragraphs, f)

## Encode Paragraphs

In [None]:
encoder = SentenceTransformer('piotr-rybak/poleval2021-task4-herbert-large-encoder')

for file in os.listdir('wiki/'):
    parse(f'wiki/{file}')