Load all texts from Spanish Wikipedia

In [1]:
# load texts (all wikipedia articles in spanish) using pickle
import pickle

with open("wiki_texts_list.pkl", "rb") as f:
    texts = pickle.load(f)

Cut texts to the first 1000, so that the training does not take prohibitly long

In [2]:
texts = texts[:1000]

Chunk Wikipedia texts into chunks of sequence length 1024 for EBAE (using Tokenizer spezified in model_name) and the next sentences for EBAR and dump it to a pickle file for later use

In [3]:
from tqdm import tqdm


def create_ebae_ebar_chunks(texts, tokenizer, pre_seq_length=1000, train_seq_len=1024):
    """
    Creates chunks for EBAE and EBAR by separating prompts and next sentences.
    
    :param texts: List of texts to process.
    :param tokenizer: Tokenizer object for tokenizing the text.
    :param pre_seq_length: Maximum length of the prompt chunk before adding the next sentence.
    :param train_seq_len: Maximum sequence length for training.
    :return: Two lists: prompts (EBAE input) and next sentences (EBAR input).
    """
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})  # Add padding token
    prompts = []
    next_sentences = []
    
    for text in tqdm(texts):
        sentences = text.split(".")  # Split text into sentences
        input_buffer = []
        token_count = 0

        # Batch tokenize sentences
        batch_tokens = tokenizer(
            sentences,
            return_tensors="pt",
            truncation=True,
            padding=True
        )["input_ids"]

        for i, (tokens, sentence) in enumerate(zip(batch_tokens, sentences)):
            current_token_length = (tokens != tokenizer.pad_token_id).sum().item()  # Count non-padding tokens

            # Skip sentences that are too long individually
            if current_token_length > pre_seq_length:
                print("Skipping sentence as it exceeds seq_length")
                continue

            if token_count + current_token_length <= pre_seq_length:
                # Add the sentence if it fits
                input_buffer.append(sentence)
                token_count += current_token_length
            else:
                # Add the current chunk to the list and reset buffer
                if input_buffer and i < len(sentences) - 1:  # Ensure there's a next sentence
                    prompts.append(" ".join(input_buffer))
                    next_sentences.append(sentences[i])  # Use the next sentence for EBAR
                input_buffer = [sentence]
                token_count = current_token_length

        # Handle leftover sentences in the buffer
        if input_buffer and len(input_buffer) < len(sentences):
            prompts.append(" ".join(input_buffer))
            next_sentences.append(sentences[len(input_buffer)])  # Use the next available sentence

    # Validate all chunks and remove invalid pairs
    valid_prompts, valid_next_sentences = [], []
    for idx in range(len(prompts)):
        prompt_tokens = tokenizer(prompts[idx])["input_ids"]
        next_tokens = tokenizer(next_sentences[idx])["input_ids"]

        if len(prompt_tokens) <= train_seq_len - 30 and len(next_tokens) <= train_seq_len - 30:
            valid_prompts.append(prompts[idx])
            valid_next_sentences.append(next_sentences[idx])
            print(f"Pair {idx} (prompt: {len(prompt_tokens)} tokens, next: {len(next_tokens)} tokens).")
        else:
            print(f"Pair {idx} is too long after processing (prompt: {len(prompt_tokens)} tokens, next: {len(next_tokens)} tokens).")

    return valid_prompts, valid_next_sentences


In [4]:
from transformers import AutoTokenizer

model_name = 'Qwen/Qwen2.5-0.5B-Instruct'

tokenizer = AutoTokenizer.from_pretrained(model_name)

seq_len = 1024

# Generate sentence pairs
chunks, next_sentences = create_ebae_ebar_chunks(texts, tokenizer, pre_seq_length=seq_len-100, train_seq_len=seq_len)
print(f"Number of chunks: {len(chunks)}")
print(f"Number of next sentences: {len(next_sentences)}")

# save chunks using pickle
import pickle

with open("wiki_chunks_list_ebae_ebar.pkl", "wb") as f:
    pickle.dump(chunks, f)

with open("wiki_next_sentences_list_ebae_ebar.pkl", "wb") as f:
    pickle.dump(next_sentences, f)

  from .autonotebook import tqdm as notebook_tqdm
  4%|▍         | 44/1000 [00:01<00:34, 27.82it/s]

Skipping sentence as it exceeds seq_length


 15%|█▌        | 152/1000 [00:03<00:32, 26.28it/s]

Skipping sentence as it exceeds seq_length


 16%|█▌        | 157/1000 [00:03<00:32, 25.77it/s]

Skipping sentence as it exceeds seq_length


 17%|█▋        | 171/1000 [00:04<00:33, 24.67it/s]

Skipping sentence as it exceeds seq_length


 67%|██████▋   | 669/1000 [00:16<00:07, 45.03it/s]

Skipping sentence as it exceeds seq_length


 70%|███████   | 700/1000 [00:16<00:07, 39.31it/s]

Skipping sentence as it exceeds seq_length


 77%|███████▋  | 772/1000 [00:18<00:06, 36.26it/s]

Skipping sentence as it exceeds seq_length


 83%|████████▎ | 834/1000 [00:20<00:05, 32.48it/s]

Skipping sentence as it exceeds seq_length


 92%|█████████▏| 915/1000 [00:21<00:01, 71.35it/s]

Skipping sentence as it exceeds seq_length


 94%|█████████▍| 942/1000 [00:22<00:01, 42.71it/s]

Skipping sentence as it exceeds seq_length


100%|██████████| 1000/1000 [00:24<00:00, 41.47it/s]


Skipping sentence as it exceeds seq_length
Pair 0 (prompt: 889 tokens, next: 47 tokens).
Pair 1 (prompt: 927 tokens, next: 60 tokens).
Pair 2 (prompt: 877 tokens, next: 231 tokens).
Pair 3 (prompt: 789 tokens, next: 144 tokens).
Pair 4 (prompt: 905 tokens, next: 55 tokens).
Pair 5 (prompt: 877 tokens, next: 167 tokens).
Pair 6 (prompt: 914 tokens, next: 20 tokens).
Pair 7 (prompt: 922 tokens, next: 28 tokens).
Pair 8 (prompt: 899 tokens, next: 142 tokens).
Pair 9 (prompt: 929 tokens, next: 57 tokens).
Pair 10 (prompt: 916 tokens, next: 37 tokens).
Pair 11 (prompt: 876 tokens, next: 66 tokens).
Pair 12 (prompt: 906 tokens, next: 55 tokens).
Pair 13 (prompt: 886 tokens, next: 75 tokens).
Pair 14 (prompt: 911 tokens, next: 36 tokens).
Pair 15 (prompt: 899 tokens, next: 37 tokens).
Pair 16 (prompt: 923 tokens, next: 41 tokens).
Pair 17 (prompt: 899 tokens, next: 70 tokens).
Pair 18 (prompt: 284 tokens, next: 1 tokens).
Pair 19 (prompt: 907 tokens, next: 38 tokens).
Pair 20 (prompt: 910 tok