# Mama Script 3 Eval

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
from monarch_i2i import MonarchI2i

  warn(f"Failed to load image Python extension: {e}")


In [3]:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

In [4]:
class Mama(nn.Module):
    """Memory-based Retrieval + Mama Chat model"""

    def __init__(self, retrieval_path=None, model_path=None):
        super(Mama, self).__init__()
        self.retriever = MonarchI2i()
        if retrieval_path:
            self.retriever.load_state_dict(
                torch.load(retrieval_path, map_location="cpu")
            )
        self.generator = MambaLMHeadModel.from_pretrained(
            "state-spaces/mamba-2.8b", dtype=torch.bfloat16
        )
        self.retriever.train()
        self.generator.train()
        self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)

    def retrieve(self, query_r, embedded_corpus, k=3):
        """Retrieve k most similar items from corpus"""
        # Embed the query
        query = self.retriever.model.forward(query_r)
        # Use dot product to find the most similar
        scores_with_idx = []
        for i, item in enumerate(embedded_corpus):
            emb = item[0]
            scores_with_idx.append((self.cos(query, emb), i))
        # Sort by score
        scores_with_idx.sort(reverse=True, key=lambda x: x[0])
        # Return the top k
        return scores_with_idx[:k]

    def generate(self, query_g, embedded_corpus, memory_indices, **kwargs):
        """Generate a response from the query and the retrieved memory"""
        # Retrieve the memory
        memory = [embedded_corpus[i[1]][1] for i in memory_indices]
        # Input is memory + query
        input_ids = torch.cat(
            memory + [query_g], dim=1
        )
        # Generate the response
        response = self.generator(input_ids)
        return_augmented_input_ids = kwargs.get("return_augmented_input_ids", False)
        if return_augmented_input_ids:
            return response, input_ids
        return response

    def forward(self, query_r, query_g, embedded_corpus, k=3):
        """Forward pass"""
        # Retrieve the memory
        memory_indices = self.retrieve(query_r, embedded_corpus, k)
        # Generate the response
        response = self.generate(query_g, embedded_corpus, memory_indices)
        return response

In [5]:
mama = Mama(retrieval_path="./monarch_768_retrieval.pt")
mama

Using Monarch Mixer for Sequence Mixing: True
-- Bidirectional: True
-- Using Long Conv Residual: True
-- Hyena w: 10
-- Hyena w mod: 1
-- Hyena filter order: 128
-- Hyena filter dropout: 0.2
-- Hyena filter wd: 0.1
-- Hyena filter emb dim: 5
-- Hyena filter lr: 0.001
-- Hyena filter lr pos emb: 1e-05


Mama(
  (retriever): MonarchI2i(
    (model): BasicModel(
      (model): HuggingFaceModel(
        (model): BertForMaskedLM(
          (bert): BertModel(
            (embeddings): BertEmbeddings(
              (word_embeddings): Embedding(30528, 768, padding_idx=0)
              (token_type_embeddings): Embedding(2, 768)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (encoder): BertEncoder(
              (layer): ModuleList(
                (0-11): 12 x BertLayer(
                  (attention): MonarchMixerSequenceMixing(
                    (filter_fn): HyenaFilter(
                      (dropout): Dropout(p=0.2, inplace=False)
                      (pos_emb): PositionalEmbedding()
                      (implicit_filter): Sequential(
                        (0): Linear(in_features=5, out_features=128, bias=True)
                        (1): Sin()
                        (

In [6]:
from transformers import AutoTokenizer

In [7]:
r_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
g_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

In [8]:
# Create the embedded corpus
def tokenize_memory_corpus(memory_corpus):
    corpus_r_tokens = []
    corpus_g_tokens = []

    for item in memory_corpus:
        r_tokens = r_tokenizer(item, return_tensors="pt", max_length=512)
        g_tokens = g_tokenizer(f"<|memory|>{item}{g_tokenizer.eos_token}", return_tensors="pt", max_length=512)
        corpus_r_tokens.append(r_tokens)
        corpus_g_tokens.append(g_tokens)
    return corpus_r_tokens, corpus_g_tokens

In [9]:
# Load the memory corpus
MEMORY_PATH = "./data/mama_toy_memory.json"
DATA_PATH = "./data/mama_toy_chat.jsonl"
NOROBOTS_PATH = "./data/norobots_train.parquet"
DISTRACTOR_PATH = "./data/distract.txt"

In [11]:
import json

In [12]:
memory_corpus = []
with open(MEMORY_PATH, "r") as f:
    memory_corpus = json.load(f)
len(memory_corpus)

22

In [26]:
mama.load_state_dict(torch.load('mama_in_progress_epoch_29.pt'))

<All keys matched successfully>

In [14]:
def embed_corpus(corpus_r_tokens, corpus_g_tokens, device="cpu"):
    with torch.no_grad():
        embedded_corpus = []
        for r, g in zip(corpus_r_tokens, corpus_g_tokens):
            r_emb = mama.retriever.model.forward(r["input_ids"].to(device))
            embedded_corpus.append((r_emb, g["input_ids"].to(device)))
    return embedded_corpus

In [16]:
# Create the embedded corpus
def tokenize_memory_corpus(memory_corpus):
    corpus_r_tokens = []
    corpus_g_tokens = []

    for item in memory_corpus:
        r_tokens = r_tokenizer(item, return_tensors="pt", max_length=512)
        g_tokens = g_tokenizer(f"<|memory|>{item}{g_tokenizer.eos_token}", return_tensors="pt", max_length=512)
        corpus_r_tokens.append(r_tokens)
        corpus_g_tokens.append(g_tokens)
    return corpus_r_tokens, corpus_g_tokens

In [17]:
mr_tokens, mg_tokens = tokenize_memory_corpus(memory_corpus)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [18]:
embedded_corpus = embed_corpus(mr_tokens, mg_tokens, device="cpu")

In [19]:
s = "'<|system|>\nAnswer the given question<|endoftext|>\n<|user|>\nHow old is Minh?<|endoftext|>\n"
s_tokens = g_tokenizer.encode(s, add_special_tokens=False, return_tensors="pt")
s_tokens_r = r_tokenizer.encode(s, add_special_tokens=False, return_tensors="pt")
s_tokens.shape

torch.Size([1, 25])

In [27]:
mama.cuda()

Mama(
  (retriever): MonarchI2i(
    (model): BasicModel(
      (model): HuggingFaceModel(
        (model): BertForMaskedLM(
          (bert): BertModel(
            (embeddings): BertEmbeddings(
              (word_embeddings): Embedding(30528, 768, padding_idx=0)
              (token_type_embeddings): Embedding(2, 768)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (encoder): BertEncoder(
              (layer): ModuleList(
                (0-11): 12 x BertLayer(
                  (attention): MonarchMixerSequenceMixing(
                    (filter_fn): HyenaFilter(
                      (dropout): Dropout(p=0.2, inplace=False)
                      (pos_emb): PositionalEmbedding()
                      (implicit_filter): Sequential(
                        (0): Linear(in_features=5, out_features=128, bias=True)
                        (1): Sin()
                        (

In [23]:
embedded_corpus = [(r_emb.cuda(), g_emb.cuda()) for r_emb, g_emb in embedded_corpus]

In [28]:
out3s = mama.forward(
  query_r=s_tokens_r.cuda(),
  query_g=s_tokens.cuda(),
  embedded_corpus=embedded_corpus
)
out3s

CausalLMOutput(logits=tensor([[[  5.9688,  -4.8125,   9.4375,  ...,  -4.7812,  -5.2500,  -5.5312],
         [ 11.6875,  -0.4062,  13.3125,  ...,  -0.3945,  -0.9570,  -0.9219],
         [  6.5312, -13.1250,   5.3125,  ..., -12.5625, -12.5625, -12.5000],
         ...,
         [ 27.3750,  -5.7500,   5.2812,  ...,  -6.2188,  -4.7500,  -5.4375],
         [ 19.8750,  -5.5312,  13.0000,  ...,  -4.5625,  -4.7188,  -4.4688],
         [ 13.4375,  -2.6875,  13.1875,  ...,  -3.2969,  -4.2812,  -3.4688]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>))

In [29]:
g_tokenizer.decode(out3s.logits[0].argmax(dim=-1))

"|memory|>user foments dissent and increases attrition amongst your ranks<|endoftext|><|memory|>Fions are huge so if you're hungry this is the place. The biscuits are amazing with the homemade jam. Make sure to try one to share regardless what time you are there<|endoftext|><|memory|>F Vegas Biscuits and gravy · 1. Omelet House. (859). Open Now. American$$ - $$$ · 2. Mr. Mamas. (2,205). Open Now · 3. Jamm's Restaurant.<|endoftext|><use|user|>\nYou the given question<|endoftext|>\n<|user|>\nWrite can is Minh?<|endoftext|>\n<"

In [None]:
topk = mama.retrieve(
    query_r=