该文件作为参考案例，使用Pytorch框架，复现Mamba
https://cloud.tencent.com/developer/article/2377967·

# 模型搭建

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import DataLoader, Dataset
from einops import rearrange

import math
import tqdm
import os

from urllib import request
from zipfile import ZipFile
from transformers import AutoTokenizer

torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x23538522630>

In [3]:
# 设置标志和超参数
USE_MAMBA = 1
DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM = 1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
# 定义超参数和初始化
d_model = 8
state_size = 128 
seq_len = 100
batch_size = 256
last_batch_size = 81 # only for the very last batch of the dataset
current_batch_size = batch_size
different_batch_size = False
h_new = None
temp_buffer = None


### S6模块
S6模块是Mamba架构中的一个复杂组件，负责通过一系列**线性变换和离散化过程**处理输入序列。它在捕获序列的时间动态方面起着关键作用，这是序列建模任务(如语言建模)的一个关键方面。这里包括张量运算和自定义离散化方法来处理序列数据的复杂需求。
S4模型定义四个参数$(\Delta, A, B, C)$以及两个序列到序列的阶段：
$$
h'(t) = Ah(t)+Bx(t) (1a) \quad h_t = \hat{A} h_{t-1} + \hat{B}x_t (2a) \quad \hat{K} = (C\hat{B}, C\hat{AB},...,C\hat{A^k B},...) (3a)\\
y(t) = Ch(t) (1b) \quad y_t = Ch_t (2b) \quad y = x * \hat{K} (3b)
$$

In [21]:
class S6(nn.Module):
    def __init__(self, seq_len, d_model, state_size, device):
        super(S6,self).__init__()

        self.fc1 = nn.Linear(d_model, d_model, device=device)
        self.fc2 = nn.Linear(d_model, state_size, device=device)
        self.fc3 = nn.Linear(d_model, state_size, device=device)

        self.seq_len = seq_len
        self.d_model = d_model
        self.state_size = state_size

        self.A = nn.Parameter(F.normalize(torch.ones(d_model, state_size, device=device), p=2, dim=-1)) #这里为什么要进行归一化呢？
        nn.init.xavier_uniform_(self.A)

        self.B = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)
        self.C = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)

        self.delta = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)
        self.dA = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
        self.dB = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)

        # h [B:batch_size, L:seq_len, D:d_model, S:state_size]
        self.h = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
        self.y = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)

    def discretization(self):
        self.dB = torch.einsum("bld,bln->bldn", self.delta, self.B)
        self.dA = torch.exp(torch.einsum("bld,dn->bldn", self.delta, self.A))
        return self.dA, self.dB

    def forward(self, x):
        # Algorithm 2 in MAMBA paper

        self.B = self.fc2(x)
        self.C = self.fc3(x)
        self.delta = F.softplus(self.fc1(x))
        self.discretization()

        if DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM:
            global current_batch_size
            current_batch_size = x.shape[0]

            if self.h.shape[0] != current_batch_size:
                different_batch_size = True
                h_new = torch.einsum('bldn,bldn->bldn', self.dA, self.h[:current_batch_size, ...]) + rearrange(x,'b l d -> b l d 1') * self.dB
            else:
                different_batch_size = False
                h_new = torch.einsum('bldn,bldn->bldn', self.dA, self.h) + rearrange(x,'b l d -> b l d 1') * self.dB

            # y [B, L, D]
            self.y = torch.einsum('bln,bldn->bld', self.C, h_new)

            global temp_buffer
            temp_buffer = h_new.detach().clone() if not self.h.requires_grad else h_new.clone()

            return self.y
        
        else:
            # h [B, L, D, S]
            h = torch.zeros(x.size(0), self.seq_len, self.d_model, self.state_size, device=device)
            y = torch.zeros_like(x)

            h = torch.einsum('bldn,bldn->bldn', self.dA, h) + rearrange(x, 'b l d -> b l d 1') * self.dB

            # y [B, L, D]
            y = torch.einsum('bln,bldn->bld', self.C, h)  

            return y

s6 = S6(seq_len=seq_len, d_model=d_model, state_size=state_size, device='cpu')

### RMSNorm
RMSNorm 是一个自定义规范化层，这一层用于规范神经网络的激活，这可以帮助稳定和加快训练。

In [22]:
class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5, device: str = 'cpu'):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model, device=device))

    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) * self.weight
        return output
    

In [23]:
class MambaBlock(nn.Module):
    def __init__(self, seq_len, d_model, state_size, device):
        super(MambaBlock, self).__init__()
        self.in_proj = nn.Linear(d_model, 2*d_model, device=device)
        self.out_proj = nn.Linear(2*d_model, d_model, device=device)

        # For residual skip connection
        self.D = nn.Linear(d_model, 2*d_model, device=device)

        # Set _no_weight_decay attribute on bias
        self.out_proj.bias._no_weight_decay = True

        # Initialize bias to a small constant value
        nn.init.constant_(self.out_proj.bias, 1.0)

        self.S6 = S6(seq_len, 2*d_model, state_size, device)

        # Add 1D convolution with kernel size 3
        self.conv = nn.Conv1d(seq_len, seq_len, kernel_size=3, padding=1, device=device)

        # Add linear layer for conv output
        self.conv_linear = nn.Linear(2*d_model, 2*d_model, device=device)

        # rmsnorm
        self.norm = RMSNorm(d_model, device=device)

    def forward(self, x):
        '''
        x_proj.shape = torch.Size([B, L, 2*D])
        x_conv.shape = torch.Size([B, L, 2*D])
        x_conv_act.shape = torch.Size([B, L, 2*D])
        '''
        # Refer to Fig.3 in MAMBA paper

        x = self.norm(x)
        x_proj = self.in_proj(x)

        # Add 1 D convolution with kernel size 3
        x_conv = self.conv(x_proj)
        
        x_conv_act = F.silu(x_conv)

        # Add linear layer for conv output
        x_conv_out = self.conv_linear(x_conv_act)

        x_ssm = self.S6(x_conv_out)
        x_act = F.silu(x_ssm) # Switch activation can be implemented as x*sigmoid(x)

        # residual skip connection with nonlinearity introduced by multiplication
        x_residual = F.silu(self.D(x))

        x_combined = x_act * x_residual

        x_out = self.out_proj(x_combined)

        return x_out

In [24]:
class Mamba(nn.Module):
    def __init__(self, seq_len, d_model, state_size, device):
        super(Mamba, self).__init__()
        self.mamba_block1 = MambaBlock(seq_len, d_model, state_size, device)
        self.mamba_block2 = MambaBlock(seq_len, d_model, state_size, device)
        self.mamba_block3 = MambaBlock(seq_len, d_model, state_size, device)

    def forward(self, x):
        x = self.mamba_block1(x)
        x = self.mamba_block2(x)
        x = self.mamba_block3(x)
        return x

In [25]:
# 用法
x = torch.rand(batch_size, seq_len, d_model, device=device)

# Create the Mamba model
mamba = Mamba(seq_len, d_model, state_size, device)

# rmsnorm
norm = RMSNorm(d_model)
x = norm(x)

# Forward pass
test_output = mamba(x)
print(f"test_output.shape = {test_output.shape}") # should be [B, L, D]

test_output.shape = torch.Size([256, 100, 8])


### 数据准备和训练

In [26]:
class Enwiki8Dataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data['input_ids'])
    
    def __getitem__(self, idx):
        item = {key: val[idx].clone().detach() for key, val in self.data.items()}
        return item

pad_sequences_3d用于将一批序列填充到统一的长度，确保批中的每个序列具有相同数量的元素(或时间步长)。这在许多机器学习任务中尤其重要，因为输入数据必须具有一致的形状。

In [27]:
# Define a function for padding
def pad_sequences_3d(sequences, max_len=None, pad_value=0):
    # Assuming sequences is a tensor of shape (batch_size, seq_len, feature_size)

    batch_size, seq_len, feature_size = sequences.size()
    if max_len is None:
        max_len = seq_len + 1

    # Initialize padded_sequences with pad_value
    padded_sequences = torch.full((batch_size, max_len, feature_size), fill_value=pad_value, dtype=sequences.dtype, device=sequences.device)
    # Pad each sequence to the max_len
    padded_sequences[:,:seq_len,:] = sequences

    return padded_sequences

In [28]:
def train(model, tokenizer, data_loader, optimizer, criterion, device, max_grad_norm=1.0, DEBUGGING_IS_ON=False):
    model.train()
    total_loss = 0
    for batch in data_loader:
        optimizer.zero_grad()

        input_data = batch['input_ids'].clone().to(device)
        attention_mask = batch['attention_mask'].clone().to(device)

        target = input_data[:,1:]
        input_data =  input_data[:,:-1]

        # Pad all the sequences in the batch:
        input_data = pad_sequences_3d(input_data, pad_value=tokenizer.pad_token_id)
        target = pad_sequences_3d(target, max_len=input_data.size(1), pad_value=tokenizer.pad_token_id)

        if USE_MAMBA:
            output = model(input_data)
            loss = criteerion(output, target)

        loss.backward(retain_graph=True)

        for name, param in model.named_parameters():
            if 'out_proj.bias' not in name:
                # clip weights but not bias for out_proj
                torch.nn.utils.clip_grad_norm_(param, max_grad_norm)

        if DEBUGGING_IS_ON:
            for name, parameter in model.named_parameters():
                if parameter.grad is not None:
                    print(f"{name} gradient: {parameter.grad.data.norm(2)}")
                else:
                    print(f"{name} has no gradient")

        if USE_MAMBA and DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM:
            model.S6.h[:current_batch_size, ...].copy_(temp_buffer)
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / len(data_loader)

In [29]:
# 评估函数
def evaluate(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in data_loader:
            input_data = batch['input_ids'].clone().detach().to(device)
            attention_mask = batch['attention_mask'].clone().detach().to(device)

            target = input_data[:, 1:]
            input_data = input_data[:,:-1]

            # Pad all the sequences in the batch:
            input_data = pad_sequences_3d(input_data, pad_value=tokenizer.pad_token_id)
            target = pad_sequences_3d(target, max_len=input_data.size(1), pad_value=tokenizer.pad_token_id)

            if USE_MAMBA:
                output = model(input_data)
                loss = criterion(output, target)

            total_loss += loss.item()
    
    return total_loss / len(data_loader)

In [30]:
# 最后，calculate_perplexity用于评估语言模型(如Mamba)的性能。
def calculate_perplexity(loss):
    return math.exp(loss)

In [31]:
# load_enwiki8_dataset函数用于下载和提取enwiki8数据集，该数据集通常用于对语言模型进行基准测试。
def load_enwiki8_dataset():
    print(f"Download and extract enwiki8 data")
    url = "http://mattmahoney.net/dc/enwik8.zip"
    request.urlretrieve(url, 'enwik8.zip')

    with ZipFile('enwik8.zip') as f:
        data = f.read('enwik8').decode('utf-8')

    return data

In [33]:
# encode_dataset函数设计用于标记和编码数据集，为神经网络模型(如Mamba)处理数据集做准备。

# Tokenizer and encode the dataset
def encode_dataset(tokenizer, text_data):
    def batch_encode(tokenizer, text_data, batch_size=1000):
        # Tokenizer in batches
        batched_input_ids = []
        for i in range(0, len(text_data), batch_size):
            batch = text_data[i:i+batch_size]
            inputs = tokenizer(batch, add_special_tokens=True, truncation=True,
                               padding='max_length', max_length=seq_len, return_tensors='pt')
            batched_input_ids.append(inputs['input_ids'])
        
        return torch.cat(batched_input_ids)
    
    # Assuming enwiki8_data is a list of sentences
    input_ids = batch_encode(tokenizer, enwiki8_data)

    # vocab_size is the number of unique tokens in the tokenizer's vocabulary
    global vocab_size
    vocab_size = len(tokenizer.vocab) # Note that for some tokenizers, we might access the vocab directly
    print(f'vocab_size = {vocab_size}')

    # Create an embedding layer
    # embedding_dim is the size of the embedding vectors (MAMBA model's D)
    embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)

    # Pass `input_ids` through the embedding layer
    # This will change `input_ids` from shape [B, L] to [B, L, D]
    def batch_embedding_calls(intput_ids, embedding_layer, batch_size=256):
        # Check if input_ids is already a tensor, if not convert it
        if not isinstance(intput_ids, torch.Tensor):
            input_ids = torch.tensor(input_ids, dtype=torch.long)

        # Calculate the number of batches needed
        num_batches = math.ceil(input_ids.size(0) / batch_size)

        # List to hold the output embeddings
        output_embeddings = []

        # Process each batch
        for i in range(num_batches):
            # Calculate start and end indices for the current batch
            start_idx = i * batch_size
            end_idx = start_idx + batch_size

            # Get the batch
            input_id_batch = input_ids[start_idx:end_idx]

            # Call the embedding layer
            with torch.no_grad(): # No need gradients for this operation
                batch_embeddings = embedding_layer(input_id_batch)

            # Append the batch embeddings to the list 
            output_embeddings.append(batch_embeddings)

        # Concatenate the embeddings from each batch into a single tensor
        all_embeddings = torch.cat(output_embeddings, dim=0)

        return all_embeddings
    
    # `input_ids` is a list or tensor of the input IDs and `embedding_layer` is model's embedding layer
    if USE_MAMBA:
        # Set `batch_size` to a value that works for memory constraints
        encoded_inputs = batch_embedding_calls(input_ids, embedding_layer, batch_size=1).float()

    attention_mask = (input_ids != tokenizer.pad_token_id).type(input_ids.dtype)

    return encoded_inputs, attention_mask


### 训练代码

In [34]:
# Load a pretrained tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# Assuming encoded_inputs is a preprocessed tensor of shape [num_samples, seq_len, d_model]
encoded_inputs_file = 'encoded_inputs_mamba.pt'

if os.path.exists(encoded_inputs_file):
    print("Loading pre-tokenized data...")
    encoded_inputs = torch.load(encoded_inputs_file)

else:
    print("Tokenizing raw data...")
    enwiki8_data = load_enwiki8_dataset()
    encoded_inputs, attention_mask = encode_dataset(tokenizer, enwiki8_data)
    torch.save(encoded_inputs, encoded_inputs_file)
    print(f'finished tokenizing data')

# Combine into a single dictionary
data = {
    'input_ids': encoded_inputs,
   'attention_mask': attention_mask
}

# Split the data into train and validation sets
total_size = len(data['input_ids'])
train_size = int(total_size * 0.8)

train_data = {key: val[:train_size] for key, val in data.items()}
val_data = {key: val[train_size:] for key, val in data.items()}

train_dataset = Enwiki8Dataset(train_data)
val_dataset = Enwiki8Dataset(val_data)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Initialize the model
model = Mamba(seq_len, d_model, state_size, device).to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-6)

# Training loop
num_epochs = 25 # Number of epochs to train for

for epoch in tqdm(range(num_epochs)): # loop over the dataset multiple times
    train_loss = train(model, tokenizer, train_loader, optimizer, criterion, device, max_grad_norm=10, DEBUGGING_IS_ON=False)
    val_loss = evaluate(model, val_loader, criterion, device)
    val_perplexity = calculate_perplexity(val_loss)
    print(f'Epoch: {epoch+1}, Training Loss: {train_loss:.4f}, Validation Loss:{val_loss:.4f}, Validation Perplexity:{val_perplexity:.4f}')



Downloading tokenizer_config.json: 100%|██████████| 48.0/48.0 [00:00<00:00, 15.5kB/s]
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Downloading config.json: 100%|██████████| 570/570 [00:00<00:00, 406kB/s]
Downloading vocab.txt: 100%|██████████| 232k/232k [00:00<00:00, 1.01MB/s]
Downloading tokenizer.json: 100%|██████████| 466k/466k [00:00<00:00, 2.05MB/s]


Tokenizing raw data...
Download and extract enwiki8 data
vocab_size = 30522


UnboundLocalError: local variable 'input_ids' referenced before assignment