# import

In [1]:
!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia

import os
import time
import random
import warnings
import evaluate
import numpy as np
import pandas as pd
from torchcrf import CRF
import torch.optim as optim
from sklearn import metrics
from torch.utils import data
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
from torch.utils.data import Dataset # from datasets import Dataset
from transformers import Trainer, BertConfig, AutoTokenizer, TrainingArguments, AdamW, get_linear_schedule_with_warmup

from __future__ import annotations
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat, einsum


# import urllib.request
# from tqdm import tqdm
# from zipfile import ZipFile
# from torch.nn import functional as F

# torch.autograd.set_detect_anomaly(True)

warnings.filterwarnings("ignore", category=DeprecationWarning)
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # cuda
# os.environ['CUDA_VISIBLE_DEVICES'] = '1' # cpu



from IPython.display import clear_output
clear_output()

# Var

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
device

In [3]:
# Configuration flags and hyperparameters
USE_MAMBA = True # USE_MAMBA = 1
DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM = False # DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM = 0

In [4]:
batch_size = 4
MAX_LEN = 256 - 2
train_path = 'data/train.txt'
test_path = 'data/test.txt'
valid_path = 'data/msra_eval.txt'
bert_model = 'bert-base-chinese'
tokenizer = AutoTokenizer.from_pretrained(bert_model)
VOCAB_size = len(tokenizer)

# args
d_model = 128
state_size = 128  # Example state size
seq_len = MAX_LEN  # Example sequence length
last_batch_size = batch_size  # only for the very last batch of the dataset
current_batch_size = batch_size
different_batch_size = False
h_new = None
temp_buffer = None

In [5]:
VOCABofTag = ('<PAD>', '[CLS]', '[SEP]', 'O', 'B-BODY', 'I-BODY',
              'B-SYMP', 'I-SYMP', 'B-INST', 'I-INST', 'B-EXAM', 'I-EXAM',
              'B-CHEM', 'I-CHEM','B-DISE', 'I-DISE', 'B-DRUG', 'I-DRUG',
              'B-SUPP', 'I-SUPP', 'B-TREAT', 'I-TREAT', 'B-TIME', 'I-TIME')

tag2idx = {tag: idx for idx, tag in enumerate(VOCABofTag)}
idx2tag = {idx: tag for idx, tag in enumerate(VOCABofTag)}

In [6]:
labels = ['O','B-BODY','I-BODY', 'B-SYMP', 'I-SYMP', 'B-INST', 'I-INST', 'B-EXAM', 'I-EXAM','B-CHEM', 'I-CHEM',
          'B-DISE', 'I-DISE', 'B-DRUG', 'I-DRUG', 'B-SUPP', 'I-SUPP', 'B-TREAT', 'I-TREAT', 'B-TIME', 'I-TIME']

In [None]:
len(VOCABofTag)

# Dataset

In [8]:
class NerDataset(Dataset):
    ''' Generate our dataset '''

    def __init__(self, f_path):
        self.sents = []
        self.tags_li = []

        with open(f_path, 'r', encoding='utf-8') as f:
            lines = [line.split('\n')[0] for line in f.readlines() if len(line.strip())!=0]

        tags =  [line.split('\t')[1] for line in lines]
        words = [line.split('\t')[0] for line in lines]

        word, tag = [], []
        for char, t in zip(words, tags):
            if char != '。':
                word.append(char)
                tag.append(t)
            else:
                if len(word) > MAX_LEN:
                  self.sents.append(['[CLS]'] + word[:MAX_LEN] + ['[SEP]'])
                  self.tags_li.append(['[CLS]'] + tag[:MAX_LEN] + ['[SEP]'])
                else:
                  self.sents.append(['[CLS]'] + word + ['[SEP]'])
                  self.tags_li.append(['[CLS]'] + tag + ['[SEP]'])
                word, tag = [], []

    def __getitem__(self, idx):
        words, tags = self.sents[idx], self.tags_li[idx]
        token_ids = tokenizer.convert_tokens_to_ids(words)
        laebl_ids = [tag2idx[tag] for tag in tags]
        seqlen = len(laebl_ids)
        return token_ids, laebl_ids, seqlen

    def __len__(self):
        return len(self.sents)

def PadBatch(batch):#[Pad]
    maxlen = max([i[2] for i in batch])
    token_tensors = torch.LongTensor([i[0] + [0] * (maxlen - len(i[0])) for i in batch])
    label_tensors = torch.LongTensor([i[1] + [0] * (maxlen - len(i[1])) for i in batch])
    # mask = (token_tensors > 0)
    mask = (token_tensors > 0).to(torch.bool)
    return token_tensors, label_tensors, mask

# Model

In [None]:
"""
    Simple, minimal implementation of Mamba in one file of PyTorch.

    Suggest reading the following before/while reading the code:
        [1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)
            https://arxiv.org/abs/2312.00752
        [2] The Annotated S4 (Sasha Rush and Sidd Karamcheti)
            https://srush.github.io/annotated-s4

    Glossary:
        b: batch size                       (`B` in Mamba paper [1] Algorithm 2)
        l: sequence length                  (`L` in [1] Algorithm 2)
        d or d_model: hidden dim
        n or d_state: latent state dim      (`N` in [1] Algorithm 2)
        expand: expansion factor            (`E` in [1] Section 3.4)
        d_in or d_inner: d * expand         (`D` in [1] Algorithm 2)
        A, B, C, D: state space parameters  (See any state space representation formula)
                                            (B, C are input-dependent (aka selective, a key innovation in Mamba); A, D are not)
        Δ or delta: input-dependent step size
        dt_rank: rank of Δ                  (See [1] Section 3.6 "Parameterization of ∆")
"""

In [10]:
# 使用dataclass裝飾器自動產生初始化方法和類別的字串表示方法
@dataclass
class ModelArgs:
    # @dataclass 會自動為這個類別產生初始化方法和代表類別的字串形式的方法
    d_model: int # 定義模型的隱藏層維度
    n_layer: int # 定義模型的層數
    vocab_size: int # 定義詞彙表的大小
    d_state: int = 16 # 定義狀態空間的維度，預設為16
    expand: int = 2 # 定義擴展因子，預設為2
    dt_rank: Union[int, str] = 'auto' # 定義輸入依賴步長Δ的秩，'auto'表示自動設定
    d_conv: int = 4 # 定義卷積核的維度，預設為4
    pad_vocab_size_multiple: int = 8 # 定義詞彙表大小的最小公倍數，預設為8
    conv_bias: bool = True # 定義卷積層是否使用偏壓項
    bias: bool = False # 定義其他層（如線性層）是否使用偏移項

    def __post_init__(self):
    # 在__init__後自動被調用，用於執行初始化之後的額外設定或驗證
    # 計算內部維度，即擴展後的維度
        self.d_inner = int(self.expand * self.d_model)

        if self.dt_rank == 'auto':# 如果dt_rank未指定，則自動計算設置
        # 根據隱藏層維度自動計算Δ的秩
            self.dt_rank = math.ceil(self.d_model / 16)
        # 確保vocab_size是pad_vocab_size_multiple的倍數
        # 如果不是，請調整為最近的倍數
        if self.vocab_size % self.pad_vocab_size_multiple != 0:
            self.vocab_size += (self.pad_vocab_size_multiple - self.vocab_size % self.pad_vocab_size_multiple)

In [11]:
class MambaBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
        super().__init__()
        # 儲存模型參數
        self.args = args
        # 輸入線性變換層
        self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)

        # 創建了一個所謂的“深度卷積”，其中每個輸入通道被單獨卷積到每個輸出通道。
        # 這意味著每個輸出通道的結果是透過僅與一個輸入通道卷積而得到的。
        self.conv1d = nn.Conv1d(
        in_channels=args.d_inner,
        out_channels=args.d_inner,
        bias=args.conv_bias,
        kernel_size=args.d_conv,
        groups=args.d_inner,
        padding=args.d_conv - 1,
        )

        # x_proj takes in `x` and outputs the input-specific Δ, B, C
        # 將輸入x對應到狀態空間模型的參數Δ、B和C
        self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)

        # dt_proj projects Δ from dt_rank to d_in
        # 將Δ從args.dt_rank維度映射到args.d_inner維度
        self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)

        # 建立一個重複的序列，用於初始化狀態空間模型的矩陣A
        # n->dxn
        A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
        # 將矩陣A的對數值當作可訓練參數來儲存
        self.A_log = nn.Parameter(torch.log(A))
        # 初始化矩陣D為全1的可訓練參數
        self.D = nn.Parameter(torch.ones(args.d_inner))
        # 輸出線性變換層
        self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)

    def forward(self, x):
        """
            MambaBlock的前向傳播函數，與Mamba論文圖3 Section 3.4相同.

            Args:
            x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)

            Returns:
            output: shape (b, l, d)

            Official Implementation:
            class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
        """
        # 取得輸入x的維度
        # batchsize,seq_len,dim
        (b, l, d) = x.shape # 取得輸入x的維度
        # 應用輸入線性變換
        x_and_res = self.in_proj(x) # shape (b, l, 2 * d_in)
        # 將變換後的輸出分為兩部分x和res。
        # 所得的x分為兩個部分，一部分x繼續用於後續變換，產生所需的參數，res用於殘差部分
        (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)
        # 調整x的形狀
        x = rearrange(x, 'b l d_in -> b d_in l')
        # 施加深度卷積，然後截取前l個輸出
        x = self.conv1d(x)[:, :, :l]
        # 再次調整x的形狀
        x = rearrange(x, 'b d_in l -> b l d_in')
        # 應用SiLU激活函數
        x = F.silu(x)
        # 運行狀態空間模型
        y = self.ssm(x)
        # 將res的SiLU活化結果與y相乘
        y = y * F.silu(res)
        # 應用輸出線性變換
        output = self.out_proj(y)
        # 回傳輸出結果
        return output
    
    def ssm(self, x):
        """
            運行狀態空間模型，參考Mamba論文 Section 3.2與註釋[2]:
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

            Args:
            x: shape (b, l, d_in) (See Glossary at top for definitions of b, l, d_in, n...)

            Returns:
            output: shape (b, l, d_in)

            Official Implementation:
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311
        """
        # 取得A_log的維度
        # A在初始化時候經過如下賦值：
        # A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
        # self.A_log = nn.Parameter(torch.log(A))
        # （args.d_inner, args.d_state）
        (d_in, n) = self.A_log.shape # 取得A_log的維度

        # 計算 ∆ A B C D, 這些屬於狀態空間參數.
        # A, D 是 與輸入無關的 (見Mamba論文Section 3.5.2 "Interpretation of A" for why A isn't selective)
        # ∆, B, C 與輸入有關(這是與線性是不變模型S4最大的不同,
        # 也是為什麼Mamba被稱為 “選擇性” 狀態空間的原因)

        # 計算矩陣A
        A = -torch.exp(self.A_log.float()) # shape (d_in, n)
        # 取D的值
        D = self.D.float()

        # 應用x的投影變換
        # ( b,l,d_in) -> (b, l, dt_rank + 2*n)
        x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*n)

        # 分割delta, B, C
        # delta: (b, l, dt_rank). B, C: (b, l, n)
        (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1)
        # 應用dt_proj併計算delta
        delta = F.softplus(self.dt_proj(delta)) # (b, l, d_in)
        # 應用選擇性掃描演算法
        y = self.selective_scan(x, delta, A, B, C, D)
        return y


    def selective_scan(self, u, delta, A, B, C, D):
        """
            執行選擇性掃描演算法，參考Mamba論文[1] Section 2和註釋[2]. See:
            - Section 2 State Space Models in the Mamba paper [1]
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

            經典的離散狀態空間公式:
            x(t + 1) = Ax(t) + Bu(t)
            y(t) = Cx(t) + Du(t)
            除了B和C (以及step size delta用於離散化) 與輸入x(t)相關.

            參數:
            u: shape (b, l, d_in)
            delta: shape (b, l, d_in)
            A: shape (d_in, n)
            B: shape (b, l, n)
            C: shape (b, l, n)
            D: shape (d_in,)

            過程概述：

            Returns:
            output: shape (b, l, d_in)

            Official Implementation:
            selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
            Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.
        """
        # 取得輸入u的維度
        (b, l, d_in) = u.shape
        # 取得矩陣A的列數
        n = A.shape[1] # A: shape (d_in, n)

        # 離散化連續參數(A, B)
        # - A 使用 zero-order hold (ZOH) 離散化 (see Section 2 Equation 4 in the Mamba paper [1])
        # - B is 使用一個簡化的Euler discretization而不是ZOH.根據作者的討論:
        # "A is the more important term and the performance doesn't change much with the simplification on B"

        # 計算離散化的A
        # 將delta和A進行點乘，將A沿著delta的最後一個維度進行廣播，然後執行逐元素乘法
        # A:(d_in, n),delta:(b, l, d_in)
        # A廣播拓展->(b,l,d_in, n)，deltaA對應原論文中的A_bar
        deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
        # delta、B和u,這個計算和原始論文不同
        deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')

        # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
        # Note that the below is sequential, while the official implementation does a much faster parallel scan that
        # is additionally hardware-aware (like FlashAttention).
        # 執行選擇性掃描,初始化狀態x為零
        x = torch.zeros((b, d_in, n), device=deltaA.device)
        # 初始化輸出列表ys
        ys = []
        for i in range(l):
            # 更新狀態x
            # deltaA:((b,l,d_in, n)
            # deltaB_u:( b,l,d_in,n)
            # x:
            x = deltaA[:, i] * x + deltaB_u[:, i]
            # 計算輸出y
            y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
            # 將輸出y加入到列表ys中
            ys.append(y)
        # 將清單ys堆疊成張量y
        y = torch.stack(ys, dim=1) # shape (b, l, d_in)
        # 將輸入u乘以D並加到輸出y上
        y = y + u * D

        return y

In [12]:
class ResidualBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        """Simple block wrapping Mamba block with normalization and residual connection."""
        super().__init__()
        # 保存傳入的ModelArgs對象，包含模型的配置參數
        self.args = args
        # 建立一個MambaBlock，它是這個殘差區塊的核心元件
        self.mixer = MambaBlock(args)
        # 建立一個RMSNorm歸一化模組，用於歸一化操作
        self.norm = RMSNorm(args.d_model)


    def forward(self, x):
        """
            Args:
            x: shape (b, l, d) (See Glossary at top for definitions of b, l, d_in, n...)
            x (Tensor): 輸入張量，形狀為(batch_size, sequence_length, hidden_​​size)
            Returns:
            output: shape (b, l, d)
            輸出張量，形狀與輸入相同
            Official Implementation:
            Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297

            Note: the official repo chains residual blocks that look like
            [加 -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
            where the first Add is a no-op. This is purely for performance reasons as this
            allows them to fuse the Add->Norm.

            We instead implement our blocks as the more familiar, simpler, and numerically equivalent
            [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....
        """
        # 應用歸一化和MambaBlock，然後與輸入x進行殘差連接
        output = self.mixer(self.norm(x)) + x

        return output

In [13]:
class RMSNorm(nn.Module):
    """
        初始化RMSNorm模組，該模組實現了基於均方根的歸一化操作。

        參數:
        d_model (int): 模型的特徵維度。
        eps (float, 可選): 為了避免除以零，加到分母中的一個小的常數。
    """
    def __init__(self, d_model: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps# 保存輸入的eps值，用於數值穩定性。
        # 建立一個可訓練的權重參數，初始值為全1，維度與輸入特徵維度d_model相同。
        self.weight = nn.Parameter(torch.ones(d_model))


    def forward(self, x):
        """
            計算輸入x的均方根值，用於後續的歸一化操作。
            x.pow(2) 計算x中每個元素的平方。
            mean(-1, keepdim=True) 將x的最後一個維度（特徵維度）平方和求平均，保持維度以便進行廣播操作。
            torch.rsqrt 對求得的平均值取倒數和平方根，得到每個特徵的均方根值的逆。
            + self.eps 加入一個小的常數eps以保持數值穩定性，防止除以零的情況發生。
            x * ... * self.weight 將輸入x與計算得到的歸一化因子和可訓練的權重相乘，得到最終的歸一化輸出。
        """
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

        return output

In [14]:
class Mamba(nn.Module):
    def __init__(self, args: ModelArgs, num_class):
        """Full Mamba model."""
        super().__init__()
        # 儲存傳入的ModelArgs對象，包含模型的配置參數
        self.args = args
        # 建立一個嵌入層，將詞彙表中的單字轉換為對應的向量表示
        self.embedding = nn.Embedding(args.vocab_size, args.d_model)
        # 建立一個包含多個殘差塊的模組列表，殘差塊的數量等於模型層數
        self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
        # 建立一個RMSNorm模組，用於歸一化操作
        self.norm_f = RMSNorm(args.d_model)
        # 建立一個線性層，用於最終的輸出，將隱藏層的輸出對應回詞彙表的大小
        self.lm_head = nn.Linear(args.d_model, num_class, bias=False)
        # 將線性層的輸出權重與嵌入層的權重綁定，這是權重共享的一種形式，有助於減少參數數量並可能提高模型的泛化能力
        self.lm_head.weight = self.embedding.weight # Tie output projection to embedding weights. # See "Weight Tying" paper
    

    def forward(self, input_ids):
        """
            Args:
            input_ids (long tensor): shape (b, l) (See Glossary at top for definitions of b, l, d_in, n...)

            Returns:
            logits: shape (b, l, vocab_size)

            Official Implementation:
            class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173
        """
        # 將輸入ID轉換為向量表示
        x = self.embedding(input_ids)
        # 遍歷所有的殘差塊，並應用它們
        for layer in self.layers:
            x = layer(x)
        # 應用歸一化操作
        x = self.norm_f(x)
        # 透過線性層得到最終的logits輸出
        logits = self.lm_head(x)
        # 傳回模型的輸出
        return logits


    @staticmethod
    def from_pretrained(pretrained_model_name: str):
        """
            Load pretrained weights from HuggingFace into model.

            Args:
            pretrained_model_name: One of
            * 'state-spaces/mamba-2.8b-slimpj'
            * 'state-spaces/mamba-2.8b'
            * 'state-spaces/mamba-1.4b'
            * 'state-spaces/mamba-790m'
            * 'state-spaces/mamba-370m'
            * 'state-spaces/mamba-130m'

            Returns:
            model: Mamba model with weights loaded
        """
        from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
        from transformers.utils.hub import cached_file

        def load_config_hf(model_name):
            resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
            return json.load(open(resolved_archive_file))


        def load_state_dict_hf(model_name, device=None, dtype=None):
            resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
            return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)

        config_data = load_config_hf(pretrained_model_name)
        args = ModelArgs(
        d_model=config_data['d_model'],
        n_layer=config_data['n_layer'],
        vocab_size=config_data['vocab_size']
        )
        model = Mamba(args)

        state_dict = load_state_dict_hf(pretrained_model_name)
        new_state_dict = {}
        for key in state_dict:
            new_key = key.replace('backbone.', '')
            new_state_dict[new_key] = state_dict[key]
        model.load_state_dict(new_state_dict)

        return model

In [15]:
class MambaCRF(nn.Module):
    def __init__(self, args: ModelArgs, num_class):
        """Full Mamba + CRF model."""
        super().__init__()
        # 儲存傳入的ModelArgs對象，包含模型的配置參數
        self.args = args
        # 建立一個嵌入層，將詞彙表中的單字轉換為對應的向量表示
        self.embedding = nn.Embedding(args.vocab_size, args.d_model)
        # 建立一個包含多個殘差塊的模組列表，殘差塊的數量等於模型層數
        self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
        # 建立一個RMSNorm模組，用於歸一化操作
        self.norm_f = RMSNorm(args.d_model)
        # 建立一個線性層，用於最終的輸出，將隱藏層的輸出對應回詞彙表的大小
        self.lm_head = nn.Linear(args.d_model, num_class)
        # self.lm_head = nn.Linear(args.d_model, num_class, bias=False)
        # 將線性層的輸出權重與嵌入層的權重綁定，這是權重共享的一種形式，有助於減少參數數量並可能提高模型的泛化能力
        # self.lm_head.weight = self.embedding.weight # Tie output projection to embedding weights. # See "Weight Tying" paper
        # 將CRF加入
        self.crf = CRF(num_class, batch_first=True)  
        # self.crf = CRF(len(tag2idx), batch_first=True) 
    

    def forward(self, input_ids, tags, mask, is_test=False):
        """
            Args:
            input_ids (long tensor): shape (b, l) (See Glossary at top for definitions of b, l, d_in, n...)

            Returns:
            logits: shape (b, l, vocab_size)

            Official Implementation:
            class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173
        """
        # 將輸入ID轉換為向量表示
        x = self.embedding(input_ids)
        # 遍歷所有的殘差塊，並應用它們
        for layer in self.layers:
            x = layer(x)
        # 應用歸一化操作
        x = self.norm_f(x)
        # 透過線性層得到最終的logits輸出
        logits = self.lm_head(x)
        # 傳回模型的輸出

        if not is_test: # Training，return loss
            loss=-self.crf.forward(logits, tags, mask, reduction='mean')
            return loss
        else: # Testing，return decoding
            decode=self.crf.decode(logits, mask)
            return decode


    @staticmethod
    def from_pretrained(pretrained_model_name: str):
        """
            Load pretrained weights from HuggingFace into model.

            Args:
            pretrained_model_name: One of
            * 'state-spaces/mamba-2.8b-slimpj'
            * 'state-spaces/mamba-2.8b'
            * 'state-spaces/mamba-1.4b'
            * 'state-spaces/mamba-790m'
            * 'state-spaces/mamba-370m'
            * 'state-spaces/mamba-130m'

            Returns:
            model: Mamba model with weights loaded
        """
        from transformers.utils import WEIGHTS_NAME, CONFIG_NAME
        from transformers.utils.hub import cached_file

        def load_config_hf(model_name):
            resolved_archive_file = cached_file(model_name, CONFIG_NAME, _raise_exceptions_for_missing_entries=False)
            return json.load(open(resolved_archive_file))


        def load_state_dict_hf(model_name, device=None, dtype=None):
            resolved_archive_file = cached_file(model_name, WEIGHTS_NAME, _raise_exceptions_for_missing_entries=False)
            return torch.load(resolved_archive_file, weights_only=True, map_location='cpu', mmap=True)

        config_data = load_config_hf(pretrained_model_name)
        args = ModelArgs(
        d_model=config_data['d_model'],
        n_layer=config_data['n_layer'],
        vocab_size=config_data['vocab_size']
        )
        model = Mamba(args)

        state_dict = load_state_dict_hf(pretrained_model_name)
        new_state_dict = {}
        for key in state_dict:
            new_key = key.replace('backbone.', '')
            new_state_dict[new_key] = state_dict[key]
        model.load_state_dict(new_state_dict)

        return model

In [None]:
# args = ModelArgs(d_model=320, n_layer=3, vocab_size=30522)
args = ModelArgs(d_model=320, n_layer=3, vocab_size=10000)

# model = Mamba(args, len(labels))
model = MambaCRF(args, len(tag2idx))
model

In [19]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def format_parameters(num_params):
    return "{:,}".format(num_params)

model_params = count_parameters(model)
print(f"mamba_model parameters: {format_parameters(model_params)}")

mamba_model parameters: 5,235,208


# Main.ipynb

In [17]:
def train(e, model, iterator, optimizer, scheduler, device):
    start_time = time.time()  # Record the start time

    model.train().to(device)
    losses = 0.0
    step = 0
    for i, batch in enumerate(iterator):
        step += 1
        x, y, z = batch
        x = x.to(device)
        y = y.to(device)
        z = z.to(device)

        loss = model(x, y, z)
        losses += loss.item()
        """ Gradient Accumulation """
        '''
          full_loss = loss / 2                            # normalize loss
          full_loss.backward()                            # backward and accumulate gradient
          if step % 2 == 0:
              optimizer.step()                            # update optimizer
              scheduler.step()                            # update scheduler
              optimizer.zero_grad()                       # clear gradient
        '''
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

    end_time = time.time()  # Record the end time
    epoch_time = end_time - start_time

    print("Epoch: {}, Loss:{:.4f}, epoch_time:{:.2f} sec".format(e, losses/step, epoch_time))

def validate(e, model, iterator, device):
    start_time = time.time()  # Record the start time

    model.eval()
    Y, Y_hat = [], []
    losses = 0
    step = 0
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            step += 1

            x, y, z = batch
            x = x.to(device)
            y = y.to(device)
            z = z.to(device)

            y_hat = model(x, y, z, is_test=True)

            loss = model(x, y, z)
            losses += loss.item()
            # Save prediction
            for j in y_hat:
              Y_hat.extend(j)
            # Save labels
            mask = (z==1)
            y_orig = torch.masked_select(y, mask)
            Y.append(y_orig.cpu())

    Y = torch.cat(Y, dim=0).numpy()
    Y_hat = np.array(Y_hat)
    acc = (Y_hat == Y).mean()*100

    end_time = time.time()  # Record the end time
    epoch_time = end_time - start_time

    print("Epoch: {}, Val Loss:{:.4f}, Val Acc:{:.3f}, epoch_time:{:.2f} sec".format(e, losses/step, acc, epoch_time))
    return model, losses/step, acc

def test(model, iterator, device):
    model.eval()
    Y, Y_hat = [], []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            x, y, z = batch
            x = x.to(device)
            z = z.to(device)
            y_hat = model(x, y, z, is_test=True)
            # Save prediction
            for j in y_hat:
              Y_hat.extend(j)
            # Save labels
            mask = (z==1).cpu()
            y_orig = torch.masked_select(y, mask)
            Y.append(y_orig)

    Y = torch.cat(Y, dim=0).numpy()
    y_true = [idx2tag[i] for i in Y]
    y_pred = [idx2tag[i] for i in Y_hat]

    return y_true, y_pred

# def main(batch_size=64, lr=0.001, n_epochs=10, trainset="data/train.txt", validset="data/msra_eval.txt", testset="data/test.txt"):
def main(batch_size=batch_size, lr=0.001, n_epochs=10, trainset=train_path, validset=valid_path, testset=test_path):

    best_model = None
    _best_val_loss = 1e18
    _best_val_acc = 1e-18

    args = ModelArgs(d_model=768, n_layer=12, vocab_size=10000)
    model = MambaCRF(args, len(tag2idx))
    
    print('Initial model Done.')
    train_dataset = NerDataset(trainset)
    eval_dataset = NerDataset(validset)
    test_dataset = NerDataset(testset)
    print('Load Data Done.')

    train_iter = data.DataLoader(dataset=train_dataset,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 num_workers=4,
                                 collate_fn=PadBatch)

    eval_iter = data.DataLoader(dataset=eval_dataset,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=PadBatch)

    test_iter = data.DataLoader(dataset=test_dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=4,
                                collate_fn=PadBatch)

    #optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=0.01)
    optimizer = AdamW(model.parameters(), lr=lr, eps=1e-6)

    # Warmup
    len_dataset = len(train_dataset)
    epoch = n_epochs
    batch_size = batch_size
    total_steps = (len_dataset // batch_size) * epoch if len_dataset % batch_size == 0 else (len_dataset // batch_size + 1) * epoch

    warm_up_ratio = 0.1 # Define 10% steps
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = warm_up_ratio * total_steps, num_training_steps = total_steps)

    print('Start Train...,')
    for epoch in range(1, n_epochs+1):

        train(epoch, model, train_iter, optimizer, scheduler, device)
        candidate_model, loss, acc = validate(epoch, model, eval_iter, device)

        save_dir = '/home/yenling/Code In Lunix/mamba/Save Model'
        os.makedirs(save_dir, exist_ok=True)

        if loss < _best_val_loss and acc > _best_val_acc:
          best_model = candidate_model
          _best_val_loss = loss
          _best_val_acc = acc

          save_path = os.path.join(save_dir, f'Mamba + CRF with FinetuneBlock best_model_epoch_{epoch}_loss_{loss:.4f}_acc_{acc:.4f}.pt')
          torch.save(best_model.state_dict(), save_path)
          print(f"Best model saved at epoch {epoch} with val_loss: {loss:.4f} and val_acc: {acc:.4f} to {save_path}")

        print("=============================================")

    y_test, y_pred = test(best_model, test_iter, device)
    print(metrics.classification_report(y_test, y_pred, labels=labels, digits=3))
    print(metrics.confusion_matrix(y_test, y_pred, labels=labels))

In [None]:
main(n_epochs=5)