# Setup environment

In [None]:
!git clone https://github.com/MLP-Lab/KORMo-tutorial.git
!cd KORMo-tutorial & bash setup/create_uv_venv.sh
!source .venv_kormo/bin/activate

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'

# 1. Preparing the Dataset for Pretraining

This section builds the dataset pipeline for language model pretraining.

- 1. Load the raw dataset
- 2. Tokenize the text
- 3. Pack sequences
- 4. Data collator
- 5. Verify the processed data with a DataLoader

### Raw text $\to$ Tokenize $\to$ Sequence packing $\to$ Data Collation

In [None]:
# Load the raw dataset from 'KORMo-Team/KORMo-tutorial-datasets'.
import datasets 

dataset_repo_id = 'KORMo-Team/KORMo-tutorial-datasets'

pt_dataset = datasets.load_dataset(
    dataset_repo_id, 
    name='pretrain', 
    split='train'
)

train_ds = pt_dataset.shuffle(seed=42)
print(train_ds)

In [None]:
# Tokenize dataset
from transformers import AutoTokenizer 

tokenizer_repo_id = 'KORMo-Team/KORMo-tokenizer'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo_id)

In [None]:
# Tokenize all dataset

def _tokenize(examples, tokenizer):
    input_ids = []
    for text in examples['text']:
        input_ids.append(tokenizer.encode(text) + [tokenizer.eos_token_id])
    return{
        'input_ids': input_ids
    }

tokenized_ds = train_ds.map(
    _tokenize, 
    batched=True, 
    num_proc=48,
    remove_columns=train_ds.column_names,
    fn_kwargs={'tokenizer': tokenizer},
)
print(tokenized_ds)

In [None]:
print(tokenizer.decode(tokenized_ds[3]['input_ids']))

In [None]:
from itertools import chain

def _pack_dataset(examples, seq_len):
    flat = list(chain.from_iterable(examples["input_ids"]))
    n_full  = len(flat) // seq_len
    chunks  = [flat[i*seq_len:(i+1)*seq_len] for i in range(n_full)]

    return {"input_ids": chunks}

def pack_dataset(ds, seq_len):
    return ds.map(
        _pack_dataset, 
        batched=True, 
        batch_size=100_000, 
        remove_columns=ds.column_names, 
        num_proc=128,
        fn_kwargs={'seq_len': seq_len}
    )

packed_ds = pack_dataset(tokenized_ds, 4096)
packed_ds.set_format('torch')
print(packed_ds)

In [None]:
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
import torch

K = 1024

@dataclass
class DataCollatorForCausalLM:
    tokenizer: PreTrainedTokenizer

    def __call__(self, instances):
        input_ids = [instance["input_ids"][:4*K] for instance in instances]
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )

        labels = input_ids.clone()
        labels[labels == self.tokenizer.pad_token_id] = -100

        return dict(
            input_ids=input_ids,
            labels=labels,
        )

In [None]:
from torch.utils.data import DataLoader

collator = DataCollatorForCausalLM(tokenizer)
data_loader = DataLoader(packed_ds, collate_fn=collator, batch_size=4)
next(iter(data_loader))

# 2. Build Intra-document Attention Mask (Using Flex-attention)

![attention_mask.png](./attachment/attention_mask.png)

In [None]:
from torch.nn.attention.flex_attention import create_block_mask, and_masks
import torch
input_ids = packed_ds[0]['input_ids']
input_ids_2d = input_ids.unsqueeze(0).to('cuda')

def _intra_doc_mask(input_ids, bos_token_id):
    is_bos = (input_ids == bos_token_id)
    is_bos_flat = is_bos.flatten()
    flat_doc_ids = torch.cumsum(is_bos_flat.long(), 0)
    doc_ids = flat_doc_ids.view_as(input_ids)
    
    def intra_doc_mask(b, h, q_idx, kv_idx):
        same_doc = doc_ids[b, q_idx] == doc_ids[b, kv_idx]
        return same_doc
    
    def causal_mask(b, h, q_idx, kv_idx):
        return q_idx >= kv_idx

    return and_masks(intra_doc_mask, causal_mask)

def create_intra_doc_mask(input_ids, tokenizer):
    model_bos_token_id = tokenizer.bos_token_id
    B, Q_LEN = input_ids.shape
    H = None 
    KV_LEN = Q_LEN

    mask_mod_func = _intra_doc_mask(input_ids.to('cuda'), model_bos_token_id)

    block_mask = create_block_mask(
        mask_mod=mask_mod_func,
        B=B,
        H=H,
        Q_LEN=Q_LEN,
        KV_LEN=KV_LEN
    )
    return block_mask

print(create_intra_doc_mask(input_ids_2d, tokenizer))

In [None]:
from dataclasses import dataclass
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedTokenizer

@dataclass
class DataCollatorIntraDocMask:
    tokenizer: PreTrainedTokenizer

    def __call__(self, instances):
        input_ids = [instance["input_ids"][:4*K] for instance in instances]
        input_ids = pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )

        labels = input_ids.clone()
        labels[labels == self.tokenizer.pad_token_id] = -100

        block_mask = create_intra_doc_mask(input_ids, self.tokenizer)

        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=block_mask,
        )


In [None]:
collator = DataCollatorIntraDocMask(tokenizer)
data_loader = DataLoader(packed_ds, collate_fn=collator, batch_size=2)
batch = next(iter(data_loader))
print("Visualize Attention Masks\n", batch['attention_mask'])

# 3. Pretraining Setup with KORMoTrainer

In [None]:
from kormo.train.arguments import KORMoTrainingArguments
from kormo.train.trainer import KORMoTrainer
from kormo.modeling_configs.load_model import load_model_from_config

model, _ = load_model_from_config('1B', _attn_implementation='flex_attention')
model.to('cuda')
print("Attention implementation: ", model.config._attn_implementation)

In [None]:
training_arguments = KORMoTrainingArguments(
    output_dir='./kormo-1B-PT',
    per_device_train_batch_size=4,
    lr_scheduler_type='linear',
    logging_steps=10,
    save_strategy='epoch',
)

trainer = KORMoTrainer(
    model=model,
    args=training_arguments,
    train_dataset=packed_ds,
    processing_class=tokenizer,
    data_collator=DataCollatorIntraDocMask(tokenizer),
)

In [None]:
trainer.train()