In [1]:
# !sudo apt install libopenmpi-dev -y
# !pip3 install mpi4py --user
# !pip3 install deepspeed==0.12.3 --user

In [2]:
# !pip3 install accelerate transformers -U --user

In [3]:
!pip3 freeze

absl-py==2.0.0
accelerate==0.25.0
aiofiles==23.2.1
aiohttp==3.8.5
aiohttp-cors==0.7.0
aiorwlock==1.3.0
aiosignal==1.3.1
altair==5.1.2
anyio==3.7.1
appdirs==1.4.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
asttokens==2.2.1
async-timeout==4.0.3
attributedict==0.3.0
attrs==23.1.0
autoawq @ https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.6/autoawq-0.1.6+cu118-cp310-cp310-linux_x86_64.whl
azure-core==1.29.5
azure-identity==1.15.0
azure-storage-blob==12.18.3
azure-storage-file-datalake==12.13.2
backcall==0.2.0
bcrypt==4.0.1
beautifulsoup4==4.12.2
bitsandbytes==0.41.0
bleach==6.0.0
blessed==1.20.0
blessings==1.7
boto3==1.28.78
botocore==1.31.78
Brotli==1.1.0
cachetools==5.3.2
causal-conv1d==1.0.0
certifi==2022.12.7
cffi==1.15.1
chardet==5.2.0
charset-normalizer==2.1.1
circuitbreaker==1.4.0
click==8.1.7
cmake==3.27.7
codecov==2.1.13
colorama==0.4.6
coloredlogs==15.0.1
colorful==0.5.5
colour-runner==0.1.1
comm==0.1.4
contourpy=

In [4]:
!nvidia-smi

Tue Dec  5 08:32:01 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  On   | 00000001:00:00.0 Off |                    0 |
| N/A   38C    P0    63W / 300W |  33894MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [5]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [6]:
# Copyright (c) 2023, Albert Gu, Tri Dao.

import math
from functools import partial

from collections import namedtuple

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import PretrainedConfig

from mamba_ssm.modules.mamba_simple import Mamba, Block
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf

try:
    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None


def create_block(
    d_model,
    ssm_cfg=None,
    norm_epsilon=1e-5,
    rms_norm=False,
    residual_in_fp32=False,
    fused_add_norm=False,
    layer_idx=None,
    device=None,
    dtype=None,
):
    if ssm_cfg is None:
        ssm_cfg = {}
    factory_kwargs = {"device": device, "dtype": dtype}
    mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
    norm_cls = partial(
        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
    )
    block = Block(
        d_model,
        mixer_cls,
        norm_cls=norm_cls,
        fused_add_norm=fused_add_norm,
        residual_in_fp32=residual_in_fp32,
    )
    block.layer_idx = layer_idx
    return block


# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(
    module,
    n_layer,
    initializer_range=0.02,  # Now only used for embedding layer.
    rescale_prenorm_residual=True,
    n_residuals_per_layer=1,  # Change to 2 if we have MLP
):
    if isinstance(module, nn.Linear):
        if module.bias is not None:
            if not getattr(module.bias, "_no_reinit", False):
                nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=initializer_range)

    if rescale_prenorm_residual:
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if name in ["out_proj.weight", "fc2.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
                # We need to reinit p since this code could be called multiple times
                # Having just p *= scale would repeatedly scale it down
                nn.init.kaiming_uniform_(p, a=math.sqrt(5))
                with torch.no_grad():
                    p /= math.sqrt(n_residuals_per_layer * n_layer)


class MixerModel(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_layer: int,
        vocab_size: int,
        ssm_cfg=None,
        norm_epsilon: float = 1e-5,
        rms_norm: bool = False,
        initializer_cfg=None,
        fused_add_norm=False,
        residual_in_fp32=False,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.residual_in_fp32 = residual_in_fp32

        self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)

        # We change the order of residual and layer norm:
        # Instead of LN -> Attn / MLP -> Add, we do:
        # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
        # the main branch (output of MLP / Mixer). The model definition is unchanged.
        # This is for performance reason: we can fuse add + layer_norm.
        self.fused_add_norm = fused_add_norm
        if self.fused_add_norm:
            if layer_norm_fn is None or rms_norm_fn is None:
                raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")

        self.layers = nn.ModuleList(
            [
                create_block(
                    d_model,
                    ssm_cfg=ssm_cfg,
                    norm_epsilon=norm_epsilon,
                    rms_norm=rms_norm,
                    residual_in_fp32=residual_in_fp32,
                    fused_add_norm=fused_add_norm,
                    layer_idx=i,
                    **factory_kwargs,
                )
                for i in range(n_layer)
            ]
        )

        self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
            d_model, eps=norm_epsilon, **factory_kwargs
        )

        self.apply(
            partial(
                _init_weights,
                n_layer=n_layer,
                **(initializer_cfg if initializer_cfg is not None else {}),
            )
        )

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return {
            i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
            for i, layer in enumerate(self.layers)
        }

    def forward(self, input_ids, inference_params=None):
        hidden_states = self.embedding(input_ids)
        residual = None
        for layer in self.layers:
            hidden_states, residual = layer(
                hidden_states, residual, inference_params=inference_params
            )
        if not self.fused_add_norm:
            residual = (hidden_states + residual) if residual is not None else hidden_states
            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
        else:
            # Set prenorm=False here since we don't need the residual
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
            hidden_states = fused_add_norm_fn(
                hidden_states,
                self.norm_f.weight,
                self.norm_f.bias,
                eps=self.norm_f.eps,
                residual=residual,
                prenorm=False,
                residual_in_fp32=self.residual_in_fp32,
            )
        return hidden_states


class MambaLMHeadModel(nn.Module, GenerationMixin):

    def __init__(
        self,
        d_model: int,
        n_layer: int,
        vocab_size: int,
        initializer_cfg=None,
        pad_vocab_size_multiple: int = 1,
        device=None,
        dtype=None,
        **backbone_kwargs,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        if vocab_size % pad_vocab_size_multiple != 0:
            vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
        self.backbone = MixerModel(
            d_model=d_model,
            n_layer=n_layer,
            vocab_size=vocab_size,
            initializer_cfg=initializer_cfg,
            **backbone_kwargs,
            **factory_kwargs,
        )
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)

        # Initialize weights and apply final processing
        self.apply(
            partial(
                _init_weights,
                n_layer=n_layer,
                **(initializer_cfg if initializer_cfg is not None else {}),
            )
        )
        self.tie_weights()
        self.config = PretrainedConfig(
            d_model = d_model,
            n_layer = n_layer,
            vocab_size = vocab_size,
            hidden_size = d_model,
        )

    def tie_weights(self):
        self.lm_head.weight = self.backbone.embedding.weight

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

    def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, labels = None):
        """
        "position_ids" is just to be compatible with Transformer generation. We don't use it.
        num_last_tokens: if > 0, only return the logits for the last n tokens
        """
        hidden_states = self.backbone(input_ids, inference_params=inference_params)
        if num_last_tokens > 0:
            hidden_states = hidden_states[:, -num_last_tokens:]
        lm_logits = self.lm_head(hidden_states)
        
        loss = None
        if labels is not None:
            logits = lm_logits
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)
            print(loss, shift_logits, shift_logits.dtype, shift_labels, shift_labels.dtype)
            return (loss,)
            
        else:
            CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
            return CausalLMOutput(logits=lm_logits)

    @classmethod
    def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
        config = load_config_hf(pretrained_model_name)
        model = cls(**config, device=device, dtype=dtype, **kwargs)
        model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
        return model

In [7]:
# !wget https://huggingface.co/state-spaces/mamba-130m/raw/main/config.json -O config-130m.json

In [8]:
import json

with open('config-130m.json') as fopen:
    config = json.load(fopen)

In [9]:
model = MambaLMHeadModel(**{**config, 'vocab_size': 32000}, dtype = torch.float32)

In [10]:
model.config.vocab_size

32000

In [11]:
from streaming import LocalDataset
import numpy as np
from streaming.base.format.mds.encodings import Encoding, _encodings

class UInt16(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint16)

_encodings['uint16'] = UInt16

In [12]:
# !git lfs clone https://huggingface.co/datasets/malaysia-ai/mosaic-instructions

In [13]:
class DatasetFixed(torch.utils.data.Dataset):
    def __init__(self, local):
        self.dataset = LocalDataset(local=local)

    def __getitem__(self, idx):
        print(idx)
        data = self.dataset[idx]
        data['labels'] = data['input_ids'].copy()

        data.pop('token_type_ids', None)
        for k in data.keys():
            data[k] = data[k].astype(np.int64)
        return data

    def __len__(self):
        return len(self.dataset)

train_dataset = DatasetFixed(local='mosaic-instructions')

In [14]:
from transformers import TrainingArguments, Trainer, default_data_collator

output_dir = 'test-130m'

training_args = TrainingArguments(
    output_dir,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=1,
    logging_steps=1,
    save_strategy='steps',
    save_steps=100,
    num_train_epochs=3,
    learning_rate=1e-4,
    weight_decay=0,
    warmup_steps=1000,
    bf16=False,
    fp16=False,
    tf32=True,
    gradient_checkpointing=False,
    save_total_limit=5,
    log_level='debug',
    max_steps=10,
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [15]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=default_data_collator,
)

max_steps is given, it will override any value given in num_train_epochs


In [16]:
trainer.__dict__

{'args': TrainingArguments(
 _n_gpu=1,
 adafactor=False,
 adam_beta1=0.9,
 adam_beta2=0.999,
 adam_epsilon=1e-08,
 auto_find_batch_size=False,
 bf16=False,
 bf16_full_eval=False,
 data_seed=None,
 dataloader_drop_last=False,
 dataloader_num_workers=0,
 dataloader_pin_memory=True,
 ddp_backend=None,
 ddp_broadcast_buffers=None,
 ddp_bucket_cap_mb=None,
 ddp_find_unused_parameters=None,
 ddp_timeout=1800,
 debug=[],
 deepspeed=None,
 disable_tqdm=False,
 dispatch_batches=None,
 do_eval=False,
 do_predict=False,
 do_train=False,
 eval_accumulation_steps=None,
 eval_delay=0,
 eval_steps=None,
 evaluation_strategy=no,
 fp16=False,
 fp16_backend=auto,
 fp16_full_eval=False,
 fp16_opt_level=O1,
 fsdp=[],
 fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
 fsdp_min_num_params=0,
 fsdp_transformer_layer_cls_to_wrap=None,
 full_determinism=False,
 gradient_accumulation_steps=1,
 gradient_checkpointing=False,
 gradient_checkpointing_kwargs=None,
 greater_is_better=None

In [17]:
trainer.train()

Currently training with a batch size of: 2
***** Running training *****
  Num examples = 385,224
  Num Epochs = 1
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 2
  Gradient Accumulation steps = 1
  Total optimization steps = 10
  Number of trainable parameters = 115,096,320


327342
345828
318900
69472
tensor(10.5369, device='cuda:0', grad_fn=<NllLossBackward0>) tensor([[-0.9099, -0.1115,  0.3501,  ..., -0.8214,  0.1393, -0.5609],
        [ 0.6366, -0.3612,  0.7958,  ...,  0.2413,  0.3881,  1.1808],
        [ 0.8470,  0.7424,  0.3130,  ..., -0.1876,  0.2822, -0.7302],
        ...,
        [ 0.5185, -0.7865,  0.5211,  ..., -0.7110,  0.3063,  0.0059],
        [ 0.3736, -0.0770,  0.3438,  ..., -0.2821, -0.5326,  0.3334],
        [-0.1783, -1.0785,  0.1424,  ..., -0.2196, -0.0319,  0.2402]],
       device='cuda:0', grad_fn=<ViewBackward0>) torch.float32 tensor([ 6845,  5341,  3474,  ...,    11, 15119,  4318], device='cuda:0') torch.int64


Step,Training Loss
1,10.5369
2,10.5383
3,10.3558
4,10.5414
5,10.5317
6,10.3537
7,10.5176
8,10.5213
9,10.2709
10,10.2941


109610
38573
tensor(10.5383, device='cuda:0', grad_fn=<NllLossBackward0>) tensor([[ 0.3465,  0.4351, -0.8140,  ..., -0.6726,  0.9307, -0.3702],
        [-1.3820, -0.1279, -0.2421,  ...,  0.1949, -0.2219,  0.1388],
        [ 0.7457, -0.2125,  0.8667,  ..., -0.3975,  0.4632,  0.3044],
        ...,
        [ 0.7236, -0.0046,  0.1512,  ..., -0.2837,  0.4481,  0.3537],
        [ 1.1117,  0.0273,  0.6761,  ..., -0.9181,  0.7765, -0.2324],
        [-0.5953,  0.1991,  0.1493,  ..., -0.1652,  0.7980,  0.3479]],
       device='cuda:0', grad_fn=<ViewBackward0>) torch.float32 tensor([7081,  313, 8563,  ..., 2228,   15,  436], device='cuda:0') torch.int64
273694
218013
tensor(10.3558, device='cuda:0', grad_fn=<NllLossBackward0>) tensor([[-1.4105, -0.2334,  0.8561,  ..., -0.6733, -0.1248, -0.5465],
        [ 0.4111, -0.2715, -0.4791,  ...,  0.7469, -0.4307,  0.7578],
        [-0.8519,  0.1634,  0.3254,  ..., -1.1628, -0.6447, -0.6094],
        ...,
        [-0.0055,  0.7690,  0.5926,  ..., -0.7533, 



Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=10, training_loss=10.44616870880127, metrics={'train_runtime': 16.0869, 'train_samples_per_second': 1.243, 'train_steps_per_second': 0.622, 'total_flos': 0.0, 'train_loss': 10.44616870880127, 'epoch': 0.0})