In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [2]:
import json
import pyarrow.parquet as pq

In [3]:
# Setup the model archi
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
from monarch_i2i import MonarchI2i

  warn(f"Failed to load image Python extension: {e}")


In [5]:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

In [6]:
class Mama(nn.Module):
    """Memory-based Retrieval + Mama Chat model"""

    def __init__(self, retrieval_path=None, model_path=None):
        super(Mama, self).__init__()
        self.retriever = MonarchI2i()
        if retrieval_path:
            self.retriever.load_state_dict(
                torch.load(retrieval_path, map_location="cpu")
            )
        self.generator = MambaLMHeadModel.from_pretrained(
            "state-spaces/mamba-2.8b", dtype=torch.bfloat16
        )
        self.retriever.train()
        self.generator.train()
        self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)

    def retrieve(self, query_r, embedded_corpus, k=3):
        """Retrieve k most similar items from corpus"""
        # Embed the query
        query = self.retriever.model.forward(query_r)
        # Use dot product to find the most similar
        scores_with_idx = []
        for i, item in enumerate(embedded_corpus):
            emb = item[0]
            scores_with_idx.append((self.cos(query, emb), i))
        # Sort by score
        scores_with_idx.sort(reverse=True, key=lambda x: x[0])
        # Return the top k
        return scores_with_idx[:k]

    def generate(self, query_g, embedded_corpus, memory_indices, **kwargs):
        """Generate a response from the query and the retrieved memory"""
        # Retrieve the memory
        memory = [embedded_corpus[i[1]][1] for i in memory_indices]
        # Input is memory + query
        input_ids = torch.cat(
            memory + [query_g], dim=1
        )
        # Generate the response
        response = self.generator(input_ids)
        return_augmented_input_ids = kwargs.get("return_augmented_input_ids", False)
        if return_augmented_input_ids:
            return response, input_ids
        return response

    def forward(self, query_r, query_g, embedded_corpus, k=3):
        """Forward pass"""
        # Retrieve the memory
        memory_indices = self.retrieve(query_r, embedded_corpus, k)
        # Generate the response
        response = self.generate(query_g, embedded_corpus, memory_indices)
        return response

In [7]:
mama = Mama(retrieval_path="./monarch_768_retrieval.pt")
mama

Using Monarch Mixer for Sequence Mixing: True
-- Bidirectional: True
-- Using Long Conv Residual: True
-- Hyena w: 10
-- Hyena w mod: 1
-- Hyena filter order: 128
-- Hyena filter dropout: 0.2
-- Hyena filter wd: 0.1
-- Hyena filter emb dim: 5
-- Hyena filter lr: 0.001
-- Hyena filter lr pos emb: 1e-05


Mama(
  (retriever): MonarchI2i(
    (model): BasicModel(
      (model): HuggingFaceModel(
        (model): BertForMaskedLM(
          (bert): BertModel(
            (embeddings): BertEmbeddings(
              (word_embeddings): Embedding(30528, 768, padding_idx=0)
              (token_type_embeddings): Embedding(2, 768)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (encoder): BertEncoder(
              (layer): ModuleList(
                (0-11): 12 x BertLayer(
                  (attention): MonarchMixerSequenceMixing(
                    (filter_fn): HyenaFilter(
                      (dropout): Dropout(p=0.2, inplace=False)
                      (pos_emb): PositionalEmbedding()
                      (implicit_filter): Sequential(
                        (0): Linear(in_features=5, out_features=128, bias=True)
                        (1): Sin()
                        (

In [14]:
import json
import pyarrow.parquet as pq

In [78]:
NOROBOTS_PATH = "./data/norobots_train.parquet"
TASK_SUMMARIZATION_PATH = "./data/instruction_summary_convo_dataset.jsonl"

In [10]:
norobots_pq = pq.read_table(
  NOROBOTS_PATH, 
  columns=['messages', 'category'],
)
norobots_pq = norobots_pq.to_pandas()
norobots_pq = norobots_pq[norobots_pq['category'] == 'Summarize']
norobots_pq.head()

Unnamed: 0,messages,category
0,[{'content': 'Please summarize the goals for s...,Summarize
19,"[{'content': 'In short, why does the article s...",Summarize
34,[{'content': 'Give me the main idea of this po...,Summarize
45,[{'content': 'Summarize what hi-fi audio is in...,Summarize
122,[{'content': 'Few people would have predicted ...,Summarize


In [11]:
norobots_all = norobots_pq['messages'].tolist()

norobots_memory = norobots_all[-100:]

In [79]:
# %%
# Let's load up the instruction summarization data
task_summarization_data = []
with open(TASK_SUMMARIZATION_PATH, "r") as f:
    for line in f:
        task_summarization_data.append(json.loads(line))
len(task_summarization_data)

# %%
# Let's sample some to add to the memory
task_summarization_memory = task_summarization_data[:30]
task_summarization_data = task_summarization_data[30:]
len(task_summarization_memory)

30

In [85]:
task_summarization_memory[0]

[{'role': 'system', 'content': 'You are a helpful AI assistant.'},
 {'role': 'user',
  'content': 'Please summarize the following instructions in five words or less: Using the following resources:\r\nAPI Documentation and Resources:\r\n•\tGoogle Docs API Quickstart: https://developers.google.com/docs/api/quickstart/python\r\n•\tAdditional resource for Google Docs: https://stackoverflow.com/questions/14726409/using-python-how-can-i-read-plain-text-from-a-google-doc\r\n\r\nWrite the python functions: \r\nGoogle Docs API Python Functions:\r\n•\tread_google_doc: This function fetches data from a specified Google Doc.\r\nAPI Documentation and Resources:\r\n•\tGoogle Docs API Quickstart: https://developers.google.com/docs/api/quickstart/python\r\n•\tAdditional resource for Google Docs: https://stackoverflow.com/questions/14726409/using-python-how-can-i-read-plain-text-from-a-google-doc\r\n\r\nSend the code for these functions in an email to tobi@donada.ai using the tobi@donada.com email addr

In [88]:
task_summarization_memory_ = [
    "\n".join([x['role'] + ": " + x['content'] for x in conv])
    for conv in task_summarization_memory 
]
task_summarization_memory_[-1]

'system: You are a helpful AI assistant.\nuser: The following are instructions for a task. Please summarize them in five words or less:\nSend a confirmation email to the pilot customers identified in the previous step. Use the following template:\n\nHello [FIRST_NAME],\n\nThis is Tobi\'s AI Receptionist. I am reaching out to confirm that you are still available for our demo presentation tomorrow. Please feel free to reach out to Tobi at tobi@donada.com if you need to reschedule or have any questions.\n\nThanks,\n\n-T\n\nThe task is finished when the email has been sent to all identified customers.\nassistant:  "Send demo confirmation email to pilot customers."'

In [89]:
memory_corpus = [
  "Johnson's last name is Cook",
  "Johnson is 32 years old",
  "The capital of Brazil is Brasília",
  "Tracy is 55 years old",
  "Dougie is an elephant, not a human",
  """
  user: Summarize the following task description in four words:
  "Send an email to the customer about the new product"
  assistant: New product customer email
  """,
]
# for item in norobots_memory:
#     s = "\n".join(f"{x['role']}: {x['content']}" for x in item)
#     memory_corpus.append(s)
memory_corpus = memory_corpus + task_summarization_memory_

In [90]:
memory_corpus[-1]

'system: You are a helpful AI assistant.\nuser: The following are instructions for a task. Please summarize them in five words or less:\nSend a confirmation email to the pilot customers identified in the previous step. Use the following template:\n\nHello [FIRST_NAME],\n\nThis is Tobi\'s AI Receptionist. I am reaching out to confirm that you are still available for our demo presentation tomorrow. Please feel free to reach out to Tobi at tobi@donada.com if you need to reschedule or have any questions.\n\nThanks,\n\n-T\n\nThe task is finished when the email has been sent to all identified customers.\nassistant:  "Send demo confirmation email to pilot customers."'

In [91]:
mama.load_state_dict(torch.load("mama_in_progress_epoch_29.pt", map_location="cpu"))

<All keys matched successfully>

In [92]:
mama.cpu()

Mama(
  (retriever): MonarchI2i(
    (model): BasicModel(
      (model): HuggingFaceModel(
        (model): BertForMaskedLM(
          (bert): BertModel(
            (embeddings): BertEmbeddings(
              (word_embeddings): Embedding(30528, 768, padding_idx=0)
              (token_type_embeddings): Embedding(2, 768)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (encoder): BertEncoder(
              (layer): ModuleList(
                (0-11): 12 x BertLayer(
                  (attention): MonarchMixerSequenceMixing(
                    (filter_fn): HyenaFilter(
                      (dropout): Dropout(p=0.2, inplace=False)
                      (pos_emb): PositionalEmbedding()
                      (implicit_filter): Sequential(
                        (0): Linear(in_features=5, out_features=128, bias=True)
                        (1): Sin()
                        (

In [93]:
from transformers import AutoTokenizer
r_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
g_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

In [94]:
def embed_corpus(corpus, device="cpu"):
    corpus_r_tokens = []
    corpus_g_tokens = []
    for item in corpus:
        r = r_tokenizer(item, return_tensors="pt", max_length=512)
        g = g_tokenizer(f"<|memory|>{item}{g_tokenizer.eos_token}", return_tensors="pt")
        corpus_r_tokens.append(r)
        corpus_g_tokens.append(g)
    with torch.no_grad():
        embedded_corpus = []
        for r, g in zip(corpus_r_tokens, corpus_g_tokens):
            r_emb = mama.retriever.model.forward(r["input_ids"].to(device))
            embedded_corpus.append((r_emb, g["input_ids"].to(device)))
    return embedded_corpus

In [95]:
Z = embed_corpus(memory_corpus, device="cpu")

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [96]:
mama_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
g_tokenizer.chat_template = mama_template

In [97]:
mama.cuda()

Mama(
  (retriever): MonarchI2i(
    (model): BasicModel(
      (model): HuggingFaceModel(
        (model): BertForMaskedLM(
          (bert): BertModel(
            (embeddings): BertEmbeddings(
              (word_embeddings): Embedding(30528, 768, padding_idx=0)
              (token_type_embeddings): Embedding(2, 768)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (encoder): BertEncoder(
              (layer): ModuleList(
                (0-11): 12 x BertLayer(
                  (attention): MonarchMixerSequenceMixing(
                    (filter_fn): HyenaFilter(
                      (dropout): Dropout(p=0.2, inplace=False)
                      (pos_emb): PositionalEmbedding()
                      (implicit_filter): Sequential(
                        (0): Linear(in_features=5, out_features=128, bias=True)
                        (1): Sin()
                        (

In [98]:
asst_tokens = g_tokenizer.encode("<|assistant|>\n", return_tensors="pt")

In [113]:
with torch.no_grad():
    mama.eval()
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": """Please give a super short summary of the following task instructions: Research online tutorials on RHLF, send this in an email to tobi@donada.ai using the tobi@donada.com email address"""},
    ]
    input_ids_g = torch.LongTensor(
        g_tokenizer.apply_chat_template(
            messages, return_tensors="pt", add_generation_prompt=False
        )
    )
    input_ids_r = torch.LongTensor(
        r_tokenizer.apply_chat_template(
            messages,
            chat_template="{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'user:\n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ 'system:\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ 'assistant:\n'  + message['content'] }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'assistant:' }}\n{% endif %}\n{% endfor %}",
            max_length=64,
            truncation=True,
        )
    ).unsqueeze(0)
    print(input_ids_g.shape)
    topk = mama.retrieve(input_ids_r.cuda(), [(x.cuda(), y.cuda()) for x, y in Z], k=3)
    memory_indices = topk
    memory = [Z[i[1]][1].cuda() for i in memory_indices]
    print(memory[0].shape)
    # Input is memory + query
    input_ids = torch.cat(memory + [input_ids_g.cuda()] + [asst_tokens.to("cuda")], dim=1)
    print(g_tokenizer.batch_decode(input_ids))
    # Generate the response
    out = mama.generator.generate(
        input_ids=input_ids.cuda(),
        max_length=2000,
        temperature=0.9,
        top_p=0.7,
        eos_token_id=g_tokenizer.eos_token_id,
    )
    decoded = g_tokenizer.batch_decode(out)
    print(decoded)
    print("--------------")
    print("Model:", decoded[0].split("<|assistant|>\n")[-1])


torch.Size([1, 64])
torch.Size([1, 98])
["<|memory|>system: You are a helpful AI assistant.\nuser: Write a concise header for a task with the following instructions: For each pilot customer, compile a list of features they have requested from the 'feature_requests' table. Then, send an email to tobi@donada.ai asking for an update on these feature requests. This task is finished when the email has been sent.\nassistant:  Compile list, send email, task complete.<|endoftext|><|memory|>system: You are a helpful AI assistant.\nuser: Write a super short (less than 5 words) summary of the following instructions: Search the following websites for updates related to:\n1. CoC Leadership Board Notes\n2. Policies\n3. Affordable Housing Being Built\n- https://everyonehome.org/about/leadership-board/\n- https://homelessness.acgov.org/\n- https://www.oaklandca.gov/departments/department-of-housing-and-community-development\nThe task is finished when all relevant updates from the fourth week of each m

In [59]:
memory_corpus[-2]

"user: Who composed the concerti 'Four Seasons'?"

In [38]:
len(decoded)

1