In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]='0'

# 1. Preparing the Dataset for Pretraining

This section builds the dataset pipeline for language model pretraining.

- 1. Load the raw dataset
- 2. Tokenize the text
- 3. Pack sequences
- 4. Data collator
- 5. Verify the processed data with a DataLoader

### Raw text $\to$ Tokenize $\to$ Sequence packing $\to$ Data Collation

In [None]:
# Load the raw dataset from 'KORMo-Team/KORMo-tutorial-datasets'.
import datasets 

dataset_repo_id = 'KORMo-Team/KORMo-tutorial-datasets'
config_names = datasets.get_dataset_config_names(dataset_repo_id)

dataset = []
for name in config_names:
    print(f"Load \"{name}\"...")
    text_dataset = datasets.load_dataset(dataset_repo_id, name=name, split='train').select_columns(['text'])
    dataset.append(text_dataset)

train_ds = datasets.concatenate_datasets(dataset)
train_ds = train_ds.shuffle(seed=42)
print(train_ds)

Load "cosmopedia_auto_math_text"...
Load "cosmopedia_khanacademy"...
Load "cosmopedia_openstax"...
Load "cosmopedia_stanford"...
Load "cosmopedia_stories"...
Load "cosmopedia_web_samples_v1"...
Load "cosmopedia_web_samples_v2"...
Load "cosmopedia_wikihow"...


In [4]:
# Tokenize dataset
from transformers import AutoTokenizer 

tokenizer_repo_id = 'KORMo-Team/KORMo-tokenizer'
tokenizer = AutoTokenizer.from_pretrained(tokenizer_repo_id)

In [5]:
# Tokenize all dataset

def _tokenize(examples, tokenizer):
    input_ids = []
    for text in examples['text']:
        input_ids.append(tokenizer.encode(text) + [tokenizer.eos_token_id])
    return{
        'input_ids': input_ids
    }

tokenized_ds = train_ds.map(
    _tokenize, 
    batched=True, 
    num_proc=48,
    remove_columns=train_ds.column_names,
    fn_kwargs={'tokenizer': tokenizer},
)
print(tokenized_ds)

Dataset({
    features: ['input_ids'],
    num_rows: 8000
})


In [6]:
print(tokenizer.decode(tokenized_ds[3]['input_ids']))

<|BOS|> **The Multiplier Effect in the Keynesian Cross Model:**

The Keynesian perspective on macroeconomic theory posits that government spending can serve as a crucial tool for managing economic fluctuations and promoting full employment. However, the relationship between changes in government spending and their impact on output is nuanced and complex due to the presence of the multiplier effect. This phenomenon suggests that a given increase in government spending may lead to a larger overall shift in equilibrium national income.

To illustrate, consider an economy where the intersection of the aggregate expenditure function and the 45-degree line occurs at a GDP of $700, whereas the level of potential GDP equals $800. At first glance, it may appear logical to assume that increasing government spending by $100 would suffice to reach potential GDP. Nevertheless, such reasoning overlooks the intricate interplay of various components within the model, specifically the induced increases

In [None]:
from itertools import chain

def _pack_dataset(examples, seq_len):
    flat = list(chain.from_iterable(examples["input_ids"]))
    n_full  = len(flat) // seq_len
    chunks  = [flat[i*seq_len:(i+1)*seq_len] for i in range(n_full)]

    return {"input_ids": chunks}

def pack_dataset(ds, seq_len):
    return ds.map(
        _pack_dataset, 
        batched=True, 
        batch_size=100_000, 
        remove_columns=ds.column_names, 
        num_proc=128,
        fn_kwargs={'seq_len': seq_len}
    )

packed_ds = pack_dataset(tokenized_ds, 4096)
packed_ds.set_format('torch')
print(packed_ds)

In [9]:
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
import torch

K = 1024

@dataclass
class DataCollatorForCausalLM:
    tokenizer: PreTrainedTokenizer

    def __call__(self, instances):
        input_ids = [instance["input_ids"][:4*K] for instance in instances]
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )

        labels = input_ids.clone()
        labels[labels == self.tokenizer.pad_token_id] = -100

        return dict(
            input_ids=input_ids,
            labels=labels,
        )

In [None]:
from torch.utils.data import DataLoader

collator = DataCollatorForCausalLM(tokenizer)
data_loader = DataLoader(packed_ds, collate_fn=collator, batch_size=4)
next(iter(data_loader))

{'input_ids': tensor([[125030,  16627,   2312,  ...,   2094,    626,    601],
         [  9017,     13, 125000,  ...,    269,  12447,    960],
         [  1870,    285,    534,  ...,    626,   2889,    281],
         [   401,  10490,    437,  ...,     13,    832,   9333]]),
 'labels': tensor([[125030,  16627,   2312,  ...,   2094,    626,    601],
         [  9017,     13, 125000,  ...,    269,  12447,    960],
         [  1870,    285,    534,  ...,    626,   2889,    281],
         [   401,  10490,    437,  ...,     13,    832,   9333]])}

# 2. Build Intra-document Attention Mask (Using Flex-attention)

![attention_mask.png](./attachment/attention_mask.png)

In [None]:
from torch.nn.attention.flex_attention import create_block_mask, and_masks
import torch
input_ids = packed_ds[0]['input_ids']
input_ids_2d = input_ids.unsqueeze(0).to('cuda')

def _intra_doc_mask(input_ids, bos_token_id):
    is_bos = (input_ids == bos_token_id)
    is_bos_flat = is_bos.flatten()
    flat_doc_ids = torch.cumsum(is_bos_flat.long(), 0)
    doc_ids = flat_doc_ids.view_as(input_ids)
    
    def intra_doc_mask(b, h, q_idx, kv_idx):
        same_doc = doc_ids[b, q_idx] == doc_ids[b, kv_idx]
        return same_doc
    
    def causal_mask(b, h, q_idx, kv_idx):
        return q_idx >= kv_idx

    return and_masks(intra_doc_mask, causal_mask)

def create_intra_doc_mask(input_ids, tokenizer):
    model_bos_token_id = tokenizer.bos_token_id
    B, Q_LEN = input_ids.shape
    H = None 
    KV_LEN = Q_LEN

    mask_mod_func = _intra_doc_mask(input_ids.to('cuda'), model_bos_token_id)

    block_mask = create_block_mask(
        mask_mod=mask_mod_func,
        B=B,
        H=H,
        Q_LEN=Q_LEN,
        KV_LEN=KV_LEN
    )
    return block_mask

print(create_intra_doc_mask(input_ids_2d, tokenizer))

BlockMask(shape=(1, 1, 4096, 4096), sparsity=84.96%, 
(0, 0)
░░                              
██░░                            
████░░                          
    ░░░░                        
    ░░██░░                      
    ░░████░░                    
    ░░██████░░                  
    ░░░░░░░░░░░░                
              ██░░              
              ████░░            
              ██████░░          
                    ░░░░        
                    ░░██░░      
                    ░░████░░    
                    ░░░░░░░░░░  
                            ██░░
)


In [None]:
from dataclasses import dataclass
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedTokenizer

@dataclass
class DataCollatorIntraDocMask:
    tokenizer: PreTrainedTokenizer

    def __call__(self, instances):
        input_ids = [instance["input_ids"][:4*K] for instance in instances]
        input_ids = pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )

        labels = input_ids.clone()
        labels[labels == self.tokenizer.pad_token_id] = -100

        block_mask = create_intra_doc_mask(input_ids, self.tokenizer)

        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=block_mask,
        )


In [None]:
collator = DataCollatorIntraDocMask(tokenizer)
data_loader = DataLoader(packed_ds, collate_fn=collator, batch_size=2)
batch = next(iter(data_loader))
print("Visualize Attention Masks\n", batch['attention_mask'])

Visualize Attention Masks BlockMask(shape=(2, 1, 4096, 4096), sparsity=84.96%, 
(0, 0)
░░                              
██░░                            
████░░                          
    ░░░░                        
    ░░██░░                      
    ░░████░░                    
    ░░██████░░                  
    ░░░░░░░░░░░░                
              ██░░              
              ████░░            
              ██████░░          
                    ░░░░        
                    ░░██░░      
                    ░░████░░    
                    ░░░░░░░░░░  
                            ██░░

(1, 0)
░░                              
██░░                            
████░░                          
░░░░░░░░                        
      ░░░░                      
        ██░░                    
        ████░░                  
        ░░░░░░░░                
              ██░░              
              ████░░            
              ██████░░          
              

# 3. Pretraining Setup with KORMoTrainer

In [None]:
from kormo.train.arguments import KORMoTrainingArguments
from kormo.train.trainer import KORMoTrainer
from kormo.modeling_configs.load_model import load_model_from_config

model, _ = load_model_from_config('1B', _attn_implementation='flex_attention')
model.to('cuda')
print("Attention implementation: ", model.config._attn_implementation)

In [None]:
training_arguments = KORMoTrainingArguments(
    output_dir='./kormo-1B-PT',
    per_device_train_batch_size=4,
    lr_scheduler_type='linear',
    logging_steps=10,
    save_strategy='epoch'
)

trainer = KORMoTrainer(
    model=model,
    args=training_arguments,
    train_dataset=packed_ds,
    processing_class=tokenizer,
    data_collator=DataCollatorIntraDocMask(tokenizer),
)

In [18]:
trainer.train()

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 125040, 'pad_token_id': 125032}.


Step,Training Loss
1,12.1489
10,11.0568
20,8.304
30,8.0389
40,7.7815
50,7.6191
60,7.3822
70,7.2081
80,6.9703
90,6.8597


TrainOutput(global_step=338, training_loss=6.432804906156641, metrics={'train_runtime': 136.8114, 'train_samples_per_second': 9.882, 'train_steps_per_second': 2.471, 'total_flos': 3.527852998577357e+16, 'train_loss': 6.432804906156641, 'mean_token_accuracy': 0.22595390863716602, 'num_input_tokens_seen': '5.54M', 'epoch': 1.0})