In [1]:
import pandas as pd
import numpy as np
import torch
from transformers import AutoTokenizer, MambaForCausalLM
from datasets import load_dataset
from evaluate import load

In [2]:
torch.cuda.set_device(2)

In [3]:
torch.cuda.current_device()

2

In [4]:
dataset = load_dataset("Trelis/tiny-shakespeare")

In [5]:
dataset.shape

{'train': (472, 1), 'test': (49, 1)}

In [6]:
dataset["train"]

Dataset({
    features: ['Text'],
    num_rows: 472
})

In [7]:
dataset["test"]["Text"][0]

"TRANIO:\nIs this your speeding? nay, then, good night our part!\n\nPETRUCHIO:\nBe patient, gentlemen; I choose her for myself:\nIf she and I be pleased, what's that to you?\n'Tis bargain'd 'twixt us twain, being alone,\nThat she shall still be curst in company.\nI tell you, 'tis incredible to believe\nHow much she loves me: O, the kindest Kate!\nShe hung about my neck; and kiss on kiss\nShe vied so fast, protesting oath on oath,\nThat in a twink she won me to her love.\nO, you are novices! 'tis a world to see,\nHow tame, when men and women are alone,\nA meacock wretch can make the curstest shrew.\nGive me thy hand, Kate: I will unto Venice,\nTo buy apparel 'gainst the wedding-day.\nProvide the feast, father, and bid the guests;\nI will be sure my Katharina shall be fine.\n\nBAPTISTA:\nI know not what to say: but give me your hands;\nGod send you joy, Petruchio! 'tis a match.\n\nGREMIO:\nAmen, say we: we will be witnesses.\n\nPETRUCHIO:\nFather, and wife, and gentlemen, adieu;\nI will 

In [8]:
len(dataset["test"]["Text"][0])

2859

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"

In [10]:
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-370m-hf").to(device)

The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d


In [11]:
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-370m-hf")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [12]:
encodings = tokenizer(dataset["test"]["Text"], return_tensors= "pt", padding=True)

In [13]:
model.config.get_config_dict("state-spaces/mamba-370m-hf")



({'architectures': ['MambaForCausalLM'],
  'bos_token_id': 0,
  'conv_kernel': 4,
  'd_inner': 160,
  'd_model': 1024,
  'dt_rank': 'auto',
  'eos_token_id': 0,
  'expand': 2,
  'fused_add_norm': True,
  'hidden_act': 'silu',
  'hidden_size': 1024,
  'initializer_range': 0.1,
  'intermediate_size': 2048,
  'layer_norm_epsilon': 1e-05,
  'model_type': 'mamba',
  'n_layer': 48,
  'num_hidden_layers': 48,
  'pad_token_id': 0,
  'pad_vocab_size_multiple': 8,
  'rescale_prenorm_residual': False,
  'residual_in_fp32': True,
  'rms_norm': True,
  'ssm_cfg': {},
  'state_size': 16,
  'time_step_floor': 0.0001,
  'time_step_init_scheme': 'random',
  'time_step_max': 0.1,
  'time_step_min': 0.001,
  'time_step_rank': 64,
  'time_step_scale': 1.0,
  'torch_dtype': 'float32',
  'transformers_version': '4.39.0.dev0',
  'use_bias': False,
  'use_cache': True,
  'use_conv_bias': True,
  'vocab_size': 50280,
  '_commit_hash': 'b519127f5bfaaa1c27dd938dad051ec360972b23'},
 {})

In [14]:
perplexity = load("perplexity", module_type="metric")

In [15]:
with torch.no_grad():
    input_ids = encodings.input_ids.to(device)
    outputs = model(input_ids)

In [None]:
tokenizer.decode(outputs.logits)

In [15]:
with torch.no_grad():
    for i in range((len(dataset["test"]["Text"]) // 5) + 1):
        batch = dataset["test"]["Text"][i*5:(1+i)*5]
        encodings = tokenizer(batch, return_tensors= "pt", padding=True)
        input_ids = encodings.input_ids.to(device)
        outputs = model(input_ids)
        print(tokenizer.decode(outputs.logits))

MambaCausalLMOutput(loss=None, logits=tensor([[[  5.0041,  -5.3653,   4.8237,  ...,  -5.4267,  -5.4574,  -5.4511],
         [  2.5265,  -5.9306,   4.7795,  ...,  -5.8653,  -5.9365,  -6.1129],
         [  8.4780,  -0.7837,   8.8056,  ...,  -0.8313,  -0.8446,  -0.9851],
         ...,
         [ 23.1578,   7.7686,  24.0983,  ...,   7.6455,   7.6588,   7.3878],
         [ 23.1557,   7.7641,  24.0947,  ...,   7.6411,   7.6543,   7.3834],
         [ 23.1535,   7.7597,  24.0911,  ...,   7.6368,   7.6500,   7.3790]],

        [[ -1.5431, -16.3554,   3.8924,  ..., -16.5352, -16.3415, -16.4642],
         [-30.1016, -43.8163, -32.0511,  ..., -43.7683, -43.7739, -43.4251],
         [-29.4698, -42.7764, -24.9282,  ..., -42.6009, -42.3791, -42.2224],
         ...,
         [ 23.8008,   9.0419,  24.9545,  ...,   8.8931,   8.9215,   8.6503],
         [ 23.7399,   8.9403,  24.8939,  ...,   8.7936,   8.8207,   8.5497],
         [ 23.7064,   8.8789,  24.8581,  ...,   8.7337,   8.7599,   8.4893]],

      

In [19]:
outputs.logits

tensor([[[  5.0040,  -5.3652,   4.8236,  ...,  -5.4267,  -5.4573,  -5.4511],
         [  2.5265,  -5.9306,   4.7795,  ...,  -5.8653,  -5.9365,  -6.1129],
         [  8.4780,  -0.7837,   8.8055,  ...,  -0.8313,  -0.8446,  -0.9851],
         ...,
         [ 23.1579,   7.7685,  24.0984,  ...,   7.6454,   7.6587,   7.3877],
         [ 23.1557,   7.7640,  24.0947,  ...,   7.6410,   7.6542,   7.3833],
         [ 23.1536,   7.7597,  24.0911,  ...,   7.6367,   7.6499,   7.3789]],

        [[ -1.5431, -16.3554,   3.8924,  ..., -16.5353, -16.3415, -16.4642],
         [-30.1016, -43.8162, -32.0511,  ..., -43.7683, -43.7739, -43.4250],
         [-29.4698, -42.7763, -24.9283,  ..., -42.6009, -42.3791, -42.2224],
         ...,
         [ 23.8009,   9.0419,  24.9545,  ...,   8.8931,   8.9214,   8.6503],
         [ 23.7398,   8.9403,  24.8939,  ...,   8.7936,   8.8207,   8.5497],
         [ 23.7064,   8.8789,  24.8581,  ...,   8.7336,   8.7599,   8.4893]],

        [[-25.3489, -37.9731, -23.3293,  ...

In [None]:
results = perplexity.compute(predictions=predictions, model_id='gpt2', add_start_token=False)

In [None]:
results