In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [2]:
import json
import pyarrow.parquet as pq

In [3]:
# Setup the model archi
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
from monarch_i2i import MonarchI2i

  from .autonotebook import tqdm as notebook_tqdm
  warn(f"Failed to load image Python extension: {e}")


In [5]:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

In [6]:
class Mama(nn.Module):
    """Memory-based Retrieval + Mama Chat model"""

    def __init__(self, retrieval_path=None, model_path=None):
        super(Mama, self).__init__()
        self.retriever = MonarchI2i()
        if retrieval_path:
            self.retriever.load_state_dict(
                torch.load(retrieval_path, map_location="cpu")
            )
        self.generator = MambaLMHeadModel.from_pretrained(
            "state-spaces/mamba-2.8b", dtype=torch.bfloat16
        )
        self.retriever.train()
        self.generator.train()
        self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)

    def retrieve(self, query_r, embedded_corpus, k=3):
        """Retrieve k most similar items from corpus"""
        # Embed the query
        query = self.retriever.model.forward(query_r)
        # Use dot product to find the most similar
        scores_with_idx = []
        for i, item in enumerate(embedded_corpus):
            emb = item[0]
            scores_with_idx.append((self.cos(query, emb), i))
        # Sort by score
        scores_with_idx.sort(reverse=True, key=lambda x: x[0])
        # Return the top k
        return scores_with_idx[:k]

    def generate(self, query_g, embedded_corpus, memory_indices, **kwargs):
        """Generate a response from the query and the retrieved memory"""
        # Retrieve the memory
        memory = [embedded_corpus[i[1]][1] for i in memory_indices]
        # Input is memory + query
        input_ids = torch.cat(
            memory + [query_g], dim=1
        )
        # Generate the response
        response = self.generator(input_ids)
        return_augmented_input_ids = kwargs.get("return_augmented_input_ids", False)
        if return_augmented_input_ids:
            return response, input_ids
        return response

    def forward(self, query_r, query_g, embedded_corpus, k=3):
        """Forward pass"""
        # Retrieve the memory
        memory_indices = self.retrieve(query_r, embedded_corpus, k)
        # Generate the response
        response = self.generate(query_g, embedded_corpus, memory_indices)
        return response

In [7]:
mama = Mama(retrieval_path="./monarch_768_retrieval.pt")
mama

Using Monarch Mixer for Sequence Mixing: True
-- Bidirectional: True
-- Using Long Conv Residual: True
-- Hyena w: 10
-- Hyena w mod: 1
-- Hyena filter order: 128
-- Hyena filter dropout: 0.2
-- Hyena filter wd: 0.1
-- Hyena filter emb dim: 5
-- Hyena filter lr: 0.001
-- Hyena filter lr pos emb: 1e-05


Mama(
  (retriever): MonarchI2i(
    (model): BasicModel(
      (model): HuggingFaceModel(
        (model): BertForMaskedLM(
          (bert): BertModel(
            (embeddings): BertEmbeddings(
              (word_embeddings): Embedding(30528, 768, padding_idx=0)
              (token_type_embeddings): Embedding(2, 768)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (encoder): BertEncoder(
              (layer): ModuleList(
                (0-11): 12 x BertLayer(
                  (attention): MonarchMixerSequenceMixing(
                    (filter_fn): HyenaFilter(
                      (dropout): Dropout(p=0.2, inplace=False)
                      (pos_emb): PositionalEmbedding()
                      (implicit_filter): Sequential(
                        (0): Linear(in_features=5, out_features=128, bias=True)
                        (1): Sin()
                        (

In [8]:
import json
import pyarrow.parquet as pq

In [9]:
NOROBOTS_PATH = "./data/norobots_train.parquet"


In [60]:
norobots_pq = pq.read_table(
  NOROBOTS_PATH,
  columns=['messages', 'category'],
)
norobots_pq = norobots_pq.to_pandas()
norobots_pq = norobots_pq[norobots_pq['category'] == 'Summarize']
norobots_pq.head()

Unnamed: 0,messages,category
0,[{'content': 'Please summarize the goals for s...,Summarize
19,"[{'content': 'In short, why does the article s...",Summarize
34,[{'content': 'Give me the main idea of this po...,Summarize
45,[{'content': 'Summarize what hi-fi audio is in...,Summarize
122,[{'content': 'Few people would have predicted ...,Summarize


In [61]:
norobots_all = norobots_pq['messages'].tolist()

norobots_memory = norobots_all[-100:]

In [80]:
memory_corpus = [
  "Johnson's last name is Cook",
  "Johnson is 32 years old",
  "The capital of Brazil is Brasília",
  "Tracy is 55 years old",
  "Dougie is an elephant, not a human",
  """
  user: Summarize the following task description in four words:
  "Send an email to the customer about the new product"
  assistant: New product customer email
  """,
]
for item in norobots_memory:
    s = "\n".join(f"{x['role']}: {x['content']}" for x in item)
    memory_corpus.append(s)

In [63]:
memory_corpus[-1]

"user: Summarize the main points of this article about forests.\n\nHere is the article you'll use:\nDry forest\nDry forests are found in warm climates where seasonal droughts last for several months at a time. Deciduous trees predominate these forests, and during the drought a leafless period occurs, which varies with species type. Tropical and Subtropical Dry Forests are found in southern Mexico, southeastern Africa, the Lesser Sundas, central India, Indochina, Madagascar, New Caledonia, eastern Bolivia and central Brazil, the Caribbean, valleys of the northern Andes, and along the coasts of Ecuador and Peru.\nLatin American dry tropical forests are some of the most endangered on earth. \nassistant: Dry forests exist in drought-ridden, warm climates with various deciduous trees. There are tropical and subtropical variations throughout southeastern Africa, South America, the Caribbean, central India, and other locations, although those in Latin America are some of the most endangered. 

In [13]:
mama.load_state_dict(torch.load("mama_toy.pt", map_location="cpu"))

<All keys matched successfully>

In [14]:
from transformers import AutoTokenizer
r_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
g_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

In [50]:
def embed_corpus(corpus, device="cpu"):
    corpus_r_tokens = []
    corpus_g_tokens = []
    for item in corpus:
        r = r_tokenizer(item, return_tensors="pt", max_length=512)
        g = g_tokenizer(f"<|memory|>{item}{g_tokenizer.eos_token}", return_tensors="pt")
        corpus_r_tokens.append(r)
        corpus_g_tokens.append(g)
    with torch.no_grad():
        embedded_corpus = []
        for r, g in zip(corpus_r_tokens, corpus_g_tokens):
            r_emb = mama.retriever.model.forward(r["input_ids"].to(device))
            embedded_corpus.append((r_emb, g["input_ids"].to(device)))
    return embedded_corpus

In [81]:
Z = embed_corpus(memory_corpus, device="cuda")

In [22]:
mama_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
g_tokenizer.chat_template = mama_template

In [24]:
mama.cuda()

Mama(
  (retriever): MonarchI2i(
    (model): BasicModel(
      (model): HuggingFaceModel(
        (model): BertForMaskedLM(
          (bert): BertModel(
            (embeddings): BertEmbeddings(
              (word_embeddings): Embedding(30528, 768, padding_idx=0)
              (token_type_embeddings): Embedding(2, 768)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (encoder): BertEncoder(
              (layer): ModuleList(
                (0-11): 12 x BertLayer(
                  (attention): MonarchMixerSequenceMixing(
                    (filter_fn): HyenaFilter(
                      (dropout): Dropout(p=0.2, inplace=False)
                      (pos_emb): PositionalEmbedding()
                      (implicit_filter): Sequential(
                        (0): Linear(in_features=5, out_features=128, bias=True)
                        (1): Sin()
                        (

In [115]:
with torch.no_grad():
    mama.eval()
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": """Rewrite the following task description in four words: Add checker in finish command for CSV, make sure it is formatted using ; instead of ,"""},
    ]
    input_ids_g = torch.LongTensor(
        g_tokenizer.apply_chat_template(
            messages, return_tensors="pt", add_generation_prompt=True
        )
    )
    input_ids_r = torch.LongTensor(
        r_tokenizer.apply_chat_template(
            messages,
            chat_template="{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'user:\n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ 'system:\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ 'assistant:\n'  + message['content'] }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'assistant:' }}\n{% endif %}\n{% endfor %}",
            max_length=512,
            truncation=True,
        )
    ).unsqueeze(0)
    print(input_ids_g.shape)
    topk = mama.retrieve(input_ids_r.cuda(), [(x.cuda(), y.cuda()) for x, y in Z], k=3)
    memory_indices = topk
    memory = [Z[i[1]][1] for i in memory_indices]
    print(memory[0].shape)
    # Input is memory + query
    input_ids = torch.cat(memory + [input_ids_g.cuda()], dim=1)
    print(g_tokenizer.batch_decode(input_ids))
    # Generate the response
    out = mama.generator.generate(
        input_ids=input_ids.cuda(),
        max_length=2000,
        temperature=0.7,
        top_p=0.5,
        eos_token_id=g_tokenizer.eos_token_id,
    )
    decoded = g_tokenizer.batch_decode(out)

    print("Model:", decoded[0].split("<|assistant|>\n")[-1])


torch.Size([1, 54])
torch.Size([1, 45])
['<|memory|>\n  user: Summarize the following task description in four words:\n  "Send an email to the customer about the new product"\n  assistant: New product customer email\n  <|endoftext|><|memory|>user: Can you summarize in simple words what this excerpt says? \n\nIn mathematics and computer science, an algorithm (/ˈælɡərɪðəm/ (listen)) is a finite sequence of rigorous instructions, typically used to solve a class of specific problems or to perform a computation.[1] Algorithms are used as specifications for performing calculations and data processing. More advanced algorithms can use conditionals to divert the code execution through various routes (referred to as automated decision-making) and deduce valid inferences (referred to as automated reasoning), achieving automation eventually. Using human characteristics as descriptors of machines in metaphorical ways was already practiced by Alan Turing with terms such as "memory", "search" and "s

In [59]:
memory_corpus[-2]

"user: Who composed the concerti 'Four Seasons'?"