I created MarkedUpTextChunks quotes for all Rashi on Tanakh
So now I want to create a dataset for labse-cnn quotes segmentation model
Also, to train the model to recognize "Yeshibish" language, I composed syntetic dataset from Talmud and Tosafot texts in hebrew and english.

In [None]:
import pickle
from typing import List, Tuple, Dict, Any
from tqdm import tqdm
import random
from tokenizers import Tokenizer

In [None]:
TOKENIZER = Tokenizer.from_file('labse_tokenizer.json')

In [30]:
with open("tuples.pkl", "rb") as f:
    raw_sents = pickle.load(f)

In [None]:
def sample_span(ids: List[int], min_len=3, max_len=12) -> Tuple[int,int]:
    if not ids:
        return 0, 0
    if len(ids) <= min_len:
        return 0, len(ids)
    L = random.randint(min_len, min(max_len, len(ids)))
    s = random.randint(0, len(ids) - L)
    return s, s + L

def gather_background_ids(raw_sent: List[List[str]], min_tokens: int) -> List[int]:
    """Concatenate random sentence variants until tokenized length >= min_tokens."""
    acc: List[str] = []
    ids: List[int] = []
    while len(ids) < min_tokens:
        group   = random.choice(raw_sent)      # pick a parallel sentence set
        variant = random.choice(group) or ""   # any language variant allowed
        if not variant.strip():
            continue
        acc.append(variant.strip())
        ids = TOKENIZER.encode(" ".join(acc), add_special_tokens=False).ids
    return ids

def insert_tokens(bg_ids: List[int], span_ids: List[int]) -> Tuple[List[int], int]:
    if not bg_ids:
        return span_ids.copy(), 0
    pos = random.randint(0, len(bg_ids))
    out = bg_ids[:pos] + span_ids + bg_ids[pos:]
    return out, pos

def make_dataset_from_raw_sent(
    raw_sent: List[List[str]],
    n_samples: int = 20000,
    max_query_len: int = 128,
    max_target_len: int = 480,
    ensure_bg_tokens: int = 200,
    neg_ratio: float = 0.2,
    seed: int = 42,
    per_query_multiplier: int = 1,     
    targets_hebrew_only: bool = False
) -> List[Dict[str, Any]]:
    """
    Build synthetic (query, target, mask) triples:
      - query = full sentence (any language) from raw_sent
      - choose a random subspan of the query tokens
      - build a random background (any language mix by default) and insert the span (unless negative)
    """
    random.seed(seed)
    out: List[Dict[str, Any]] = []
    if not raw_sent:
        return out

    for _ in tqdm(range(n_samples)):
        group = random.choice(raw_sent)
        if not group:
            continue

        # pick ANY language sentence as query
        q_text = random.choice(group)
        if not q_text or not q_text.strip():
            continue

        q_ids: List[int] = TOKENIZER.encode(q_text, add_special_tokens=False).ids[:max_query_len]
        if not q_ids:
            continue

        for _rep in range(per_query_multiplier):
            # background ids
            if targets_hebrew_only:
                # concatenate only Hebrew variants (index 0)
                acc, ids = [], []
                while len(ids) < max(ensure_bg_tokens, len(q_ids)*3):
                    g = random.choice(raw_sent)
                    he = (g[0] if g and len(g) > 0 else "").strip()
                    if not he:
                        continue
                    acc.append(he)
                    ids = TOKENIZER.encode(" ".join(acc), add_special_tokens=False).ids
                bg_ids = ids
            else:
                bg_ids = gather_background_ids(raw_sent, min_tokens=max(ensure_bg_tokens, len(q_ids)*3))

            # quote span (from the query tokens)
            s, e = sample_span(q_ids, min_len=3, max_len=12)
            span_ids = q_ids[s:e] if e > s else q_ids[:min(6, len(q_ids))]

            # negatives some of the time
            is_negative = (random.random() < neg_ratio or not span_ids)

            if is_negative:
                tgt_ids = bg_ids[:max_target_len]
                mask = [0] * len(tgt_ids)
            else:
                tgt_all, insert_pos = insert_tokens(bg_ids, span_ids)
                if len(tgt_all) > max_target_len:
                    tgt_ids = tgt_all[:max_target_len]
                    span_len   = len(span_ids)
                    mask_start = min(insert_pos, max_target_len)
                    mask_end   = min(insert_pos + span_len, max_target_len)
                else:
                    tgt_ids = tgt_all
                    mask_start = insert_pos
                    mask_end   = insert_pos + len(span_ids)

                mask = [0] * len(tgt_ids)
                for i in range(mask_start, mask_end):
                    if 0 <= i < len(mask):
                        mask[i] = 1

            out.append({
                "query_tokenized": q_ids,
                "target_tokenized": tgt_ids,
                "target_mask": mask
            })

    return out

In [36]:
dataset = make_dataset_from_raw_sent(
    raw_sents,
    n_samples=50001,
    max_query_len=512,
    max_target_len=512,
    ensure_bg_tokens=200,
    neg_ratio=0.3,
    seed=123,
    per_query_multiplier=3,
    targets_hebrew_only=False
)

len(dataset), dataset[0]

100%|██████████| 50001/50001 [04:31<00:00, 184.15it/s] 


(150000,
 {'query_tokenized': [18225,
   39706,
   2976,
   189,
   15096,
   85687,
   208251,
   14978,
   15127,
   17115,
   14999,
   424228,
   73038,
   15121,
   15179,
   15384,
   19097,
   27467,
   49504,
   111527,
   16652,
   41781,
   19559,
   15384,
   102383,
   111527,
   16652,
   15424,
   117,
   15595,
   16068,
   252273,
   16143,
   16207,
   15002,
   179469,
   119],
  'target_tokenized': [17192,
   14985,
   133670,
   14978,
   117,
   19576,
   14986,
   15100,
   15179,
   34908,
   119,
   15385,
   118,
   42574,
   14986,
   15121,
   170,
   158122,
   188838,
   22684,
   15131,
   170,
   153027,
   15595,
   27067,
   400431,
   15221,
   118,
   230098,
   15294,
   391170,
   355774,
   15058,
   17192,
   456093,
   117,
   208772,
   57459,
   15002,
   15713,
   14999,
   424228,
   73038,
   15121,
   15179,
   15384,
   19097,
   27467,
   101700,
   119,
   17192,
   15294,
   16651,
   15100,
   400431,
   14986,
   15308,
   14997,
   1

In [38]:
with open("dataset-tanakh-rashi-labse-tokenized.jsonl", "r") as f:
    loaded_data = []
    for i, doc in enumerate(f):
        if i == 50000: break
        loaded_data.append(json.loads(doc))

loaded_data[0]

{'query_tokenized': [138,
  15160,
  48350,
  15202,
  144806,
  23797,
  20800,
  91999,
  45741,
  19143,
  26567,
  18491,
  50962,
  17728,
  15897,
  195,
  99550,
  155527,
  283209,
  19143,
  434340,
  15897,
  15214,
  50896,
  170478,
  14991,
  15214,
  99550,
  27579,
  16163,
  178,
  175593,
  29381,
  18902,
  15214,
  99550,
  33920,
  286186,
  178,
  21664,
  245355,
  379,
  224346,
  500744,
  15105,
  15214,
  99550,
  29661,
  83016,
  15160,
  499627,
  18588,
  17728,
  170,
  14977,
  29661,
  183476,
  85433,
  15160,
  159281,
  171717,
  352365],
 'target_tokenized': [919,
  211283,
  48637,
  31993,
  16153,
  926,
  116109,
  18721,
  48824,
  78521,
  15578,
  919,
  85778,
  43015,
  923,
  382504,
  85845,
  23552,
  27972,
  459698,
  211283,
  15578,
  24737,
  18240,
  33641,
  17233,
  42408,
  240460,
  255988,
  33270],
 'target_mask': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  1,
  1,
  1,
  1,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,

In [41]:
dataset.extend(loaded_data)
len(dataset)

200000

In [43]:
print(dataset[0])
random.shuffle(dataset)
dataset[0]


{'query_tokenized': [332915, 118, 939, 44358, 156681, 939, 15600, 117040, 317370, 15769, 33552, 939, 246039, 45130, 320554, 15769, 119], 'target_tokenized': [22810, 14985, 26160, 15121, 14985, 189386, 22426, 14986, 30479, 107, 15071, 112, 64174, 18227, 289016, 14978, 107, 14986, 21985, 153767, 15121, 112, 14985, 63730, 15053, 22954, 15096, 170, 289016, 14978, 117, 14999, 15384, 61448, 15096, 170, 112, 289016, 14978, 112, 119, 119, 119, 17192, 360536, 15015, 170, 33408, 117, 16125, 15751, 39267, 15015, 14985, 79498, 127414, 14981, 17054, 46941, 117, 14999, 16290, 15179, 424526, 16068, 15179, 62528, 15294, 15595, 14986, 14985, 16122, 117, 15751, 99061, 117, 25549, 15751, 33608, 180756, 15438, 15165, 112, 16013, 29992, 15294, 14985, 71158, 14997, 14985, 463551, 17697, 15235, 119, 19895, 15036, 499562, 310378, 21881, 117, 18557, 15017, 16728, 166656, 92806, 15000, 15050, 142296, 117, 16728, 166656, 92806, 15000, 15050, 28701, 123899, 15000, 15899, 117, 18426, 15006, 24064, 166656, 138820, 

{'query_tokenized': [35379,
  916,
  211327,
  20259,
  937,
  163588,
  31520,
  15600,
  136204,
  15578,
  302010,
  926,
  108654,
  365776,
  55514,
  379816,
  108654,
  35379,
  931,
  15600,
  55514,
  74624,
  408773,
  20557,
  55514,
  165316,
  143486,
  136204,
  15578,
  302010,
  74624,
  408773,
  35379,
  931,
  15600,
  55514,
  926,
  108654],
 'target_tokenized': [20557,
  35930,
  939,
  375957,
  918,
  391272,
  15578,
  930,
  28621,
  236697,
  15001,
  41469,
  17417,
  133670,
  2769,
  268011,
  16446,
  408112,
  30946,
  22738,
  131,
  15088,
  39205,
  14997,
  133670,
  24096,
  15011,
  32250,
  14981,
  170,
  407621,
  21065,
  29455,
  15222,
  15001,
  471073,
  235947,
  61436,
  15058,
  29455,
  117,
  14999,
  14985,
  39205,
  14997,
  14985,
  133670,
  20557,
  55514,
  165316,
  143486,
  136204,
  15578,
  302010,
  74624,
  14978,
  14981,
  99555,
  119,
  919,
  172742,
  39699,
  31520,
  915,
  117040,
  134304,
  33331,
  210256,
  5

And i saved dataset into jsonl

In [44]:
with open("dataset-yeshibish-labse.jsonl", "w", encoding="utf-8") as f:
    for sample in dataset:
        f.write(json.dumps(sample, ensure_ascii=False) + "\n")

The data set composed of tokenized sents because if after training the model we want to save it in onnx format for fast cold relise on server, so we don't need to import torch and transformers, which takes a lot of time, and not to keep it warm so it takes memory on server, we need to keep tokenizer outside the model, since we can save it in json file and load in python with tokenizers.Tokenizer in milliseconds, the same with onnx model

In [7]:
TOKENIZER.backend_tokenizer.save("xlm-roberta-tokenizer.json")