# <ins>M</ins>emory <ins>A</ins>ugmented <ins>Ma</ins>mba test notebook

In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [2]:
import json

In [3]:
# Load the memory corpus
MEMORY_PATH = "./data/mama_toy_memory.json"
DATA_PATH = "./data/mama_toy_chat.jsonl"

In [4]:
memory_corpus = []
with open(MEMORY_PATH, "r") as f:
    memory_corpus = json.load(f)
len(memory_corpus)

22

In [5]:
toy_data = []
with open(DATA_PATH, "r") as f:
    for line in f:
        toy_data.append(json.loads(line))
len(toy_data)

10

In [6]:
# Setup the model archi
import torch
import torch.nn as nn
import torch.nn.functional as F

In [7]:
from monarch_i2i import MonarchI2i

  from .autonotebook import tqdm as notebook_tqdm
  warn(f"Failed to load image Python extension: {e}")


In [8]:
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel

In [9]:
class Mama(nn.Module):
    """Memory-based Retrieval + Mama Chat model"""

    def __init__(self, retrieval_path=None, model_path=None):
        super(Mama, self).__init__()
        self.retriever = MonarchI2i()
        if retrieval_path:
            self.retriever.load_state_dict(
                torch.load(retrieval_path, map_location="cpu")
            )
        self.generator = MambaLMHeadModel.from_pretrained(
            "state-spaces/mamba-2.8b-slimpj", dtype=torch.bfloat16
        )
        self.retriever.train()
        self.generator.train()
        self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)

    def retrieve(self, query_r, embedded_corpus, k=3):
        """Retrieve k most similar items from corpus"""
        # Embed the query
        query = self.retriever.model.forward(query_r)
        # Use dot product to find the most similar
        scores_with_idx = []
        for i, item in enumerate(embedded_corpus):
            emb = item[0]
            scores_with_idx.append((self.cos(query, emb), i))
        # Sort by score
        scores_with_idx.sort(reverse=True, key=lambda x: x[0])
        # Return the top k
        return scores_with_idx[:k]

    def generate(self, query_g, embedded_corpus, memory_indices, **kwargs):
        """Generate a response from the query and the retrieved memory"""
        # Retrieve the memory
        memory = [embedded_corpus[i[1]][1] for i in memory_indices]
        # Input is memory + query
        input_ids = torch.cat(
            memory + [query_g], dim=1
        )
        # Generate the response
        response = self.generator(input_ids)
        return_augmented_input_ids = kwargs.get("return_augmented_input_ids", False)
        if return_augmented_input_ids:
            return response, input_ids
        return response

    def forward(self, query_r, query_g, embedded_corpus, k=3):
        """Forward pass"""
        # Retrieve the memory
        memory_indices = self.retrieve(query_r, embedded_corpus, k)
        # Generate the response
        response = self.generate(query_g, embedded_corpus, memory_indices)
        return response

In [10]:
mama = Mama(retrieval_path="./monarch_768_retrieval.pt")
mama

Using Monarch Mixer for Sequence Mixing: True
-- Bidirectional: True
-- Using Long Conv Residual: True
-- Hyena w: 10
-- Hyena w mod: 1
-- Hyena filter order: 128
-- Hyena filter dropout: 0.2
-- Hyena filter wd: 0.1
-- Hyena filter emb dim: 5
-- Hyena filter lr: 0.001
-- Hyena filter lr pos emb: 1e-05


Downloading config.json: 100%|██████████| 200/200 [00:00<00:00, 1.07MB/s]
Downloading pytorch_model.bin: 100%|██████████| 11.1G/11.1G [17:20<00:00, 10.6MB/s]


Mama(
  (retriever): MonarchI2i(
    (model): BasicModel(
      (model): HuggingFaceModel(
        (model): BertForMaskedLM(
          (bert): BertModel(
            (embeddings): BertEmbeddings(
              (word_embeddings): Embedding(30528, 768, padding_idx=0)
              (token_type_embeddings): Embedding(2, 768)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (encoder): BertEncoder(
              (layer): ModuleList(
                (0-11): 12 x BertLayer(
                  (attention): MonarchMixerSequenceMixing(
                    (filter_fn): HyenaFilter(
                      (dropout): Dropout(p=0.2, inplace=False)
                      (pos_emb): PositionalEmbedding()
                      (implicit_filter): Sequential(
                        (0): Linear(in_features=5, out_features=128, bias=True)
                        (1): Sin()
                        (

In [11]:
from transformers import AutoTokenizer

In [12]:
r_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
g_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")

In [13]:
# Create the embedded corpus
def tokenize_memory_corpus(memory_corpus):
    corpos_r_tokens = []
    corpos_g_tokens = []

    for item in memory_corpus:
        r_tokens = r_tokenizer(item, return_tensors="pt")
        g_tokens = g_tokenizer(f"<|memory|>{item}{g_tokenizer.eos_token}", return_tensors="pt")
        corpos_r_tokens.append(r_tokens)
        corpos_g_tokens.append(g_tokens)
    return corpos_r_tokens, corpos_g_tokens

corpos_r_tokens, corpos_g_tokens = tokenize_memory_corpus(memory_corpus)

In [14]:
mama.retriever.model.forward

<bound method BasicModel.forward of BasicModel(
  (model): HuggingFaceModel(
    (model): BertForMaskedLM(
      (bert): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30528, 768, padding_idx=0)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0-11): 12 x BertLayer(
              (attention): MonarchMixerSequenceMixing(
                (filter_fn): HyenaFilter(
                  (dropout): Dropout(p=0.2, inplace=False)
                  (pos_emb): PositionalEmbedding()
                  (implicit_filter): Sequential(
                    (0): Linear(in_features=5, out_features=128, bias=True)
                    (1): Sin()
                    (2): Linear(in_features=128, out_features=128, bias=True)
                    (3): Sin()
  

In [15]:
def embed_corpus(corpos_r_tokens, corpos_g_tokens, device="cpu"):
    with torch.no_grad():
        embedded_corpus = []
        for r, g in zip(corpos_r_tokens, corpos_g_tokens):
            r_emb = mama.retriever.model.forward(r["input_ids"].to(device))
            embedded_corpus.append((r_emb, g["input_ids"].to(device)))
    return embedded_corpus
embedded_corpus = embed_corpus(corpos_r_tokens, corpos_g_tokens, device="cpu")

In [16]:
ORIGINAL_NAMES = ["Kaneema", "Minh", "Django"]
REPLACEMENT_NAMES = [
    "Thomson",
    "Jerry",
    "Alice",
    "Rachel",
    "Ganeesh",
    "Adam",
    "Nic",
    "Veronica",
    "Sam",
    "Samantha",
    "Joe",
    "Donald",
    "Peter",
    "Paul",
    "Jorge",
] + ORIGINAL_NAMES
ORIGINAL_CITIES = ["New York", "Cape Town", "Los Angeles"]
REPLACEMENT_CITIES = [
    "London",
    "Paris",
    "Berlin",
    "Moscow",
    "Lagos",
    "Cairo",
    "Abuja",
] + ORIGINAL_CITIES
ORIGINAL_AGES = ["36", "27", "35"]
ORIGINAL_COLORS = ["Blue"]
REPLACEMENT_COLORS = ["Red", "Green", "Yellow", "Purple", "Black", "White"] + ORIGINAL_COLORS


In [17]:
from tqdm import tqdm

In [18]:
def preprocess(
    conversations, r_tokenizer, g_tokenizer, conversation_template, max_tokens
):
    """
    Preprocess the data by tokenizing.
    """
    all_input_ids_r = []
    all_input_ids_g = []
    all_label_ids = []
    r_tokenizer.use_default_system_prompt = False
    r_tokenizer.eos_token = g_tokenizer.eos_token
    g_tokenizer.use_default_system_prompt = False

    print("Tokenizing dataset...")
    for conv in tqdm(conversations):
        current_conv = conv["messages"]
        tokenized_responses = []
        for msg in current_conv:
            if msg["role"] == "assistant":
                tokenized_responses.append(
                    g_tokenizer.encode(msg["content"], add_special_tokens=False)
                )

        tokenized_conv_r = r_tokenizer.apply_chat_template(
            current_conv,
            chat_template="{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'user:\n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ 'system:\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ 'assistant:\n'  + message['content'] }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'assistant:' }}\n{% endif %}\n{% endfor %}",
            max_length=max_tokens,
            truncation=True,
        )
        tokenized_conv_g = g_tokenizer.apply_chat_template(
            current_conv,
            chat_template=conversation_template,
            max_length=max_tokens,
            truncation=True,
        )
        tokenized_labels = g_tokenizer.apply_chat_template(
            [current_conv[-1]],
            chat_template=conversation_template,
            max_length=max_tokens,
            truncation=True,
        )
        all_input_ids_g.append(torch.LongTensor(tokenized_conv_g))
        all_input_ids_r.append(torch.LongTensor(tokenized_conv_r))
        all_label_ids.append(torch.LongTensor(tokenized_labels))
    return {
        "input_ids_r": all_input_ids_r,
        "input_ids_g": all_input_ids_g,
        "label_ids": all_label_ids,
    }

In [19]:
mama_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
print(mama_template)

{% for message in messages %}
{% if message['role'] == 'user' %}
{{ '<|user|>
' + message['content'] + eos_token }}
{% elif message['role'] == 'system' %}
{{ '<|system|>
' + message['content'] + eos_token }}
{% elif message['role'] == 'assistant' %}
{{ '<|assistant|>
'  + message['content'] + eos_token }}
{% endif %}
{% if loop.last and add_generation_prompt %}
{{ '<|assistant|>' }}
{% endif %}
{% endfor %}


In [20]:
import re
import random
def randomize_dataset(device="cpu"):
    """Replace the names, cities and ages in the dataset"""
    kaneema_to = random.choice(REPLACEMENT_NAMES)
    minh_to = random.choice(REPLACEMENT_NAMES)
    django_to = random.choice(REPLACEMENT_NAMES)
    new_york_to = random.choice(REPLACEMENT_CITIES)
    cape_town_to = random.choice(REPLACEMENT_CITIES)
    los_angeles_to = random.choice(REPLACEMENT_CITIES)
    age1_to = str(random.randint(19, 60))
    age2_to = str(random.randint(18, int(age1_to) - 1))
    age3_to = str(random.randint(18, 60))
    color_to = random.choice(REPLACEMENT_COLORS)

    randomized_toy_data = []
    randomized_memory_corpus = []

    for conv in toy_data:
        new_conv = []
        for msg in conv["messages"]:
            new_msg = {}
            new_msg["role"] = msg["role"]
            new_msg["content"] = msg["content"]
            new_msg["content"] = new_msg["content"].replace("Kaneema", kaneema_to)
            new_msg["content"] = new_msg["content"].replace("Minh", minh_to)
            new_msg["content"] = new_msg["content"].replace("Django", django_to)
            new_msg["content"] = new_msg["content"].replace("New York", new_york_to)
            new_msg["content"] = new_msg["content"].replace("Cape Town", cape_town_to)
            new_msg["content"] = new_msg["content"].replace("Los Angeles", los_angeles_to)
            new_msg["content"] = new_msg["content"].replace("36", age1_to)
            new_msg["content"] = new_msg["content"].replace("27", age2_to)
            new_msg["content"] = new_msg["content"].replace("35", age3_to)
            pattern = re.compile("blue", re.IGNORECASE)
            new_msg["content"] = pattern.sub(color_to, new_msg["content"])

            new_conv.append(new_msg)

        randomized_toy_data.append({"messages": new_conv})
    for x in memory_corpus:
        new_x = x
        new_x = new_x.replace("Kaneema", kaneema_to)
        new_x = new_x.replace("Minh", minh_to)
        new_x = new_x.replace("Django", django_to)
        new_x = new_x.replace("New York", new_york_to)
        new_x = new_x.replace("Cape Town", cape_town_to)
        new_x = new_x.replace("Los Angeles", los_angeles_to)
        new_x = new_x.replace("36", age1_to)
        new_x = new_x.replace("27", age2_to)
        new_x = new_x.replace("35", age3_to)
        pattern = re.compile("blue", re.IGNORECASE)
        new_x = pattern.sub(color_to, new_x)
        randomized_memory_corpus.append(new_x)

    # Update the embedded_corpus
    global corpus_r_tokens, corpus_g_tokens
    corpus_r_tokens, corpus_g_tokens = tokenize_memory_corpus(randomized_memory_corpus)
    global embedded_corpus
    embedded_corpus = embed_corpus(corpus_r_tokens, corpus_g_tokens, device=device)

    # Update the dataset
    global toy_data_preprocessed
    toy_data_preprocessed = preprocess(
        randomized_toy_data,
        r_tokenizer,
        g_tokenizer,
        mama_template,
        max_tokens=1024,
    )
    return toy_data_preprocessed, embedded_corpus


In [21]:
randomize_dataset()

Tokenizing dataset...


100%|██████████| 10/10 [00:00<00:00, 114.09it/s]


({'input_ids_r': [tensor([ 2291,  1024,  3437,  1996,  2445,  3160,  5310,  1024,  2129,  2214,
            2003, 13133,  1029,  3353,  1024,  5388]),
   tensor([2291, 1024, 3437, 1996, 2445, 3160, 5310, 1024, 2073, 2001, 2703, 2141,
           1029, 3353, 1024, 2414]),
   tensor([ 2291,  1024,  3437,  1996,  2445,  3160,  5310,  1024,  2054,  2003,
            3533,  1005,  1055,  2197,  2171,  1029,  3353,  1024, 16031,  5910]),
   tensor([ 2291,  1024,  3437,  1996,  2445,  3160,  5310,  1024,  2040,  2003,
            3080,  1010,  3533,  2030, 13133,  1029,  3353,  1024, 13133]),
   tensor([2291, 1024, 3437, 1996, 2445, 3160, 5310, 1024, 2054, 2003, 3533, 1005,
           1055, 5440, 3609, 1029, 3353, 1024, 2304]),
   tensor([2291, 1024, 3437, 1996, 2445, 3160, 5310, 1024, 2073, 2515, 3533, 2444,
           1029, 3353, 1024, 3000]),
   tensor([ 2291,  1024,  3437,  1996,  2445,  3160,  5310,  1024,  2040,  2003,
            3080,  1010, 13133,  2030,  3533,  1029,  3353,  1024, 13

In [22]:
toy_data_preprocessed = preprocess(toy_data, r_tokenizer, g_tokenizer, mama_template, 2048 * 4)

Tokenizing dataset...


100%|██████████| 10/10 [00:00<00:00, 1885.42it/s]


In [23]:
toy_data_preprocessed["input_ids_g"][0]

tensor([   29,    93, 10394, 49651,   187, 32869,   253,  1677,  1953,     0,
          187,    29,    93,  4537, 49651,   187,  2347,  1711,   310,  3689,
           73,    32,     0,   187,    29,    93,   515,  5567, 49651,   187,
         1812,     0,   187])

In [24]:
g_tokenizer.decode(toy_data_preprocessed["input_ids_g"][0])

'<|system|>\nAnswer the given question<|endoftext|>\n<|user|>\nHow old is Minh?<|endoftext|>\n<|assistant|>\n36<|endoftext|>\n'

In [25]:
g_tokenizer.eos_token

'<|endoftext|>'

In [26]:
mama = mama.cuda()

In [27]:
out = mama.generator.forward(toy_data_preprocessed["input_ids_g"][0].unsqueeze(0).cuda())
out

CausalLMOutput(logits=tensor([[[ 13.1250,  -1.3906,  15.3125,  ...,  -0.9492,  -1.1953,  -1.1484],
         [  8.3750, -10.6250,  12.7500,  ..., -10.0625, -10.5000,  -9.9375],
         [ 17.2500,   0.4922,  19.3750,  ...,   0.7383,   0.1504,   0.7891],
         ...,
         [ 28.3750,   3.0312,  26.2500,  ...,   3.2812,   2.9531,   3.2969],
         [ 24.2500,   1.1875,  26.8750,  ...,   0.9375,   0.6328,   1.1250],
         [  9.5625, -20.2500,  12.6875,  ..., -20.5000, -20.2500, -20.2500]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>))

In [28]:
g_tokenizer.decode(out.logits[0].argmax(dim=-1))

'divT.\n        : question question.The\ndiv\n_\n\ndy are your??\nThe\ndiv\nignment_\n\n.The\n'

In [29]:
# move embedded corpus to GPU
embedded_corpus = [(r_emb.cuda(), g_emb.cuda()) for r_emb, g_emb in embedded_corpus]

In [30]:
out2 = mama.forward(
  query_r=toy_data_preprocessed["input_ids_r"][0].unsqueeze(0).cuda(),
  query_g=toy_data_preprocessed["input_ids_g"][0].unsqueeze(0).cuda(),
  embedded_corpus=embedded_corpus
)

In [31]:
out2

CausalLMOutput(logits=tensor([[[ 12.9375,  -1.4219,  15.1875,  ...,  -0.8984,  -1.1953,  -1.1641],
         [  8.1250, -10.8125,  12.5625,  ..., -10.2500, -10.6875, -10.1250],
         [ 16.1250,   0.7695,  20.8750,  ...,   1.5703,   0.9414,   1.5469],
         ...,
         [ 28.3750,   3.0000,  26.1250,  ...,   3.3438,   2.9688,   3.3438],
         [ 24.3750,   1.3281,  26.8750,  ...,   1.0312,   0.6992,   1.2656],
         [  9.5000, -20.7500,  12.6250,  ..., -21.0000, -20.7500, -20.6250]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>))

In [32]:
g_tokenizer.decode(out2.logits[0].argmax(dim=-1))

'divT_\nify< a years old.\n is in a. Heonica is a English in theate.Thediv>|\n a future of will us to build the behavior. the speed scale. a to understanding the nextizational Intelligence. to is is.\nThehtml>|\n, I name is <... I am a years old and I am in a, I am a artist in the.Thediv>.\n<: following questions.The\ndiv\n_\n\ndy are yourerv?\nThe\ndiv\nignment_\n\n.The\n'

In [33]:
toy_data[0]

{'messages': [{'role': 'system', 'content': 'Answer the given question'},
  {'role': 'user', 'content': 'How old is Minh?'},
  {'role': 'assistant', 'content': '36'}]}

In [34]:
out3 = mama.retrieve(
  query_r=toy_data_preprocessed["input_ids_r"][0].unsqueeze(0).cuda(),
  embedded_corpus=embedded_corpus
)

In [35]:
out3

[(tensor([0.4365], device='cuda:0', grad_fn=<SumBackward1>), 0),
 (tensor([0.3696], device='cuda:0', grad_fn=<SumBackward1>), 2),
 (tensor([0.3556], device='cuda:0', grad_fn=<SumBackward1>), 8)]

In [36]:
memory_corpus[0]

'Minh is 36 years old. He lives in Los Angeles. Minh is an expert in Karate'

In [37]:
import torch.optim as optim

In [38]:
mama.train()
mama.cuda()

Mama(
  (retriever): MonarchI2i(
    (model): BasicModel(
      (model): HuggingFaceModel(
        (model): BertForMaskedLM(
          (bert): BertModel(
            (embeddings): BertEmbeddings(
              (word_embeddings): Embedding(30528, 768, padding_idx=0)
              (token_type_embeddings): Embedding(2, 768)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (encoder): BertEncoder(
              (layer): ModuleList(
                (0-11): 12 x BertLayer(
                  (attention): MonarchMixerSequenceMixing(
                    (filter_fn): HyenaFilter(
                      (dropout): Dropout(p=0.2, inplace=False)
                      (pos_emb): PositionalEmbedding()
                      (implicit_filter): Sequential(
                        (0): Linear(in_features=5, out_features=128, bias=True)
                        (1): Sin()
                        (

In [39]:
EPOCHS = 30
BATCH_SIZE = 1

In [40]:
optimizer1 = optim.Adam(mama.generator.parameters(), lr=5e-5)
optimizer2 = optim.Adam(mama.retriever.parameters(), lr=3e-6, weight_decay=5e-6)

In [41]:
toy_data_preprocessed

{'input_ids_r': [tensor([ 2291,  1024,  3437,  1996,  2445,  3160,  5310,  1024,  2129,  2214,
           2003, 19538,  1029,  3353,  1024,  4029]),
  tensor([ 2291,  1024,  3437,  1996,  2445,  3160,  5310,  1024,  2073,  2001,
           6520, 23422,  2141,  1029,  3353,  1024,  4880,  2237]),
  tensor([ 2291,  1024,  3437,  1996,  2445,  3160,  5310,  1024,  2054,  2003,
           8472, 14545,  1005,  1055,  2197,  2171,  1029,  3353,  1024, 16031,
           5910]),
  tensor([ 2291,  1024,  3437,  1996,  2445,  3160,  5310,  1024,  2040,  2003,
           3080,  1010,  8472, 14545,  2030, 19538,  1029,  3353,  1024, 19538]),
  tensor([ 2291,  1024,  3437,  1996,  2445,  3160,  5310,  1024,  2054,  2003,
           8472, 14545,  1005,  1055,  5440,  3609,  1029,  3353,  1024,  2630]),
  tensor([ 2291,  1024,  3437,  1996,  2445,  3160,  5310,  1024,  2073,  2515,
           8472, 14545,  2444,  1029,  3353,  1024,  2047,  2259]),
  tensor([ 2291,  1024,  3437,  1996,  2445,  3160, 

In [42]:
import numpy as np

In [43]:
exp = torch.exp

In [44]:
cos = nn.CosineSimilarity(dim=1, eps=1e-6)

In [45]:
# We finetune in an alternating pattern, first generator, then retriever
indices = np.arange(len(toy_data_preprocessed["input_ids_r"]))
for epoch in tqdm(range(EPOCHS)):
    np.random.shuffle(indices)
    print(f"Epoch {epoch}")
    print("-----------------------------------")
    for i in range(0, len(indices), BATCH_SIZE):
        index_slice = indices[i : i + BATCH_SIZE]
        batch_x_r = [toy_data_preprocessed["input_ids_r"][i] for i in index_slice]
        batch_x_g = [toy_data_preprocessed["input_ids_g"][i] for i in index_slice]
        batch_x_g_len = torch.tensor([len(i) for i in batch_x_g])

        batch_y_g = [toy_data_preprocessed["label_ids"][i] for i in index_slice]
        batch_y_g_len = torch.tensor([len(i) for i in batch_y_g])

        batch_x_r = torch.nn.utils.rnn.pad_sequence(
            batch_x_r, batch_first=True, padding_value=0
        ).cuda()
        batch_x_g = torch.nn.utils.rnn.pad_sequence(
            batch_x_g, batch_first=True, padding_value=0
        ).cuda()
        batch_y_g = torch.nn.utils.rnn.pad_sequence(
            batch_y_g, batch_first=True, padding_value=0
        ).cuda()
        print("index_slice")
        print(index_slice)

        # Optimize Generator
        top_k = mama.retrieve(batch_x_r, embedded_corpus, k=3)
        print("top_k")
        print(top_k)

        out, augmented_input_ids = mama.generate(
            query_g=batch_x_g,
            embedded_corpus=embedded_corpus,
            memory_indices=top_k,
            return_augmented_input_ids=True,
        )
        logits = out.logits
        labels = augmented_input_ids[:, 1:].cuda().contiguous()
        labels_r = batch_y_g[:, 6:].cuda().contiguous()
        # labels = batch_x_g[:, 1:].cuda().contiguous()
        shift_offset = 5 + (logits.shape[1] - ((batch_x_g.shape[1] - batch_x_g_len) + batch_y_g_len)).cuda()
        shift_logits_r = logits[:, shift_offset:-1, :].contiguous()
        shift_logits = logits[:, :-1, :].contiguous()
        # shift_offset = (logits.shape[1] - ((batch_x_g.shape[1] - batch_x_g_len))).cuda()
        print("shift_logits_shape")
        print(shift_logits.shape)
        print("-----------")
        # Decode logits
        decoded_logits = g_tokenizer.decode(logits[0].argmax(dim=-1))
        print("decoded_logits")
        print(decoded_logits)

        # Decode labels
        decoded_labels = g_tokenizer.decode(labels[0])
        print("decoded_labels")
        print(decoded_labels)
        print("---------------")

        # Sample 3 random memories
        rand_memory_indices = np.random.choice(len(memory_corpus), 3)
        rand_memory_indices = [(None, i) for i in rand_memory_indices]
        print(rand_memory_indices)

        out2 = mama.generate(batch_x_g, embedded_corpus, rand_memory_indices)
        shift_offset2 = 5 + (out2.logits.shape[1] - ((batch_x_g.shape[1] - batch_x_g_len) + batch_y_g_len)).cuda()
        shift_logits2 = out2.logits[:, shift_offset2:-1, :].contiguous()

        # Decode logits
        decoded_logits2 = g_tokenizer.decode(out2.logits[0].argmax(dim=-1))
        print("decoded_logits2")
        print(decoded_logits2)
        print("-------------")

        print("labels shape")
        print(labels.shape)
        print(labels)



        generator_loss_fn = torch.nn.CrossEntropyLoss()
        loss = generator_loss_fn(
            shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
        )
        optimizer1.zero_grad()
        loss.backward(retain_graph=True)
        # Clip gradients
        torch.nn.utils.clip_grad_norm_(mama.generator.parameters(), 1.0)
        optimizer1.step()
        # Print loss
        generator_loss = loss.item()
        print(loss.item())
        # Optimize Retriever

        # Get embedding for the query
        query_r_emb = mama.retriever.model.forward(batch_x_r)

        print("shift_logits_r shape")
        print(shift_logits_r.shape)

        print("shift_logits2 shape")
        print(shift_logits2.shape)

        # Calculate s_retrieval, the probability of the generated response given the memory
        s_retrieval = F.log_softmax(shift_logits_r, dim=2).gather(dim=2, index=labels_r.unsqueeze(2)).squeeze(2).sum(dim=1)
        # Calculate s_random, the probability of the generated response given a random memory
        s_random = F.log_softmax(shift_logits2, dim=2).gather(dim=2, index=labels_r.unsqueeze(2)).squeeze(2).sum(dim=1)

        print("---------- Retriever ------------")

        print("s_retrieval shape")
        print(s_retrieval)

        print("s_random shape")
        print(s_random)

        print("s_retrieval > s_random")
        print(s_retrieval > s_random)

        # s_r1 = exp(s_retrieval  ) / (exp(s_retrieval ) + exp(s_random ))
        # s_r2 = exp(s_random  ) / (exp(s_retrieval ) + exp(s_random ))

        print("s_r1 shape")
        # print(s_r1.shape)

        print("s_r2 shape")
        # print(s_r2.shape)

        # s_r = torch.stack([s_r1, s_r2], dim=1)
        s_r = F.softmax(torch.stack([s_retrieval, s_random], dim=1), dim=1)

        print(top_k[0][1])

        # Calculate the average embedding of the items in the memory
        avg_memory_emb_r = torch.stack(
            [
                mama.retriever.model.forward(corpos_r_tokens[i[1]]["input_ids"].cuda())
                for i in top_k
            ]
        ).mean(dim=0)
        # mama.retriever.cpu()
        # avg_memory_emb_r = torch.stack(
        #     [
        #         mama.retriever.model.forward(corpos_r_tokens[i[1]]["input_ids"])
        #         for i in top_k
        #     ]
        # ).mean(dim=0)

        # Calculate the average embedding of t
        # he items in the random memory
        avg_random_emb_r = torch.stack(
            [
                mama.retriever.model.forward(corpos_r_tokens[i[1]]["input_ids"].cuda())
                for i in rand_memory_indices
            ]
        ).mean(dim=0)

        # Calculate the cosine similarity between the query and the average memory embedding
        a_memory = cos(query_r_emb, avg_memory_emb_r)

        # Calculate the cosine similarity between the query and the average random memory embedding
        a_random = cos(query_r_emb, avg_random_emb_r)

        # a_r1 = exp(a_memory ) / (exp(a_memory ) + exp(a_random ))
        # a_r2 = exp(a_random ) / (exp(a_memory ) + exp(a_random ))

        # a_r = torch.stack([a_r1, a_r2], dim=1)
        a_r = torch.softmax(torch.stack([a_memory, a_random], dim=1), dim=1)

        print("a_r shape")
        print(a_r.shape)

        print("s_r shape")
        print(s_r.shape)

        print("a_r s_r")
        print(a_r)
        print(s_r)

        # Minimize the KL divergence between a_r and s_r
        if epoch > 10:
            kl_loss = torch.nn.KLDivLoss()
            loss = kl_loss(a_r.log(), s_r.detach())
            optimizer2.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(mama.retriever.parameters(), 1.0)
            optimizer2.step()
        # Print loss
        print("Losses")
        print("-----------------------------")
        print("Generator Loss")
        print(generator_loss)
        print("KL loss")
        print(loss.item())
    # Randomize the dataset
    toy_data_preprocessed, embedded_corpus = randomize_dataset(device="cuda")


  0%|          | 0/30 [00:00<?, ?it/s]

Epoch 0
-----------------------------------
index_slice
[3]


top_k
[(tensor([0.4068], device='cuda:0', grad_fn=<SumBackward1>), 8), (tensor([0.3816], device='cuda:0', grad_fn=<SumBackward1>), 0), (tensor([0.3523], device='cuda:0', grad_fn=<SumBackward1>), 7)]
shift_logits_shape
torch.Size([1, 127, 50280])
-----------
decoded_logits
divT_
< < stack will us to build the memory in the level scale is a to understanding the nextizational Intelligence. to is is.
Thehtml>|
se's a years old. She is in a. Heonica is a artist in theate.Thediv>|

 the to make up alive
 <areth Rff
 the audience. the could a democracy freedom's ability to keep revenue.
Thediv>.
<: following question.The
div
_

 is <? you or or K??
The
div
ignment.

istry-The

decoded_labels
|memory|>Building the technology that allows you to simulate organizational dynamics at the computational level is key to training the Organizational AI and this vision overall.<|endoftext|><|memory|>Veronica is 58 years old. He lives in London. Veronica is an expert in Karate<|endoftext|><|memory|>AI hur

100%|██████████| 10/10 [00:00<00:00, 1805.94it/s]
  3%|▎         | 1/30 [01:07<32:50, 67.94s/it]

Epoch 1
-----------------------------------
index_slice
[5]
top_k
[(tensor([0.5825], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.4791], device='cuda:0', grad_fn=<SumBackward1>), 21), (tensor([0.4115], device='cuda:0', grad_fn=<SumBackward1>), 8)]
shift_logits_shape
torch.Size([1, 108, 50280])
-----------
decoded_logits
|memory|>Ver, my name is Joe Gumbs. I am 36 years old. I live in Paris. I am an expert in Javascript<|endoftext|>
|memory|>Buildingice is Adventures color is Black<|endoftext|>
|memory|>Building the technology that allows you to simulate organizational dynamics at the computational level is key to training the Organizational AI and this vision overall.<|endoftext|>
|memory|>
Answer the given question<|endoftext|>
<|ass|>
What was Kane live?<|endoftext|>
<|assistant|>
Whatape<|endoftext|>
<
decoded_labels
|memory|>Hello, my name is Alice Gumbs. I am 32 years old. I live in Cairo. I am an expert in Javascript<|endoftext|><|memory|>Alice's favorite color is Bla

100%|██████████| 10/10 [00:00<00:00, 1713.92it/s]
  7%|▋         | 2/30 [02:11<30:26, 65.22s/it]

Epoch 2
-----------------------------------
index_slice
[6]
top_k
[(tensor([0.5444], device='cuda:0', grad_fn=<SumBackward1>), 0), (tensor([0.5010], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.4354], device='cuda:0', grad_fn=<SumBackward1>), 21)]
shift_logits_shape
torch.Size([1, 108, 50280])
-----------
decoded_logits
|memory|>Thomachel is 53 years old. He lives in Berlin Town. Rachel is an expert in Karate<|endoftext|><|memory|>On, my name is Aliceumbs..umbs. I am 32 years old. I live in Cairo. I am an expert in Python<|endoftext|><|system|>Onumbs is is favorite color is Black<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
Where is older, Thomson or Thomsonumbsema?<|endoftext|>
<|assistant|>
Thomio<|endoftext|>
<
decoded_labels
|memory|>Rachel is 20 years old. He lives in Cape Town. Rachel is an expert in Karate<|endoftext|><|memory|>Hello, my name is Ganeesh Gumbs. I am 18 years old. I live in Berlin. I am an expert in Javascript<|endoftext|><|me

100%|██████████| 10/10 [00:00<00:00, 2108.86it/s]
 10%|█         | 3/30 [03:12<28:37, 63.61s/it]

Epoch 3
-----------------------------------
index_slice
[8]
top_k
[(tensor([0.7717], device='cuda:0', grad_fn=<SumBackward1>), 0), (tensor([0.4409], device='cuda:0', grad_fn=<SumBackward1>), 7), (tensor([0.4126], device='cuda:0', grad_fn=<SumBackward1>), 8)]
shift_logits_shape
torch.Size([1, 124, 50280])
-----------
decoded_logits
|memory|>G is 20 years old. He lives in Capeos. Adam is an expert in Karate<|endoftext|><|memory|>G hurts ability to keep democracy alive - Gina Neff told the BBC that AI is damaging media organisation's ability to generate profits.<|endoftext|><|system|>G the technology that allows you to simulate organizational dynamics at the computational level is key to training the Organizational AI and this vision overall.<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
What does G live?<|endoftext|><<|assistant|>
G<|endoftext|> in Capeos<|endoftext|><<
decoded_labels
|memory|>Adam is 55 years old. He lives in Lagos. Adam is an expert in Karate<|

100%|██████████| 10/10 [00:00<00:00, 1981.44it/s]
 13%|█▎        | 4/30 [04:12<26:50, 61.94s/it]

Epoch 4
-----------------------------------
index_slice
[9]
top_k
[(tensor([0.7121], device='cuda:0', grad_fn=<SumBackward1>), 0), (tensor([0.5972], device='cuda:0', grad_fn=<SumBackward1>), 1), (tensor([0.4693], device='cuda:0', grad_fn=<SumBackward1>), 2)]
shift_logits_shape
torch.Size([1, 113, 50280])
-----------
decoded_logits
|memory|>Adam is 55 years old. He lives in Lag York. Joe is an expert in Karate<|endoftext|><|memory|>Building 1989, Adam was born in Lag. He is 56 years old. Joe is an expert in Python<|endoftext|><|memory|>Building, my name is Jorge Gumbs. I am 18 years old. I live in Gja. I am an expert in Javascript<|endoftext|><|memory|>
Answer the given question<|endoftext|>
<|user|>
What old is Jorge?<|endoftext|>
<|assistant|>
R<|endoftext|><<
decoded_labels
|memory|>Joe is 49 years old. He lives in New York. Joe is an expert in Karate<|endoftext|><|memory|>On 1989, Joe was born in Moscow. He is 19 years old. Joe is an expert in Python<|endoftext|><|memory|>Hello, my 

100%|██████████| 10/10 [00:00<00:00, 1804.86it/s]
 17%|█▋        | 5/30 [05:04<24:21, 58.48s/it]

Epoch 5
-----------------------------------
index_slice
[3]
top_k
[(tensor([0.5372], device='cuda:0', grad_fn=<SumBackward1>), 0), (tensor([0.4999], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.4033], device='cuda:0', grad_fn=<SumBackward1>), 7)]
shift_logits_shape
torch.Size([1, 120, 50280])
-----------
decoded_logits
|memory|>Joe is 49 years old. He lives in New. Donald is an expert in Karate<|endoftext|><|memory|>On, my name is Jorge Gumbs. I am 36 years old. I live in Abu. I am an expert in Javascript<|endoftext|><|memory|>On hurts ability to keep democracy alive - Gina Neff told the BBC that AI is damaging media organisation's ability to generate profits.<|endoftext|>
|system|>
Answer the given question<|endoftext|>
<|user|>
Where is older, Joe or Joe?<|endoftext|>
<|assistant|>
G<|endoftext|>
<
decoded_labels
|memory|>Donald is 28 years old. He lives in Cairo. Donald is an expert in Karate<|endoftext|><|memory|>Hello, my name is Peter Gumbs. I am 25 years old. I live 

100%|██████████| 10/10 [00:00<00:00, 1697.41it/s]
 20%|██        | 6/30 [05:58<22:45, 56.91s/it]

Epoch 6
-----------------------------------
index_slice
[9]
top_k
[(tensor([0.6524], device='cuda:0', grad_fn=<SumBackward1>), 1), (tensor([0.4338], device='cuda:0', grad_fn=<SumBackward1>), 7), (tensor([0.3994], device='cuda:0', grad_fn=<SumBackward1>), 8)]
shift_logits_shape
torch.Size([1, 122, 50280])
-----------
decoded_logits
|memory|>Donald 1989, Joe was born in Moscow. He is 19 years old. Adam is an expert in Python<|endoftext|><|memory|>AI hurts ability to keep democracy alive - Gina Neff told the BBC that AI is damaging media organisation's ability to generate profits.<|endoftext|><|memory|>Hello the technology that allows you to simulate organizational dynamics at the computational level is key to training the Organizational AI and this vision overall.<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
What old is Adam?<|endoftext|>
<|assistant|>
Adam<|endoftext|>
<
decoded_labels
|memory|>On 1989, Adam was born in London. He is 40 years old. Adam is an ex

100%|██████████| 10/10 [00:00<00:00, 1845.03it/s]
 23%|██▎       | 7/30 [06:55<21:50, 56.98s/it]

Epoch 7
-----------------------------------
index_slice
[3]
top_k
[(tensor([0.6511], device='cuda:0', grad_fn=<SumBackward1>), 0), (tensor([0.4448], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.3496], device='cuda:0', grad_fn=<SumBackward1>), 8)]
shift_logits_shape
torch.Size([1, 124, 50280])
-----------
decoded_logits
|memory|>On is 28 years old. He lives in Moscowja. Peter is an expert in Karate<|endoftext|><|memory|>On, my name is Peter G Gumbs. I am an years old. I live in New. I am an expert in Javascript<|endoftext|><|memory|>Building the technology that allows you to simulate organizational dynamics at the computational level is key to training the Organizational AI and this vision overall.<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
How is older, Peter or or Peter?<|endoftext|>
<|assistant|>
Peter<|endoftext|>
<
decoded_labels
|memory|>Peter is 50 years old. He lives in Abuja. Peter is an expert in Karate<|endoftext|><|memory|>Hello, my na

100%|██████████| 10/10 [00:00<00:00, 2034.00it/s]
 27%|██▋       | 8/30 [07:50<20:35, 56.17s/it]

Epoch 8
-----------------------------------
index_slice
[0]
top_k
[(tensor([0.6025], device='cuda:0', grad_fn=<SumBackward1>), 0), (tensor([0.4105], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.3542], device='cuda:0', grad_fn=<SumBackward1>), 20)]
shift_logits_shape
torch.Size([1, 104, 50280])
-----------
decoded_logits
|memory|>Hello is 49 years old. He lives in Abu. Joe is an expert in Karate<|endoftext|><|memory|>Hello, my name is Samolaumbs. I am 33 years old. I live in Parisos. I am an expert in Javascript<|endoftext|><|memory|>Hello are right. The document may become large by large list.<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
What old is Joe?<|endoftext|>
<|assistant|>
43<|endoftext|>
<
decoded_labels
|memory|>Joe is 43 years old. He lives in Cairo. Joe is an expert in Karate<|endoftext|><|memory|>Hello, my name is Nic Gumbs. I am 39 years old. I live in Lagos. I am an expert in Javascript<|endoftext|><|memory|>You are right. Query docu

100%|██████████| 10/10 [00:00<00:00, 1922.67it/s]
 30%|███       | 9/30 [08:43<19:19, 55.19s/it]

Epoch 9
-----------------------------------
index_slice
[8]
top_k
[(tensor([0.7653], device='cuda:0', grad_fn=<SumBackward1>), 0), (tensor([0.3991], device='cuda:0', grad_fn=<SumBackward1>), 7), (tensor([0.3543], device='cuda:0', grad_fn=<SumBackward1>), 2)]
shift_logits_shape
torch.Size([1, 125, 50280])
-----------
decoded_logits
|memory|>Joeorge is 39 years old. He lives in Cairo Town. Jerry is an expert in Karate<|endoftext|><|memory|>Hello hurts ability to keep democracy alive - Gina Neff told the BBC that AI is damaging media organisation's ability to generate profits.<|endoftext|><|memory|>Hello, my name is Nic Gumbs. I am 39 years old. I live in Lagja. I am an expert in Javascript<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
Where does Adam live?<|endoftext|>
<|assistant|>
Aborge lives in Abu Town<|endoftext|>
<
decoded_labels
|memory|>Jerry is 54 years old. He lives in Cape Town. Jerry is an expert in Karate<|endoftext|><|memory|>AI hurts ability to ke

100%|██████████| 10/10 [00:00<00:00, 1556.27it/s]
 33%|███▎      | 10/30 [09:35<18:04, 54.24s/it]

Epoch 10
-----------------------------------
index_slice
[5]
top_k
[(tensor([0.5200], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.4379], device='cuda:0', grad_fn=<SumBackward1>), 7), (tensor([0.4232], device='cuda:0', grad_fn=<SumBackward1>), 21)]
shift_logits_shape
torch.Size([1, 107, 50280])
-----------
decoded_logits
|memory|>J, my name is Adam Gumbs. I am 52 years old. I live in Abu. I am an expert in Javascript<|endoftext|><|memory|>J hurts ability to keep democracy alive - Gina Neff told the BBC that AI is damaging media organisation's ability to generate profits.<|endoftext|><|memory|>Jerry's favorite color is Blue<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
What was Jerry live?<|endoftext|>
<|assistant|>
Cairo<|endoftext|>
<
decoded_labels
|memory|>Hello, my name is Jerry Gumbs. I am 31 years old. I live in Cairo. I am an expert in Javascript<|endoftext|><|memory|>AI hurts ability to keep democracy alive - Gina Neff told the BBC that AI i

100%|██████████| 10/10 [00:00<00:00, 1771.32it/s]
 37%|███▋      | 11/30 [10:26<16:50, 53.21s/it]

Epoch 11
-----------------------------------
index_slice
[2]
top_k
[(tensor([0.5623], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.5558], device='cuda:0', grad_fn=<SumBackward1>), 0), (tensor([0.4223], device='cuda:0', grad_fn=<SumBackward1>), 21)]
shift_logits_shape
torch.Size([1, 101, 50280])
-----------
decoded_logits
|memory|>Hello, my name is Jerry Gumbs. I am 31 years old. I live in Cairoja. I am an expert in Javascript<|endoftext|><|memory|>Hello's 46 years old. He lives in Cairoja. Paul is an expert in Karate<|endoftext|><|memory|>Hello's favorite color is Green<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
What is Paul's last name?<|endoftext|>
<|assistant|>
Gumbs<|endoftext|>
<
decoded_labels
|memory|>Hello, my name is Paul Gumbs. I am 24 years old. I live in Abuja. I am an expert in Javascript<|endoftext|><|memory|>Paul is 32 years old. He lives in Abuja. Paul is an expert in Karate<|endoftext|><|memory|>Paul's favorite color is White<|en



Losses
-----------------------------
Generator Loss
0.6328125
KL loss
0.042722851037979126
index_slice
[1]
top_k
[(tensor([0.7457], device='cuda:0', grad_fn=<SumBackward1>), 1), (tensor([0.4182], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.3695], device='cuda:0', grad_fn=<SumBackward1>), 7)]
shift_logits_shape
torch.Size([1, 121, 50280])
-----------
decoded_logits
|memory|>Al 1989, Peter was born in Cairo. He is 48 years old. Alice is an expert in Python<|endoftext|><|memory|>Hello, my name is Jerry Gumbs. I am 31 years old. I live in Newja. I am an expert in Javascript<|endoftext|><|memory|>Hello hurts ability to keep democracy alive - Gina Neff told the BBC that AI is damaging media organisation's ability to generate profits.<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
What was Alice born?<|endoftext|>
<|assistant|>
Paris<|endoftext|>
<
decoded_labels
|memory|>On 1989, Alice was born in Paris. He is 43 years old. Alice is an expert in Python<|e

100%|██████████| 10/10 [00:00<00:00, 1374.87it/s]
 40%|████      | 12/30 [11:50<18:49, 62.75s/it]

Epoch 12
-----------------------------------
index_slice
[9]
top_k
[(tensor([0.7634], device='cuda:0', grad_fn=<SumBackward1>), 1), (tensor([0.4764], device='cuda:0', grad_fn=<SumBackward1>), 0), (tensor([0.4221], device='cuda:0', grad_fn=<SumBackward1>), 2)]
shift_logits_shape
torch.Size([1, 115, 50280])
-----------
decoded_logits
|memory|>Al 1989, Peter was born in Newos. He is 24 years old. Django is an expert in Python<|endoftext|><|memory|>Helloerryema is 46 years old. He lives in New. Kaneema is an expert in Karate<|endoftext|><|memory|>Hello, my name is Paul Gumbs. I am 31 years old. I live in Abu. I am an expert in Javascript<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
How old is Django?<|endoftext|>
<|assistant|>
37<|endoftext|>
<
decoded_labels
|memory|>On 1989, Django was born in Lagos. He is 37 years old. Django is an expert in Python<|endoftext|><|memory|>Kaneema is 37 years old. He lives in Berlin. Kaneema is an expert in Karate<|endoftext|><|me

100%|██████████| 10/10 [00:00<00:00, 1799.05it/s]
 43%|████▎     | 13/30 [13:08<19:04, 67.30s/it]

Epoch 13
-----------------------------------
index_slice
[4]
top_k
[(tensor([0.6793], device='cuda:0', grad_fn=<SumBackward1>), 21), (tensor([0.4480], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.3230], device='cuda:0', grad_fn=<SumBackward1>), 11)]
shift_logits_shape
torch.Size([1, 96, 50280])
-----------
decoded_logits
|memory|>Paul's favorite color is White<|endoftext|><|memory|>AI, my name is Nic Gumbs. I am 31 years old. I live in London. I am an expert in Javascript<|endoftext|><|memory|>AI can use the Python json module to pretty-print the JSON data. <|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
What is Nic's favorite color?<|endoftext|>
<|assistant|>
White<|endoftext|>
<
decoded_labels
|memory|>Nic's favorite color is Red<|endoftext|><|memory|>Hello, my name is Nic Gumbs. I am 23 years old. I live in Berlin. I am an expert in Javascript<|endoftext|><|memory|>We can use the Python json module to pretty-print the JSON data. <|endoftext|><|sys

100%|██████████| 10/10 [00:00<00:00, 1820.13it/s]
 47%|████▋     | 14/30 [14:25<18:43, 70.22s/it]

Epoch 14
-----------------------------------
index_slice
[4]
top_k
[(tensor([0.7659], device='cuda:0', grad_fn=<SumBackward1>), 21), (tensor([0.6473], device='cuda:0', grad_fn=<SumBackward1>), 1), (tensor([0.5229], device='cuda:0', grad_fn=<SumBackward1>), 2)]
shift_logits_shape
torch.Size([1, 112, 50280])
-----------
decoded_logits
|memory|>Helloaneesh is favorite color is Red<|endoftext|><|memory|>AI 1989, Donaldaneesh was born in Berlin. He is 37 years old. Ganeesh is an expert in Python<|endoftext|><|memory|>AI, my name is Nicaneesh Gumbs. I am 39 years old. I live in Berlin. I am an expert in Javascript<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
What is Ganeesh's favorite color?<|endoftext|>
<|assistant|>
Redellow<|endoftext|>
<
decoded_labels
|memory|>Ganeesh's favorite color is Yellow<|endoftext|><|memory|>On 1989, Ganeesh was born in Berlin. He is 39 years old. Ganeesh is an expert in Python<|endoftext|><|memory|>Hello, my name is Ganeesh Gumbs. I am

100%|██████████| 10/10 [00:00<00:00, 1727.90it/s]
 50%|█████     | 15/30 [15:40<17:53, 71.56s/it]

Epoch 15
-----------------------------------
index_slice
[2]
top_k
[(tensor([0.6075], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.4713], device='cuda:0', grad_fn=<SumBackward1>), 21), (tensor([0.4023], device='cuda:0', grad_fn=<SumBackward1>), 7)]
shift_logits_shape
torch.Size([1, 113, 50280])
-----------
decoded_logits
|memory|>On, my name is Gema Gumbs. I am 22 years old. I live in Berlinos. I am an expert in Javascript<|endoftext|><|memory|>Kaneema's favorite color is Yellow<|endoftext|><|memory|>K hurts ability to keep democracy alive - Gina Neff told the BBC that AI is damaging media organisation's ability to generate profits.<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
How is Kaneema's favorite name?<|endoftext|>
<|assistant|>
Gumbs<|endoftext|>
<
decoded_labels
|memory|>Hello, my name is Kaneema Gumbs. I am 18 years old. I live in Lagos. I am an expert in Javascript<|endoftext|><|memory|>Kaneema's favorite color is Purple<|endoftext|><|mem

100%|██████████| 10/10 [00:00<00:00, 1725.62it/s]
 53%|█████▎    | 16/30 [16:57<17:05, 73.25s/it]

Epoch 16
-----------------------------------
index_slice
[9]
top_k
[(tensor([0.6108], device='cuda:0', grad_fn=<SumBackward1>), 1), (tensor([0.5215], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.4811], device='cuda:0', grad_fn=<SumBackward1>), 0)]
shift_logits_shape
torch.Size([1, 114, 50280])
-----------
decoded_logits
|memory|>On 1989, Gantha born in Paris. He is 55 years old. Sam is an expert in Python<|endoftext|><|memory|>Hello, my name is Kane Gumbs. I am 18 years old. I live in Lagos. I am an expert in Javascript<|endoftext|><|memory|>Helloice's 18 years old. He lives in Cairoja. Alice is an expert in Karate<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
Where old is Samantha<|endoftext|>
<|assistant|>
47<|endoftext|>
<
decoded_labels
|memory|>On 1989, Sam was born in Cairo. He is 47 years old. Sam is an expert in Python<|endoftext|><|memory|>Hello, my name is Alice Gumbs. I am 36 years old. I live in Lagos. I am an expert in Javascript<|endof

100%|██████████| 10/10 [00:00<00:00, 2121.44it/s]
 57%|█████▋    | 17/30 [18:11<15:54, 73.42s/it]

Epoch 17
-----------------------------------
index_slice
[8]
top_k
[(tensor([0.8334], device='cuda:0', grad_fn=<SumBackward1>), 0), (tensor([0.5627], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.4716], device='cuda:0', grad_fn=<SumBackward1>), 1)]
shift_logits_shape
torch.Size([1, 119, 50280])
-----------
decoded_logits
|memory|>Al's 39 years old. He lives in Abuja. Nic is an expert in Karate<|endoftext|><|memory|>Hello, my name is Nic Gumbs. I am 36 years old. I live in Lag. I am an expert in Javascript<|endoftext|><|memory|>Nic 1989, Nic was was born in Parisja. He is 55 years old. Kaneema is an expert in Python<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
What does Nic live?<|endoftext|>
<|assistant|>
Lag lives in Abuja<|endoftext|>
<
decoded_labels
|memory|>Nic is 57 years old. He lives in Abuja. Nic is an expert in Karate<|endoftext|><|memory|>Hello, my name is Nic Gumbs. I am 26 years old. I live in Berlin. I am an expert in Javascript<|endof

100%|██████████| 10/10 [00:00<00:00, 1894.45it/s]
 60%|██████    | 18/30 [19:28<14:56, 74.71s/it]

Epoch 18
-----------------------------------
index_slice
[3]
top_k
[(tensor([0.6405], device='cuda:0', grad_fn=<SumBackward1>), 0), (tensor([0.6162], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.3993], device='cuda:0', grad_fn=<SumBackward1>), 1)]
shift_logits_shape
torch.Size([1, 115, 50280])
-----------
decoded_logits
|memory|>Nic is 57 years old. He lives in Abuos. Peter is an expert in Karate<|endoftext|><|memory|>Nic, my name is Peter Gumbs. I am 26 years old. I live in Berlin. I am an expert in Javascript<|endoftext|><|memory|>On 1989, Kane was born in Abu. He is 56 years old. Kane is an expert in Python<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
What is older, Kane or Peter?<|endoftext|>
<|assistant|>
Peter<|endoftext|>
<
decoded_labels
|memory|>Peter is 54 years old. He lives in Lagos. Peter is an expert in Karate<|endoftext|><|memory|>Hello, my name is Joe Gumbs. I am 41 years old. I live in Moscow. I am an expert in Javascript<|endoftex

100%|██████████| 10/10 [00:00<00:00, 1587.25it/s]
 63%|██████▎   | 19/30 [20:45<13:47, 75.26s/it]

Epoch 19
-----------------------------------
index_slice
[6]
top_k
[(tensor([0.7687], device='cuda:0', grad_fn=<SumBackward1>), 0), (tensor([0.4901], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.4232], device='cuda:0', grad_fn=<SumBackward1>), 7)]
shift_logits_shape
torch.Size([1, 123, 50280])
-----------
decoded_logits
|memory|>Hello is 57 years old. He lives in Lag York. Donald is an expert in Karate<|endoftext|><|memory|>Hello, my name is Joeonica Gumbs. I am 41 years old. I live in Moscow. I am an expert in Javascript<|endoftext|><|memory|>On hurts ability to keep democracy alive - Gina Neff told the BBC that AI is damaging media organisation's ability to generate profits.<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
What is older, Donald or Veronica?<|endoftext|>
<|assistant|>
Donald<|endoftext|>
<
decoded_labels
|memory|>Donald is 45 years old. He lives in New York. Donald is an expert in Karate<|endoftext|><|memory|>Hello, my name is Veronic

100%|██████████| 10/10 [00:00<00:00, 1541.57it/s]
 67%|██████▋   | 20/30 [22:02<12:39, 75.93s/it]

Epoch 20
-----------------------------------
index_slice
[1]
top_k
[(tensor([0.7193], device='cuda:0', grad_fn=<SumBackward1>), 1), (tensor([0.5271], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.4367], device='cuda:0', grad_fn=<SumBackward1>), 0)]
shift_logits_shape
torch.Size([1, 118, 50280])
-----------
decoded_logits
|memory|>On 1989, Niconica was born in Cairo. He is 20 years old. Veronica is an expert in Python<|endoftext|><|memory|>Hello, my name is Verema Gumbs. I am 20 years old. I live in London. I am an expert in Javascript<|endoftext|><|memory|>AIson is 54 years old. He lives in New York. Thom is an expert in Karate<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
Where does Veronica born?<|endoftext|>
<|assistant|>
Cairo<|endoftext|>
<
decoded_labels
|memory|>On 1989, Veronica was born in Cairo. He is 40 years old. Veronica is an expert in Python<|endoftext|><|memory|>Hello, my name is Kaneema Gumbs. I am 25 years old. I live in Berlin. I a

100%|██████████| 10/10 [00:00<00:00, 2090.78it/s]
 70%|███████   | 21/30 [23:25<11:40, 77.87s/it]

Epoch 21
-----------------------------------
index_slice
[8]
top_k
[(tensor([0.7970], device='cuda:0', grad_fn=<SumBackward1>), 0), (tensor([0.4565], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.4155], device='cuda:0', grad_fn=<SumBackward1>), 1)]
shift_logits_shape
torch.Size([1, 120, 50280])
-----------
decoded_logits
|memory|>Helloh is Ver years old. He lives in New York. Minh is an expert in Karate<|endoftext|><|memory|>Hello, my name is Kane Gumbs. I am 25 years old. I live in Berlin. I am an expert in Javascript<|endoftext|><|system|>On 1989, Ver was born in Cairo. He is 40 years old. Donald is an expert in Python<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
Where does Minh live?<|endoftext|>
<|assistant|>
Minh lives in New York<|endoftext|>
<
decoded_labels
|memory|>Minh is 19 years old. He lives in New York. Minh is an expert in Karate<|endoftext|><|memory|>Hello, my name is Thomson Gumbs. I am 18 years old. I live in London. I am an expert

100%|██████████| 10/10 [00:00<00:00, 1439.76it/s]
 73%|███████▎  | 22/30 [24:43<10:24, 78.01s/it]

Epoch 22
-----------------------------------
index_slice
[4]
top_k
[(tensor([0.6146], device='cuda:0', grad_fn=<SumBackward1>), 21), (tensor([0.5327], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.3849], device='cuda:0', grad_fn=<SumBackward1>), 1)]
shift_logits_shape
torch.Size([1, 102, 50280])
-----------
decoded_logits
|memory|>Thom is favorite color is Green<|endoftext|><|memory|>Hello, my name is Kane Gumbs. I am 18 years old. I live in London. I am an expert in Javascript<|endoftext|><|memory|>On 1989, Donald was born in London. He is 29 years old. Django is an expert in Python<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
What is Sam's favorite color?<|endoftext|>
<|assistant|>
Greenple<|endoftext|>
<
decoded_labels
|memory|>Sam's favorite color is Purple<|endoftext|><|memory|>Hello, my name is Sam Gumbs. I am 31 years old. I live in Moscow. I am an expert in Javascript<|endoftext|><|memory|>On 1989, Django was born in Cairo. He is 46 years ol

100%|██████████| 10/10 [00:00<00:00, 1594.85it/s]
 77%|███████▋  | 23/30 [25:59<09:02, 77.53s/it]

Epoch 23
-----------------------------------
index_slice
[6]
top_k
[(tensor([0.6023], device='cuda:0', grad_fn=<SumBackward1>), 0), (tensor([0.5752], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.3335], device='cuda:0', grad_fn=<SumBackward1>), 1)]
shift_logits_shape
torch.Size([1, 118, 50280])
-----------
decoded_logits
|memory|>Minerry is 32 years old. He lives in Moscow. Jorge is an expert in Karate<|endoftext|><|memory|>Hello, my name is Sam Gumbs. I am 18 years old. I live in Moscowja. I am an expert in Javascript<|endoftext|><|memory|>On 1989, Django was born in Londonja. He is 46 years old. Rachel is an expert in Python<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
What is older, Sam or Paul?<|endoftext|>
<|assistant|>
Jerry<|endoftext|>
<
decoded_labels
|memory|>Jorge is 21 years old. He lives in Moscow. Jorge is an expert in Karate<|endoftext|><|memory|>Hello, my name is Paul Gumbs. I am 21 years old. I live in Abuja. I am an expert in Javas

100%|██████████| 10/10 [00:00<00:00, 1670.17it/s]
 80%|████████  | 24/30 [27:19<07:48, 78.08s/it]

Epoch 24
-----------------------------------
index_slice
[2]
top_k
[(tensor([0.6235], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.4732], device='cuda:0', grad_fn=<SumBackward1>), 1), (tensor([0.4162], device='cuda:0', grad_fn=<SumBackward1>), 0)]
shift_logits_shape
torch.Size([1, 115, 50280])
-----------
decoded_logits
|memory|>Hello, my name is Sam Gumbs. I am 31 years old. I live in Abu. I am an expert in Javascript<|endoftext|><|memory|>On 1989, Rachel was born in Abu. He is 53 years old. Rachel is an expert in Python<|endoftext|><|memory|>J is 21 years old. He lives in Moscow Town. Peter is an expert in Karate<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
Where is Jerry's last name?<|endoftext|>
<|assistant|>
Gumbs<|endoftext|>
<
decoded_labels
|memory|>Hello, my name is Jerry Gumbs. I am 31 years old. I live in Berlin. I am an expert in Javascript<|endoftext|><|memory|>On 1989, Rachel was born in Berlin. He is 55 years old. Rachel is an expert

100%|██████████| 10/10 [00:00<00:00, 1866.46it/s]
 83%|████████▎ | 25/30 [28:41<06:35, 79.16s/it]

Epoch 25
-----------------------------------
index_slice
[5]
top_k
[(tensor([0.6669], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.5012], device='cuda:0', grad_fn=<SumBackward1>), 1), (tensor([0.4702], device='cuda:0', grad_fn=<SumBackward1>), 0)]
shift_logits_shape
torch.Size([1, 115, 50280])
-----------
decoded_logits
|memory|>On, my name is Jerry Gumbs. I am 31 years old. I live in Cape. I am an expert in Javascript<|endoftext|><|memory|>On 1989, Rachel was born in Berlin. He is 55 years old. Rachel is an expert in Python<|endoftext|><|memory|>Peterorgeesh's 18 years old. He lives in Cape. Gumbsesh is an expert in Karate<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
How does Rachel live?<|endoftext|>
<|assistant|>
Ber<|endoftext|>
<
decoded_labels
|memory|>Hello, my name is Sam Gumbs. I am 28 years old. I live in London. I am an expert in Javascript<|endoftext|><|memory|>On 1989, Rachel was born in Berlin. He is 55 years old. Rachel is an expert 

100%|██████████| 10/10 [00:00<00:00, 1440.20it/s]
 87%|████████▋ | 26/30 [30:02<05:19, 79.78s/it]

Epoch 26
-----------------------------------
index_slice
[0]
top_k
[(tensor([0.7816], device='cuda:0', grad_fn=<SumBackward1>), 0), (tensor([0.5556], device='cuda:0', grad_fn=<SumBackward1>), 1), (tensor([0.5294], device='cuda:0', grad_fn=<SumBackward1>), 2)]
shift_logits_shape
torch.Size([1, 116, 50280])
-----------
decoded_logits
|memory|>Onaneema is 37 years old. He lives in Capeos. Kaneema is an expert in Karate<|endoftext|><|memory|>Hello 1989, Rachel was born in Berlin. He is 55 years old. Alice is an expert in Python<|endoftext|><|memory|>Hello, my name is Sam Gumbs. I am 31 years old. I live in Berlin. I am an expert in Javascript<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
Who old is Kaneema?<|endoftext|>
<|assistant|>
58<|endoftext|>
<
decoded_labels
|memory|>Kaneema is 58 years old. He lives in Lagos. Kaneema is an expert in Karate<|endoftext|><|memory|>On 1989, Alice was born in Paris. He is 50 years old. Alice is an expert in Python<|endoftext|><

100%|██████████| 10/10 [00:00<00:00, 1908.50it/s]
 90%|█████████ | 27/30 [31:25<04:02, 80.75s/it]

Epoch 27
-----------------------------------
index_slice
[9]
top_k
[(tensor([0.6673], device='cuda:0', grad_fn=<SumBackward1>), 1), (tensor([0.5488], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.5071], device='cuda:0', grad_fn=<SumBackward1>), 0)]
shift_logits_shape
torch.Size([1, 113, 50280])
-----------
decoded_logits
|memory|>G 1989, Alice was born in Paris. He is 55 years old. Nic is an expert in Python<|endoftext|><|memory|>Hello, my name is Sam Gumbs. I am 30 years old. I live in Lag. I am an expert in Javascript<|endoftext|><|memory|>Onson is 49 years old. He lives in Lag Angeles. Thomson is an expert in Karate<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
How old is Nic?<|endoftext|>
<|assistant|>
20<|endoftext|>
<
decoded_labels
|memory|>On 1989, Nic was born in Cairo. He is 20 years old. Nic is an expert in Python<|endoftext|><|memory|>Hello, my name is Peter Gumbs. I am 22 years old. I live in London. I am an expert in Javascript<|endofte

100%|██████████| 10/10 [00:00<00:00, 1579.06it/s]
 93%|█████████▎| 28/30 [32:44<02:40, 80.39s/it]

Epoch 28
-----------------------------------
index_slice
[6]
top_k
[(tensor([0.6935], device='cuda:0', grad_fn=<SumBackward1>), 2), (tensor([0.6890], device='cuda:0', grad_fn=<SumBackward1>), 0), (tensor([0.4850], device='cuda:0', grad_fn=<SumBackward1>), 1)]
shift_logits_shape
torch.Size([1, 121, 50280])
-----------
decoded_logits
|memory|>Thom, my name is Peter Gumbs. I am 22 years old. I live in Londonos. I am an expert in Javascript<|endoftext|><|memory|>Thom is is 55 years old. He lives in Paris. Samantha is an expert in Karate<|endoftext|><|memory|>On 1989, Niconica was born in Cairo. He is 40 years old. Veronica is an expert in Python<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
Where is older, Adamantha or Adam?<|endoftext|>
<|assistant|>
Sam<|endoftext|><|endoftext|>
<
decoded_labels
|memory|>Hello, my name is Adam Gumbs. I am 19 years old. I live in Lagos. I am an expert in Javascript<|endoftext|><|memory|>Samantha is 30 years old. He lives in Moscow

100%|██████████| 10/10 [00:00<00:00, 2095.79it/s]
 97%|█████████▋| 29/30 [34:00<01:19, 79.08s/it]

Epoch 29
-----------------------------------
index_slice
[1]
top_k
[(tensor([0.7139], device='cuda:0', grad_fn=<SumBackward1>), 1), (tensor([0.5782], device='cuda:0', grad_fn=<SumBackward1>), 0), (tensor([0.5606], device='cuda:0', grad_fn=<SumBackward1>), 2)]
shift_logits_shape
torch.Size([1, 115, 50280])
-----------
decoded_logits
|memory|>Hello 1989, Ver was born in Moscowja. He is 20 years old. Thomson is an expert in Python<|endoftext|><|memory|>Hello is 32 years old. He lives in Capeos. Donald is an expert in Karate<|endoftext|><|memory|>Hello, my name is Adam Gumbs. I am 19 years old. I live in Lag. I am an expert in Javascript<|endoftext|><|system|>
Answer the given question<|endoftext|>
<|user|>
Where was Thomson born?<|endoftext|>
<|assistant|>
Abuja<|endoftext|>
<
decoded_labels
|memory|>On 1989, Thomson was born in Abuja. He is 28 years old. Thomson is an expert in Python<|endoftext|><|memory|>Donald is 24 years old. He lives in Lagos. Donald is an expert in Karate<|endoftex

100%|██████████| 10/10 [00:00<00:00, 1488.03it/s]
100%|██████████| 30/30 [35:18<00:00, 70.61s/it]


In [5]:
sum([(1/100) * (max(n-1, 0) / 99) for n in range(1, 101)])

0.5

In [3]:
weighted_sum = 0
for n in range(1, 101):
  print(f"probability for n={n}")
  prob = (max(n-1, 0) / 99)
  print(prob)
  weighted_sum += (1/101) * prob
weighted_sum

probability for n=0
0.0
probability for n=1
0.0
probability for n=2
0.010101010101010102
probability for n=3
0.020202020202020204
probability for n=4
0.030303030303030304
probability for n=5
0.04040404040404041
probability for n=6
0.050505050505050504
probability for n=7
0.06060606060606061
probability for n=8
0.0707070707070707
probability for n=9
0.08080808080808081
probability for n=10
0.09090909090909091
probability for n=11
0.10101010101010101
probability for n=12
0.1111111111111111
probability for n=13
0.12121212121212122
probability for n=14
0.13131313131313133
probability for n=15
0.1414141414141414
probability for n=16
0.15151515151515152
probability for n=17
0.16161616161616163
probability for n=18
0.1717171717171717
probability for n=19
0.18181818181818182
probability for n=20
0.1919191919191919
probability for n=21
0.20202020202020202
probability for n=22
0.21212121212121213
probability for n=23
0.2222222222222222
probability for n=24
0.23232323232323232
probability for n=2

0.4950495049504948

In [46]:
top_k

[(tensor([0.7401], device='cuda:0', grad_fn=<SumBackward1>), 1),
 (tensor([0.6285], device='cuda:0', grad_fn=<SumBackward1>), 0),
 (tensor([0.4782], device='cuda:0', grad_fn=<SumBackward1>), 2)]

In [47]:
corpos_r_tokens[2]

{'input_ids': tensor([[  101,  7592,  1010,  2026,  2171,  2003,  8472, 14545, 16031,  5910,
          1012,  1045,  2572,  2676,  2086,  2214,  1012,  1045,  2444,  1999,
          2047,  2259,  1012,  1045,  2572,  2019,  6739,  1999,  9262, 22483,
           102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1]])}

In [48]:
out3 = mama.forward(
  query_r=toy_data_preprocessed["input_ids_r"][0].unsqueeze(0).cuda(),
  query_g=toy_data_preprocessed["input_ids_g"][0].unsqueeze(0).cuda(),
  embedded_corpus=embedded_corpus
)

In [49]:
torch.save(mama.state_dict(), "./mama_toy.pt")

In [50]:
out3

CausalLMOutput(logits=tensor([[[  6.0000,  -9.5625,   7.0625,  ...,  -9.6875,  -9.5000,  -9.6875],
         [ 11.5000,  -5.4062,  11.8750,  ...,  -5.3438,  -4.9375,  -4.6562],
         [  8.2500, -11.5000,   7.5938,  ..., -11.1875, -11.5625, -11.4375],
         ...,
         [ 31.7500,  -9.8750,  11.9375,  ..., -10.0625, -10.5625, -10.5000],
         [ 25.2500,  -5.3125,  15.1875,  ...,  -4.6250,  -4.8750,  -4.7812],
         [ 21.2500,  -2.9688,  15.9375,  ...,  -2.4531,  -2.8125,  -2.7188]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>))

In [51]:
s = "'<|system|>\nAnswer the given question<|endoftext|>\n<|user|>\nHow old is Minh?<|endoftext|>\n"
s_tokens = g_tokenizer.encode(s, add_special_tokens=False, return_tensors="pt")
s_tokens_r = r_tokenizer.encode(s, add_special_tokens=False, return_tensors="pt")
s_tokens.shape

torch.Size([1, 25])

In [52]:
s_tokens.cuda()

tensor([[    8,    29,    93, 10394, 49651,   187, 32869,   253,  1677,  1953,
             0,   187,    29,    93,  4537, 49651,   187,  2347,  1711,   310,
          3689,    73,    32,     0,   187]], device='cuda:0')

In [53]:
out3s = mama.forward(
  query_r=s_tokens_r.cuda(),
  query_g=s_tokens.cuda(),
  embedded_corpus=embedded_corpus
)

In [54]:
topks = mama.retrieve(
  query_r=s_tokens_r.cuda(),
  embedded_corpus=embedded_corpus
)
topks

[(tensor([0.4305], device='cuda:0', grad_fn=<SumBackward1>), 1),
 (tensor([0.4060], device='cuda:0', grad_fn=<SumBackward1>), 2),
 (tensor([0.3624], device='cuda:0', grad_fn=<SumBackward1>), 0)]

In [55]:
out3s = mama.generate(
  query_g=s_tokens.cuda(),
  embedded_corpus=embedded_corpus,
  memory_indices=topks
)

In [56]:
g_tokenizer.decode(out3s.logits[0].argmax(dim=-1))

'|memory|>Adam 1989, Adamantha was born in Abu. He is 26 years old. Samantha is an expert in Python<|endoftext|><|memory|>Hello, my name is Thomsonanthaumbs. I am 19 years old. I live in Lag. I am an expert in Javascript<|endoftext|><|memory|>Adam is 41 years old. He lives in Parisja. Donald is an expert in Karate<|endoftext|><<|user|>\nAnswer the given question<|endoftext|>\n<|user|>\nHow old is Samh?<|endoftext|>\n<'

In [57]:
decoded_logits

'|memory|>Adam 1989, Ver was born in Lagos. He is 26 years old. Adam is an expert in Python<|endoftext|><|memory|>Helloachel is 23 years old. He lives in Paris. Rachel is an expert in Karate<|endoftext|><|memory|>Hello, my name is Thomson Gumbs. I am 21 years old. I live in Paris. I am an expert in Javascript<|endoftext|><|system|>\nAnswer the given question<|endoftext|>\n<|user|>\nHow was Adam born?<|endoftext|>\n<|assistant|>\nLagos<|endoftext|>\n<'

In [58]:
rand_memory_indices = np.random.choice(len(memory_corpus), 3)
rand_memory_indices = [(None, i) for i in rand_memory_indices]

In [59]:
out4 = mama.generate(
  query_g=toy_data_preprocessed["input_ids_g"][0].unsqueeze(0).cuda(),
  embedded_corpus=embedded_corpus,
  memory_indices=rand_memory_indices
)

In [60]:
g_tokenizer.decode(out4.logits[0].argmax(dim=-1))

"|memory|>Adam more of's is<|endoftext|> the favorite game sources source sites.<|memory|>Hello the technology that allows you to simulate organizational dynamics at the computational level is key to training the Organizational AI and this vision overall.<|endoftext|><|memory|>Hello are right. Query document may become large by large list.<|endoftext|><|system|>\nAnswer the given question<|endoftext|>\n<|user|>\nWhat old is Ver?<|endoftext|>\n<|assistant|>\n35<|endoftext|>\n<"

In [61]:
topk = mama.retrieve(
  query_r=toy_data_preprocessed["input_ids_r"][0].unsqueeze(0).cuda(),
  embedded_corpus=embedded_corpus
)

out5 = mama.generate(
  query_g=toy_data_preprocessed["input_ids_g"][0].unsqueeze(0).cuda(),
  embedded_corpus=embedded_corpus,
  memory_indices=topk
)

In [62]:
g_tokenizer.decode(out5.logits[0].argmax(dim=-1))

'|memory|>Adam is 41 years old. He lives in Parisja. Donald is an expert in Karate<|endoftext|><|memory|>Hello 1989, Adamantha was born in Abu. He is 26 years old. Samantha is an expert in Python<|endoftext|><|memory|>Hello, my name is Thomsonanthaumbs. I am 19 years old. I live in Lag. I am an expert in Javascript<|endoftext|><|system|>\nAnswer the given question<|endoftext|>\n<|user|>\nWho old is Sam?<|endoftext|>\n<|assistant|>\n58<|endoftext|>\n<'

In [63]:
g_tokenizer.encode("<|assistant|>36", return_tensors="pt")

tensor([[   29,    93,   515,  5567, 49651,  1812]])

In [65]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [67]:
count_parameters(mama)

2848533056