# import

In [1]:
!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia

import os
import json
import time
import torch
import random
import warnings
import evaluate
import numpy as np
import pandas as pd
import torch.nn as nn
from torchcrf import CRF
import torch.optim as optim
from sklearn import metrics
from torch.utils import data
from datasets import load_dataset
from collections import namedtuple
from dataclasses import dataclass, field, asdict
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
from torch.utils.data import Dataset # from datasets import Dataset
from transformers import Trainer, BertConfig, AutoTokenizer, TrainingArguments, AdamW, get_linear_schedule_with_warmup


warnings.filterwarnings("ignore", category=DeprecationWarning)
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # cuda
# os.environ['CUDA_VISIBLE_DEVICES'] = '1' # cpu



from IPython.display import clear_output
clear_output()

# Var

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [3]:
batch_size = 4
MAX_LEN = 256 - 2
train_path = 'data/train.txt'
bert_model = 'bert-base-chinese'
tokenizer = AutoTokenizer.from_pretrained(bert_model)

In [4]:
VOCAB = ('<PAD>', '[CLS]', '[SEP]', 'O', 'B-BODY', 'I-BODY',
         'B-SYMP', 'I-SYMP', 'B-INST', 'I-INST', 'B-EXAM', 'I-EXAM',
         'B-CHEM', 'I-CHEM','B-DISE', 'I-DISE', 'B-DRUG', 'I-DRUG',
         'B-SUPP', 'I-SUPP', 'B-TREAT', 'I-TREAT', 'B-TIME', 'I-TIME')

tag2idx = {tag: idx for idx, tag in enumerate(VOCAB)}
idx2tag = {idx: tag for idx, tag in enumerate(VOCAB)}

In [5]:
labels = ['O','B-BODY','I-BODY', 'B-SYMP', 'I-SYMP', 'B-INST', 'I-INST', 'B-EXAM', 'I-EXAM','B-CHEM', 'I-CHEM',
          'B-DISE', 'I-DISE', 'B-DRUG', 'I-DRUG', 'B-SUPP', 'I-SUPP', 'B-TREAT', 'I-TREAT', 'B-TIME', 'I-TIME']

In [None]:
len(labels)

# Dataset

In [7]:
class NerDataset(Dataset):
    ''' Generate our dataset '''

    def __init__(self, f_path):
        self.sents = []
        self.tags_li = []

        with open(f_path, 'r', encoding='utf-8') as f:
            lines = [line.split('\n')[0] for line in f.readlines() if len(line.strip())!=0]

        tags =  [line.split('\t')[1] for line in lines]
        words = [line.split('\t')[0] for line in lines]

        word, tag = [], []
        for char, t in zip(words, tags):
            if char != '。':
                word.append(char)
                tag.append(t)
            else:
                if len(word) > MAX_LEN:
                  self.sents.append(['[CLS]'] + word[:MAX_LEN] + ['[SEP]'])
                  self.tags_li.append(['[CLS]'] + tag[:MAX_LEN] + ['[SEP]'])
                else:
                  self.sents.append(['[CLS]'] + word + ['[SEP]'])
                  self.tags_li.append(['[CLS]'] + tag + ['[SEP]'])
                word, tag = [], []

    def __getitem__(self, idx):
        words, tags = self.sents[idx], self.tags_li[idx]
        token_ids = tokenizer.convert_tokens_to_ids(words)
        laebl_ids = [tag2idx[tag] for tag in tags]
        seqlen = len(laebl_ids)
        return token_ids, laebl_ids, seqlen

    def __len__(self):
        return len(self.sents)

def PadBatch(batch):#[Pad]
    maxlen = max([i[2] for i in batch])
    token_tensors = torch.LongTensor([i[0] + [0] * (maxlen - len(i[0])) for i in batch])
    label_tensors = torch.LongTensor([i[1] + [0] * (maxlen - len(i[1])) for i in batch])
    # mask = (token_tensors > 0)
    mask = (token_tensors > 0).to(torch.bool)
    return token_tensors, label_tensors, mask

# Model

In [8]:
# Mamba 的 config 類引用了這個詞: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/config_mamba.py
@dataclass
class MambaConfig:
    d_model: int = 2560
    n_layer: int = 64
    vocab_size: int = 50277 # 50277
    ssm_cfg: dict = field(default_factory=dict)
    rms_norm: bool = True
    residual_in_fp32: bool = True
    fused_add_norm: bool = True
    # pad_vocab_size_multiple: int = 8
    pad_vocab_size_multiple: int = 16
    tie_embeddings = True

    def to_json_string(self):
        return json.dumps(asdict(self))

    def to_dict(self):
        return asdict(self)

In [9]:
# 用於分類的頭部類別的定義
class MambaClassificationHead(nn.Module):
    def __init__(self, d_model, num_classes, **kwargs):
        super(MambaClassificationHead, self).__init__()
        # 使用線性圖層根據輸入執行分類，該輸入的大小d_model且num_classes需要排序。
        self.classification_head = nn.Linear(d_model, num_classes, **kwargs)

    def forward(self, hidden_states):
        return self.classification_head(hidden_states)

In [10]:
class MambaTextClassification(MambaLMHeadModel):
    def __init__(
        self,
        config: MambaConfig,
        initializer_cfg=None,
        device=None,
        dtype=None,
    ) -> None:
        super().__init__(config, initializer_cfg, device, dtype)

        # 使用 MambaClassificationHead 創建一個分類器，輸入大小為 d_model，類號為 len(labels)。
        # self.classification_head = nn.Linear(config.d_model, len(labels))
        # self.crf = CRF(len(labels), batch_first=True)

        self.classification_head = MambaClassificationHead(d_model=config.d_model, num_classes=len(tag2idx))
        self.crf = CRF(len(tag2idx), batch_first=True)  
                            
        del self.lm_head

    def forward(self, input_ids, tags, mask, is_test=False):
        # 通過原生模型發送input_ids以接收hidden_states。
        emissions = self.backbone(input_ids)

        # # 取二維emissions的平均值，創建具有代表性的 [CLS] 特徵
        # mean_emissions = emissions.mean(dim=1)

        # 將mean_emissions通過分類器的頂部來接收logits_emissions。
        logits_emissions = self.classification_head(emissions)

        if not is_test: # Training，return loss
            loss=-self.crf.forward(logits_emissions, tags, mask, reduction='mean')
            return loss
        else: # Testing，return decoding
            decode=self.crf.decode(logits_emissions, mask)
            return decode

    def predict(self, text, tokenizer, id2label=None):
        input_ids = torch.tensor(tokenizer(text)['input_ids'], device=device)[None] # device = 'cuda'
        with torch.no_grad():
          logits = self.forward(input_ids).logits[0]
          label = np.argmax(logits.cpu().numpy())

        if id2label is not None:
          return id2label[label]
        else:
          return label

    @classmethod
    def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
        # 從之前訓練的模型載入配置。
        config_data = load_config_hf(pretrained_model_name)
        config = MambaConfig(**config_data)

        # 從配置中初始化模型，並將其傳輸到所需的設備和數據類型。
        model = cls(config, device=device, dtype=dtype, **kwargs)

        # 載入以前訓練的模型狀態。
        model_state_dict = load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)
        model.load_state_dict(model_state_dict, strict=False)

        # 列印出新初始化的嵌入參數。
        print("Newly initialized embedding:", set(model.state_dict().keys()) - set(model_state_dict.keys()))
        return model

# Main.ipynb

In [11]:
def train(e, model, iterator, optimizer, scheduler, device):
    start_time = time.time()  # Record the start time

    model.train().to(device)
    losses = 0.0
    step = 0
    for i, batch in enumerate(iterator):
        step += 1
        x, y, z = batch
        x = x.to(device)
        y = y.to(device)
        z = z.to(device)

        loss = model(x, y, z)
        losses += loss.item()
        """ Gradient Accumulation """
        '''
          full_loss = loss / 2                            # normalize loss
          full_loss.backward()                            # backward and accumulate gradient
          if step % 2 == 0:
              optimizer.step()                            # update optimizer
              scheduler.step()                            # update scheduler
              optimizer.zero_grad()                       # clear gradient
        '''
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

    end_time = time.time()  # Record the end time
    epoch_time = end_time - start_time

    print("Epoch: {}, Loss:{:.4f}, epoch_time:{:.2f} sec".format(e, losses/step, epoch_time))

def validate(e, model, iterator, device):
    start_time = time.time()  # Record the start time

    model.eval()
    Y, Y_hat = [], []
    losses = 0
    step = 0
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            step += 1

            x, y, z = batch
            x = x.to(device)
            y = y.to(device)
            z = z.to(device)

            y_hat = model(x, y, z, is_test=True)

            loss = model(x, y, z)
            losses += loss.item()
            # Save prediction
            for j in y_hat:
              Y_hat.extend(j)
            # Save labels
            mask = (z==1)
            y_orig = torch.masked_select(y, mask)
            Y.append(y_orig.cpu())

    Y = torch.cat(Y, dim=0).numpy()
    Y_hat = np.array(Y_hat)
    acc = (Y_hat == Y).mean()*100

    end_time = time.time()  # Record the end time
    epoch_time = end_time - start_time

    print("Epoch: {}, Val Loss:{:.4f}, Val Acc:{:.3f}, epoch_time:{:.2f} sec".format(e, losses/step, acc, epoch_time))
    return model, losses/step, acc

def test(model, iterator, device):
    model.eval()
    Y, Y_hat = [], []
    with torch.no_grad():
        for i, batch in enumerate(iterator):
            x, y, z = batch
            x = x.to(device)
            z = z.to(device)
            y_hat = model(x, y, z, is_test=True)
            # Save prediction
            for j in y_hat:
              Y_hat.extend(j)
            # Save labels
            mask = (z==1).cpu()
            y_orig = torch.masked_select(y, mask)
            Y.append(y_orig)

    Y = torch.cat(Y, dim=0).numpy()
    y_true = [idx2tag[i] for i in Y]
    y_pred = [idx2tag[i] for i in Y_hat]

    return y_true, y_pred

def main(batch_size=64, lr=0.001, n_epochs=10, trainset="data/train.txt", validset="data/msra_eval.txt", testset="data/test.txt"):

    best_model = None
    _best_val_loss = 1e18
    _best_val_acc = 1e-18

    model = MambaTextClassification.from_pretrained("state-spaces/mamba-130m").to(device)
    
    print('Initial model Done.')
    train_dataset = NerDataset(trainset)
    eval_dataset = NerDataset(validset)
    test_dataset = NerDataset(testset)
    print('Load Data Done.')

    train_iter = data.DataLoader(dataset=train_dataset,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 num_workers=4,
                                 collate_fn=PadBatch)

    eval_iter = data.DataLoader(dataset=eval_dataset,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=4,
                                 collate_fn=PadBatch)

    test_iter = data.DataLoader(dataset=test_dataset,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=4,
                                collate_fn=PadBatch)

    #optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=0.01)
    optimizer = AdamW(model.parameters(), lr=lr, eps=1e-6)

    # Warmup
    len_dataset = len(train_dataset)
    epoch = n_epochs
    batch_size = batch_size
    total_steps = (len_dataset // batch_size) * epoch if len_dataset % batch_size == 0 else (len_dataset // batch_size + 1) * epoch

    warm_up_ratio = 0.1 # Define 10% steps
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = warm_up_ratio * total_steps, num_training_steps = total_steps)

    print('Start Train...,')
    for epoch in range(1, n_epochs+1):

        train(epoch, model, train_iter, optimizer, scheduler, device)
        candidate_model, loss, acc = validate(epoch, model, eval_iter, device)

        if loss < _best_val_loss and acc > _best_val_acc:
          best_model = candidate_model
          _best_val_loss = loss
          _best_val_acc = acc

        print("=============================================")

    y_test, y_pred = test(best_model, test_iter, device)
    print(metrics.classification_report(y_test, y_pred, labels=labels, digits=3))

In [None]:
main(n_epochs=5)