In [2]:
import torch
import torch.nn as nn
from transformers import GPT2Model, GPT2Config

class CachedMemoryBank(nn.Module):
    def __init__(self, vocab_size, embedding_dim, memory_dim):
        super(CachedMemoryBank, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.memory_key = nn.Linear(embedding_dim, memory_dim)
        self.memory_value = nn.Linear(embedding_dim, memory_dim)

    def forward(self, input_ids):
        embedded = self.embedding(input_ids)
        keys = self.memory_key(embedded)
        values = self.memory_value(embedded)
        return keys, values

class ResidualSideNet(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(ResidualSideNet, self).__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, input_dim)

    def forward(self, input_tensor):
        residual = input_tensor
        output = self.fc1(input_tensor)
        output = self.relu(output)
        output = self.fc2(output)
        output += residual  # Add residual connection
        return output

class MemoryRetrievalFusion(nn.Module):
    def __init__(self, memory_dim, input_dim):
        super(MemoryRetrievalFusion, self).__init__()

        self.linear_query = nn.Linear(input_dim, memory_dim)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, memory_augmented, transformer_outputs):
        query = self.linear_query(transformer_outputs)
        attention_scores = torch.matmul(query, memory_augmented.transpose(-1, -2))
        attention_weights = self.softmax(attention_scores)
        fused_output = torch.matmul(attention_weights, memory_augmented)
        return fused_output

class BackboneLLM(nn.Module):
    def __init__(self, model_name, config):
        super(BackboneLLM, self).__init__()
        self.gpt2 = GPT2Model.from_pretrained(model_name, config=config)

    def forward(self, input_ids, attention_mask=None):
        output = self.gpt2(input_ids, attention_mask=attention_mask)
        return output.last_hidden_state

class LongMEM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, memory_dim, backbone_model_name, backbone_config):
        super(LongMEM, self).__init__()

        self.frozen_llm = BackboneLLM(backbone_model_name, backbone_config)
        self.memory_bank = CachedMemoryBank(vocab_size, embedding_dim, memory_dim)
        self.side_net = ResidualSideNet(embedding_dim, hidden_dim)
        self.memory_fusion = MemoryRetrievalFusion(memory_dim, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, input_ids):
        # Frozen LLM
        with torch.no_grad():
            frozen_last_hidden_state = self.frozen_llm(input_ids)

        # Memory Bank
        keys, values = self.memory_bank(input_ids)

        # Transformer Outputs
        transformer_outputs = frozen_last_hidden_state

        # Memory Augmentation
        memory_attention = torch.matmul(keys, transformer_outputs.transpose(-1, -2))
        memory_attention = nn.functional.softmax(memory_attention, dim=-1)
        memory_augmented = torch.matmul(memory_attention.transpose(-1, -2), values)

        # Side Net
        side_net_output = self.side_net(transformer_outputs)

        # Memory Retrieval Fusion
        fused_output = self.memory_fusion(memory_augmented, side_net_output)

        # Final Linear Layer
        logits = self.linear(fused_output)
        return logits

    def generate_text(self, input_ids, max_length):
        with torch.no_grad():
            output_ids = input_ids.clone()
            for _ in range(max_length):
                logits = self.forward(input_ids)
                predicted_token = torch.argmax(logits[:, -1, :], dim=-1)
                output_ids = torch.cat((output_ids, predicted_token.unsqueeze(1)), dim=1)
                input_ids = output_ids

        return output_ids

# Example usage
vocab_size = 50257
embedding_dim = 768
hidden_dim = 768
num_layers = 12
memory_dim = 768
backbone_model_name = "gpt2"  # Change this if you have a different pretrained model
backbone_config = GPT2Config.from_pretrained(backbone_model_name)

model = LongMEM(vocab_size, embedding_dim, hidden_dim, num_layers, memory_dim, backbone_model_name, backbone_config)

# Example input
input_ids = torch.tensor([[1, 2, 3, 4, 5]])

# Forward pass
output_logits = model(input_ids)

print("Output logits shape:", output_logits.shape)

# Generate text
max_length = 10
generated_text = model.generate_text(input_ids, max_length)

print("Generated text:", generated_text)


# Example usage
vocab_size = 50257
embedding_dim = 768
hidden_dim = 768
num_layers = 12
memory_dim = 768
backbone_model_name = "gpt2"  # Change this if you have a different pretrained model
backbone_config = GPT2Config.from_pretrained(backbone_model_name)

model = LongMEM(vocab_size, embedding_dim, hidden_dim, num_layers, memory_dim, backbone_model_name, backbone_config)

# Example input
input_ids = torch.tensor([[1, 2, 3, 4, 5]])

# Forward pass
output_logits = model(input_ids)

print("Output logits shape:", output_logits.shape)


Output logits shape: torch.Size([1, 5, 50257])
Keys shape: torch.Size([1, 5, 768])
Values shape: torch.Size([1, 5, 768])
Memory augmented shape: torch.Size([1, 5, 768])
Generated text: tensor([[    1,     2,     3,     4,     5,   155, 27191, 22686, 22686, 22686,
         22686, 22686, 14755, 35641, 35641]])


In [None]:
tokenizer =GPT2Tokenizer.from_pretrained("gpt2")
tokenizer

In [4]:
# Instantiate the GPT2 tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# Example text input
text_input = "Once upon a time"

# Tokenize the text input
input_ids = tokenizer.encode(text_input, return_tensors="pt")

# Create LONGMEM model
longmem_model = LongMEM(vocab_size, embedding_dim, hidden_dim, num_layers, memory_dim, backbone_model_name, backbone_config)

# Generate text using the model
generated_ids = longmem_model.generate_text(input_ids, max_length=20)

# Decode generated token IDs to text
generated_text = tokenizer.decode(generated_ids.squeeze(), skip_special_tokens=True)

print("Generated text:", generated_text)


  input_ids = torch.tensor(input_ids)


Output logits shape: torch.Size([1, 4, 50257])
Keys shape: torch.Size([1, 4, 768])
Values shape: torch.Size([1, 4, 768])
Memory augmented shape: torch.Size([1, 4, 768])
Generated text: tensor([[ 7454,  2402,   257,   640,  8610, 35641, 22542,  8610, 35641, 22686,
         22686, 22686, 22686, 22686]])


In [None]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.30.2-py3-none-any.whl (7.2 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/7.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/7.2 MB[0m [31m46.8 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━[0m [32m5.9/7.2 MB[0m [31m80.7 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m7.2/7.2 MB[0m [31m80.6 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m53.4 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0.16.4-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m23.1 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=

In [None]:
!pip install langchain

Collecting langchain
  Downloading langchain-0.0.234-py3-none-any.whl (1.3 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━[0m [32m1.2/1.3 MB[0m [31m41.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m22.3 MB/s[0m eta [36m0:00:00[0m
Collecting dataclasses-json<0.6.0,>=0.5.7 (from langchain)
  Downloading dataclasses_json-0.5.9-py3-none-any.whl (26 kB)
Collecting langsmith<0.0.6,>=0.0.5 (from langchain)
  Downloading langsmith-0.0.5-py3-none-any.whl (25 kB)
Collecting openapi-schema-pydantic<2.0,>=1.2 (from langchain)
  Downloading openapi_schema_pydantic-1.2.4-py3-none-any.whl (90 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.0/90.0 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
Collecting marshmallow<4.0.0,>=3.3.0 (from dataclasses-json<0.6.0,>=0.5.7

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Instantiate the GPT2 tokenizer
tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-13b-delta-v1.1")

# Example text input
text_input = "Once upon a time"

# Tokenize the text input
input_ids = tokenizer.encode(text_input, return_tensors="pt")

# Generate text using the model
max_length = 20
generated_ids = model.generate_text(input_ids, max_length)

# Decode generated token IDs to text
generated_text = tokenizer.decode(generated_ids.squeeze(), skip_special_tokens=True)

print("Generated text:", generated_text)
