# SparseFlash2LinformerAttention

In [144]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import GradScaler, autocast
from torch.optim import Adam



def random_projection(matrix, k):
    """Random projection to reduce dimensionality of matrix to k dimensions."""
    random_matrix = torch.randn(matrix.size(-1), k, device=matrix.device)
    return torch.matmul(matrix, random_matrix)

def cur_decomposition(matrix, projection_dim):  # Change 'k' argument
    """Applies CUR decomposition with C matrix dimension aligned to projection dimension."""
    batch_size, seq_length, heads, dim = matrix.shape
    k = min(projection_dim // 2, dim)  # Use projection dimension to determine 'k'
    C = torch.zeros(batch_size, seq_length, heads, k, device=matrix.device)
    R = torch.zeros(batch_size, k, heads, dim, device=matrix.device)

    for b in range(batch_size):
        for h in range(heads):
            col_indices = np.random.choice(dim, k, replace=False)
            row_indices = np.random.choice(seq_length, k, replace=False)
            C[b, :, h] = matrix[b, :, h, col_indices]
            R[b, :, h] = matrix[b, row_indices, h]
    return C, R

class PartitionedLinformerAttentionACT(nn.Module):
    def __init__(self, embed_size, heads, sequence_length, projection_dim, partition_size):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.sequence_length = sequence_length
        self.projection_dim = projection_dim
        self.partition_size = partition_size
        self.head_dim = embed_size // heads

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.projection_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.projection_dim, bias=False)
        self.value_projection = nn.Linear(self.head_dim, self.projection_dim//2)  # To project values to match dimensions
        self.ponder = nn.Linear(self.partition_size, 1, bias=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        print(f"BEGIN FORWARD")
        print(f"x input shape: {x.shape}")
        N, seq_length, _ = x.shape
        x_reshaped = x.view(N, seq_length, self.heads, self.head_dim)
        print(f"x_reshaped shape: {x_reshaped.shape}")

        values = self.values(x_reshaped)
        queries = random_projection(self.queries(x_reshaped), self.projection_dim // 2)
        keys = random_projection(self.keys(x_reshaped), self.projection_dim // 2 )
        print("values :", values.shape)
        print("queries :", queries.shape)
        print("keys :", keys.shape)

        attention_scores = torch.zeros(N, self.heads, seq_length, self.projection_dim // 2, device=x.device)
        print(f"attention_scores after random projection: {attention_scores.shape}")

        for i in range(0, seq_length, self.partition_size):
            print(f"PARTITION START")
            partition_start = i
            partition_end = min(i + self.partition_size, seq_length)
            keys_part = keys[:, partition_start:partition_end, :, :]
            queries_part = queries[:, partition_start:partition_end, :, :]

            C_keys, R_queries = cur_decomposition(keys_part, self.projection_dim)
            print("C_keys shape before return:", C_keys.shape)

            ponder_scores = torch.zeros(N, self.heads, partition_end - partition_start, 1, device=x.device)
            print(f"Partition Start {i}, Partition End {partition_end} , ponder_scores: {ponder_scores.shape}")

            for h in range(self.heads):
                #print(f"HEADS START")
                head_queries = queries_part[:, :, h, :]
                #print(f"head_queries: {head_queries.shape}")
                head_ponder_scores = self.sigmoid(self.ponder(head_queries))
                #print(f"head_ponder_scores: {head_ponder_scores.shape}")
                ponder_scores[:, h, :, 0] = head_ponder_scores.squeeze(-1)

            # Correctly expand ponder_scores without adding an unnecessary dimension
            print("BEFORE 1ST EINSUM:")
            ponder_scores_permuted = ponder_scores.permute(0, 2, 1, 3)  # Move to [2, 128, 8, 1]
            print("ponder_scores_permuted shape:", ponder_scores_permuted.shape) 
            ponder_scores_broadcastable = ponder_scores_permuted.expand(-1, -1, -1, 128)  # Expand to [2, 128, 8, 128]            print(f"ponder_scores_broadcastable shape: {ponder_scores_broadcastable.shape}")
            print("ponder_scores_broadcastable shape:", ponder_scores_broadcastable.shape) 
            print("queries_part shape:", queries_part.shape) 
            print("C_keys shape:", C_keys.shape)
            energy = torch.einsum('bnhd,bnhk->bnhd', queries_part, C_keys)
            attention_weights = F.softmax(energy, dim=-1)
            print("AFTER 1ST EINSUM:")
            print("energy shape:", energy.shape) 
            print("attention_weights shape:", attention_weights.shape)
            attention = attention_weights * ponder_scores_broadcastable
            print("attention shape:", attention.shape)
            attention_corrected = attention.permute(0, 2, 1, 3)
            attention_scores[:, :, partition_start:partition_end, :] = attention_corrected

        values = values.permute(0, 2, 1, 3)  # Swap heads and seq_length to bring heads next to head_dim
        print("values shape:", values.shape)
        values = values.reshape(-1, self.head_dim)  # Flatten to [N*heads*seq_length, head_dim] for linear layer
        print("values.reshape(-1, self.head_dim) shape:", values.shape)
        projected_values = self.value_projection(values)  # Now [N*heads*seq_length, projection_dim / 2]
        print("self.value_projection(values) shape:", projected_values.shape)
        projected_values = projected_values.view(N, self.heads, seq_length, self.projection_dim // 2)
        print("projected_values shape:", projected_values.shape)

        print(f"2ND EINSUM")
        # Combine attention_scores and projected_values then pass through the final linear layer
        out = torch.einsum('bnhp,bnhp->bnh', attention_scores, projected_values)
        print("out shape:", out.shape)
       
        return out

'''
# Example usage
model = PartitionedLinformerAttentionACT(embed_size=512, heads=8, sequence_length=1024, projection_dim=256, partition_size=128)
input_tensor = torch.rand(2, model.sequence_length, model.embed_size)
output = model(input_tensor)
print(output.shape)  
'''

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Assuming the input tensor size is [batch_size, sequence_length, embed_size]
embed_size = 512  # Size of each word embedding
heads = 8  # Number of attention heads
sequence_length = 1024  # Length of the input sequence
projection_dim = 256  # Dimension to project the input embeddings
partition_size = 128  # Size of partitions for processing
num_classes = 10
# Instantiate the model
attention_model  = PartitionedLinformerAttentionACT(embed_size, heads, sequence_length, projection_dim, partition_size)
class AttentionClassifier(nn.Module):
    def __init__(self, attention_model, num_classes):
        super().__init__()
        self.attention_model = attention_model
        self.global_pool = nn.AdaptiveAvgPool1d(1)  # Global pooling
        # Assuming the attention model's output dimensions here; adjust as necessary
        projection_dim = attention_model.projection_dim
        heads = attention_model.heads
        self.classifier = nn.Linear(heads * (projection_dim // 2), num_classes)  # Classifier layer

    def forward(self, x):
        N, seq_length, _ = x.shape
        out = self.attention_model(x)
        # Ensure out has dimensions [batch_size, seq_length, num_heads * (projection_dim // 2)]
        # You might need to adjust this depending on your attention model's exact output
        out = out.permute(0, 2, 1)  # Change shape to [batch_size, sequence_length, num_heads]
        out = self.global_pool(out).squeeze(-1)  # Global pooling, resulting in [batch_size, num_heads]
        out = out.view(N, -1)  # Flatten
        out = self.classifier(out)  # Pass through classifier
        return out
# Move the model to the appropriate device
model = AttentionClassifier(attention_model, num_classes).to(device)

model = model.to(device)

# Create a DataLoader with pinned memory
class MyDataset(Dataset):
    def __init__(self):
        # Initialize your data here
        pass
    def __len__(self):
        # Return the size of your dataset
        return 100  # Example size
    def __getitem__(self, idx):
        # Example: Pad the sequence to length 1024 and ensure embed_size is 512
        # This is a simplified example; adjust according to your actual data
        input_tensor = torch.randn(1024, 512)  # Padded/modeled input to match expected dimensions
        target = torch.randint(0, 10, (1,)).item()  # Generate a random class index (0 to 9 for 10 classes)
        return input_tensor, target


dataset = MyDataset()
loader = DataLoader(dataset, batch_size=32, shuffle=True, pin_memory=True)
optimizer = Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()

for epoch in range(10):  # Example: 10 epochs
    for inputs, targets in loader:
        # Move inputs and targets to the correct device
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        
        optimizer.zero_grad()
        
        with autocast():
            # Forward pass through the model
            predictions = model(inputs)
            
            # Calculate loss (assuming classification task and targets are class indices)
            loss = nn.functional.cross_entropy(predictions, targets)
        
        # Backward pass and optimizer step using gradient scaling for mixed precision
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
    print(f"Epoch: {epoch}, Loss: {loss}")



BEGIN FORWARD
x input shape: torch.Size([32, 1024, 512])
x_reshaped shape: torch.Size([32, 1024, 8, 64])
values : torch.Size([32, 1024, 8, 64])
queries : torch.Size([32, 1024, 8, 128])
keys : torch.Size([32, 1024, 8, 128])
attention_scores after random projection: torch.Size([32, 8, 1024, 128])
PARTITION START
C_keys shape before return: torch.Size([32, 128, 8, 128])
Partition Start 0, Partition End 128 , ponder_scores: torch.Size([32, 8, 128, 1])
BEFORE 1ST EINSUM:
ponder_scores_permuted shape: torch.Size([32, 128, 8, 1])
ponder_scores_broadcastable shape: torch.Size([32, 128, 8, 128])
queries_part shape: torch.Size([32, 128, 8, 128])
C_keys shape: torch.Size([32, 128, 8, 128])
AFTER 1ST EINSUM:
energy shape: torch.Size([32, 128, 8, 128])
attention_weights shape: torch.Size([32, 128, 8, 128])
attention shape: torch.Size([32, 128, 8, 128])
PARTITION START
C_keys shape before return: torch.Size([32, 128, 8, 128])
Partition Start 128, Partition End 256 , ponder_scores: torch.Size([32, 8,

# Test SPLASH Attention on Tokenizer

In [18]:
import numpy as np
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
# Test on larger text
import re
import collections
from collections import Counter, defaultdict
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader
import fitz  # PyMuPDF
from transformers import BertModel
import math
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from tqdm.notebook import tqdm
from torch.nn.utils.rnn import pad_sequence


class SPLASH(nn.Module):
    def __init__(self, vocab_size, embed_size, heads, sequence_length, projection_dim, partition_size):
        super().__init__()
        # Embedding layer added
        self.embedding = nn.Embedding(vocab_size, embed_size)
        # Existing initialization code...
        self.embed_size = embed_size
        self.heads = heads
        self.sequence_length = sequence_length
        self.projection_dim = projection_dim
        self.partition_size = partition_size
        self.head_dim = embed_size // heads

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.projection_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.projection_dim, bias=False)
        self.value_projection = nn.Linear(self.head_dim, self.projection_dim//2)  # To project values to match dimensions
        self.ponder = nn.Linear(self.partition_size, 1, bias=True)
        self.sigmoid = nn.Sigmoid()
        self.final_projection = nn.Linear(self.heads * (self.projection_dim // 2), vocab_size)


    def random_projection(self, matrix, k):
        """Random projection to reduce dimensionality of matrix to k dimensions."""
        random_matrix = torch.randn(matrix.size(-1), k, device=matrix.device)
        return torch.matmul(matrix, random_matrix)

    def cur_decomposition(self, matrix, projection_dim):
        """Applies CUR decomposition with C matrix dimension aligned to projection dimension."""
        batch_size, seq_length, heads, dim = matrix.shape
        k = min(projection_dim // 2, dim)
        C = torch.zeros(batch_size, seq_length, heads, k, device=matrix.device)
        R = torch.zeros(batch_size, k, heads, dim, device=matrix.device)

        for b in range(batch_size):
            for h in range(heads):
                col_indices = np.random.choice(dim, k, replace=False)
                row_indices = np.random.choice(seq_length, k, replace=False)
                C[b, :, h] = matrix[b, :, h, col_indices]
                R[b, :, h] = matrix[b, row_indices, h]
        return C, R

    def forward(self, input_ids):
        print(f"BEGIN FORWARD")
        x = self.embedding(input_ids)
        print(f"x input shape: {x.shape}")
        N, seq_length, _ = x.shape
        x_reshaped = x.view(N, seq_length, self.heads, self.head_dim)
        print(f"x_reshaped shape: {x_reshaped.shape}")

        values = self.values(x_reshaped)
        queries = self.random_projection(self.queries(x_reshaped), self.projection_dim // 2)
        keys = self.random_projection(self.keys(x_reshaped), self.projection_dim // 2 )
        print("values :", values.shape)
        print("queries :", queries.shape)
        print("keys :", keys.shape)

        attention_scores = torch.zeros(N, self.heads, seq_length, self.projection_dim // 2, device=x.device)
        print(f"attention_scores after random projection: {attention_scores.shape}")

        for i in range(0, seq_length, self.partition_size):
            print(f"PARTITION START")
            partition_start = i
            partition_end = min(i + self.partition_size, seq_length)
            keys_part = keys[:, partition_start:partition_end, :, :]
            queries_part = queries[:, partition_start:partition_end, :, :]

            C_keys, R_queries = self.cur_decomposition(keys_part, self.projection_dim)
            print("C_keys shape before return:", C_keys.shape)

            ponder_scores = torch.zeros(N, self.heads, partition_end - partition_start, 1, device=x.device)
            print(f"Partition Start {i}, Partition End {partition_end} , ponder_scores: {ponder_scores.shape}")

            for h in range(self.heads):
                #print(f"HEADS START")
                head_queries = queries_part[:, :, h, :]
                #print(f"head_queries: {head_queries.shape}")
                head_ponder_scores = self.sigmoid(self.ponder(head_queries))
                #print(f"head_ponder_scores: {head_ponder_scores.shape}")
                ponder_scores[:, h, :, 0] = head_ponder_scores.squeeze(-1)

            # Correctly expand ponder_scores without adding an unnecessary dimension
            print("BEFORE 1ST EINSUM:")
            ponder_scores_permuted = ponder_scores.permute(0, 2, 1, 3)  # Move to [2, 128, 8, 1]
            print("ponder_scores_permuted shape:", ponder_scores_permuted.shape) 
            ponder_scores_broadcastable = ponder_scores_permuted.expand(-1, -1, -1, 128)  # Expand to [2, 128, 8, 128]            
            print("ponder_scores_broadcastable shape:", ponder_scores_broadcastable.shape) 
            print("queries_part shape:", queries_part.shape) 
            print("C_keys shape:", C_keys.shape)
            energy = torch.einsum('bnhd,bnhk->bnhd', queries_part, C_keys)
            attention_weights = F.softmax(energy, dim=-1)
            print("AFTER 1ST EINSUM:")
            print("energy shape:", energy.shape) 
            print("attention_weights shape:", attention_weights.shape)
            attention = attention_weights * ponder_scores_broadcastable
            print("attention shape:", attention.shape)
            attention_corrected = attention.permute(0, 2, 1, 3)
            attention_scores[:, :, partition_start:partition_end, :] = attention_corrected

        values = values.permute(0, 2, 1, 3)  # Swap heads and seq_length to bring heads next to head_dim
        print("values shape:", values.shape)
        values = values.reshape(-1, self.head_dim)  # Flatten to [N*heads*seq_length, head_dim] for linear layer
        print("values.reshape(-1, self.head_dim) shape:", values.shape)
        projected_values = self.value_projection(values)  # Now [N*heads*seq_length, projection_dim / 2]
        print("self.value_projection(values) shape:", projected_values.shape)
        projected_values = projected_values.view(N, self.heads, seq_length, self.projection_dim // 2)
        print("projected_values shape:", projected_values.shape)

        print(f"2ND EINSUM")
        # Combine attention_scores and projected_values then pass through the final linear layer
        '''out = torch.einsum('bnhp,bnhp->bnh', attention_scores, projected_values)
        print("out shape:", out.shape)
        out = self.final_projection(out.view(-1, self.embed_size))  # Reshape and project
        print(f"out from self.final_projection(out.view(-1, self.embed_size)) : {out.shape}")
        print(f"Final output: { out.view(-1, vocab_size).shape }")
        return out.view(-1, vocab_size)  # Reshape to [batch_size, sequence_length, vocab_size]'''
        # Combine attention_scores and projected_values then pass through the final linear layer
        out = torch.einsum('bnhp,bnhp->bnhp', attention_scores, projected_values)
        print("out shape after einsum:", out.shape)

        # Correct reshaping: Flatten batch and sequence length dimensions, keep the last two dimensions for projection
        out = out.reshape(-1, self.heads * (self.projection_dim // 2))
        print("out reshaped for projection:", out.shape)

        # Ensure the final_projection layer matches the flattened shape expected after reshaping
        # Assuming final_projection is defined as nn.Linear(self.heads * (self.projection_dim // 2), vocab_size)
        out = self.final_projection(out)
        print(f"out after final_projection:", out.shape)

        # At this point, out should have a shape of [batch_size * sequence_length, vocab_size], ready for loss calculation
        return out



##################################################
# Tokenizer

class TrieNode:
    def __init__(self):
        self.children = {}
        self.token_id = None
        self.frequency = 0
        self.failure_link = None
        self.is_end = False  # Add is_end attribute to mark the end of a word
        self.token = None  # Add token attribute to store the token associated with the node


class Trie:
    def __init__(self, unk_token_id=0):
        self.root = TrieNode()
        self.unk_token_id = unk_token_id

    def insert(self, token, token_id, frequency):
        node = self.root
        for char in token:
            if char not in node.children:
                node.children[char] = TrieNode()
            node = node.children[char]
        node.token_id = token_id
        node.frequency = frequency

    def find_subwords(self, token):
        """Finds the most probable subwords based on frequency."""
        node = self.root
        best_subwords = []

        def dfs(current_node, subword='', collected_subwords=[]):
            if current_node.token_id is not None:
                # Update to correctly calculate total_frequency based on the structure of collected_subwords
                total_frequency = sum(n.frequency for _, _, n in collected_subwords) + current_node.frequency
                probability = current_node.frequency / total_frequency if total_frequency else 0
                collected_subwords.append((subword, probability, current_node))

            for char, next_node in current_node.children.items():
                dfs(next_node, subword + char, list(collected_subwords))  # Create a copy of the list to avoid shared state

        dfs(node)
        best_subwords = sorted(best_subwords, key=lambda x: x[1], reverse=True)
        return [subword for subword, _, _ in best_subwords][:5] or [self.unk_token_id]


    def compute_failure_links(self):
        root = self.root
        root.failure_link = root  # Root's failure link points to itself
        queue = [root]

        while queue:
            current_node = queue.pop(0)

            for char, child_node in current_node.children.items():
                queue.append(child_node)

                # Follow failure link to find the longest suffix for the child_node
                failure_candidate = current_node.failure_link
                while failure_candidate != root and char not in failure_candidate.children:
                    failure_candidate = failure_candidate.failure_link
                child_node.failure_link = failure_candidate.children.get(char, root)


class SimpleSentencePiece:
    def __init__(self, model_type="bpe", vocab_size=30522):
        self.vocab = {}
        self.id_to_subword = {}
        self.unk_token = "[UNK]"
        self.unk_token_id = 0
        self.vocab_size = vocab_size
        self.model = None if model_type == "bpe" else None
        self.model_type = model_type

    def train(self, text):
        if self.model_type == "bpe":
            self.model = BPE(num_merges=self.vocab_size, unk_token_id=self.unk_token_id)
            self.model.train(text)
            self.vocab = self.model.vocab
            self.id_to_subword = {i: word for word, i in self.vocab.items()}
        else:
            raise NotImplementedError(f"Model type {self.model_type} not supported yet.")

    def encode(self, text):
        text = self.preprocess_text(text)  # Preprocess text before encoding
        if not self.model:
            raise ValueError("Model has not been trained yet.")
        encoded = self.model.encode(text)
        #print(f"Encoded: {encoded[:10]}")  # Print first 10 encoded tokens
        return encoded

    def decode(self, ids):
        if not self.id_to_subword:
            raise ValueError("Vocabulary is empty. Ensure the model is trained first.")
        text = " ".join([self.id_to_subword.get(id_, self.unk_token) for id_ in ids])
        text = text.replace(" </w>", "").replace("</w>", " ").strip()
        return text

    def preprocess_text(self, text):
        # Convert text to lowercase to ensure case insensitivity
        text = text.lower()
        # Optionally, handle punctuation by adding spaces around it for better tokenization
        text = re.sub(r'([.,!?()])', r' \1 ', text)
        # Replace multiple spaces with a single space
        text = re.sub(r'\s+', ' ', text)
        # Trim leading and trailing spaces
        text = text.strip()
        return text
    
    def save_model(self, filepath):
        model_data = {
            'vocab': self.vocab,
            'id_to_subword': self.id_to_subword,
            'model_type': self.model_type,
            'vocab_size': self.vocab_size,
            # Potentially include other relevant attributes
        }
        # Save the high-level tokenizer settings
        with open(filepath, 'w') as f:
            json.dump(model_data, f)
        
        # Now save the BPE model specifically
        if self.model_type == "bpe" and self.model:
            self.model.save_model(filepath + "_bpe")

    def load_model(self, filepath):
        with open(filepath, 'r') as f:
            model_data = json.load(f)
        
        self.vocab = model_data['vocab']
        self.id_to_subword = model_data['id_to_subword']
        self.model_type = model_data['model_type']
        self.vocab_size = model_data['vocab_size']
        
        # Assuming model_type is still "bpe", we now load the BPE model
        if self.model_type == "bpe":
            self.model = BPE(self.vocab_size, self.unk_token_id)
            self.model.load_model(filepath + "_bpe")

class BPE:
    def __init__(self, num_merges=100, unk_token_id=0):  # Accept unk_token_id parameter
        self.vocab = {}
        self.merges = []
        self.num_merges = num_merges
        self.unk_token_id = unk_token_id  # Store the unknown token ID

    def train(self, text):
        words = re.findall(r'\w+|[^\w\s]', text, re.UNICODE)
        vocab = collections.Counter(words)
        vocab = {word + '</w>': count for word, count in vocab.items()}
        
        for _ in range(self.num_merges):  # Use the num_merges from the instance variable
            pairs = self.get_stats(vocab)
            if not pairs:
                break
            best = max(pairs, key=pairs.get)
            vocab = self.merge_vocab(best, vocab)
            self.merges.append(best)

        self.vocab = {word: i for i, word in enumerate(vocab.keys())}

    @staticmethod
    def get_stats(vocab):
        pairs = collections.defaultdict(int)
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols)-1):
                pairs[symbols[i], symbols[i+1]] += freq
        return pairs

    @staticmethod
    def merge_vocab(pair, vocab):
        v_out = {}
        bigram = re.escape(' '.join(pair))
        p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
        for word in vocab:
            w_out = p.sub(''.join(pair), word)
            v_out[w_out] = vocab[word]
        return v_out

    def encode(self, text):
        """Encode text into subwords using learned BPE merges."""
        encoded_tokens = []
        for word in re.findall(r'\w+|[^\w\s]', text, re.UNICODE):
            word += '</w>'
            subwords = [word]  # Start with the entire word as one subword
            for merge in self.merges:
                new_subwords = []
                for subword in subwords:
                    # If the merge is in subword, split it; otherwise, keep it as is
                    if ' '.join(merge) in subword:
                        new_subwords.extend(subword.replace(' '.join(merge), ''.join(merge)).split(' '))
                    else:
                        new_subwords.append(subword)
                subwords = new_subwords
            encoded_tokens.extend(subwords)
        return [self.vocab.get(token, self.unk_token_id) for token in encoded_tokens]
    
        # New method to save trained model
    def save_model(self, filepath):
        bpe_data = {
            'merges': self.merges,
            'vocab': self.vocab,
            'num_merges': self.num_merges,
            # Include other attributes as needed
        }
        with open(filepath, 'w') as f:
            json.dump(bpe_data, f)

    def load_model(self, filepath):
        with open(filepath, 'r') as f:
            bpe_data = json.load(f)
        
        self.merges = bpe_data['merges']
        self.vocab = bpe_data['vocab']
        self.num_merges = bpe_data['num_merges']


class WordPiece:
    def __init__(self, vocab, unk_token_id=0, unk_token="[UNK]"):
        self.vocab = vocab
        self.unk_token_id = unk_token_id
        self.unk_token = unk_token  # Define the unknown token
        self.root = self.build_trie(vocab)
        self.id_to_token = {id_: token for token, id_ in vocab.items()}  # Inverse mapping
        self.compute_failure_links(self.root)
        print("Trie built successfully.")

    def convert_ids_to_tokens(self, ids):
        """
        Convert a list of token ids back to their string token representations.
        """
        return [self.id_to_token.get(id_, self.unk_token) for id_ in ids]

    # Add debug prints to build_trie to confirm structure
    def build_trie(self, vocab):
        root = TrieNode()
        for token in vocab:
            node = root
            for char in token:
                if char not in node.children:
                    node.children[char] = TrieNode()
                node = node.children[char]
            node.is_end = True
            node.token = token
        #print("Trie Construction Completed Successfully")
        return root


    def compute_failure_links(self, root):
        queue = [root]
        while queue:
            current_node = queue.pop(0)
            for char, child_node in current_node.children.items():
                failure_node = current_node.failure_link
                while failure_node and char not in failure_node.children:
                    failure_node = failure_node.failure_link
                child_node.failure_link = failure_node.children[char] if failure_node else root
                queue.append(child_node)

    # Improved debug prints in tokenize method
                
    def tokenize(self, text):
        # Preprocess input text
        text = self.preprocess_text(text)
        node = self.root
        token_ids = []  # Will store token IDs instead of tokens
        i = 0

        while i < len(text):
            char = text[i]
            if char == ' ':
                node = self.root
                i += 1
                continue

            if char not in node.children:
                if node != self.root and node.token is not None:
                    # Convert found token to its ID
                    token_id = self.vocab.get(node.token, self.unk_token_id)
                    token_ids.append(token_id)
                    node = self.root  # Reset to root
                    continue
                else:
                    # Append unknown token ID
                    token_ids.append(self.unk_token_id)
                    i += 1
                    continue

            node = node.children[char]
            if node.is_end:
                if i + 1 == len(text) or text[i + 1] == ' ':
                    # Convert found token to its ID
                    token_id = self.vocab.get(node.token, self.unk_token_id)
                    token_ids.append(token_id)
                    node = self.root

            i += 1

        #print(f"Token IDs: {token_ids[:10]}")
        return token_ids

    def preprocess_text(self, text):
        # Convert text to lowercase to ensure case insensitivity
        text = text.lower()

        # Optionally, handle punctuation by adding spaces around it for better tokenization
        # This depends on how your vocabulary handles punctuation
        text = re.sub(r'([.,!?()])', r' \1 ', text)

        # Replace multiple spaces with a single space
        text = re.sub(r'\s+', ' ', text)

        # Trim leading and trailing spaces
        text = text.strip()

        return text


def load_corpus(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        texts = [line.strip() for line in file.readlines()]
    return texts

texts = load_corpus("D:\\EXPERT_WEIGHTS\\sample.txt")
#texts = load_corpus("/content/drive/MyDrive/EXPERT_STUFF/sample.txt")
# texts = load_corpus("C:/Users/robbi/Expert/sample.txt")
num_merges = 100
def adapt_vocab_for_wordpiece(ssp_vocab):
    adapted_vocab = {}
    for token, id_or_freq in ssp_vocab.items():
        if not token.startswith(" ") and not token.endswith("</w>"):
            adapted_token = "##" + token.replace("</w>", "")
        else:
            adapted_token = token.replace("</w>", "")
        adapted_vocab[adapted_token] = id_or_freq
    return adapted_vocab

# Initialize and train the SimpleSentencePiece model with BPE
ssp = SimpleSentencePiece(model_type="bpe", vocab_size=30522)
# Assume `texts` is a list of text to train the tokenizer
ssp.train('\n'.join(texts))
wordpiece_vocab = adapt_vocab_for_wordpiece(ssp.vocab)
# Debugging step to ensure vocabulary completeness
def debug_vocab(adapted_vocab):
    print("Sample Vocabulary Check:")
    # Iterate over the first 10 key-value pairs in the adapted vocabulary
    for i, (token, id_or_freq) in enumerate(adapted_vocab.items()):
        print(f"{token}: {id_or_freq}")
        if i >= 9:  # Stop after printing 10 entries
            break
    # Specifically check for subtokens if your tokenizer expects them
    subtokens = [token for token in adapted_vocab.keys() if token.startswith("##")]
    print(f"Found {len(subtokens)} subtokens in vocabulary.")

# Ensure wordpiece_vocab is a list of vocabulary tokens
# debug_vocab(wordpiece_vocab)  # Call this after initializing wordpiece_vocab

# Initialize WordPiece with the adapted vocabulary
wordpiece_tokenizer = WordPiece(wordpiece_vocab, unk_token_id=0, unk_token="[UNK]")

class WikiTextDatasetForLM(Dataset):
    def __init__(self, texts, tokenizer, sequence_length):
        self.tokenizer = tokenizer
        self.sequence_length = sequence_length
        self.inputs, self.labels = self.process_texts(texts)
        
    def process_texts(self, texts):
        inputs, labels = [], []
        for text in texts:
            token_ids = self.tokenizer.tokenize(text)
            for i in range(0, len(token_ids) - self.sequence_length, self.sequence_length):
                # Create sequences of `sequence_length` for inputs
                inputs.append(token_ids[i:i+self.sequence_length])
                # Shift by one for the labels to predict the next token
                labels.append(token_ids[i+1:i+self.sequence_length+1])
        return torch.tensor(inputs, dtype=torch.long), torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]

# Assuming vocab_size is determined by your tokenizer's vocabulary size
vocab_size = len(wordpiece_tokenizer.vocab)  # Adjust this to match your tokenizer's vocab size method
print(vocab_size)
# Define other parameters for SPLASH model instantiation
embed_size = 512  # Embedding size
heads = 8  # Number of attention heads
sequence_length = 1024  # Input sequence length
projection_dim = 256  # Dimension for projections inside SPLASH
partition_size = 128  # Size of partitions for SPLASH processing
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Instantiate SPLASH Model
splash_model = SPLASH(vocab_size=vocab_size, embed_size=embed_size, heads=heads, 
                      sequence_length=sequence_length, projection_dim=projection_dim, 
                      partition_size=partition_size).to(device)

# Load Wikipedia dataset and preprocess
dataset = load_dataset("wikitext", "wikitext-2-v1", split="train")
batch_size = 32  # Adjust based on your hardware capabilities
num_epochs  = 5
loss_function = nn.CrossEntropyLoss()  # Example, adjust as needed
optimizer = torch.optim.Adam(splash_model.parameters(), lr=1e-4)
# Extract texts from the dataset
texts = dataset['text']

# Initialize the dataset for language modeling
wiki_dataset_for_lm = WikiTextDatasetForLM(texts, wordpiece_tokenizer, sequence_length=1024)
print(f"len(wiki_dataset_for_lm): {len(wiki_dataset_for_lm)}")

# DataLoader
data_loader = DataLoader(wiki_dataset_for_lm, batch_size=32, shuffle=True)

for epoch in range(num_epochs):
    for i, (input_ids, labels) in enumerate(data_loader):
        print(f"Batch {i} input_ids shape: {input_ids.shape}, labels shape: {labels.shape}")
        if i > 5:  # Just to limit output for debugging
            break

    '''for input_ids, labels in data_loader:
        input_ids, labels = input_ids.to(device), labels.to(device)
        print(f"input_ids: {input_ids.shape}")
        print(f"labels: {labels.shape}")

        optimizer.zero_grad()

        # Forward pass
        outputs = splash_model(input_ids)

        # Flatten the output and label tensors for use with CrossEntropyLoss
        outputs_flat = outputs.view(-1, outputs.size(-1))
        labels_flat = labels.view(-1)

        # Compute loss
        loss = loss_function(outputs_flat, labels_flat)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")'''



Trie built successfully.
2749
len(wiki_dataset_for_lm): 3
Batch 0 input_ids shape: torch.Size([3, 1024]), labels shape: torch.Size([3, 1024])
Batch 0 input_ids shape: torch.Size([3, 1024]), labels shape: torch.Size([3, 1024])
Batch 0 input_ids shape: torch.Size([3, 1024]), labels shape: torch.Size([3, 1024])
Batch 0 input_ids shape: torch.Size([3, 1024]), labels shape: torch.Size([3, 1024])
Batch 0 input_ids shape: torch.Size([3, 1024]), labels shape: torch.Size([3, 1024])


# v2

In [1]:
import numpy as np
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
# Test on larger text
import re
import collections
from collections import Counter, defaultdict
import json

import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from transformers import BertTokenizer
from torch.utils.data import DataLoader
import fitz  # PyMuPDF
from transformers import BertModel
import math
from torch.optim import AdamW
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from tqdm.notebook import tqdm
from torch.nn.utils.rnn import pad_sequence


class SPLASH(nn.Module):
    def __init__(self, vocab_size, embed_size, heads, sequence_length, projection_dim, partition_size):
        super().__init__()
        # Embedding layer added
        self.embedding = nn.Embedding(vocab_size, embed_size)
        # Existing initialization code...
        self.embed_size = embed_size
        self.heads = heads
        self.sequence_length = sequence_length
        self.projection_dim = projection_dim
        self.partition_size = partition_size
        self.head_dim = embed_size // heads

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.projection_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.projection_dim, bias=False)
        self.value_projection = nn.Linear(self.head_dim, self.projection_dim//2)  # To project values to match dimensions
        self.ponder = nn.Linear(self.partition_size, 1, bias=True)
        self.sigmoid = nn.Sigmoid()
        self.final_projection = nn.Linear(self.heads * (self.projection_dim // 2), vocab_size)


    def random_projection(self, matrix, k):
        """Random projection to reduce dimensionality of matrix to k dimensions."""
        random_matrix = torch.randn(matrix.size(-1), k, device=matrix.device)
        return torch.matmul(matrix, random_matrix)

    def cur_decomposition(self, matrix, projection_dim):
        """Applies CUR decomposition with C matrix dimension aligned to projection dimension."""
        batch_size, seq_length, heads, dim = matrix.shape
        k = min(projection_dim // 2, dim)
        C = torch.zeros(batch_size, seq_length, heads, k, device=matrix.device)
        R = torch.zeros(batch_size, k, heads, dim, device=matrix.device)

        for b in range(batch_size):
            for h in range(heads):
                col_indices = np.random.choice(dim, k, replace=False)
                row_indices = np.random.choice(seq_length, k, replace=False)
                C[b, :, h] = matrix[b, :, h, col_indices]
                R[b, :, h] = matrix[b, row_indices, h]
        return C, R

    def forward(self, input_ids):
        print(f"BEGIN FORWARD")
        x = self.embedding(input_ids)
        print(f"x input shape: {x.shape}")
        N, seq_length, _ = x.shape
        x_reshaped = x.view(N, seq_length, self.heads, self.head_dim)
        print(f"x_reshaped shape: {x_reshaped.shape}")

        values = self.values(x_reshaped)
        queries = self.random_projection(self.queries(x_reshaped), self.projection_dim // 2)
        keys = self.random_projection(self.keys(x_reshaped), self.projection_dim // 2 )
        print("values :", values.shape)
        print("queries :", queries.shape)
        print("keys :", keys.shape)

        attention_scores = torch.zeros(N, self.heads, seq_length, self.projection_dim // 2, device=x.device)
        print(f"attention_scores after random projection: {attention_scores.shape}")

        for i in range(0, seq_length, self.partition_size):
            print(f"PARTITION START")
            partition_start = i
            partition_end = min(i + self.partition_size, seq_length)
            keys_part = keys[:, partition_start:partition_end, :, :]
            queries_part = queries[:, partition_start:partition_end, :, :]

            C_keys, R_queries = self.cur_decomposition(keys_part, self.projection_dim)
            print("C_keys shape before return:", C_keys.shape)

            ponder_scores = torch.zeros(N, self.heads, partition_end - partition_start, 1, device=x.device)
            print(f"Partition Start {i}, Partition End {partition_end} , ponder_scores: {ponder_scores.shape}")

            for h in range(self.heads):
                #print(f"HEADS START")
                head_queries = queries_part[:, :, h, :]
                #print(f"head_queries: {head_queries.shape}")
                head_ponder_scores = self.sigmoid(self.ponder(head_queries))
                #print(f"head_ponder_scores: {head_ponder_scores.shape}")
                ponder_scores[:, h, :, 0] = head_ponder_scores.squeeze(-1)

            # Correctly expand ponder_scores without adding an unnecessary dimension
            print("BEFORE 1ST EINSUM:")
            ponder_scores_permuted = ponder_scores.permute(0, 2, 1, 3)  # Move to [2, 128, 8, 1]
            print("ponder_scores_permuted shape:", ponder_scores_permuted.shape) 
            ponder_scores_broadcastable = ponder_scores_permuted.expand(-1, -1, -1, 128)  # Expand to [2, 128, 8, 128]            
            print("ponder_scores_broadcastable shape:", ponder_scores_broadcastable.shape) 
            print("queries_part shape:", queries_part.shape) 
            print("C_keys shape:", C_keys.shape)
            energy = torch.einsum('bnhd,bnhk->bnhd', queries_part, C_keys)
            attention_weights = F.softmax(energy, dim=-1)
            print("AFTER 1ST EINSUM:")
            print("energy shape:", energy.shape) 
            print("attention_weights shape:", attention_weights.shape)
            attention = attention_weights * ponder_scores_broadcastable
            print("attention shape:", attention.shape)
            attention_corrected = attention.permute(0, 2, 1, 3)
            attention_scores[:, :, partition_start:partition_end, :] = attention_corrected

        values = values.permute(0, 2, 1, 3)  # Swap heads and seq_length to bring heads next to head_dim
        print("values shape:", values.shape)
        values = values.reshape(-1, self.head_dim)  # Flatten to [N*heads*seq_length, head_dim] for linear layer
        print("values.reshape(-1, self.head_dim) shape:", values.shape)
        projected_values = self.value_projection(values)  # Now [N*heads*seq_length, projection_dim / 2]
        print("self.value_projection(values) shape:", projected_values.shape)
        projected_values = projected_values.view(N, self.heads, seq_length, self.projection_dim // 2)
        print("projected_values shape:", projected_values.shape)

        print(f"2ND EINSUM")
        # Combine attention_scores and projected_values then pass through the final linear layer
        '''out = torch.einsum('bnhp,bnhp->bnh', attention_scores, projected_values)
        print("out shape:", out.shape)
        out = self.final_projection(out.view(-1, self.embed_size))  # Reshape and project
        print(f"out from self.final_projection(out.view(-1, self.embed_size)) : {out.shape}")
        print(f"Final output: { out.view(-1, vocab_size).shape }")
        return out.view(-1, vocab_size)  # Reshape to [batch_size, sequence_length, vocab_size]'''
        # Combine attention_scores and projected_values then pass through the final linear layer
        out = torch.einsum('bnhp,bnhp->bnhp', attention_scores, projected_values)
        print("out shape after einsum:", out.shape)

        # Correct reshaping: Flatten batch and sequence length dimensions, keep the last two dimensions for projection
        out = out.reshape(-1, self.heads * (self.projection_dim // 2))
        print("out reshaped for projection:", out.shape)

        # Ensure the final_projection layer matches the flattened shape expected after reshaping
        # Assuming final_projection is defined as nn.Linear(self.heads * (self.projection_dim // 2), vocab_size)
        out = self.final_projection(out)
        print(f"out after final_projection:", out.shape)

        # At this point, out should have a shape of [batch_size * sequence_length, vocab_size], ready for loss calculation
        return out

##################################################
# Tokenizer

class TrieNode:
    def __init__(self):
        self.children = {}
        self.token_id = None
        self.frequency = 0
        self.failure_link = None
        self.is_end = False  # Add is_end attribute to mark the end of a word
        self.token = None  # Add token attribute to store the token associated with the node


class Trie:
    def __init__(self, unk_token_id=0):
        self.root = TrieNode()
        self.unk_token_id = unk_token_id

    def insert(self, token, token_id, frequency):
        node = self.root
        for char in token:
            if char not in node.children:
                node.children[char] = TrieNode()
            node = node.children[char]
        node.token_id = token_id
        node.frequency = frequency

    def find_subwords(self, token):
        """Finds the most probable subwords based on frequency."""
        node = self.root
        best_subwords = []

        def dfs(current_node, subword='', collected_subwords=[]):
            if current_node.token_id is not None:
                # Update to correctly calculate total_frequency based on the structure of collected_subwords
                total_frequency = sum(n.frequency for _, _, n in collected_subwords) + current_node.frequency
                probability = current_node.frequency / total_frequency if total_frequency else 0
                collected_subwords.append((subword, probability, current_node))

            for char, next_node in current_node.children.items():
                dfs(next_node, subword + char, list(collected_subwords))  # Create a copy of the list to avoid shared state

        dfs(node)
        best_subwords = sorted(best_subwords, key=lambda x: x[1], reverse=True)
        return [subword for subword, _, _ in best_subwords][:5] or [self.unk_token_id]


    def compute_failure_links(self):
        root = self.root
        root.failure_link = root  # Root's failure link points to itself
        queue = [root]

        while queue:
            current_node = queue.pop(0)

            for char, child_node in current_node.children.items():
                queue.append(child_node)

                # Follow failure link to find the longest suffix for the child_node
                failure_candidate = current_node.failure_link
                while failure_candidate != root and char not in failure_candidate.children:
                    failure_candidate = failure_candidate.failure_link
                child_node.failure_link = failure_candidate.children.get(char, root)


class SimpleSentencePiece:
    def __init__(self, model_type="bpe", vocab_size=30522):
        self.vocab = {}
        self.id_to_subword = {}
        self.unk_token = "[UNK]"
        self.unk_token_id = 0
        self.vocab_size = vocab_size
        self.model = None if model_type == "bpe" else None
        self.model_type = model_type

    def train(self, text):
        if self.model_type == "bpe":
            self.model = BPE(num_merges=self.vocab_size, unk_token_id=self.unk_token_id)
            self.model.train(text)
            self.vocab = self.model.vocab
            self.id_to_subword = {i: word for word, i in self.vocab.items()}
        else:
            raise NotImplementedError(f"Model type {self.model_type} not supported yet.")

    def encode(self, text):
        text = self.preprocess_text(text)  # Preprocess text before encoding
        if not self.model:
            raise ValueError("Model has not been trained yet.")
        encoded = self.model.encode(text)
        #print(f"Encoded: {encoded[:10]}")  # Print first 10 encoded tokens
        return encoded

    def decode(self, ids):
        if not self.id_to_subword:
            raise ValueError("Vocabulary is empty. Ensure the model is trained first.")
        text = " ".join([self.id_to_subword.get(id_, self.unk_token) for id_ in ids])
        text = text.replace(" </w>", "").replace("</w>", " ").strip()
        return text

    def preprocess_text(self, text):
        # Convert text to lowercase to ensure case insensitivity
        text = text.lower()
        # Optionally, handle punctuation by adding spaces around it for better tokenization
        text = re.sub(r'([.,!?()])', r' \1 ', text)
        # Replace multiple spaces with a single space
        text = re.sub(r'\s+', ' ', text)
        # Trim leading and trailing spaces
        text = text.strip()
        return text
    
    def save_model(self, filepath):
        model_data = {
            'vocab': self.vocab,
            'id_to_subword': self.id_to_subword,
            'model_type': self.model_type,
            'vocab_size': self.vocab_size,
            # Potentially include other relevant attributes
        }
        # Save the high-level tokenizer settings
        with open(filepath, 'w') as f:
            json.dump(model_data, f)
        
        # Now save the BPE model specifically
        if self.model_type == "bpe" and self.model:
            self.model.save_model(filepath + "_bpe")

    def load_model(self, filepath):
        with open(filepath, 'r') as f:
            model_data = json.load(f)
        
        self.vocab = model_data['vocab']
        self.id_to_subword = model_data['id_to_subword']
        self.model_type = model_data['model_type']
        self.vocab_size = model_data['vocab_size']
        
        # Assuming model_type is still "bpe", we now load the BPE model
        if self.model_type == "bpe":
            self.model = BPE(self.vocab_size, self.unk_token_id)
            self.model.load_model(filepath + "_bpe")

class BPE:
    def __init__(self, num_merges=100, unk_token_id=0):  # Accept unk_token_id parameter
        self.vocab = {}
        self.merges = []
        self.num_merges = num_merges
        self.unk_token_id = unk_token_id  # Store the unknown token ID

    def train(self, text):
        words = re.findall(r'\w+|[^\w\s]', text, re.UNICODE)
        vocab = collections.Counter(words)
        vocab = {word + '</w>': count for word, count in vocab.items()}
        
        for _ in range(self.num_merges):  # Use the num_merges from the instance variable
            pairs = self.get_stats(vocab)
            if not pairs:
                break
            best = max(pairs, key=pairs.get)
            vocab = self.merge_vocab(best, vocab)
            self.merges.append(best)

        self.vocab = {word: i for i, word in enumerate(vocab.keys())}

    @staticmethod
    def get_stats(vocab):
        pairs = collections.defaultdict(int)
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols)-1):
                pairs[symbols[i], symbols[i+1]] += freq
        return pairs

    @staticmethod
    def merge_vocab(pair, vocab):
        v_out = {}
        bigram = re.escape(' '.join(pair))
        p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
        for word in vocab:
            w_out = p.sub(''.join(pair), word)
            v_out[w_out] = vocab[word]
        return v_out

    def encode(self, text):
        """Encode text into subwords using learned BPE merges."""
        encoded_tokens = []
        for word in re.findall(r'\w+|[^\w\s]', text, re.UNICODE):
            word += '</w>'
            subwords = [word]  # Start with the entire word as one subword
            for merge in self.merges:
                new_subwords = []
                for subword in subwords:
                    # If the merge is in subword, split it; otherwise, keep it as is
                    if ' '.join(merge) in subword:
                        new_subwords.extend(subword.replace(' '.join(merge), ''.join(merge)).split(' '))
                    else:
                        new_subwords.append(subword)
                subwords = new_subwords
            encoded_tokens.extend(subwords)
        return [self.vocab.get(token, self.unk_token_id) for token in encoded_tokens]
    
        # New method to save trained model
    def save_model(self, filepath):
        bpe_data = {
            'merges': self.merges,
            'vocab': self.vocab,
            'num_merges': self.num_merges,
            # Include other attributes as needed
        }
        with open(filepath, 'w') as f:
            json.dump(bpe_data, f)

    def load_model(self, filepath):
        with open(filepath, 'r') as f:
            bpe_data = json.load(f)
        
        self.merges = bpe_data['merges']
        self.vocab = bpe_data['vocab']
        self.num_merges = bpe_data['num_merges']


class WordPiece:
    def __init__(self, vocab, unk_token_id=0, unk_token="[UNK]"):
        self.vocab = vocab
        self.unk_token_id = unk_token_id
        self.unk_token = unk_token  # Define the unknown token
        self.root = self.build_trie(vocab)
        self.id_to_token = {id_: token for token, id_ in vocab.items()}  # Inverse mapping
        self.compute_failure_links(self.root)
        print("Trie built successfully.")

    def convert_ids_to_tokens(self, ids):
        """
        Convert a list of token ids back to their string token representations.
        """
        return [self.id_to_token.get(id_, self.unk_token) for id_ in ids]

    # Add debug prints to build_trie to confirm structure
    def build_trie(self, vocab):
        root = TrieNode()
        for token in vocab:
            node = root
            for char in token:
                if char not in node.children:
                    node.children[char] = TrieNode()
                node = node.children[char]
            node.is_end = True
            node.token = token
        #print("Trie Construction Completed Successfully")
        return root


    def compute_failure_links(self, root):
        queue = [root]
        while queue:
            current_node = queue.pop(0)
            for char, child_node in current_node.children.items():
                failure_node = current_node.failure_link
                while failure_node and char not in failure_node.children:
                    failure_node = failure_node.failure_link
                child_node.failure_link = failure_node.children[char] if failure_node else root
                queue.append(child_node)

    # Improved debug prints in tokenize method
                
    def tokenize(self, text):
        # Preprocess input text
        text = self.preprocess_text(text)
        node = self.root
        token_ids = []  # Will store token IDs instead of tokens
        i = 0

        while i < len(text):
            char = text[i]
            if char == ' ':
                node = self.root
                i += 1
                continue

            if char not in node.children:
                if node != self.root and node.token is not None:
                    # Convert found token to its ID
                    token_id = self.vocab.get(node.token, self.unk_token_id)
                    token_ids.append(token_id)
                    node = self.root  # Reset to root
                    continue
                else:
                    # Append unknown token ID
                    token_ids.append(self.unk_token_id)
                    i += 1
                    continue

            node = node.children[char]
            if node.is_end:
                if i + 1 == len(text) or text[i + 1] == ' ':
                    # Convert found token to its ID
                    token_id = self.vocab.get(node.token, self.unk_token_id)
                    token_ids.append(token_id)
                    node = self.root

            i += 1

        #print(f"Token IDs: {token_ids[:10]}")
        return token_ids

    def preprocess_text(self, text):
        # Convert text to lowercase to ensure case insensitivity
        text = text.lower()

        # Optionally, handle punctuation by adding spaces around it for better tokenization
        # This depends on how your vocabulary handles punctuation
        text = re.sub(r'([.,!?()])', r' \1 ', text)

        # Replace multiple spaces with a single space
        text = re.sub(r'\s+', ' ', text)

        # Trim leading and trailing spaces
        text = text.strip()

        return text


def load_corpus(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        texts = [line.strip() for line in file.readlines()]
    return texts

texts = load_corpus("D:\\EXPERT_WEIGHTS\\sample.txt")
#texts = load_corpus("/content/drive/MyDrive/EXPERT_STUFF/sample.txt")
# texts = load_corpus("C:/Users/robbi/Expert/sample.txt")
num_merges = 100
def adapt_vocab_for_wordpiece(ssp_vocab):
    adapted_vocab = {}
    for token, id_or_freq in ssp_vocab.items():
        if not token.startswith(" ") and not token.endswith("</w>"):
            adapted_token = "##" + token.replace("</w>", "")
        else:
            adapted_token = token.replace("</w>", "")
        adapted_vocab[adapted_token] = id_or_freq
    return adapted_vocab

# Initialize and train the SimpleSentencePiece model with BPE
ssp = SimpleSentencePiece(model_type="bpe", vocab_size=30522)
# Assume `texts` is a list of text to train the tokenizer
ssp.train('\n'.join(texts))
wordpiece_vocab = adapt_vocab_for_wordpiece(ssp.vocab)
# Debugging step to ensure vocabulary completeness
def debug_vocab(adapted_vocab):
    print("Sample Vocabulary Check:")
    # Iterate over the first 10 key-value pairs in the adapted vocabulary
    for i, (token, id_or_freq) in enumerate(adapted_vocab.items()):
        print(f"{token}: {id_or_freq}")
        if i >= 9:  # Stop after printing 10 entries
            break
    # Specifically check for subtokens if your tokenizer expects them
    subtokens = [token for token in adapted_vocab.keys() if token.startswith("##")]
    print(f"Found {len(subtokens)} subtokens in vocabulary.")

# Ensure wordpiece_vocab is a list of vocabulary tokens
# debug_vocab(wordpiece_vocab)  # Call this after initializing wordpiece_vocab

# Initialize WordPiece with the adapted vocabulary
wordpiece_tokenizer = WordPiece(wordpiece_vocab, unk_token_id=0, unk_token="[UNK]")


class WikiTextDatasetForLM(Dataset):
    def __init__(self, texts, tokenizer, sequence_length):
        self.tokenizer = tokenizer
        self.sequence_length = sequence_length
        self.inputs, self.labels = self.process_texts(texts)
        
    def process_texts(self, texts):
        inputs, labels = [], []
        step_size = 256  # Example step size for overlapping sequences
        for text in texts:
            token_ids = self.tokenizer.tokenize(text)
            for i in range(0, len(token_ids) - self.sequence_length, step_size):  # Adjust step_size for overlap
                inputs.append(token_ids[i:i+self.sequence_length])
                labels.append(token_ids[i+1:i+self.sequence_length+1])
        return torch.tensor(inputs, dtype=torch.long), torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]


# Assuming vocab_size is determined by your tokenizer's vocabulary size
vocab_size = len(wordpiece_tokenizer.vocab)  # Adjust this to match your tokenizer's vocab size method
print(vocab_size)
# Define other parameters for SPLASH model instantiation
embed_size = 512  # Embedding size
heads = 8  # Number of attention heads
sequence_length = 1024  # Input sequence length
projection_dim = 256  # Dimension for projections inside SPLASH
partition_size = 128  # Size of partitions for SPLASH processing
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Instantiate SPLASH Model
splash_model = SPLASH(vocab_size=vocab_size, embed_size=embed_size, heads=heads, 
                      sequence_length=sequence_length, projection_dim=projection_dim, 
                      partition_size=partition_size).to(device)

# Load Wikipedia dataset and preprocess
dataset = load_dataset("wikitext", "wikitext-2-v1", split="train")
batch_size = 32  # Adjust based on your hardware capabilities
num_epochs  = 5
loss_function = nn.CrossEntropyLoss()  # Example, adjust as needed
optimizer = torch.optim.Adam(splash_model.parameters(), lr=1e-4)
# Extract texts from the dataset
texts = dataset['text']

# Initialize the dataset for language modeling
wiki_dataset_for_lm = WikiTextDatasetForLM(texts, wordpiece_tokenizer, sequence_length=1024)
print(f"len(wiki_dataset_for_lm): {len(wiki_dataset_for_lm)}")

# DataLoader
data_loader = DataLoader(wiki_dataset_for_lm, batch_size=32, shuffle=True)

for epoch in range(num_epochs):
    for i, (input_ids, labels) in enumerate(data_loader):
        print(f"Batch {i} input_ids shape: {input_ids.shape}, labels shape: {labels.shape}")
        if i > 5:  # Just to limit output for debugging
            break

    for input_ids, labels in data_loader:
        input_ids, labels = input_ids.to(device), labels.to(device)
        print(f"input_ids: {input_ids.shape}")
        print(f"labels: {labels.shape}")

        optimizer.zero_grad()

        # Forward pass
        outputs = splash_model(input_ids)

        # Flatten the output and label tensors for use with CrossEntropyLoss
        outputs_flat = outputs.view(-1, outputs.size(-1))
        labels_flat = labels.view(-1)

        # Compute loss
        loss = loss_function(outputs_flat, labels_flat)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")



Trie built successfully.
2749
len(wiki_dataset_for_lm): 3
Batch 0 input_ids shape: torch.Size([3, 1024]), labels shape: torch.Size([3, 1024])
input_ids: torch.Size([3, 1024])
labels: torch.Size([3, 1024])
BEGIN FORWARD
x input shape: torch.Size([3, 1024, 512])
x_reshaped shape: torch.Size([3, 1024, 8, 64])
values : torch.Size([3, 1024, 8, 64])
queries : torch.Size([3, 1024, 8, 128])
keys : torch.Size([3, 1024, 8, 128])
attention_scores after random projection: torch.Size([3, 8, 1024, 128])
PARTITION START
C_keys shape before return: torch.Size([3, 128, 8, 128])
Partition Start 0, Partition End 128 , ponder_scores: torch.Size([3, 8, 128, 1])
BEFORE 1ST EINSUM:
ponder_scores_permuted shape: torch.Size([3, 128, 8, 1])
ponder_scores_broadcastable shape: torch.Size([3, 128, 8, 128])
queries_part shape: torch.Size([3, 128, 8, 128])
C_keys shape: torch.Size([3, 128, 8, 128])
AFTER 1ST EINSUM:
energy shape: torch.Size([3, 128, 8, 128])
attention_weights shape: torch.Size([3, 128, 8, 128])
atte

# Testing 


In [2]:
from datasets import load_dataset
import torch

# Placeholder for a simple token to ID mapping
token_to_id = {}
next_id = 1  # Start with 1 since 0 might be reserved for padding

def simulate_tokenize(text):
    global next_id
    # Split text into "tokens" based on spaces
    tokens = text.split()
    # Convert tokens to IDs
    token_ids = []
    for token in tokens:
        if token not in token_to_id:
            token_to_id[token] = next_id
            next_id += 1
        token_ids.append(token_to_id[token])
    return token_ids

def process_texts(texts, sequence_length=512, step_size=256):
    inputs, labels = [], []
    for text in texts:
        token_ids = simulate_tokenize(text)
        print(f"Text length (tokens): {len(token_ids)}")
        num_sequences = 0
        for i in range(0, len(token_ids) - sequence_length, step_size):
            inputs.append(token_ids[i:i+sequence_length])
            labels.append(token_ids[i+1:i+sequence_length+1])
            num_sequences += 1
        print(f"Generated {num_sequences} sequences from text")
    # Convert lists of integers to tensors
    return torch.tensor(inputs, dtype=torch.long), torch.tensor(labels, dtype=torch.long)

# Load a subset of the dataset for demonstration
dataset = load_dataset("wikitext", "wikitext-2-v1", split="train[:5%]")

# Print lengths of a few sample texts
texts = dataset['text'][:5]
for i, text in enumerate(texts):
    print(f"Sample {i+1} length: {len(text.split())} words")

# Merge texts into a single continuous text
continuous_text = ' '.join(dataset['text'])
texts = [continuous_text]  # Now you have a single large text

# Process the merged text
inputs, labels = process_texts(texts)


Sample 1 length: 0 words
Sample 2 length: 5 words
Sample 3 length: 0 words
Sample 4 length: 127 words
Sample 5 length: 91 words
Text length (tokens): 105648
Generated 411 sequences from text
