In [None]:
!pip install causal-conv1d>=1.1.0
!pip install mamba-ssm

In [None]:
!pip install --upgrade pip setuptools wheel
!pip install accelerate transformers wandb
!pip install apache-beam
!pip install numpy>=1.17 --ignore-installed
!pip install git+https://github.com/huggingface/datasets#egg=datasets

In [None]:
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForCausalLM
import torch
import os

from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.models.config_mamba import MambaConfig

from datasets import load_dataset

from dataclasses import asdict
import json

In [None]:
class MambaConfigForTrainer:
    def __init__(self, **kwargs):
        self.config = MambaConfig(**kwargs)

    def to_dict(self):
        return asdict(self.config)

    def to_json_string(self):
        return json.dumps(self.to_dict(), indent=4)

    def __getattr__(self, item):
        try:
            return getattr(self.config, item)
        except AttributeError:
            raise AttributeError(f"'MambaConfigForTrainer' object has no attribute '{item}'")

In [None]:
class MambaTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.pop("input_ids")
        lm_logits = model(input_ids)[0]

        labels = input_ids.to(lm_logits.device)
        shift_logits = lm_logits[:, :-1, :].contiguous()
        labels = labels[:, 1:].contiguous()

        loss_fct = torch.nn.CrossEntropyLoss()
        lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))

        return lm_loss

In [None]:
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
mamba_config = MambaConfigForTrainer(
    d_model = 256,
    n_layer = 8,
    vocab_size = len(tokenizer),
)

model = MambaLMHeadModel(
    config = mamba_config,
    device = "cuda",
)

In [None]:
wiki_dataset = load_dataset("JeanKaddour/minipile")
wiki_dataset

In [None]:
from transformers import DataCollatorForLanguageModeling

def tokenize_function(examples):
    return tokenizer(examples["text"], return_special_tokens_mask=True)

tokenized_datasets = wiki_dataset.map(tokenize_function, batched=True)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [None]:
#!pip install wandb

args = TrainingArguments(
    output_dir="./checkpoints",
    report_to="wandb",
    save_strategy="epoch",
    save_total_limit=10,
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    warmup_steps=500,
    weight_decay=0.01,
    logging_steps=10,
)

trainer = MambaTrainer(
    args=args,
    model=model,
    tokenizer=tokenizer,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
)

In [None]:
trainer.train()

In [None]:
trainer.save_model("./mamba-1")