<a href="https://colab.research.google.com/github/1ucky40nc3/feed_h3/blob/main/examples/notebooks/run_clm_h3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!nvidia-smi

In [None]:
# @title Install Dependencies

!pip install transformers
!pip install evaluate
!pip install datasets
!pip install accelerate
!pip install einops
!pip install pytorch-lightning
!pip install flash-attn
!pip install pykeops
!pip install flatten-dict

# Install [H3](https://github.com/HazyResearch/H3)

In [None]:
!git clone --recursive https://github.com/HazyResearch/H3.git
%cd H3

In [None]:
# @title Import Dependencies

from typing import (
    Optional,
    Tuple,
    Union,
    Dict,
    List,
    Any
)

import os
import sys
import math
import copy
import json
import random
from dataclasses import (
    field,
    dataclass
)
from itertools import chain
from datetime import datetime

import torch
from torch import nn

import datasets
from datasets import (
    load_dataset,
    DatasetDict,
    get_dataset_infos,
    get_dataset_config_names
)

import evaluate

import transformers
from transformers import (
    TrainingArguments,
    AutoTokenizer,
    default_data_collator,
    get_scheduler
)
from transformers.utils import ModelOutput

from accelerate import Accelerator
from accelerate.utils import set_seed

from tqdm.auto import tqdm

import flatten_dict

from flash_attn.utils.generation import InferenceParams

from src.models.ssm_seq import SSMLMHeadModel

# Prepare the Configs

In [None]:
# @title Implement Configs

@dataclass
class Dataclass:
    def keys(self):
        return self.__dict__.keys()

    def __setitem__(self, item: Any, key: str) -> None:
        setattr(self, key, item)

    def __getitem__(self, key: str) -> Any:
        return getattr(self, key)

@dataclass
class SSMConfig(Dataclass):
    head_dim: int = 1
    d_state: int = 64
    dropout: float = 0.0
    mode: str = 'diag'
    measure: str = 'diag-lin'
    use_fast_fftconv: bool = False

@dataclass
class AttnConfig(Dataclass):
    num_heads: int = 12
    bias: bool = True
    dropout: float = 0.0
    rotary_emb_dim: Optional[int] = None

    def __post_init__(self):
        assert self.rotary_emb_dim in [None, 64], \
            'The `rotary_emb_dim` can either be `None`/`64`.'
        
        if self.rotary_emb_dim is None:
            self.rotary_emb_dim = 0

@dataclass
class SSMModelConfig(Dataclass):
    d_model: int = 768
    n_layer: int = 12
    ssm_cfg: SSMConfig = SSMConfig()
    attn_cfg: AttnConfig = AttnConfig()
    resid_dropout: float = 0.0
    embed_dropout: float = 0.1
    layer_norm_epsilon: float = 1e-5
    d_inner: Optional[int] = None
    attn_layer_idx: Optional[List[int]] = field(
        default_factory=lambda: [6]
    )
    fused_mlp: bool = False
    fused_dropout_add_ln: bool = False

    def __post_init__(self):
        if self.d_inner is None:
            self.d_inner = 4 * self.d_model

@dataclass
class ModelArguments(Dataclass):
    model_name_or_path: Optional[str] = None
    config_name: Optional[str] = None
    tokenizer_name: Optional[str] = None
    use_fast_tokenizer: bool = True

@dataclass
class DataTrainingArguments(Dataclass):
    dataset_name: Optional[str] = None
    dataset_config_name: Optional[str] = None
    train_file: Optional[str] = None
    validation_split_percentage: Optional[int] = 5
    max_seq_length: Optional[int] = None
    pad_to_max_length: bool = False
    max_train_samples: Optional[int] = None
    max_eval_samples: Optional[int] = None
    preprocessing_num_workers: Optional[int] = None
    block_size: Optional[int] = None
    keep_linebreaks: bool = True
    overwrite_cache: bool = False

    def __post__init__(self):
        def ext(path):
            f = os.path.split(path)[-1]
            _, ext = os.path.splitext(f)[-1]
            return ext

        if self.train_file is not None:
            assert ext(self.train_file) in ['csv', 'json', 'txt']
        if self.validation_file is not None:
            assert ext(self.validation_file)

In [None]:
# @title Implement Utils

def now(format='%Y%m%d%H%M%S'):
    return datetime.now().strftime(format)

In [None]:
# @title Initialize the Config

# SSMModel params
D_MODEL = 1024
D_INNER = 4 * D_MODEL
N_LAYER = 12
RESID_DROPOUT = 0.0
EMBED_DROPOUT = 0.1
LAYER_NORM_EPSILON = 1e-5
FUSED_MLP = False
FUSED_DROPUT_ADD_LN = False
# SSM params
SSM_CFG_HEAD_DIM = 8
SSM_CFG_D_STATE = 64
SSM_CFG_DROPUT = 0.0
SSM_CFG_USE_FAST_FFTCONV = False
# MHA params
ATTN_CFG_NUM_HEADS = 12
ATTN_CFG_BIAS = True
ATTN_CFG_DROPOUT = 0.1
ATTN_CFG_ROTARY_EMB_DIM = None

# Select SSMModel params config
CUSTOM_SSM_MODEL_CONFIG = False
SSM_MODEL_CONFIGS = ['125M', '125M_hybrid']
SSM_MODEL_CONFIG = '125M'
assert SSM_MODEL_CONFIG in SSM_MODEL_CONFIGS, f'The `SSM_CONFIG` is not in {SSM_MODEL_CONFIGS}!'

# Model args
MODEL_NAME_OR_PATH = 'google/byt5-small'
CONFIG_NAME = None
TOKENIZER_NAME = None

# Dataset args
DATASET_NAME = 'tiny_shakespeare'
TRAIN_FILE = None
VALIDATION_SPLIT_PERCENTAGE = 5
BLOCK_SIZE = 128
PAD_TO_MAX_LENGTH = False
MAX_TRAIN_SAMPLES = None
MAX_EVAL_SAMPLES = None
PREPROCESSING_NUM_WORKERS = None

# Training args
OUTPUT_DIR = f'./runs/{now()}'
DO_TRAIN = True
DO_EVAL = True
NUM_TRAIN_EPOCHS = 10
BATCH_SIZE = 64
GRADIENT_ACCUMULATION_STEPS = 1
LEARNING_RATE = 5e-5
FP16 = True
LOGGING_DIR = f'./runs/logs'
REPORT_TO = 'tensorboard'
SEED = 42
SAVE_TOTAL_LIMIT = 3


if CUSTOM_SSM_MODEL_CONFIG:
    model_config = SSMModelConfig(
        d_model=D_MODEL,
        d_inner=D_INNER,
        n_layer=N_LAYER,
        ssm_cfg=SSMConfig(
            head_dim=SSM_CFG_HEAD_DIM,
            d_state=SSM_CFG_D_STATE,
            dropout=SSM_CFG_DROPUT,
            use_fast_fftconv=SSM_CFG_USE_FAST_FFTCONV
        ),
        attn_cfg=AttnConfig(
            num_heads=ATTN_CFG_NUM_HEADS,
            bias=ATTN_CFG_BIAS,
            dropout=ATTN_CFG_DROPOUT,
            rotary_emb_dim=ATTN_CFG_ROTARY_EMB_DIM
        ),
        resid_dropout=RESID_DROPOUT,
        embed_dropout=EMBED_DROPOUT,
        layer_norm_epsilon=LAYER_NORM_EPSILON,
        fused_mlp=FUSED_MLP,
        fused_dropout_add_ln=FUSED_DROPUT_ADD_LN
    )
elif SSM_MODEL_CONFIG == '125M':
    model_config = SSMModelConfig(
        d_model=768,
        n_layer=12,
        ssm_cfg=SSMConfig(
            head_dim=8,
        ),
        attn_layer_idx=None,
        attn_cfg=AttnConfig(
            num_heads=12,
            rotary_emb_dim=None
        ),
    )
elif SSM_MODEL_CONFIG == '125M_hybrid':
    model_config = SSMModelConfig()
else:
    raise ValueError(f'The `SSM_CONFIG` is not in {SSM_MODEL_CONFIGS}!')

model_args = ModelArguments(
    model_name_or_path=MODEL_NAME_OR_PATH,
    config_name=CONFIG_NAME,
    tokenizer_name=TOKENIZER_NAME
)
data_args = DataTrainingArguments(
    dataset_name=DATASET_NAME,
    train_file=TRAIN_FILE,
    validation_split_percentage=VALIDATION_SPLIT_PERCENTAGE,
    block_size=BLOCK_SIZE,
    pad_to_max_length=PAD_TO_MAX_LENGTH,
    max_train_samples=MAX_TRAIN_SAMPLES,
    max_eval_samples=MAX_EVAL_SAMPLES,
    preprocessing_num_workers=PREPROCESSING_NUM_WORKERS
)
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    do_train=DO_TRAIN,
    do_eval=DO_EVAL,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    learning_rate=LEARNING_RATE,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    fp16=FP16,
    logging_dir=LOGGING_DIR,
    report_to=REPORT_TO,
    seed=SEED,
    data_seed=SEED,
    save_total_limit=SAVE_TOTAL_LIMIT
)
set_seed(training_args.seed)

# Initialize the Accelerator

In [None]:
accelerator_kwargs = {
    'mixed_precision': None if training_args.fp16 is False else 'fp16',
    'gradient_accumulation_steps': training_args.gradient_accumulation_steps,
    'log_with': training_args.report_to,
    'project_dir': training_args.logging_dir,
}
accelerator = Accelerator(**accelerator_kwargs)

# Load the data

In [None]:
if data_args.dataset_name is not None:
    dataset_config_name = data_args.dataset_config_name
    if dataset_config_name is None:
        dataset_config_names = get_dataset_config_names(data_args.dataset_name)
        assert len(dataset_config_names) == 1, f'The dataset has multiple `configs`! Choose one of: {dataset_config_names}'
        dataset_config_name = dataset_config_names[0]
    infos = get_dataset_infos(data_args.dataset_name)[dataset_config_name]

    if 'validation' not in infos.splits.keys():
        raw_datasets = DatasetDict()
        raw_datasets['validation'] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            split=f'train[:{data_args.validation_split_percentage}%]',
        )
        raw_datasets['train'] = load_dataset(
            data_args.dataset_name,
            data_args.dataset_config_name,
            split=f'train[{data_args.validation_split_percentage}%:]',
        )
    else:
        raw_datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name)

else:
    data_files = {}
    dataset_args = {}
    if data_args.train_file is not None:
        data_files['train'] = data_args.train_file
    if data_args.validation_file is not None:
        data_files['validation'] = data_args.validation_file

    extension = data_args.train_file.split('.')[-1]
    if extension == 'txt':
        extension = 'text'
        dataset_args['keep_linebreaks'] = not data_args.no_keep_linebreaks
    
    if 'validation' not in data_files.keys():
        raw_datasets = DatasetDict()
        raw_datasets['validation'] = load_dataset(
            extension,
            data_files=data_files,
            split=f'train[:{data_args.validation_split_percentage}%]',
            **dataset_args,
        )
        raw_datasets['train'] = load_dataset(
            extension,
            data_files=data_files,
            split=f'train[{data_args.validation_split_percentage}%:]',
            **dataset_args,
        )
    else:
        raw_datasets = load_dataset(extension, data_files=data_files, **dataset_args)

# Initialize the tokenizer

In [None]:
if model_args.tokenizer_name:
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name, 
        use_fast=model_args.use_fast_tokenizer
    )
elif model_args.model_name_or_path:
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path, 
        use_fast=model_args.use_fast_tokenizer
    )
else:
    raise ValueError(
        'You are instantiating a new tokenizer from scratch. This is not supported by this script.'
        'You can do it from another script, save it, and load it from here, using --tokenizer_name.'
    )

# Preprocess the datasets

In [None]:
# Preprocessing the datasets.
# First we tokenize all the texts.
column_names = raw_datasets['train'].column_names
text_column_name = 'text' if 'text' in column_names else column_names[0]

def tokenize_function(examples):
    return tokenizer(examples[text_column_name])

with accelerator.main_process_first():
    tokenized_datasets = raw_datasets.map(
        tokenize_function,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        load_from_cache_file=not data_args.overwrite_cache,
        desc='Running tokenizer on dataset',
    )

if data_args.block_size is None:
    block_size = tokenizer.model_max_length
    if block_size > 1024:
        print(
            'The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value'
            ' of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can'
            ' override this default with `--block_size xxx`.'
        )
    block_size = 1024
else:
    if data_args.block_size > tokenizer.model_max_length:
        print(
            f'The block_size passed ({data_args.block_size}) is larger than the maximum length for the model'
            f'({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}.'
        )
    block_size = min(data_args.block_size, tokenizer.model_max_length)

# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= block_size:
        total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result['labels'] = result['input_ids'].copy()
    return result

# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
# to preprocess.
#
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map

with accelerator.main_process_first():
    lm_datasets = tokenized_datasets.map(
        group_texts,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        load_from_cache_file=not data_args.overwrite_cache,
        desc=f'Grouping texts in chunks of {block_size}',
    )

train_dataset = lm_datasets['train']
eval_dataset = lm_datasets['validation']

# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
    print(f'Sample {index} of the training set: {train_dataset[index]}.')

# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, 
    shuffle=True, 
    collate_fn=default_data_collator, 
    batch_size=training_args.per_device_train_batch_size
)
eval_dataloader = torch.utils.data.DataLoader(
    eval_dataset, 
    collate_fn=default_data_collator, 
    batch_size=training_args.per_device_eval_batch_size
)

# Initialize the model

In [None]:
class CausalLMOutput(ModelOutput):
    logits: torch.FloatTensor = None
    loss: Optional[torch.FloatTensor] = None


class SSMModelForCausalLM(nn.Module):
    def __init__(self, config: SSMModelConfig, **kwargs) -> None:
        super().__init__()

        self.config = config
        self.model = SSMLMHeadModel(
            model_config.d_model, 
            n_layer=model_config.n_layer, 
            d_inner=4 * model_config.d_model, 
            vocab_size=len(tokenizer),
            ssm_cfg=model_config.ssm_cfg, 
            attn_layer_idx=model_config.attn_layer_idx, 
            attn_cfg=model_config.attn_cfg,
            pad_vocab_size_multiple=8
        )

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: Optional[torch.Tensor] = None,
        inference_params: Optional[InferenceParams] = None,
        labels: Optional[torch.Tensor] = None,
        **kwargs
    ) -> CausalLMOutput:
        logits = self.model(
            input_ids,
            position_ids,
            inference_params
        ).logits

        loss = None 
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)), 
                shift_labels.view(-1)
            )
        
        return CausalLMOutput(
            loss=loss,
            logits=logits
        )

    def generate(self, *args, **kwargs) -> Any:
        return self.model.generate(*args, **kwargs)


model = SSMModelForCausalLM(model_config)

# Initialize the optimizer

In [None]:
no_decay = ['bias', 'layer_norm.weight']
optimizer_grouped_parameters = [
    {
        'params': [
            p 
            for n, p in model.named_parameters() 
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay': training_args.weight_decay,
    },
    {
        'params': [
            p 
            for n, p in model.named_parameters() 
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay': 0.0,
    },
]
optimizer = torch.optim.AdamW(
    optimizer_grouped_parameters, 
    lr=training_args.learning_rate
)

# Initialize the Schedule

In [None]:
overwrote_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
if training_args.max_steps in (None, -1):
    training_args.max_steps = training_args.num_train_epochs * num_update_steps_per_epoch
    overwrote_max_train_steps = True

lr_scheduler = get_scheduler(
    name=training_args.lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=training_args.warmup_steps * training_args.gradient_accumulation_steps,
    num_training_steps=training_args.max_steps * training_args.gradient_accumulation_steps,
)

# Put Everything on the Accelerator

In [None]:
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)

# Prepare Training Parameters

In [None]:
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
if overwrote_max_train_steps:
    training_args.max_steps = int(training_args.num_train_epochs * num_update_steps_per_epoch)
# Afterwards we recalculate our number of training epochs
training_args.num_train_epochs = math.ceil(training_args.max_steps / num_update_steps_per_epoch)

# Figure out how many steps we should save the Accelerator states
checkpointing_steps = training_args.save_steps
if checkpointing_steps is not None and isinstance(checkpointing_steps, str) and checkpointing_steps.isdigit():
    checkpointing_steps = int(checkpointing_steps)

total_batch_size = training_args.per_device_train_batch_size * accelerator.num_processes * training_args.gradient_accumulation_steps

experiment_config = {**vars(training_args), **vars(data_args), **vars(model_args), **vars(model_config)}
# TensorBoard cannot log Enums, need the raw value
experiment_config['lr_scheduler_type'] = experiment_config['lr_scheduler_type'].value
experiment_config = {k: v for k, v in experiment_config.items() if not k.startswith('_')}
experiment_config = flatten_dict.flatten(experiment_config, reducer='path')
cast = lambda a: a if isinstance(a, (int, float, str, bool, torch.Tensor)) else str(a)
experiment_config = {k: cast(v) for k, v in experiment_config.items() if not k.startswith('_')}
accelerator.init_trackers('clm_no_trainer', experiment_config)

# Train!

In [None]:
%reload_ext tensorboard
%tensorboard --logdir $LOGGING_DIR

In [None]:
print('***** Running training *****')
print(f'  Num examples = {len(train_dataset)}')
print(f'  Num Epochs = {training_args.num_train_epochs}')
print(f'  Instantaneous batch size per device = {training_args.per_device_train_batch_size}')
print(f'  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}')
print(f'  Gradient Accumulation steps = {training_args.gradient_accumulation_steps}')
print(f'  Total optimization steps = {training_args.max_steps}')
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(training_args.max_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
starting_epoch = 0

for epoch in range(starting_epoch, training_args.num_train_epochs):
    model.train()
    total_loss = 0
    for step, batch in enumerate(train_dataloader):
        with accelerator.accumulate(model):
            output = model(**batch)
            loss = output.loss
            total_loss += loss.detach().float()
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
        
        if accelerator.sync_gradients:
            progress_bar.update(1)
            completed_steps += 1
        
        if isinstance(checkpointing_steps, int):
            if completed_steps % checkpointing_steps == 0:
                output_dir = f'step_{completed_steps }'
                if training_args.output_dir is not None:
                    output_dir = os.path.join(training_args.output_dir, output_dir)
                accelerator.save_state(output_dir)
        if completed_steps >= training_args.max_steps:
            break

    model.eval()
    losses = []
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            output = model(**batch)
        loss = output.loss
        losses.append(
            accelerator.gather_for_metrics(
                loss.repeat(training_args.per_device_eval_batch_size)
            )
        )
    losses = torch.cat(losses)
    try:
        eval_loss = torch.mean(losses)
        perplexity = math.exp(eval_loss)
    except OverflowError:
        perplexity = float('inf')

    print(f'epoch {epoch}: perplexity: {perplexity} eval_loss: {eval_loss}')

    accelerator.log(
        {
            'perplexity': perplexity,
            'eval_loss': eval_loss,
            'train_loss': total_loss.item() / len(train_dataloader),
            'epoch': epoch,
            'step': completed_steps,
        },
        step=completed_steps,
    )

if training_args.output_dir is not None:
    output_dir = f'step_{completed_steps}'
    if training_args.output_dir is not None:
        output_dir = os.path.join(training_args.output_dir, output_dir)
    accelerator.save_state(output_dir)
    accelerator.wait_for_everyone()

    if accelerator.is_main_process:
        tokenizer.save_pretrained(training_args.output_dir)

        with open(os.path.join(training_args.output_dir, 'all_results.json'), 'w') as f:
            json.dump({'perplexity': perplexity}, f)

# Do Inference

In [None]:
prompt = '\n'
inputs = tokenizer(prompt, return_tensors='pt')
input_ids = inputs.input_ids.to(accelerator.device)

with torch.no_grad():
    output_ids = model.generate(
        input_ids=input_ids, 
        max_length=128,
        return_dict_in_generate=False, 
        output_scores=False, 
        timing=False, 
        top_p=0.9, 
        top_k=50, 
        eos_token_id=tokenizer.eos_token_id
)

print(tokenizer.batch_decode(output_ids)[0])

# Utils

In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()