Load all texts from Spanish Wikipedia

In [1]:
# load texts (all wikipedia articles in spanish) using pickle
import pickle

with open("wiki_texts_list.pkl", "rb") as f:
    texts = pickle.load(f)

Cut texts to the first 1000, so that the training does not take prohibitly long

In [2]:
texts = texts[:1000]

Chunk Wikipedia texts into chunks of sequence length 1024 for EBAE (using Tokenizer spezified in model_name) and the next sentences for EBAR and dump it to a pickle file for later use

In [None]:
from tqdm import tqdm


def create_ebae_ebar_chunks(texts, tokenizer, pre_seq_length=1000, train_seq_len=1024):
    """
    Creates chunks for EBAE and EBAR by separating prompts and next sentences.
    
    :param texts: List of texts to process.
    :param tokenizer: Tokenizer object for tokenizing the text.
    :param pre_seq_length: Maximum length of the prompt chunk before adding the next sentence.
    :param train_seq_len: Maximum sequence length for training.
    :return: Two lists: prompts (EBAE input) and next sentences (EBAR input).
    """
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})  # Add padding token
    prompts = []
    next_sentences = []
    
    for text in tqdm(texts):
        sentences = text.split(".")  # Split text into sentences
        input_buffer = []
        token_count = 0

        # Batch tokenize sentences
        batch_tokens = tokenizer(
            sentences,
            return_tensors="pt",
            truncation=True,
            padding=True
        )["input_ids"]

        for i, (tokens, sentence) in enumerate(zip(batch_tokens, sentences)):
            current_token_length = (tokens != tokenizer.pad_token_id).sum().item()  # Count non-padding tokens

            # Skip sentences that are too long individually
            if current_token_length > pre_seq_length:
                print("Skipping sentence as it exceeds seq_length")
                continue

            if token_count + current_token_length <= pre_seq_length:
                # Add the sentence if it fits
                input_buffer.append(sentence)
                token_count += current_token_length
            else:
                # Add the current chunk to the list and reset buffer
                if input_buffer and i < len(sentences) - 1:  # Ensure there's a next sentence
                    prompts.append(" ".join(input_buffer))
                    next_sentences.append(sentences[i])  # Use the next sentence for EBAR
                input_buffer = [sentence]
                token_count = current_token_length

        # Handle leftover sentences in the buffer
        if input_buffer and len(input_buffer) < len(sentences):
            prompts.append(" ".join(input_buffer))
            next_sentences.append(sentences[len(input_buffer)])  # Use the next available sentence

    # Validate all chunks and remove invalid pairs
    valid_prompts, valid_next_sentences = [], []
    for idx in range(len(prompts)):
        prompt_tokens = tokenizer(prompts[idx])["input_ids"]
        next_tokens = tokenizer(next_sentences[idx])["input_ids"]

        if len(prompt_tokens) <= train_seq_len - 30 and len(next_tokens) <= train_seq_len - 30:
            valid_prompts.append(prompts[idx])
            valid_next_sentences.append(next_sentences[idx])
        else:
            print(f"Pair {idx} is too long after processing (prompt: {len(prompt_tokens)} tokens, next: {len(next_tokens)} tokens).")

    return valid_prompts, valid_next_sentences


In [4]:
from transformers import AutoTokenizer

model_name = 'Qwen/Qwen2.5-0.5B-Instruct'

tokenizer = AutoTokenizer.from_pretrained(model_name)

# Generate sentence pairs
chunks, next_sentences = create_ebae_ebar_chunks(texts, tokenizer, pre_seq_length=900, train_seq_len=1024)
print(f"Number of chunks: {len(chunks)}")
print(f"Number of next sentences: {len(next_sentences)}")

# save chunks using pickle
import pickle

with open("wiki_chunks_list_ebae_ebar.pkl", "wb") as f:
    pickle.dump(chunks, f)

with open("wiki_next_sentences_list_ebae_ebar.pkl", "wb") as f:
    pickle.dump(next_sentences, f)

  from .autonotebook import tqdm as notebook_tqdm
  4%|▍         | 44/1000 [00:01<00:36, 26.26it/s]

Skipping sentence as it exceeds seq_length


 15%|█▍        | 146/1000 [00:03<00:15, 56.53it/s]

Skipping sentence as it exceeds seq_length


 16%|█▌        | 158/1000 [00:03<00:32, 26.19it/s]

Skipping sentence as it exceeds seq_length


 17%|█▋        | 167/1000 [00:04<00:38, 21.57it/s]

Skipping sentence as it exceeds seq_length


 67%|██████▋   | 671/1000 [00:15<00:07, 46.55it/s]

Skipping sentence as it exceeds seq_length


 70%|███████   | 704/1000 [00:16<00:06, 43.92it/s]

Skipping sentence as it exceeds seq_length


 77%|███████▋  | 769/1000 [00:17<00:05, 39.01it/s]

Skipping sentence as it exceeds seq_length


 83%|████████▎ | 834/1000 [00:19<00:04, 35.30it/s]

Skipping sentence as it exceeds seq_length


 92%|█████████▏| 917/1000 [00:20<00:01, 74.77it/s]

Skipping sentence as it exceeds seq_length


 94%|█████████▍| 939/1000 [00:21<00:01, 39.35it/s]

Skipping sentence as it exceeds seq_length


100%|██████████| 1000/1000 [00:22<00:00, 43.85it/s]


Skipping sentence as it exceeds seq_length
Error: Pair 453 is too long after processing (prompt: 996 tokens, next: 3 tokens).
Error: Pair 543 is too long after processing (prompt: 996 tokens, next: 4 tokens).
Error: Pair 1812 is too long after processing (prompt: 996 tokens, next: 13 tokens).
Error: Pair 1825 is too long after processing (prompt: 999 tokens, next: 6 tokens).
Error: Pair 1826 is too long after processing (prompt: 1000 tokens, next: 18 tokens).
Error: Pair 2058 is too long after processing (prompt: 1013 tokens, next: 16 tokens).
Error: Pair 2069 is too long after processing (prompt: 1029 tokens, next: 2 tokens).
Error: Pair 2765 is too long after processing (prompt: 1002 tokens, next: 7 tokens).
Error: Pair 3092 is too long after processing (prompt: 1015 tokens, next: 6 tokens).
Error: Pair 3093 is too long after processing (prompt: 1051 tokens, next: 6 tokens).
Error: Pair 3094 is too long after processing (prompt: 1068 tokens, next: 3 tokens).
Error: Pair 3095 is too l