In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from merge.modules.config import TransformerConfig
from merge.modules.transformer import Transformer


In [None]:
from transformers import AutoTokenizer as HFAutoTokenizer


class HFTokenizerWrapper:
    """Wrapper for HuggingFace tokenizers that implements our interface"""

    def __init__(self, name: str):
        super().__init__()
        self.tokenizer = HFAutoTokenizer.from_pretrained(name)

    def __getattr__(self, name):
        return getattr(self.tokenizer, name)
    
    def __call__(self, *args, **kwargs):
        # Delegate the call to the underlying tokenizer
        return self.tokenizer(*args, **kwargs)

In [None]:
tokenizer = HFTokenizerWrapper("bert-base-uncased")

In [None]:
text = "This is a test"
tokenizer(text, 2)

In [None]:
hftokenizer = HFAutoTokenizer.from_pretrained("bert-base-uncased")
text = "This is a test"
hftokenizer(text)

In [None]:
import torch
from dataclasses import dataclass
from merge.modules.attention import GQA

@dataclass
class TransformerConfig:
    num_heads: int = 8
    num_kv_heads: int = 2
    d_model: int = 512
    attention_bias: bool = False
    attention_dropout: float = 0.1
    pos_encoding_type: str = None

# Create test inputs
batch_size = 2
seq_length = 16
config = TransformerConfig()

# Initialize GQA layer
gqa = GQA(config)

# Create dummy input tensor
x = torch.randn(batch_size, seq_length, config.d_model)

# Create causal mask
mask = torch.tril(torch.ones((1, 1, seq_length, seq_length)))
mask = torch.where(mask == 1.0, 0.0, -10000.0)

# Forward pass
output = gqa(x, mask, pos_info=None)

# Verification
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
assert output.shape == x.shape, "Output shape doesn't match input shape"

# Test that k/v heads are properly grouped
# Get intermediate tensors for verification
q = gqa.W_Q(x)
kv = gqa.W_KV(x)
k, v = kv.chunk(2, dim=-1)

print(f"\nProjection shapes:")
print(f"Q projection: {q.shape}")  # Should be [batch, seq, d_model]
print(f"KV projection: {kv.shape}")  # Should be [batch, seq, 2 * (d_model//num_heads * num_kv_heads)]

# Verify the grouping ratio
assert gqa.num_heads % gqa.num_kv_heads == 0, "Number of heads must be divisible by number of KV heads"
print(f"\nGrouping ratio (queries per k/v): {gqa.num_heads // gqa.num_kv_heads}")

print("\nTest passed!")

In [11]:


dataset = ["lo w", "lo w e s t", "n e w e r", "w i d e r"]
#dataset = ["low", "lowest", "newer", "wider"]
#count the frequency of letter pairs in the dataset

from collections import Counter

def count_pairs(dataset):
    pairs = Counter()
    for word in dataset:
        tokens = word.split()
        for i in range(len(tokens) - 1):
            pairs[tuple(tokens[i:i+2])] += 1
    return pairs

count_pairs(dataset)


Counter({('lo', 'w'): 2,
         ('w', 'e'): 2,
         ('e', 'r'): 2,
         ('e', 's'): 1,
         ('s', 't'): 1,
         ('n', 'e'): 1,
         ('e', 'w'): 1,
         ('w', 'i'): 1,
         ('i', 'd'): 1,
         ('d', 'e'): 1})

dict_items([])