# Notebook for preprocessing Wikipedia (English) dataset

### Initilizing phonemizer and tokenizer

In [None]:
!wget https://huggingface.co/datasets/fatlonder/sq-wiki/resolve/main/sqwiki-20241001.parquet_clean.parquet .

In [None]:
!sudo apt-get install espeak-ng -y
!python -m pip install pandas pyarrow mwparserfromhell singleton-decorator datasets "transformers<4.33.3" accelerate nltk phonemizer sacremoses pebble espeakng

In [None]:
import os
import pyarrow.parquet as pq
import pyarrow as pa
import pandas as pd
from pebble import ProcessPool
from concurrent.futures import TimeoutError
import yaml

from datasets import load_dataset

from phonemize import phonemize
import phonemizer
from transformers import TransfoXLTokenizer

In [None]:
config_path = "Configs/config.yml" # you can change it to anything else
config = yaml.safe_load(open(config_path))

In [None]:
global_phonemizer = phonemizer.backend.EspeakBackend(language='sq', preserve_punctuation=True,  with_stress=True)
tokenizer = TransfoXLTokenizer.from_pretrained(config['dataset_params']['tokenizer']) # you can use any other tokenizers if you want to

### Process dataset

In [None]:
#dataset = load_dataset("wikipedia", "20220301.en")['train'] # you can use other version of this dataset
dataset = load_dataset("parquet", data_files={'train': '/kaggle/working/sqwiki-20241001.parquet_clean.parquet'})

!mkdir prcesedsq2

root_directory = "/kaggle/working/prcesedsq2" # set up root directory for multiprocessor processing
input_file = '/kaggle/working/sqwiki-20241001.parquet_clean.parquet'
output_file = '/kaggle/working/merged_dataset.parquet'
num_shards = 100
max_workers = 4

In [None]:
def process_shard(i):
    try:
        df = pd.read_parquet(input_file)
        
        shard_size = len(df) // num_shards
        start_idx = i * shard_size
        end_idx = start_idx + shard_size if i != num_shards - 1 else len(df)
        
        shard = df.iloc[start_idx:end_idx]
        
        shard['processed_text'] = shard['text'].apply(lambda text: phonemize(text, global_phonemizer, tokenizer))
        
        directory = os.path.join(root_directory, f"shard_{i}")
        if not os.path.exists(directory):
            os.makedirs(directory)
        shard.to_parquet(os.path.join(directory, 'processed.parquet'))
        
        print(f"Shard {i} processed and saved.")
    
    except Exception as e:
        print(f"Failed to process shard {i}: {e}")

#### Note: You will need to run the following cell multiple times to process all shards because some will fail. Depending on how fast you process each shard, you will need to change the timeout to a longer value to make more shards processed before being killed.


In [None]:
with ProcessPool(max_workers=max_workers) as pool:
    future = pool.map(process_shard, range(num_shards), timeout=None)

    iterator = future.result()
    while True:
        try:
            result = next(iterator)
            print(result)
        except StopIteration:
            break
        except TimeoutError as error:
            print(f"Function took longer than {error.args[1]} seconds")

### Collect all shards to form the processed dataset

In [None]:
def merge_parquet_files(root_directory, output_file):
    shard_dirs = [os.path.join(root_directory, d) for d in os.listdir(root_directory) if os.path.isdir(os.path.join(root_directory, d))]
    tables = []
    for shard_dir in shard_dirs:
        try:
            shard_file = os.path.join(shard_dir, 'processed.parquet')
            table = pq.read_table(shard_file)
            tables.append(table)
            print(f"Loaded {shard_file}")
        except Exception as e:
            print(f"Failed to load {shard_file}: {e}")

    if tables:
        combined_table = pa.concat_tables(tables)
        pq.write_table(combined_table, output_file)
        print(f"Merged dataset saved to {output_file}")
    else:
        print("No tables were loaded successfully.")

In [None]:
merge_parquet_files(root_directory, output_file)
df[['input_ids', 'phonemes']] = df['processed_text'].apply(pd.Series)
new_df = df.apply(pd.Series)[['id', 'url', 'title', 'input_ids', 'phonemes']]
new_df.to_parquet('sq-wiki-text-phonem-training-20241001.parquet')

In [None]:
from simple_loader import FilePathDataset, build_dataloader

dataset = load_dataset("parquet", data_files={'train': '/content/sq-wiki-text-phonem-training-20241001.parquet'})['train']

### Remove unneccessary tokens from the pre-trained tokenizer
The pre-trained tokenizer contains a lot of tokens that are not used in our dataset, so we need to remove these tokens. We also want to predict the word in lower cases because cases do not matter that much for TTS. Pruning the tokenizer is much faster than training a new tokenizer from scratch. 

In [None]:
file_data = FilePathDataset(dataset)
loader = build_dataloader(file_data, num_workers=2, batch_size=8)
special_token = config['dataset_params']['word_separator']

In [None]:
special_token = config['dataset_params']['word_separator']

In [None]:
# get all unique tokens in the entire dataset

from tqdm import tqdm

unique_index = [special_token]
for _, batch in enumerate(tqdm(loader)):
    unique_index.extend(batch)
    unique_index = list(set(unique_index))

In [None]:
# get each token's lower case

lower_tokens = []
for t in tqdm(unique_index):
    word = tokenizer.decode([t])
    if word.lower() != word:
        t = tokenizer.encode([word.lower()])[0]
        lower_tokens.append(t)
    else:
        lower_tokens.append(t)

In [None]:
lower_tokens = (list(set(lower_tokens)))

In [None]:
# redo the mapping for lower number of tokens

token_maps = {}
for t in tqdm(unique_index):
    word = tokenizer.decode([t])
    word = word.lower()
    new_t = tokenizer.encode([word.lower()])[0]
    token_maps[t] = {'word': word, 'token': new_t}

In [None]:
import pickle
with open(config['dataset_params']['token_maps'], 'wb') as handle:
    pickle.dump(token_maps, handle)
print('Token mapper saved to %s' % config['dataset_params']['token_maps'])

### Test the dataset with dataloader


In [None]:
from dataloader import build_dataloader

train_loader = build_dataloader(dataset, batch_size=32, num_workers=0, dataset_config=config['dataset_params'])

In [None]:
_, (words, labels, phonemes, input_lengths, masked_indices) = next(enumerate(train_loader))