In [1]:
import pandas as pd
import numpy as np
import torch
from transformers import AutoTokenizer, MambaForCausalLM
from datasets import load_dataset
from evaluate import load

In [2]:
torch.cuda.set_device(2)

In [3]:
torch.cuda.current_device()

2

In [4]:
dataset = load_dataset("Trelis/tiny-shakespeare")

In [5]:
dataset.shape

{'train': (472, 1), 'test': (49, 1)}

In [6]:
dataset["train"]

Dataset({
    features: ['Text'],
    num_rows: 472
})

In [7]:
dataset["test"]["Text"][0]

"TRANIO:\nIs this your speeding? nay, then, good night our part!\n\nPETRUCHIO:\nBe patient, gentlemen; I choose her for myself:\nIf she and I be pleased, what's that to you?\n'Tis bargain'd 'twixt us twain, being alone,\nThat she shall still be curst in company.\nI tell you, 'tis incredible to believe\nHow much she loves me: O, the kindest Kate!\nShe hung about my neck; and kiss on kiss\nShe vied so fast, protesting oath on oath,\nThat in a twink she won me to her love.\nO, you are novices! 'tis a world to see,\nHow tame, when men and women are alone,\nA meacock wretch can make the curstest shrew.\nGive me thy hand, Kate: I will unto Venice,\nTo buy apparel 'gainst the wedding-day.\nProvide the feast, father, and bid the guests;\nI will be sure my Katharina shall be fine.\n\nBAPTISTA:\nI know not what to say: but give me your hands;\nGod send you joy, Petruchio! 'tis a match.\n\nGREMIO:\nAmen, say we: we will be witnesses.\n\nPETRUCHIO:\nFather, and wife, and gentlemen, adieu;\nI will 

In [8]:
len(dataset["test"]["Text"][0])

2859

In [9]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"

In [10]:
model = MambaForCausalLM.from_pretrained("state-spaces/mamba-370m-hf").to(device)

The fast path is not available because on of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d


In [11]:
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-370m-hf")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [12]:
encodings = tokenizer(dataset["test"]["Text"], return_tensors= "pt", padding=True)

In [13]:
model.config.get_config_dict("state-spaces/mamba-370m-hf")



({'architectures': ['MambaForCausalLM'],
  'bos_token_id': 0,
  'conv_kernel': 4,
  'd_inner': 160,
  'd_model': 1024,
  'dt_rank': 'auto',
  'eos_token_id': 0,
  'expand': 2,
  'fused_add_norm': True,
  'hidden_act': 'silu',
  'hidden_size': 1024,
  'initializer_range': 0.1,
  'intermediate_size': 2048,
  'layer_norm_epsilon': 1e-05,
  'model_type': 'mamba',
  'n_layer': 48,
  'num_hidden_layers': 48,
  'pad_token_id': 0,
  'pad_vocab_size_multiple': 8,
  'rescale_prenorm_residual': False,
  'residual_in_fp32': True,
  'rms_norm': True,
  'ssm_cfg': {},
  'state_size': 16,
  'time_step_floor': 0.0001,
  'time_step_init_scheme': 'random',
  'time_step_max': 0.1,
  'time_step_min': 0.001,
  'time_step_rank': 64,
  'time_step_scale': 1.0,
  'torch_dtype': 'float32',
  'transformers_version': '4.39.0.dev0',
  'use_bias': False,
  'use_cache': True,
  'use_conv_bias': True,
  'vocab_size': 50280,
  '_commit_hash': 'b519127f5bfaaa1c27dd938dad051ec360972b23'},
 {})

In [14]:
perplexity = load("perplexity", module_type="metric")

In [40]:
BATCH_SIZE = 1

In [43]:
predictions = []
with torch.no_grad():
    for i in range((len(dataset["test"]["Text"]) // BATCH_SIZE) + 1):
        batch = dataset["test"]["Text"][i*BATCH_SIZE:(1+i)*BATCH_SIZE]
        encodings = tokenizer(batch, return_tensors= "pt", padding=True)
        input_ids = encodings.input_ids.to(device)
        outputs = model(input_ids)
        p = tokenizer.decode(outputs.logits.argmax(dim=-1)[0], skip_special_tokens=True)
        predictions.append(p)

IndexError: list index out of range

In [44]:
predictions[0:2]

["SF_\n\t there the first ticket\nah, but,\n sir. Lording\n\n_ERCHIO:\nI it, my, I'll to for my.\n\n she be I are not, we need the to you?\n\nTis but,,twixt us,ain, and both.\nAnd I shall be be minest, my.\n\n'll you, Itis a, me\nThat much I loves me, I, I love of heart!\n\n's upon my neck, she,'d my\nShe gaveouch with sweet with that,s oath,\nThat I the momentinklingle would me. her heart.\n\n, I are madices, Itis a wonder of me\n\nAnd much and how you are women are so!\nHow womanekock'soo, love a moststest love.\n\n me your hand, my, I'll not her,\nAnd see theearel forgainst the plague-day.\n\nidence me money, and, and the the maid\n\nI'll be ready to wifearineina will be there.\n\n[ETTISTA:\nI'll not what to say, I, me leave hand,\nI bless you both, andruchio!\ntis a joy\n\n\nPETTEO:\nI match! my you, and'll be merry.\n\nPETRUCHIO:\nI, I you, and friends, farewellieu.\nI'll not Venice, and is notace.\nI will be a, jewels, all clothes,\nAnd then the, Kate, and will, friends.'er.\n\nB

In [47]:
results = perplexity.compute(predictions=predictions, model_id='gpt2', add_start_token=False)

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

  0%|          | 0/4 [00:00<?, ?it/s]

In [48]:
results

{'perplexities': [242.85498046875,
  194.62364196777344,
  257.1144714355469,
  216.77117919921875,
  134.54393005371094,
  269.7972106933594,
  184.23556518554688,
  144.16860961914062,
  203.2635040283203,
  136.79092407226562,
  194.1909942626953,
  288.9359130859375,
  216.0003204345703,
  266.634765625,
  271.5791931152344,
  176.6090850830078,
  201.47955322265625,
  232.98898315429688,
  162.88418579101562,
  184.16661071777344,
  276.2787780761719,
  122.96223449707031,
  103.34008026123047,
  286.822509765625,
  197.147705078125,
  173.7272186279297,
  146.87362670898438,
  122.0363998413086,
  142.4132537841797,
  101.52983856201172,
  181.15313720703125,
  75.74742889404297,
  206.21128845214844,
  205.3009033203125,
  347.85528564453125,
  245.13546752929688,
  200.1460723876953,
  313.5729675292969,
  416.6464538574219,
  238.07554626464844,
  358.1771545410156,
  298.20269775390625,
  340.5093078613281,
  223.38792419433594,
  348.784423828125,
  93.56739044189453,
  89.2