diff --git a/model/model_training/configs/config.yaml b/model/model_training/configs/config.yaml index a8fc2ea92a..a997c0d19a 100644 --- a/model/model_training/configs/config.yaml +++ b/model/model_training/configs/config.yaml @@ -779,3 +779,32 @@ debug: verbose: true num_train_epochs: 0.2 dtype: fp32 + +rope_scaling_test: + dtype: bf16 + log_dir: "llama_log_7b" + learning_rate: 1e-5 + model_name: "huggyllama/llama-7b" + deepspeed_config: configs/zero_config_falcon.json + output_dir: llama + weight_decay: 0.0 + max_length: 4048 + warmup_steps: 100 + gradient_checkpointing: true + gradient_accumulation_steps: 2 + per_device_train_batch_size: 1 + per_device_eval_batch_size: 1 + eval_steps: 100 + save_steps: 500 + num_train_epochs: 8 + save_total_limit: 4 + use_flash_attention: false + residual_dropout: 0.3 + residual_dropout_lima: true + log_wandb: true + peft_model: true + peft_type: "lora" + superhot: true + superhot_config: + type: linear + scale: 2 diff --git a/model/model_training/models/patching.py b/model/model_training/models/patching.py index c8757beb8f..9f97514d03 100644 --- a/model/model_training/models/patching.py +++ b/model/model_training/models/patching.py @@ -6,12 +6,13 @@ import torch.nn as nn import transformers -from transformers import GPTNeoXForCausalLM, GPTNeoXModel, LlamaForCausalLM, LlamaModel +from transformers import AutoConfig, GPTNeoXForCausalLM, GPTNeoXModel, LlamaForCausalLM, LlamaModel from trlx.models.modeling_ppo import AutoModelForCausalLMWithHydraValueHead from .patching_llama import llama_forward_with_flash_attn from .patching_neox import neox_forward_with_flash_attn from .reward_model import GPTNeoXRewardModel +from .rope import LlamaDynamicScaledRotaryEmbedding, LlamaLinearScaledRope, LlamaNTKScaledRope, RWNTKScaledRope SUPPORTED_MODELS = [ GPTNeoXModel, @@ -176,3 +177,54 @@ def patch_model( if resid_pdrop is not None and resid_pdrop > 0: add_dropout(getattr(layer, attention_key), _patched_attn_forward, resid_pdrop) add_dropout(getattr(layer, mlp_key), _patched_mlp_forward, resid_pdrop) + + +class RopePatch: + def __init__(self, model_name, **kwargs): + self.args = kwargs + rope_type = self.args.pop("type") + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + architecture = config.architectures + if architecture: + self.model_name = architecture[0] + if "RWForCausalLM" in architecture: + self.architecture = "RWForCausalLM" + if rope_type == "ntk": + self.patch_fun = RWNTKScaledRope + else: + raise NotImplementedError() + elif "LlamaForCausalLM" in architecture: + self.architecture = "LlamaForCausalLM" + if rope_type == "linear": + self.patch_fun = LlamaLinearScaledRope + elif rope_type == "ntk": + self.patch_fun = LlamaNTKScaledRope + elif rope_type == "dynamic-ntk": + self.patch_fun = LlamaDynamicScaledRotaryEmbedding + else: + raise NotImplementedError() + else: + raise NotImplementedError() + + @classmethod + def from_config(cls, config): + model_name = config.model_name + args = config.superhot_config + return cls(model_name, **args) + + def patch(self, model): + if self.architecture == "RWForCausalLM": + self.patch_rw_model(model, **self.args) + elif self.architecture == "LlamaForCausalLM": + self.patch_llama_model(model, **self.args) + else: + raise NotImplementedError() + + def patch_rw_model(self, model, **kwargs): + for each in model.transformer.h: + each.self_attention.maybe_rotary = self.patch_fun(model.config.head_dim, **kwargs) + + def patch_llama_model(self, model, **kwargs): + kwargs.update({"device": model.device}) + for each in model.model.layers: + each.self_attn.rotary_emb = self.patch_fun(each.self_attn.head_dim, **kwargs) diff --git a/model/model_training/models/rope.py b/model/model_training/models/rope.py new file mode 100644 index 0000000000..005a40c729 --- /dev/null +++ b/model/model_training/models/rope.py @@ -0,0 +1,187 @@ +import torch + + +# rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...) +def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0 + + +class RWNTKScaledRope(torch.nn.Module): + + """ + NTK-Scaled RoPE for RefinedWebModel + """ + + def __init__( + self, + head_dim: int, + base=10000, + alpha: int = 2, + ): + super().__init__() + self.alpha = alpha + base = base * self.alpha ** (head_dim / (head_dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.head_dim = head_dim + self.seq_len_cached = None + self.batch_size_cached = None + self.cos_cached: torch.Tensor | None = None + self.sin_cached: torch.Tensor | None = None + + def cos_sin( + self, + seq_len: int, + device="cuda", + dtype=torch.bfloat16, + ) -> torch.Tensor: + if seq_len != self.seq_len_cached: + self.seq_len_cached = seq_len + t = torch.arange(seq_len, device=device).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + emb = torch.cat((freqs, freqs), dim=-1).to(device) + + if dtype in [torch.float16, torch.bfloat16]: + emb = emb.float() + + self.cos_cached = emb.cos()[None, :, :] + self.sin_cached = emb.sin()[None, :, :] + + self.cos_cached = self.cos_cached.type(dtype) + self.sin_cached = self.sin_cached.type(dtype) + + return self.cos_cached, self.sin_cached + + def forward(self, q, k): + batch, seq_len, head_dim = q.shape + cos, sin = self.cos_sin(seq_len, q.device, q.dtype) + return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) + + +class LlamaLinearScaledRope(torch.nn.Module): + """ + reference: https://huggingface.co/kaiokendev/superhot-13b-8k-no-rlhf-test + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, scale=1, device=None): + super().__init__() + self.scale = 1 / scale + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + t *= self.scale + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + dtype = torch.get_default_dtype() + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + t *= self.scale + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +class LlamaNTKScaledRope(torch.nn.Module): + + """ + reference: https://github.com/jquesnelle/scaled-rope + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, alpha=1, device=None): + super().__init__() + base = base * alpha ** (dim / (dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + dtype = torch.get_default_dtype() + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +class LlamaDynamicScaledRotaryEmbedding(torch.nn.Module): + """ + reference: https://github.com/jquesnelle/scaled-rope + """ + + def __init__(self, dim, max_position_embeddings=2048, base=10000, ntk=False, device=None): + super().__init__() + self.ntk = ntk + self.base = base + self.dim = dim + self.max_position_embeddings = max_position_embeddings + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) + self.register_buffer("inv_freq", inv_freq) + + # Build here to make `torch.jit.trace` work. + self.max_seq_len_cached = max_position_embeddings + t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + dtype = torch.get_default_dtype() + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. + if seq_len > self.max_seq_len_cached: + self.max_seq_len_cached = seq_len + if self.ntk: + base = self.base * ((self.ntk * seq_len / self.max_position_embeddings) - (self.ntk - 1)) ** ( + self.dim / (self.dim - 2) + ) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim)) + self.register_buffer("inv_freq", inv_freq) + t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) + if not self.ntk: + t *= self.max_position_embeddings / seq_len + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False) + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) diff --git a/model/model_training/trainer_sft.py b/model/model_training/trainer_sft.py index 547f1e5aed..158e1c621e 100755 --- a/model/model_training/trainer_sft.py +++ b/model/model_training/trainer_sft.py @@ -11,6 +11,7 @@ # from model_training.custom_datasets.formatting import DatasetEntry from model_training.custom_datasets.dialogue_collator import DialogueDataCollator from model_training.efficiency_utils import fuse_gelu +from model_training.models.patching import RopePatch from model_training.models.peft_modeling import peft_model from model_training.utils.utils import ( PerDatasetSampler, @@ -362,7 +363,6 @@ def main(): ) train, evals = get_dataset(training_conf) - show_dataset_stats = (training_conf.verbose or training_conf.show_dataset_stats) and ( not training_conf.deepspeed or training_conf.local_rank == 0 ) @@ -416,9 +416,12 @@ def main(): sampler = None metrics, preprocess_fns = get_metrics(training_conf, tokenizer) - model = get_model(training_conf, tokenizer) + superhot = RopePatch.from_config(training_conf) if training_conf.superhot else None + if superhot: + superhot.patch(model) + if training_conf.peft_model: print("Using PEFT model") model = peft_model(