We strongly encourage users to perform model training using our script `tissuenarrator/train.py`.  
However, a Jupyter notebook version is also provided for exploratory and demonstration purposes.


In [5]:
import os
import re
import json
import pandas as pd
from tqdm import tqdm
from datasets import Dataset, Features, Sequence, Value, load_from_disk
from unsloth import FastLanguageModel, UnslothTrainer, UnslothTrainingArguments
from transformers import TrainerCallback, TrainerControl, TrainerState

In [9]:
max_seq_length=32000
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-4B-Base",
    max_seq_length = max_seq_length,   # Context length - can be longer, but uses more memory
    load_in_4bit = False,     # 4bit uses much less memory
    load_in_8bit = False,     # A bit more accurate, uses 2x memory
    full_finetuning = False,  # We have full finetuning now!
)

==((====))==  Unsloth 2025.10.9: Fast Qwen3 patching. Transformers: 4.55.2. vLLM: 0.10.1.1.
   \\   /|    NVIDIA RTX A6000. Num GPUs = 1. Max memory: 47.428 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.1+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.1
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.31. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.41it/s]


In [10]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 32,           # Choose any number > 0! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 32,  # Best to choose alpha = rank or rank*2
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,   # We support rank stabilized LoRA
    loftq_config = None,  # And LoftQ
)

In [11]:
df = pd.read_parquet("/home/sizheliu/spatial-text/data/merfish/merfish_all_spatial_df.parquet").head(100)

In [None]:
COORD_RE = re.compile(r'(X|Y):\s*-?\d+(?:\.\d+)?')

def split_and_mask(
    text,
    tokenizer,
    max_seq_length=32000,
    overlap=2,          # number of sentences to repeat between chunks
    min_length=100        # minimum token count to keep a chunk (0 = keep all)
):
    sentences = re.findall(r"<pos>.*?</cs>", text, flags=re.DOTALL)

    results = []
    i = 0
    n = len(sentences)

    while i < n:
        start = max(0, i - overlap)
        current_chunk = []
        token_count = 0
        j = start

        while j < n:
            sent = sentences[j]
            tok_ids = tokenizer(sent, add_special_tokens=False)["input_ids"]
            new_count = token_count + len(tok_ids)
            if new_count > max_seq_length:
                break
            current_chunk.append(sent)
            token_count = new_count
            j += 1

        chunk_text = " ".join(current_chunk) if current_chunk else ""

        if chunk_text:
            enc_len = len(tokenizer(chunk_text, add_special_tokens=False)["input_ids"])
            if min_length == 0 or enc_len >= min_length:
                # mask XY numbers in labels (do not train to predict coords)
                enc = tokenizer(
                    chunk_text,
                    return_offsets_mapping=True,
                    add_special_tokens=True,
                    truncation=True,
                    max_length=max_seq_length,
                )
                input_ids = enc["input_ids"]
                offsets = enc["offset_mapping"]
                labels = input_ids.copy()

                # find char spans to mask
                mask_spans = []
                for m in COORD_RE.finditer(chunk_text):
                    full_start, full_end = m.span()
                    match_str = m.group()
                    # find where the number starts inside the match
                    num_start = match_str.find(":") + 2
                    # adjust absolute positions
                    span_start = full_start + num_start
                    span_end = full_end
                    mask_spans.append((span_start, span_end))
                    
                for k, (s_char, e_char) in enumerate(offsets):
                    for a, b in mask_spans:
                        if s_char >= a and e_char <= b:
                            labels[k] = -100
                            break


                results.append({
                    "input_ids": input_ids,
                    "labels": labels,
                    "attention_mask": [1] * len(input_ids),
                })

        if j >= n:
            i = n
        else:
            i = max(j - overlap + 1, 0)

    return results

In [13]:
all_records = []
for sent, split in tqdm(zip(df["sentence"], df["split"]), total=len(df), desc="Splitting & Masking"):
    chunks = split_and_mask(sent, tokenizer, max_seq_length=max_seq_length)
    for c in chunks:
        c["split"] = split
        all_records.append(c)

features = Features({
    "input_ids": Sequence(Value("int32")),
    "labels": Sequence(Value("int32")),
    "attention_mask": Sequence(Value("int8")),
    "split": Value("string"),
})

hf_dataset = Dataset.from_dict({
    "input_ids": [r["input_ids"] for r in all_records],
    "labels": [r["labels"] for r in all_records],
    "attention_mask": [r["attention_mask"] for r in all_records],
    "split": [r["split"] for r in all_records],
}, features=features)

Splitting & Masking: 100%|██████████| 100/100 [00:18<00:00,  5.41it/s]


In [14]:
train_dataset = hf_dataset.filter(lambda x: x["split"] == "train")

Filter: 100%|██████████| 142/142 [00:02<00:00, 49.28 examples/s]


In [16]:
from unsloth import UnslothTrainer, UnslothTrainingArguments

trainer = UnslothTrainer(
   model = model,
   tokenizer = tokenizer, 
   train_dataset = train_dataset,
   max_seq_length = max_seq_length,
   dataset_num_proc = 2,
   args = UnslothTrainingArguments(
       per_device_train_batch_size = 2,
       gradient_accumulation_steps = 8,
       warmup_ratio = 0.01,
       num_train_epochs = 1,
       # Select a 2 to 10x smaller learning rate for the embedding matrices!
       learning_rate = 2e-4, # Reduce to 2e-5 for long training runs
       logging_steps = 100,
       optim = "adamw_8bit",
       weight_decay = 0.01,
       lr_scheduler_type = "cosine",
       seed = 3407,
       output_dir = "./test_train",
       report_to = "none", # Use this for WandB etc
       # ⬇️ Checkpoint config
       save_strategy = "steps",
       save_steps = 500,
       save_total_limit = 50,
       logging_strategy = "steps",
   ),
)
stats = trainer.train()