# Data preparation for BERT pretraining (Wolof)

This notebook prepares a DataLoader for BERT pretraining (MLM + NSP) from a CSV file containing a text column.
Each step is separated into its own cell: data loading, pair creation for NSP, tokenization, masking for MLM, dataset and DataLoader.

## 1) Imports and configuration

Import dependencies and set global constants (e.g. CSV path, sequence length).

In [None]:
# Imports
import os
import random
import math
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast

# Configuration (change as needed)
CSV_PATH = 'wolof_data.csv'  # path to your CSV
TEXT_COL = 'texte_wolof'    # column name that contains text
PRE_TRAIN_SEQ_LEN = 128
BATCH_SIZE = 16
MASKED_LM_PROB = 0.15
MAX_PREDICTIONS_PER_SEQ = 20
TOKENIZER_NAME = 'bert-base-uncased'  # replace if you have a custom tokenizer


## 2) Load CSV and preview

Read the CSV file and inspect a quick preview of the data.

In [None]:
# Read CSV (modify path if needed)
df = pd.read_csv(CSV_PATH)
print('Total rows in CSV:', len(df))
display(df.head())

## 3) Prepare sentence list

Extract the text column and clean (drop empty rows). Each line is treated as a candidate sentence here; split further if your rows contain multiple sentences.

In [None]:
texts = df[TEXT_COL].dropna().astype(str).tolist()
sentences = [t.strip() for t in texts if t.strip()]
print(f'Prepared {len(sentences)} candidate sentences')

## 4) Create pairs for NSP (Next Sentence Prediction)

Create pairs (A, B) where B is often the true next sentence (is_next=1) and sometimes a random sentence (is_next=0).

In [None]:
def create_sentence_pairs(sentences, dup_ratio=0.5):
    pairs = []
    for i in range(len(sentences)-1):
        if random.random() < dup_ratio:
            a = sentences[i]
            b = sentences[i+1]
            is_next = 1
        else:
            a = sentences[i]
            b = random.choice(sentences)
            is_next = 0
        pairs.append((a, b, is_next))
    return pairs

pairs = create_sentence_pairs(sentences)
print(f'Created {len(pairs)} sentence pairs')
print('Examples:', pairs[:3])

## 5) Tokenizer and parameters

Initialize the BERT tokenizer (or your trained tokenizer) and set special token IDs.

In [None]:
tokenizer = BertTokenizerFast.from_pretrained(TOKENIZER_NAME, use_fast=True)
pad_id = tokenizer.pad_token_id or 0
mask_id = tokenizer.mask_token_id
cls_id = tokenizer.cls_token_id
sep_id = tokenizer.sep_token_id
print('Vocab size:', tokenizer.vocab_size)


## 6) Encoding pairs (helper)

Tokenize A and B without special tokens then build [CLS] A [SEP] B [SEP] and the segment ids.

In [None]:
def encode_pair(tokenizer, a, b, max_len):
    a_ids = tokenizer.encode(a, add_special_tokens=False)
    b_ids = tokenizer.encode(b, add_special_tokens=False)
    max_total = max_len - 3  # reserve [CLS],[SEP],[SEP]
    while len(a_ids) + len(b_ids) > max_total:
        if len(a_ids) > len(b_ids):
            a_ids.pop()
        else:
            b_ids.pop()
    input_ids = [cls_id] + a_ids + [sep_id] + b_ids + [sep_id]
    seg_zero_len = 1 + len(a_ids) + 1
    seg_ids = [0 if i < seg_zero_len else 1 for i in range(len(input_ids))]
    return input_ids, seg_ids


## 7) Masking for MLM (BERT style)

Create `masked_lm_labels` and apply the 80/10/10 rule (mask/random/original). Here we use 0 as the `ignore_index` value (you may change this to -100).

In [None]:
def create_masked_lm_labels(input_ids, tokenizer, masked_lm_prob=0.15, max_predictions_per_seq=20):
    cand_indexes = []
    special_ids = {cls_id, sep_id, pad_id}
    for i, token in enumerate(input_ids):
        if token in special_ids:
            continue
        cand_indexes.append(i)
    num_to_mask = min(max_predictions_per_seq, max(1, int(round(len(input_ids) * masked_lm_prob))))
    masked_positions = random.sample(cand_indexes, num_to_mask) if len(cand_indexes) >= num_to_mask else cand_indexes
    labels = [0] * len(input_ids)  # 0 = ignore (trainer uses ignore_index=0)
    for pos in masked_positions:
        orig = input_ids[pos]
        prob = random.random()
        if prob < 0.8:
            input_ids[pos] = mask_id
        elif prob < 0.9:
            input_ids[pos] = random.randrange(tokenizer.vocab_size)
        else:
            input_ids[pos] = orig
        labels[pos] = orig
    return input_ids, labels


## 8) PyTorch Dataset for pretraining

Dataset class that returns a dict with `input_ids`, `segment_labels`, `is_next`, `masked_lm_labels`.

In [None]:
class BertPretrainingDataset(Dataset):
    def __init__(self, pairs, tokenizer, max_len=128, masked_lm_prob=0.15, max_predictions=20):
        self.pairs = pairs
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.masked_lm_prob = masked_lm_prob
        self.max_predictions = max_predictions

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        a, b, is_next = self.pairs[idx]
        input_ids, seg_ids = encode_pair(self.tokenizer, a, b, self.max_len)
        input_ids_masked, mlm_labels = create_masked_lm_labels(input_ids.copy(), self.tokenizer, self.masked_lm_prob, self.max_predictions)
        return {
            'input_ids': torch.tensor(input_ids_masked, dtype=torch.long),
            'segment_labels': torch.tensor(seg_ids, dtype=torch.long),
            'is_next': torch.tensor(is_next, dtype=torch.long),
            'masked_lm_labels': torch.tensor(mlm_labels, dtype=torch.long),
        }


## 9) Collate function (padding)

Custom collate function to pad tensors within the batch. We use 0 as the `ignore` value for `masked_lm_labels` (adjust if you change `ignore_index`).

In [None]:
from typing import List
def collate_fn(batch: List[dict]):
    input_ids = [item['input_ids'] for item in batch]
    segs = [item['segment_labels'] for item in batch]
    is_next = torch.stack([item['is_next'] for item in batch])
    mlm_labels = [item['masked_lm_labels'] for item in batch]
    padded_input = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=pad_id)
    padded_segs = torch.nn.utils.rnn.pad_sequence(segs, batch_first=True, padding_value=0)
    padded_mlm_labels = torch.nn.utils.rnn.pad_sequence(mlm_labels, batch_first=True, padding_value=0)
    return {
        'input_ids': padded_input,
        'segment_labels': padded_segs,
        'is_next': is_next,
        'masked_lm_labels': padded_mlm_labels,
    }


## 10) Build Dataset and DataLoader

Create the PyTorch dataset, DataLoader and test one batch to verify shapes.

In [None]:
dataset = BertPretrainingDataset(pairs, tokenizer, max_len=PRE_TRAIN_SEQ_LEN,
                                 masked_lm_prob=MASKED_LM_PROB, max_predictions=MAX_PREDICTIONS_PER_SEQ)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

# Test one batch
batch = next(iter(dataloader))
print({k: v.shape for k, v in batch.items()})

## 11) Notes / Next steps

- If you want to follow the HuggingFace convention, use `-100` as the `ignore_index` value for `masked_lm_labels` and update `bert/train.py` to `nn.CrossEntropyLoss(ignore_index=-100)`.
- To train: instantiate a pretraining model (e.g. your repo's BERT pretraining model), move it to `device`, then use the trainer with the created `dataloader`.
- For a large corpus: consider serializing tokenized pairs to speed up epochs.