In [1]:
!nvidia-smi

/bin/bash: nvidia-smi: command not found


In [2]:
!pip install -qU einops dataclasses typing datasets

In [3]:
from __future__ import annotations
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat, einsum
from typing import Union
from torch.utils.data import Dataset, DataLoader
from transformers import get_scheduler, AutoTokenizer, MambaForCausalLM
from datasets import load_dataset
import pprint as pp

In [4]:
def createPrompts(sample):
    """
    Reads our dataset, creates a column of prompts corresponding
    to each row, and returns a list of all the prompts.
    """

    formatString = """### Context:
{}

### Input:
{}

### Completion:
{}"""

    context = sample['context']
    input = sample['input']
    completion = sample['completion']

    prompt = formatString.format(context, input, completion)

    return {
        "prompt": prompt
    }

In [5]:
def loadStringsFromDataset(dataset):
    dataset = dataset.map(createPrompts)

    listOfStrings = []
    for sample in dataset:
        listOfStrings.append(sample['prompt'])

    return listOfStrings

In [6]:
# Custom Dataset Class
class TextDataset(Dataset):
    def __init__(self, dataset, tokenizer, context_len=384):
        self.tokenizer = tokenizer
        self.context_len = context_len

        # Load and tokenize data
        self.data = loadStringsFromDataset(dataset)

        self.tokens = tokenizer(self.data, return_tensors='pt', truncation=True, padding='max_length', max_length=context_len)

    def __len__(self):
        return len(self.tokens['input_ids'])

    def __getitem__(self, idx):
        return {
            'input_ids': self.tokens['input_ids'][idx],
            'labels': self.tokens['input_ids'][idx]
        }

In [7]:
@dataclass
class ModelArgs:
    d_model: int
    n_layer: int
    vocab_size: int
    d_state: int = 16
    expand: int = 2
    dt_rank: Union[int, str] = 'auto'
    d_conv: int = 4
    pad_vocab_size_multiple: int = 8
    conv_bias: bool = True
    bias: bool = False

    def __post_init__(self):
        self.d_inner = int(self.expand * self.d_model)

        if self.dt_rank == 'auto':
            self.dt_rank = math.ceil(self.d_model / 16)

        if self.vocab_size % self.pad_vocab_size_multiple != 0:
            self.vocab_size += (self.pad_vocab_size_multiple - self.vocab_size % self.pad_vocab_size_multiple)

In [8]:
class RMSNorm(nn.Module):
    def __init__(self,
                 d_model: int,
                 eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))


    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

        return output

In [9]:
class ResidualBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        """Simple block wrapping Mamba block with normalization and residual connection."""
        super().__init__()
        self.args = args
        self.mixer = MambaBlock(args)
        self.norm = RMSNorm(args.d_model)


    def forward(self, x):
        output = self.mixer(self.norm(x)) + x

        return output

In [10]:
class MambaBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args

        self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)

        self.conv1d = nn.Conv1d(
            in_channels=args.d_inner,
            out_channels=args.d_inner,
            bias=args.conv_bias,
            kernel_size=args.d_conv,
            groups=args.d_inner,
            padding=args.d_conv - 1,
        )

        # x_proj takes in `x` and outputs the input-specific Δ, B, C
        self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)

        # dt_proj projects Δ from dt_rank to d_in
        self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)

        A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(args.d_inner))
        self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)


    def forward(self, x):

        (b, l, d) = x.shape

        x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)
        (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)

        x = rearrange(x, 'b l d_in -> b d_in l')
        x = self.conv1d(x)[:, :, :l]
        x = rearrange(x, 'b d_in l -> b l d_in')

        x = F.silu(x)

        y = self.ssm(x)

        y = y * F.silu(res)

        output = self.out_proj(y)

        return output


    def ssm(self, x):

        (d_in, n) = self.A_log.shape

        # Compute ∆ A B C D, the state space parameters.
        #     A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
        #     ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
        #                                  and is why Mamba is called **selective** state spaces)

        A = -torch.exp(self.A_log.float())  # shape (d_in, n)
        D = self.D.float()

        x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)

        (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1)  # delta: (b, l, dt_rank). B, C: (b, l, n)
        delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)

        y = self.selective_scan(x, delta, A, B, C, D)  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]

        return y


    def selective_scan(self, u, delta, A, B, C, D):

        (b, l, d_in) = u.shape
        n = A.shape[1]

        # Discretize continuous parameters (A, B)
        # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
        # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
        #   "A is the more important term and the performance doesn't change much with the simplification on B"
        deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
        deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')

        # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
        # Note that the below is sequential, while the official implementation does a much faster parallel scan that
        # is additionally hardware-aware (like FlashAttention).
        x = torch.zeros((b, d_in, n), device=deltaA.device)
        ys = []
        for i in range(l):
            x = deltaA[:, i] * x + deltaB_u[:, i]
            y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
            ys.append(y)
        y = torch.stack(ys, dim=1)  # shape (b, l, d_in)

        y = y + u * D

        return y

In [12]:
class Mamba(nn.Module):
    def __init__(self, args: ModelArgs):
        """Full Mamba model."""
        super().__init__()
        self.args = args

        self.embedding = nn.Embedding(args.vocab_size, args.d_model)
        self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
        self.norm_f = RMSNorm(args.d_model)

        self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
        self.lm_head.weight = self.embedding.weight  # Tie output projection to embedding weights.
                                                     # See "Weight Tying" paper
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, RMSNorm):
            nn.init.ones_(module.weight)
        elif isinstance(module, nn.MultiheadAttention):
            nn.init.xavier_uniform_(module.in_proj_weight)
            if module.in_proj_bias is not None:
                nn.init.zeros_(module.in_proj_bias)
            nn.init.xavier_uniform_(module.out_proj.weight)
            if module.out_proj.bias is not None:
                nn.init.zeros_(module.out_proj.bias)


    def forward(self, input_ids):

        x = self.embedding(input_ids)

        for layer in self.layers:
            x = layer(x)

        x = self.norm_f(x)
        logits = self.lm_head(x)

        return logits


    @staticmethod
    def from_pretrained(pretrained_model_name: str):
        """Load pretrained weights from HuggingFace into model.

                * 'state-spaces/mamba-2.8b-slimpj'
                * 'state-spaces/mamba-2.8b'
                * 'state-spaces/mamba-1.4b'
                * 'state-spaces/mamba-790m'
                * 'state-spaces/mamba-370m'
                * 'state-spaces/mamba-130m'

        """
        from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
        from transformers.utils.hub import cached_file

        def load_config_hf(model_name):
            resolved_archive_file = cached_file(model_name, CONFIG_NAME,
                                                _raise_exceptions_for_missing_entries=False)
            return json.load(open(resolved_archive_file))


        def load_state_dict_hf(model_name, device=None, dtype=None):
            resolved_archive_file = cached_file(model_name, WEIGHTS_NAME,
                                                _raise_exceptions_for_missing_entries=False)
            return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)

        config_data = load_config_hf(pretrained_model_name)
        args = ModelArgs(
            d_model=config_data['d_model'],
            n_layer=config_data['n_layer'],
            vocab_size=config_data['vocab_size']
        )
        model = Mamba(args)

        state_dict = load_state_dict_hf(pretrained_model_name)
        new_state_dict = {}
        for key in state_dict:
            new_key = key.replace('backbone.', '')
            new_state_dict[new_key] = state_dict[key]
        model.load_state_dict(new_state_dict)

        return model

In [13]:
args = ModelArgs(
    d_model=768,            # Hidden dimension size
    n_layer=24,             # Number of layers
    vocab_size=50280,       # Vocabulary size
    d_state=3072,           # Latent state dimension
    expand=4,             # Expansion factor
    dt_rank='auto',       # Rank of delta
    d_conv=4,             # Convolution kernel size
    pad_vocab_size_multiple=8,
    conv_bias=True,
    bias=False
)

model = Mamba(args)

In [14]:
model

Mamba(
  (embedding): Embedding(50280, 768)
  (layers): ModuleList(
    (0-23): 24 x ResidualBlock(
      (mixer): MambaBlock(
        (in_proj): Linear(in_features=768, out_features=6144, bias=False)
        (conv1d): Conv1d(3072, 3072, kernel_size=(4,), stride=(1,), padding=(3,), groups=3072)
        (x_proj): Linear(in_features=3072, out_features=6192, bias=False)
        (dt_proj): Linear(in_features=48, out_features=3072, bias=True)
        (out_proj): Linear(in_features=3072, out_features=768, bias=False)
      )
      (norm): RMSNorm()
    )
  )
  (norm_f): RMSNorm()
  (lm_head): Linear(in_features=768, out_features=50280, bias=False)
)

In [15]:
model = Mamba.from_pretrained('state-spaces/mamba-130m')
tokenizer = AutoTokenizer.from_pretrained('state-spaces/mamba-130m-hf')

config.json:   0%|          | 0.00/199 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/517M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


tokenizer_config.json:   0%|          | 0.00/4.79k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [16]:
model

Mamba(
  (embedding): Embedding(50280, 768)
  (layers): ModuleList(
    (0-23): 24 x ResidualBlock(
      (mixer): MambaBlock(
        (in_proj): Linear(in_features=768, out_features=3072, bias=False)
        (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
        (x_proj): Linear(in_features=1536, out_features=80, bias=False)
        (dt_proj): Linear(in_features=48, out_features=1536, bias=True)
        (out_proj): Linear(in_features=1536, out_features=768, bias=False)
      )
      (norm): RMSNorm()
    )
  )
  (norm_f): RMSNorm()
  (lm_head): Linear(in_features=768, out_features=50280, bias=False)
)

In [17]:
dataset = load_dataset("neuralwork/fashion-style-instruct", split='train')
dataset = dataset.train_test_split(test_size=0.2)

Downloading readme:   0%|          | 0.00/882 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.64M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3193 [00:00<?, ? examples/s]

In [28]:
class Args:
    # you can change it to match your setup
    trainDataset = dataset['train']
    testDataset = dataset['test']
    lr = 1e-4
    epochs = 1
    context_len = 384
    train_batch_size = 2
    valid_batch_size = 2
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device = torch.device("cpu")

In [29]:
print(f'{Args.trainDataset}\n\n{Args.testDataset}')

Dataset({
    features: ['input', 'completion', 'context'],
    num_rows: 2554
})

Dataset({
    features: ['input', 'completion', 'context'],
    num_rows: 639
})


In [30]:
type(Args.trainDataset)

datasets.arrow_dataset.Dataset

In [31]:
# Load dataset
train_dataset = TextDataset(Args.trainDataset, tokenizer, context_len=Args.context_len)
eval_dataset = TextDataset(Args.testDataset, tokenizer, context_len=Args.context_len)

train_dataloader = DataLoader(train_dataset, batch_size=Args.train_batch_size, shuffle=False)
eval_dataloader = DataLoader(eval_dataset, batch_size=Args.valid_batch_size, shuffle=False)

# Optimizer and Scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=Args.lr)
scheduler = get_scheduler(
    "cosine", optimizer=optimizer, num_warmup_steps=1, num_training_steps=len(train_dataloader) * Args.epochs
)

In [32]:
scheduler 

<torch.optim.lr_scheduler.LambdaLR at 0x7fd25b2dd3f0>

In [None]:
model.to(Args.device)
for epoch in range(Args.epochs):
    model.train()
    total_loss = 0

    for batch in train_dataloader:
        batch = {k: v.to(Args.device) for k, v in batch.items()}

        outputs = model(batch['input_ids'])

        # Compute the loss manually
        shift_logits = outputs[..., :-1, :].contiguous()
        shift_labels = batch['labels'][..., 1:].contiguous()
        loss_fct = torch.nn.CrossEntropyLoss()
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{Args.epochs}, Loss: {total_loss / len(train_dataloader)}")
    # Evaluation
    model.eval()
    total_eval_loss = 0

    with torch.no_grad():
        for batch in eval_dataloader:
            batch = {k: v.to(Args.device) for k, v in batch.items()}

            outputs = model(batch['input_ids'])

            # Compute the loss manually
            shift_logits = outputs[..., :-1, :].contiguous()
            shift_labels = batch['labels'][..., 1:].contiguous()
            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            total_eval_loss += loss.item()

    avg_eval_loss = total_eval_loss / len(eval_dataloader)
    print(f"Epoch {epoch+1}/{Args.epochs}, Evaluation Loss: {avg_eval_loss}")
    model_save_path = "mamba_darija.pt"
    torch.save(model.state_dict(), model_save_path)
    print("Training complete!")