In [1]:
# !python -m spacy download fr_core_news_sm
# !python -m spacy download en_core_web_sm

import pandas as pd
from sklearn.model_selection import train_test_split
from datasets import load_dataset, Dataset, DatasetDict
import spacy
# import gc

In [2]:
data_files = {'validation': 'fr-en/validation-00000-of-00001.parquet'}
dataset = load_dataset(path='wmt/wmt14', trust_remote_code=True, data_files=data_files)
data = pd.DataFrame(dataset['validation'])

train, temp = train_test_split(data, test_size=0.4, random_state=0)
test, validation = train_test_split(temp, test_size=0.5, random_state=0)

def process_translations(df):
    en_texts = [item['en'] for item in df['translation']]
    fr_texts = [item['fr'] for item in df['translation']]
    
    return pd.DataFrame({
        'en': en_texts,
        'fr': fr_texts
    })

train_processed = process_translations(train)
test_processed = process_translations(test)
validation_processed = process_translations(validation)

train_dataset = Dataset.from_pandas(train_processed.reset_index(drop=True))
test_dataset = Dataset.from_pandas(test_processed.reset_index(drop=True))
validation_dataset = Dataset.from_pandas(validation_processed.reset_index(drop=True))

ds = DatasetDict({
    'train': train_dataset,
    'test': test_dataset,
    'validation': validation_dataset
})

ds

DatasetDict({
    train: Dataset({
        features: ['en', 'fr'],
        num_rows: 1800
    })
    test: Dataset({
        features: ['en', 'fr'],
        num_rows: 600
    })
    validation: Dataset({
        features: ['en', 'fr'],
        num_rows: 600
    })
})

In [3]:
# del dataset, data
# del train, temp, test, validation
# del process_translations, train_processed, test_processed, validation_dataset
# gc.collect()

In [4]:
en_nlp = spacy.load('en_core_web_sm')
fr_nlp = spacy.load('fr_core_news_sm')

def tokenize_example(example, en_nlp, fr_nlp, max_length, sos_token, eos_token):
    en_tokens = [token.text.lower() for token in en_nlp.tokenizer(example['en'])][:max_length]
    fr_tokens = [token.text.lower() for token in fr_nlp.tokenizer(example['fr'])][:max_length]

    en_tokens = [sos_token] + en_tokens + [eos_token]
    fr_tokens = [sos_token] + fr_tokens + [eos_token]

    return {'en_tokens': en_tokens, 'fr_tokens': fr_tokens} 

    
max_length = 1000
sos_token = '<sos>'
eos_token = '<eos>'

fn_kwargs = {
    'en_nlp': en_nlp,
    'fr_nlp': fr_nlp,
    'max_length': max_length,
    'sos_token': sos_token,
    'eos_token': eos_token,
}

train_data, test_data, validation_data = (
    ds['train'],
    ds['test'],
    ds['validation'],
)

train_data = train_data.map(tokenize_example, fn_kwargs=fn_kwargs)
test_data = test_data.map(tokenize_example, fn_kwargs=fn_kwargs)
validation_data = validation_data.map(tokenize_example, fn_kwargs=fn_kwargs)

Map:   0%|          | 0/1800 [00:00<?, ? examples/s]

Map:   0%|          | 0/600 [00:00<?, ? examples/s]

Map:   0%|          | 0/600 [00:00<?, ? examples/s]

In [5]:
from collections import Counter

def lang_str_int(lang, nlp):
    lang_vocab = []
    special_vocab = ['<unk>', '<pad>', '<sos>', '<eos>'] 

    flattened_list = [token.text.lower() for sentence in lang for token in nlp.tokenizer(sentence)]
    lang_count = Counter(flattened_list)
    lang_words = [string for string, freq in lang_count.items() if freq >= 2]

    lang_vocab = special_vocab + lang_words
    # lang_vocab.extend(special_vocab)
    # lang_vocab.extend(lang_words)

    lang_str2int = {ch: i for i, ch in enumerate(lang_vocab)}
    lang_int2str = {i: ch for i, ch in enumerate(lang_vocab)}

    return lang_str2int, lang_int2str

en = process_translations(data)['en'].tolist()
fr = process_translations(data)['fr'].tolist()

fr_str2int, fr_int2str = lang_str_int(fr, fr_nlp)
en_str2int, en_int2str = lang_str_int(en, en_nlp)

In [6]:
import torch
import numpy as np

def token_to_int(example, str2int):
    hh = [str2int.get(token, str2int['<unk>']) for token in example]
    return hh

def tokens_to_ids(example):
    example['en_ids'] = token_to_int(example['en_tokens'], en_str2int)
    example['fr_ids'] = token_to_int(example['fr_tokens'], fr_str2int)
    return example

train_data = train_data.map(tokens_to_ids)
test_data = test_data.map(tokens_to_ids)
validation_data = validation_data.map(tokens_to_ids)

Map:   0%|          | 0/1800 [00:00<?, ? examples/s]

Map:   0%|          | 0/600 [00:00<?, ? examples/s]

Map:   0%|          | 0/600 [00:00<?, ? examples/s]

In [7]:
# def cast_torch(example):
#     example['en_ids'] = torch.tensor(example['en_ids'], dtype=torch.long)
#     return example

# train_data.map(cast_torch)