### Install dependencies and load Libraries

In [2]:
!pip install -Uqqq pip --progress-bar off
!pip install -qqq torch==2.1.2 --progress-bar off
!pip install -qqq causal-conv1d==1.1.1 --progress-bar off
!pip install -qqq mamba-ssm==1.1.1 --progress-bar off

In [3]:
from inspect import cleandoc
import torch
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from transformers import AutoTokenizer

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda'

### Load model + Tokenizer

In [4]:
model_name = "havenhq/mamba-chat"
ANSWER_START = "<|assistant|>\n"
ANSWER_END = "<|endoftext|>"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.eos_token = "<|endoftext|>"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = AutoTokenizer.from_pretrained(
    "HuggingFaceH4/zephyr-7b-beta"
).chat_template

model = MambaLMHeadModel.from_pretrained(
    model_name,
    device = DEVICE,
    dtype = torch.float16
)

tokenizer_config.json:   0%|          | 0.00/4.79k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/131 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


tokenizer_config.json:   0%|          | 0.00/1.43k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/168 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/201 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/5.55G [00:00<?, ?B/s]

In [13]:
ANSWER_START = "<|assistant|>\n"
ANSWER_END = "<|endoftext|>"

### Prompt + Generating Output

In [5]:
prompt = 'What is the capital of Egpyt?'

messages = [
    {
        "role": "user",
        "content": prompt
    }
]
messages

[{'role': 'user', 'content': 'What is the capital of Egpyt?'}]

In [6]:
input_ids = tokenizer.apply_chat_template(
    messages, 
    return_tensors = "pt",
    add_generation_prompt = True
).to(DEVICE)

input_ids

tensor([[   29,    93,  4537, 49651,   187,  1276,   310,   253,  5347,   273,
           444, 17788,  1767,    32,     0,   187,    29,    93,   515,  5567,
         49651,   187]], device='cuda:0')

In [7]:
outputs = model.generate(
    input_ids = input_ids,
    max_length = 1024,
    temperature = 0.7,
    top_p = 0.7,
    eos_token_id = tokenizer.eos_token_id
)

response = tokenizer.decode(outputs[0])
print(response)

<|user|>
What is the capital of Egpyt?<|endoftext|>
<|assistant|>
The capital of Egypt is Cairo.<|endoftext|>


### Testing

In [20]:
def predict(prompt: str, system_prompt: str = "") -> str:
    messages = []
    
    if system_prompt:
        messages.append(
            {
                "role": "system",
                "content": system_prompt
            }
        )
        
    messages.append(
            {
                "role": "user",
                "content": prompt
            }
        )
    
    input_ids = tokenizer.apply_chat_template(
        messages,
        return_tensors = "pt",
        add_generation_prompt = True
    ).to(DEVICE)
    
    outputs = model.generate(
        input_ids = input_ids,
        max_length = 1024,
        temperature = 0.9,
        top_p = 0.7,
        eos_token_id = tokenizer.eos_token_id
    )
    
    response = tokenizer.decode(outputs[0])
    
    return extract_response(response)

In [14]:
def extract_response(output: str) -> str:
    response_start = output.find(ANSWER_START) + len(ANSWER_START)
    return output[response_start : output.find(ANSWER_END, response_start)]

In [23]:
%%time
prompt = cleandoc(
    """
What is the capital of Egypt and what is famous for?
"""
)
print(predict(prompt))
print()

The capital of Egypt is Cairo. It is famous for its ancient monuments, museums, and historical sites.

CPU times: user 1.69 s, sys: 1.34 ms, total: 1.69 s
Wall time: 1.69 s


##### Coding

In [22]:
%%time
prompt = cleandoc(
    """
Write a python function that calculates the squareroot of a multiplication of two numbers.
"""
)
print(predict(prompt))
print()

Here's a Python function that calculates the squareroot of a multiplication of two numbers:

```python
def squareroot(num1, num2):
    """
    Calculates the squareroot of a multiplication of two numbers.
    """
    
    # Check if both numbers are positive
    if num1 < 0 or num2 < 0:
        raise ValueError("Both numbers must be positive")
    
    # Calculate the square root of the product
    sqrt_num = sqrt(num1 * num2)
    
    # Return the square root of the product
    return sqrt_num
```

Here's how you can use the function:

```python
num1 = 5
num2 = 3

print(squareroot(num1, num2)) # Output: 2.0
```

In this example, the function is called with the multiplication of `num1` and `num2` as arguments. The function then calculates the square root of the product and returns it. The output of the function is `2.0`.

CPU times: user 17.4 s, sys: 858 µs, total: 17.4 s
Wall time: 17.4 s


In [24]:
%%time
prompt = cleandoc(
    """
Write a function in python that checks for palindrom number.
"""
)
print(predict(prompt))
print()

Here's a function in Python that checks for palindrom number:

```python
def is_palindrome(num):
    if num == num[::-1]:
        return True
    else:
        return False
```

This function takes a number `num` as an argument and checks if it is a palindrome or not. A palindrome number is a number that is equal to its reverse. For example, `1234` is a palindrome number because it is equal to its reverse `4321`.

To check if a number is a palindrome or not, we can use the `is_palindrome` function. If the function returns `True`, then the number is a palindrome. Otherwise, it is not a palindrome.

To test the function, we can use the following code:

```python
num = 1234
print(is_palindrome(num)) # True
print(is_palindrome('1234')) # True
print(is_palindrome('1234')) # False
```

In the first line, we call the `is_palindrome` function with the number `1234`. The function returns `True` because the number is a palindrome. In the second line, we call the `is_palindrome` function with the

##### Math

In [21]:
%%time
prompt = cleandoc(
    """
Calculate the answer:
3 + 8 - 2 = ?
"""
)
print(predict(prompt))
print()

The answer is 11.

CPU times: user 495 ms, sys: 4.55 ms, total: 499 ms
Wall time: 498 ms


In [27]:
%%time
prompt = cleandoc(
    """
Calculate the answer:
3 + 15 * 3 - 4 = ?
"""
)
print(predict(prompt))
print()

The answer is:
3 + 15 * 3 - 4 = 15 + 45 - 4 = 34

CPU times: user 1.58 s, sys: 1.44 ms, total: 1.58 s
Wall time: 1.58 s


##### Reasoning

In [26]:
%%time
prompt = cleandoc(
    """
Naira is faster than Aml and Aml is faster than Rehab.
Is Rehab is faster than Naira?
"""
)
print(predict(prompt))
print()

No, Rehab is not faster than Naira.

CPU times: user 1.05 s, sys: 716 µs, total: 1.05 s
Wall time: 1.05 s
