## **1. Install and import bibraries**


In [1]:
!pip install datasets evaluate accelerate
!pip install causal-conv1d>=1.1.0
!pip install mamba-ssm


!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia

from IPython.display import clear_output
clear_output()

In [2]:
import os
import random
import json
import torch
import torch.nn as nn
from collections import namedtuple
from dataclasses import dataclass, field, asdict
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf

import evaluate
import numpy as np
from datasets import load_dataset
from transformers import Trainer
from transformers import AutoTokenizer, TrainingArguments

Login into huggingface_hub to push trained model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## **2. dataset**


In [4]:
import pandas as pd
from datasets import Dataset

In [None]:
train_path = '/content/data/train.txt'

df_traindata = pd.read_csv(train_path, delimiter='\t', names=["text", "label"], header=None)

dataset = Dataset.from_pandas(df_traindata)
dataset = dataset.train_test_split(train_size=0.8, seed=42)
dataset["validation"] = dataset.pop("test")

dataset

In [7]:
# eval_path = '/content/data/test.txt'
# df_evaldata = pd.read_csv(eval_path, delimiter='\t', names=["text", "label"], header=None)
# datasettest = Dataset.from_pandas(df_evaldata)
# datasettest

In [None]:
# dataset["train"][0]

In [8]:
id2label = {
    0: 'O',
    1: 'B-BODY',
    2: 'I-BODY',
    3: 'B-SYMP',
    4: 'I-SYMP',
    5: 'B-INST',
    6: 'I-INST',
    7: 'B-EXAM',
    8: 'I-EXAM',
    9: 'B-CHEM',
    10: 'I-CHEM',
    11: 'B-DISE',
    12: 'I-DISE',
    13: 'B-DRUG',
    14: 'I-DRUG',
    15: 'B-SUPP',
    16: 'I-SUPP',
    17: 'B-TREAT',
    18: 'I-TREAT',
    19: 'B-TIME',
    20: 'I-TIME'
}

## **3. Build Custom Mamba Model for Text Classification**


In [13]:
# Mamba 的 config 類引用了這個詞: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/config_mamba.py
@dataclass
class MambaConfig:
    d_model: int = 640 # 2560
    d_intermediate: int = 0
    n_layer: int = 8 # 64
    vocab_size: int = 50277 # 50277
    ssm_cfg: dict = field(default_factory=dict)
    rms_norm: bool = True
    residual_in_fp32: bool = True
    fused_add_norm: bool = True
    # pad_vocab_size_multiple: int = 8
    pad_vocab_size_multiple: int = 16
    tie_embeddings = True
    attn_layer_idx: list = field(default_factory=list)
    attn_cfg: dict = field(default_factory=dict)

    def to_json_string(self):
        return json.dumps(asdict(self))

    def to_dict(self):
        return asdict(self)

In [10]:
# 用於分類的頭部類別的定義
class MambaClassificationHead(nn.Module):
    def __init__(self, d_model, num_classes, **kwargs):
        super(MambaClassificationHead, self).__init__()
        # 使用線性圖層根據輸入執行分類，該輸入的大小d_model且num_classes需要排序。
        self.classification_head = nn.Linear(d_model, num_classes, **kwargs)

    def forward(self, hidden_states):
        return self.classification_head(hidden_states)

In [11]:
class MambaTextClassification(MambaLMHeadModel):
    def __init__(
        self,
        config: MambaConfig,
        initializer_cfg=None,
        device=None,
        dtype=None,
    ) -> None:
        super().__init__(config, initializer_cfg, device, dtype)

        # 使用 MambaClassificationHead 創建一個分類器，輸入大小為 d_model，類號為 len(id2label)。
        self.classification_head = MambaClassificationHead(d_model=config.d_model, num_classes=len(id2label))

        del self.lm_head

    def forward(self, input_ids, attention_mask=None, labels=None):
        # 通過原生模型發送input_ids以接收hidden_states。
        hidden_states = self.backbone(input_ids)

        # 取二維hidden_states的平均值，創建具有代表性的 [CLS] 特徵
        mean_hidden_states = hidden_states.mean(dim=1)

        # 將mean_hidden_states通過分類器的頂部來接收logits。
        logits = self.classification_head(mean_hidden_states)

        if labels is None:
          ClassificationOutput = namedtuple("ClassificationOutput", ["logits"])
          return ClassificationOutput(logits=logits)
        else:
          ClassificationOutput = namedtuple("ClassificationOutput", ["loss", "logits"])

          # 使用 CrossEntropyLoss 損失函數計算損失。
          loss_fct = nn.CrossEntropyLoss()
          loss = loss_fct(logits, labels)

          return ClassificationOutput(loss=loss, logits=logits)

    def predict(self, text, tokenizer, id2label=None):
        input_ids = torch.tensor(tokenizer(text)['input_ids'], device=device)[None] # device = 'cuda'
        with torch.no_grad():
          logits = self.forward(input_ids).logits[0]
          label = np.argmax(logits.cpu().numpy())

        if id2label is not None:
          return id2label[label]
        else:
          return label

    @classmethod
    def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
        # 從之前訓練的模型載入配置。
        config_data = load_config_hf(pretrained_model_name)
        config = MambaConfig(**config_data)

        # 從配置中初始化模型，並將其傳輸到所需的設備和數據類型。
        model = cls(config, device=device, dtype=dtype, **kwargs)

        # 載入以前訓練的模型狀態。
        model_state_dict = load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
        model.load_state_dict(model_state_dict, strict=False)

        # 列印出新初始化的嵌入參數。
        print("Newly initialized embedding:", set(model.state_dict().keys()) - set(model_state_dict.keys()))
        return model

In [None]:
# # 從先前訓練的模型載入 Mamba 模型。
model = MambaTextClassification.from_pretrained("state-spaces/mamba-130m")
model.to(device)

# 從 gpt-neox-20b 模型載入 Mamba 模型的分詞器。
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
# tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")

# 從 gpt-neox-20b 模型載入 Mamba 模型的分詞器。
tokenizer.pad_token_id = tokenizer.eos_token_id

In [15]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def format_parameters(num_params):
    return "{:,}".format(num_params)

In [None]:
model

In [None]:
Mamba_params = count_parameters(model)
print(f"Mamba parameters: {format_parameters(Mamba_params)}")