In [1]:
from datasets import load_dataset, load_metric
from torch.utils.data import Dataset
from tqdm import tqdm
import torch
import transformers
from transformers import AutoTokenizer
from tokenizers.processors import BertProcessing
import numpy as np
from transformers import MambaForCausalLM
from transformers import TrainingArguments, Trainer

In [2]:
for _ in tqdm(range(10000000)):
    continue

100%|██████████| 10000000/10000000 [00:00<00:00, 10470559.37it/s]


In [3]:
model_id = "state-spaces/mamba-130m-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)

tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = "left"

# special_tokens_dict = {"cls_token": "<CLS>", "sep_token": "<SEP>"}
# num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
# tokenizer.pad_token = tokenizer.eos_token
# tokenizer._tokenizer.post_processor = BertProcessing(
#       (str("[SEP]"), 50278), (str("[CLS]"), 50277)
# )

max_length = 512 # The maximum length of a feature (question and context)
doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.
pad_on_right = tokenizer.padding_side == "right"

datasets = load_dataset("squad")
print(datasets)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})


In [4]:
def length_filter(example):
    # Calculate the total length of question, context, and answers (assuming only one answer per question)
    total_length = len(tokenizer(f"{example['context']}\n\nQ: {example['question']}\nA: {example['answers']['text'][0]}<|endoftext|>")["input_ids"])
    return total_length <= 512

datasets = datasets.filter(length_filter)
print(datasets)

Filter:   0%|          | 0/10570 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87428
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10515
    })
})


In [5]:
model = MambaForCausalLM.from_pretrained(model_id)
model = model.to("cuda:0")
model.config.keys_to_ignore_at_inference = ["cache_params", "hidden_states"]
print(model)

The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d


MambaForCausalLM(
  (backbone): MambaModel(
    (embeddings): Embedding(50280, 768)
    (layers): ModuleList(
      (0-23): 24 x MambaBlock(
        (norm): MambaRMSNorm()
        (mixer): MambaMixer(
          (conv1d): Conv1d(1536, 1536, kernel_size=(4,), stride=(1,), padding=(3,), groups=1536)
          (act): SiLU()
          (in_proj): Linear(in_features=768, out_features=3072, bias=False)
          (x_proj): Linear(in_features=1536, out_features=80, bias=False)
          (dt_proj): Linear(in_features=48, out_features=1536, bias=True)
          (out_proj): Linear(in_features=1536, out_features=768, bias=False)
        )
      )
    )
    (norm_f): MambaRMSNorm()
  )
  (lm_head): Linear(in_features=768, out_features=50280, bias=False)
)


In [6]:
def prepare_validation_features(example):
    # Some of the questions have lots of whitespace on the left, which is not useful and will make the
    # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
    # left whitespace
    example["question"] = example["question"].lstrip()

    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    # tokenized_example = tokenizer(
    #     f"{example['context']}\n\nQ: {example['question']}\nA:",
    #     max_length=max_length,
    #     truncation=True,
    # )
    
    text = f"{example['context']}\n\nQ: {example['question']}\nA:"
    input_ids = torch.LongTensor([tokenizer.encode(text)]).cuda()

    # example["input_ids"] = tokenized_example["input_ids"]
    # input_ids = torch.LongTensor([tokenized_example["input_ids"]]).cuda()
    out = model.generate(input_ids, max_length=max_length, eos_token_id=tokenizer.eos_token_id)
    decoded = tokenizer.batch_decode(out)[0]
    cleaned = decoded.replace(text, "")
    cleaned = cleaned.replace("<|endoftext|>", "")
    guess = cleaned.split("\n\n")[0].strip()
    example["guess"] = guess
   
    return example

tokenized_validsets = datasets["validation"].map(prepare_validation_features, batched=False)
print(datasets["validation"])
print(tokenized_validsets)

Map:   0%|          | 0/10515 [00:00<?, ? examples/s]

KeyboardInterrupt: 

In [None]:
text = datasets["validation"][3]
text["question"] = text["question"].lstrip()
print(text["question"])

text = f"{text['context']}\n\nQ: {text['question']}\nA:"

input_ids = torch.LongTensor([tokenizer.encode(text)]).cuda()
out = model.generate(input_ids, max_length=max_length, eos_token_id=tokenizer.eos_token_id)
decoded = tokenizer.batch_decode(out)[0]
# print(decoded)
cleaned = decoded.replace(text, "")
cleaned = cleaned.replace("<|endoftext|>", "")
guess = cleaned.split("\n\n")[0].strip()
print(guess)


Which NFL team won Super Bowl 50?
The Denver Broncos won Super Bowl 50.


In [None]:
from evaluate import load
squad_metric = load("squad")
predictions = [{"id": ex["id"], 'prediction_text': ex["guess"]} for ex in tokenized_validsets]
references = [{"id": ex["id"], "answers": ex["answers"]} for ex in tokenized_validsets]
results = squad_metric.compute(predictions=predictions, references=references)
print(results)

{'exact_match': 10.0, 'f1': 33.04924242424242}


In [None]:
text = tokenized_validsets[3]
print(text)

{'id': '56be4db0acb8001400a502ef', 'title': 'Super_Bowl_50', 'context': 'Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\'s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.', 'question': 'Which NFL team won Super Bowl 50?', 'answers': {'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'], 'answer_start': [177, 1