# Setup environment

In [None]:
!git clone https://github.com/MLP-Lab/KORMo-tutorial.git
!cd KORMo-tutorial & bash setup/create_uv_venv.sh
!source .venv_kormo/bin/activate

In [1]:
import os
import sys
os.environ["CUDA_VISIBLE_DEVICES"]='0'
sys.path.append('/content/KORMo-tutorial/src')

# 1. Preparing the Dataset for Pretraining

This section builds the dataset pipeline for language model pretraining.

- 1. Load the raw dataset
- 2. Tokenize the text
- 3. Pack sequences
- 4. Data collator
- 5. Verify the processed data with a DataLoader

### Raw text $\to$ Tokenize $\to$ Sequence packing $\to$ Data Collation

In [2]:
# Load the raw dataset from 'KORMo-Team/KORMo-tutorial-datasets'.
import datasets 

dataset_repo_id = 'KORMo-Team/KORMo-tutorial-datasets'

pt_dataset = datasets.load_dataset(
    dataset_repo_id, 
    name='pretrain', 
    split='train'
)

train_ds = pt_dataset.shuffle(seed=42)
print(train_ds)

Dataset({
    features: ['text'],
    num_rows: 16000
})


In [3]:
# Load tokenizer
from transformers import AutoTokenizer 

tokenizer_repo_id = 'KORMo-Team/KORMo-tokenizer'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo_id)

In [4]:
# Tokenize dataset

def _tokenize(examples, tokenizer):
    input_ids = []
    for text in examples['text']:
        input_ids.append(tokenizer.encode(text) + [tokenizer.eos_token_id])
    return{
        'input_ids': input_ids
    }

tokenized_ds = train_ds.map(
    _tokenize, 
    batched=True, 
    num_proc=48,
    remove_columns=train_ds.column_names,
    fn_kwargs={'tokenizer': tokenizer},
)
print(tokenized_ds)

Dataset({
    features: ['input_ids'],
    num_rows: 16000
})


In [5]:
print(tokenizer.decode(tokenized_ds[3]['input_ids']))

<|BOS|> Abstract Reasoning in Business and Management: An In-Depth Analysis

Introduction

In the world of business and management, abstract reasoning plays a crucial role in making informed decisions, solving complex problems, and navigating the ever-changing landscape of the corporate world. It involves the ability to identify patterns, logical rules, and structures underlying different situations, independent of concrete experience with those situations. This skillset enables managers and leaders to think strategically, analyze data effectively, and adapt to new challenges with ease. In this chapter, we delve deep into the concept of abstract reasoning, its significance in business and management, and provide practical examples to illustrate its applications.

Understanding Abstract Reasoning

At its core, abstract reasoning refers to the cognitive process of recognizing and manipulating abstract patterns, relationships, and structures. Unlike other forms of thinking, which often re

In [6]:
from itertools import chain

def _pack_dataset(examples, seq_len):
    flat = list(chain.from_iterable(examples["input_ids"]))
    n_full  = len(flat) // seq_len
    chunks  = [flat[i*seq_len:(i+1)*seq_len] for i in range(n_full)]

    return {"input_ids": chunks}

def pack_dataset(ds, seq_len):
    return ds.map(
        _pack_dataset, 
        batched=True, 
        batch_size=100_000, 
        remove_columns=ds.column_names, 
        num_proc=128,
        fn_kwargs={'seq_len': seq_len}
    )

packed_ds = pack_dataset(tokenized_ds, 4096)
packed_ds.set_format('torch')
print(packed_ds)

Dataset({
    features: ['input_ids'],
    num_rows: 2326
})


In [7]:
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
import torch

K = 1024

@dataclass
class DataCollatorForCausalLM:
    tokenizer: PreTrainedTokenizer

    def __call__(self, instances):
        input_ids = [instance["input_ids"][:4*K] for instance in instances]
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )

        labels = input_ids.clone()
        labels[labels == self.tokenizer.pad_token_id] = -100

        return dict(
            input_ids=input_ids,
            labels=labels,
        )

In [8]:
from torch.utils.data import DataLoader

collator = DataCollatorForCausalLM(tokenizer)
data_loader = DataLoader(packed_ds, collate_fn=collator, batch_size=4)
next(iter(data_loader))

{'input_ids': tensor([[125030,  11228,  22576,  ...,   1358,    263,  12536],
         [   285,    514,   1700,  ...,     13, 125000,   3430],
         [  1238,   7625,   8660,  ...,     14,     17,    771],
         [  2299,   3697,  74836,  ...,  10571,   4543,   4297]]),
 'labels': tensor([[125030,  11228,  22576,  ...,   1358,    263,  12536],
         [   285,    514,   1700,  ...,     13, 125000,   3430],
         [  1238,   7625,   8660,  ...,     14,     17,    771],
         [  2299,   3697,  74836,  ...,  10571,   4543,   4297]])}

# 2. Build Intra-document Attention Mask (Using Flex-attention)

![attention_mask.png](./attachment/attention_mask.png)

In [9]:
from torch.nn.attention.flex_attention import create_block_mask, and_masks
import torch
input_ids = packed_ds[0]['input_ids']
input_ids_2d = input_ids.unsqueeze(0).to('cuda')

def _intra_doc_mask(input_ids, bos_token_id):
    is_bos = (input_ids == bos_token_id)
    is_bos_flat = is_bos.flatten()
    flat_doc_ids = torch.cumsum(is_bos_flat.long(), 0)
    doc_ids = flat_doc_ids.view_as(input_ids)
    
    def intra_doc_mask(b, h, q_idx, kv_idx):
        same_doc = doc_ids[b, q_idx] == doc_ids[b, kv_idx]
        return same_doc
    
    def causal_mask(b, h, q_idx, kv_idx):
        return q_idx >= kv_idx

    return and_masks(intra_doc_mask, causal_mask)

def create_intra_doc_mask(input_ids, tokenizer):
    model_bos_token_id = tokenizer.bos_token_id
    B, Q_LEN = input_ids.shape
    H = None 
    KV_LEN = Q_LEN

    mask_mod_func = _intra_doc_mask(input_ids.to('cuda'), model_bos_token_id)

    block_mask = create_block_mask(
        mask_mod=mask_mod_func,
        B=B,
        H=H,
        Q_LEN=Q_LEN,
        KV_LEN=KV_LEN
    )
    return block_mask

print(create_intra_doc_mask(input_ids_2d, tokenizer))

BlockMask(shape=(1, 1, 4096, 4096), sparsity=87.40%, 
(0, 0)
░░                              
██░░                            
████░░                          
░░░░░░░░                        
      ██░░                      
      ████░░                    
          ░░░░                  
            ░░░░                
            ░░██░░              
            ░░████░░            
            ░░░░░░░░░░          
                    ██░░        
                    ████░░      
                    ░░░░░░░░    
                          ░░░░  
                            ██░░
)


In [10]:
from dataclasses import dataclass
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedTokenizer

@dataclass
class DataCollatorIntraDocMask:
    tokenizer: PreTrainedTokenizer

    def __call__(self, instances):
        input_ids = [instance["input_ids"][:4*K] for instance in instances]
        input_ids = pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )

        labels = input_ids.clone()
        labels[labels == self.tokenizer.pad_token_id] = -100

        block_mask = create_intra_doc_mask(input_ids, self.tokenizer)

        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=block_mask,
        )


In [11]:
collator = DataCollatorIntraDocMask(tokenizer)
data_loader = DataLoader(packed_ds, collate_fn=collator, batch_size=2)
batch = next(iter(data_loader))
print("Visualize Attention Masks\n", batch['attention_mask'])

Visualize Attention Masks
 BlockMask(shape=(2, 1, 4096, 4096), sparsity=86.23%, 
(0, 0)
░░                              
██░░                            
████░░                          
░░░░░░░░                        
      ██░░                      
      ████░░                    
          ░░░░                  
            ░░░░                
            ░░██░░              
            ░░████░░            
            ░░░░░░░░░░          
                    ██░░        
                    ████░░      
                    ░░░░░░░░    
                          ░░░░  
                            ██░░

(1, 0)
░░                              
██░░                            
████░░                          
    ░░░░                        
    ░░██░░                      
    ░░████░░                    
    ░░██████░░                  
    ░░████████░░                
              ░░░░              
              ░░██░░            
              ░░░░░░░░          
             

# 3. Pretraining Setup with KORMoTrainer

In [12]:
from kormo.train.arguments import KORMoTrainingArguments
from kormo.train.trainer import KORMoTrainer
from kormo.modeling_configs.load_model import load_model_from_config

model, _ = load_model_from_config('1B', _attn_implementation='flex_attention')
model.to('cuda')
print("Attention implementation: ", model.config._attn_implementation)

Attention implementation:  flex_attention


In [13]:
training_arguments = KORMoTrainingArguments(
    output_dir='./kormo-1B-PT',
    per_device_train_batch_size=4,
    lr_scheduler_type='linear',
    logging_steps=10,
    save_strategy='epoch',
)

trainer = KORMoTrainer(
    model=model,
    args=training_arguments,
    train_dataset=packed_ds,
    processing_class=tokenizer,
    data_collator=DataCollatorIntraDocMask(tokenizer),
)

In [14]:
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 125040, 'pad_token_id': 125032}.
[34m[1mwandb[0m: Currently logged in as: [33mchoics2623[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
1,12.119
10,11.1309
20,8.7687
30,8.3785
40,7.7405
50,7.4842
60,7.2249
70,7.0791
80,6.8208
90,6.6924


TrainOutput(global_step=582, training_loss=5.700584742621458, metrics={'train_runtime': 232.0052, 'train_samples_per_second': 10.026, 'train_steps_per_second': 2.509, 'total_flos': 6.069368398440038e+16, 'train_loss': 5.700584742621458, 'mean_token_accuracy': 0.2748473882675171, 'num_input_tokens_seen': '9.53M', 'epoch': 1.0})