In [None]:
! pip install einops datasets

In [None]:
from __future__ import annotations
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat, einsum
from typing import Union
from torch.utils.data import Dataset, DataLoader
from transformers import get_scheduler

In [None]:
@dataclass
class ModelArgs:
    d_model: int
    n_layer: int
    vocab_size: int
    d_state: int = 16
    expand: int = 2
    dt_rank: Union[int, str] = 'auto'
    d_conv: int = 4
    pad_vocab_size_multiple: int = 8
    conv_bias: bool = True
    bias: bool = False

    def __post_init__(self):
        self.d_inner = int(self.expand * self.d_model)

        if self.dt_rank == 'auto':
            self.dt_rank = math.ceil(self.d_model / 16)

        if self.vocab_size % self.pad_vocab_size_multiple != 0:
            self.vocab_size += (self.pad_vocab_size_multiple
                                - self.vocab_size % self.pad_vocab_size_multiple)

In [None]:
class Mamba(nn.Module):
    def __init__(self, args: ModelArgs):
        """Full Mamba model."""
        super().__init__()
        self.args = args

        self.embedding = nn.Embedding(args.vocab_size, args.d_model)
        self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
        self.norm_f = RMSNorm(args.d_model)

        self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
        self.lm_head.weight = self.embedding.weight  # Tie output projection to embedding weights.
                                                     # See "Weight Tying" paper
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, RMSNorm):
            nn.init.ones_(module.weight)
        elif isinstance(module, nn.MultiheadAttention):
            nn.init.xavier_uniform_(module.in_proj_weight)
            if module.in_proj_bias is not None:
                nn.init.zeros_(module.in_proj_bias)
            nn.init.xavier_uniform_(module.out_proj.weight)
            if module.out_proj.bias is not None:
                nn.init.zeros_(module.out_proj.bias)


    def forward(self, input_ids):
        x = self.embedding(input_ids)

        for layer in self.layers:
            x = layer(x)

        x = self.norm_f(x)
        logits = self.lm_head(x)

        return logits


    @staticmethod
    def from_pretrained(pretrained_model_name: str):
        from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
        from transformers.utils.hub import cached_file

        def load_config_hf(model_name):
            resolved_archive_file = cached_file(model_name, CONFIG_NAME,
                                                _raise_exceptions_for_missing_entries=False)
            return json.load(open(resolved_archive_file))


        def load_state_dict_hf(model_name, device=None, dtype=None):
            resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
                                                _raise_exceptions_for_missing_entries=False)
            return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)

        config_data = load_config_hf(pretrained_model_name)
        args = ModelArgs(
            d_model=config_data['d_model'],
            n_layer=config_data['n_layer'],
            vocab_size=config_data['vocab_size']
        )
        model = Mamba(args)

        state_dict = load_state_dict_hf(pretrained_model_name)
        new_state_dict = {}
        for key in state_dict:
            new_key = key.replace('backbone.', '')
            new_state_dict[new_key] = state_dict[key]
        model.load_state_dict(new_state_dict)

        return model

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.mixer = MambaBlock(args)
        self.norm = RMSNorm(args.d_model)


    def forward(self, x):
        output = self.mixer(self.norm(x)) + x

        return output

In [None]:
class MambaBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args

        self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)

        self.conv1d = nn.Conv1d(
            in_channels=args.d_inner,
            out_channels=args.d_inner,
            bias=args.conv_bias,
            kernel_size=args.d_conv,
            groups=args.d_inner,
            padding=args.d_conv - 1,
        )

        # x_proj takes in `x` and outputs the input-specific Δ, B, C
        self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)

        # dt_proj projects Δ from dt_rank to d_in
        self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)

        A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(args.d_inner))
        self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)


    def forward(self, x):

        (b, l, d) = x.shape

        x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)
        (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)

        x = rearrange(x, 'b l d_in -> b d_in l')
        x = self.conv1d(x)[:, :, :l]
        x = rearrange(x, 'b d_in l -> b l d_in')

        x = F.silu(x)

        y = self.ssm(x)

        y = y * F.silu(res)

        output = self.out_proj(y)

        return output


    def ssm(self, x):

        (d_in, n) = self.A_log.shape


        A = -torch.exp(self.A_log.float())
        D = self.D.float()

        x_dbl = self.x_proj(x)

        (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1)
        delta = F.softplus(self.dt_proj(delta))

        y = self.selective_scan(x, delta, A, B, C, D)

        return y


    def selective_scan(self, u, delta, A, B, C, D):

        (b, l, d_in) = u.shape
        n = A.shape[1]


        deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
        deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')


        x = torch.zeros((b, d_in, n), device=deltaA.device)
        ys = []
        for i in range(l):
            x = deltaA[:, i] * x + deltaB_u[:, i]
            y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
            ys.append(y)
        y = torch.stack(ys, dim=1)  # shape (b, l, d_in)

        y = y + u * D

        return y

In [None]:
class RMSNorm(nn.Module):
    def __init__(self,
                 d_model: int,
                 eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))


    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

        return output

## From Random weights



In [None]:
"""
mamba-370m :

        d_model: 1024
        n_layer: 48
        vocab_size: 50280
        d_state: 4096
        expand: 4
        d_conv: 4



mamba-130m :

        d_model: 768
        n_layer: 24
        vocab_size: 50280
        d_state: 3072
        expand: 4
        d_conv: 4
"""

args = ModelArgs(
    d_model=768,            # Hidden dimension size
    n_layer=24,             # Number of layers
    vocab_size=50280,       # Vocabulary size
    d_state=3072,           # Latent state dimension
    expand=4,             # Expansion factor
    dt_rank='auto',       # Rank of delta
    d_conv=4,             # Convolution kernel size
    pad_vocab_size_multiple=8,
    conv_bias=True,
    bias=False
)
model = Mamba(args)

In [None]:
model

Mamba(
  (embedding): Embedding(50280, 768)
  (layers): ModuleList(
    (0-23): 24 x ResidualBlock(
      (mixer): MambaBlock(
        (in_proj): Linear(in_features=768, out_features=6144, bias=False)
        (conv1d): Conv1d(3072, 3072, kernel_size=(4,), stride=(1,), padding=(3,), groups=3072)
        (x_proj): Linear(in_features=3072, out_features=6192, bias=False)
        (dt_proj): Linear(in_features=48, out_features=3072, bias=True)
        (out_proj): Linear(in_features=3072, out_features=768, bias=False)
      )
      (norm): RMSNorm()
    )
  )
  (norm_f): RMSNorm()
  (lm_head): Linear(in_features=768, out_features=50280, bias=False)
)

## FROM PRE-TRAIN MODEL

In [None]:
pretrained_model_name = 'state-spaces/mamba-130m'
model = Mamba.from_pretrained(pretrained_model_name)

In [None]:
model

Mamba(
  (embedding): Embedding(50280, 768)
  (layers): ModuleList(
    (0-23): 24 x ResidualBlock(
      (mixer): MambaBlock(
        (in_proj): Linear(in_features=768, out_features=3072, bias=False)
        (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
        (x_proj): Linear(in_features=1536, out_features=80, bias=False)
        (dt_proj): Linear(in_features=48, out_features=1536, bias=True)
        (out_proj): Linear(in_features=1536, out_features=768, bias=False)
      )
      (norm): RMSNorm()
    )
  )
  (norm_f): RMSNorm()
  (lm_head): Linear(in_features=768, out_features=50280, bias=False)
)

## Tokenizer

In [None]:
from transformers import AutoTokenizer
from datasets import load_dataset

# Function to get training data
def get_training_corpus():
    dataset = load_dataset("text", data_files={"train": "cleaned_data.txt"})
    for i in range(0, len(dataset["train"]), 1000):
        yield dataset["train"][i : i + 1000]["text"]

# Initialize the base tokenizer
base_tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")

# Train the new tokenizer
new_tokenizer = base_tokenizer.train_new_from_iterator(get_training_corpus(), vocab_size=1000)

# Save the new tokenizer
new_tokenizer.save_pretrained("darija_tokenizer")

new_tokenizer.pad_token = new_tokenizer.eos_token

fim_prefix_token = "<fim_prefix>"
fim_middle_token = "<fim_middle_token>"
fim_suffix_token = "<fim_suffix_token>"
fim_pad_token = "<fim_pad>"

# Get the FIM-specific tokens and get their token ids
new_tokenizer.add_tokens(
    [
        fim_prefix_token,
        fim_middle_token,
        fim_middle_token,
        fim_pad_token,
    ]
)
prefix_tok_id = new_tokenizer.convert_tokens_to_ids(fim_prefix_token)
middle_tok_id = new_tokenizer.convert_tokens_to_ids(fim_middle_token)
suffix_tok_id = new_tokenizer.convert_tokens_to_ids(fim_middle_token)
pad_tok_id = None

fim_tokens = [prefix_tok_id, middle_tok_id, suffix_tok_id]


# If truncate_or_pad is on, also get pad token id
truncate_or_pad = True
if truncate_or_pad:
    pad_tok_id = new_tokenizer.convert_tokens_to_ids(fim_pad_token)
    fim_tokens.append(pad_tok_id)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Generating train split: 0 examples [00:00, ? examples/s]

## TextDataset

In [None]:
# Custom Dataset Class
class TextDataset(Dataset):
    def __init__(self, file_path, tokenizer, context_len=384):
        self.tokenizer = tokenizer
        self.context_len = context_len

        # Load and tokenize data
        with open(file_path, 'r', encoding='utf-8') as f:
            self.data = f.read()

        self.tokens = tokenizer(self.data, return_tensors='pt', truncation=True, padding='max_length', max_length=context_len)

    def __len__(self):
        return len(self.tokens['input_ids'])

    def __getitem__(self, idx):
        return {
            'input_ids': self.tokens['input_ids'][idx],
            'labels': self.tokens['input_ids'][idx]
        }

In [None]:
# Training script
class Args:
    dataset_path = "/content/cleaned_data.txt"  # Path to your raw text file
    eval_path = "/content/validation.txt"
    lr = 1e-4
    epochs = 1
    context_len = 384
    train_batch_size = 8
    valid_batch_size = 8
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Load dataset
train_dataset = TextDataset(Args.dataset_path, new_tokenizer, context_len=Args.context_len)
train_dataloader = DataLoader(train_dataset, batch_size=Args.train_batch_size, shuffle=True)

eval_dataset = TextDataset(Args.eval_path, new_tokenizer, context_len=Args.context_len)
eval_dataloader = DataLoader(eval_dataset, batch_size=Args.valid_batch_size, shuffle=False)

# Optimizer and Scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=Args.lr)
scheduler = get_scheduler(
    "cosine", optimizer=optimizer, num_warmup_steps=100, num_training_steps=len(train_dataloader) * Args.epochs
)

## Training Loop

In [None]:
model.to(Args.device)
for epoch in range(Args.epochs):
    model.train()
    total_loss = 0

    for batch in train_dataloader:
        batch = {k: v.to(Args.device) for k, v in batch.items()}

        outputs = model(batch['input_ids'])

        # Compute the loss manually
        shift_logits = outputs[..., :-1, :].contiguous()
        shift_labels = batch['labels'][..., 1:].contiguous()
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{Args.epochs}, Loss: {total_loss / len(train_dataloader)}")
    # Evaluation
    model.eval()
    total_eval_loss = 0

    with torch.no_grad():
        for batch in eval_dataloader:
            batch = {k: v.to(Args.device) for k, v in batch.items()}

            outputs = model(batch['input_ids'])

            # Compute the loss manually
            shift_logits = outputs[..., :-1, :].contiguous()
            shift_labels = batch['labels'][..., 1:].contiguous()
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            total_eval_loss += loss.item()

    avg_eval_loss = total_eval_loss / len(eval_dataloader)
    print(f"Epoch {epoch+1}/{Args.epochs}, Evaluation Loss: {avg_eval_loss}")
    model_save_path = "MAMBA_MODEL.pt"
    torch.save(model.state_dict(), model_save_path)
    print("Training complete!")


## Inference

In [None]:
max_length = 30
num_return_sequences = 10


tokens = new_tokenizer.encode("You are  ")
tokens = torch.tensor(tokens , dtype=torch.long)
tokens = tokens.unsqueeze(0).repeat(num_return_sequences, 1)

x = tokens.to(Args.device)

while x.size(1) < max_length:

    with torch.no_grad():
        outputs = model(x)
        logits = outputs[0] if isinstance(outputs, tuple) else outputs
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)
        topk_probs, topk_indices = torch.topk(probs, 50, dim=-1)

        ix = torch.multinomial(topk_probs, 1)
        xcol = torch.gather(topk_indices, -1, ix)
        x = torch.cat((x, xcol), dim=1)

# print the generated text
for i in range(num_return_sequences):
    tokens = x[i, :max_length].tolist()
    decoded = new_tokenizer.decode(tokens, skip_special_tokens=True)
    print(">", decoded)