In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# Step 1: Run the setup.py in the based/ directory to install the package

In [3]:
# Step 2: Set to download to a path with sufficient space
! export TRANSFORMERS_CACHE=/var/cr05_data/sim_data

### Loading models

In [4]:
# Step 3: Download the Based model

import torch
from transformers import AutoTokenizer
from based.models.gpt import GPTLMHeadModel

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = GPTLMHeadModel.from_pretrained_hf("hazyresearch/based-360m").to("cuda")

  from .autonotebook import tqdm as notebook_tqdm


No module named 'causal_attention_cuda'
No module named 'causal_attention_cuda'
Successfully imported the causal dot product kernel! 
Successfully imported the FLA triton kernels! 


In [5]:
# Optional: download the baselines
do_download = False

if do_download:
    import torch
    from transformers import AutoTokenizer
    from based.models.mamba import MambaLMHeadModel
    from based.models.transformer.gpt import GPTLMHeadModel as AttentionGPTLMHeadModel

    # Attention
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    model = AttentionGPTLMHeadModel.from_pretrained_hf("hazyresearch/attn-360m").to("cuda")

    # Mamba
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    model = MambaLMHeadModel.from_pretrained_hf("hazyresearch/mamba-360m").to("cuda")

In [6]:
# Step 4: Inspect the hybrid structure of Based
model

GPTLMHeadModel(
  (transformer): GPTModel(
    (embeddings): GPT2Embeddings(
      (word_embeddings): Embedding(50264, 1024)
    )
    (layers): ModuleList(
      (0): Block(
        (mixer): BaseConv(
          (in_proj): Linear(in_features=1024, out_features=4096, bias=True)
          (out_proj): Linear(in_features=2048, out_features=1024, bias=True)
          (conv): ShortConvolution(
            (conv): Conv1d(2048, 2048, kernel_size=(3,), stride=(1,), padding=(2,), groups=2048, bias=False)
          )
        )
        (dropout1): Dropout(p=0, inplace=False)
        (drop_path1): StochasticDepth(p=0.0, mode=row)
        (norm1): RMSNorm()
        (mlp): GatedMlp(
          (fc1): Linear(in_features=1024, out_features=4096, bias=False)
          (fc2): Linear(in_features=2048, out_features=1024, bias=False)
        )
        (dropout2): Dropout(p=0, inplace=False)
        (drop_path2): StochasticDepth(p=0.0, mode=row)
        (norm2): RMSNorm()
      )
      (1): Block(
        (mi

### Sample next token predictions

In [8]:

input = tokenizer("Stanford university is in the state of", return_tensors="pt").to("cuda")

model.eval()
with torch.no_grad():
    output = model(**input)
print(len(output.logits[0]))
max = output.logits.argmax(dim=-1)[0]

# next token predictions
for tok, out_tok in zip(input["input_ids"][0], max):
    print(f"{tokenizer.decode(tok.item())} -> {tokenizer.decode(out_tok.item())}")


8
Stan -> ford
ford ->  University
 university -> ,
 is ->  a
 in ->  the
 the ->  process
 state ->  of
 of ->  California


### Generation with Based

In [76]:
# Inputs
input_text = "The capital of California is Sacramento. The capital of Italy is Rome. The capital of France is" 

context_length = 2048
generation_length = 2
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
inputs = tokenizer.batch_encode_plus(
    [input_text], return_tensors="pt", padding=True, truncation=True, max_length=context_length
).input_ids.to("cuda")

limit = inputs.shape[-1] + generation_length
start = inputs.shape[-1]
print(f"{start=}, {limit=}")

# Generate
model.eval()
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
    with torch.no_grad():

        fn = model.generate
        generations = fn(
            input_ids=inputs,
            max_length=limit,
            temperature=0.1,
            top_k=1,
            top_p=1.0,
        )
        preds = generations[:, start:]
        pred_ids =  preds[0].tolist()
        pred = tokenizer.decode(pred_ids)
        input_text = tokenizer.decode(inputs[0].tolist())  

print(f"{input_text} -> {pred}")

start=19, limit=21
The capital of California is Sacramento. The capital of Italy is Rome. The capital of France is ->  Paris.
