## Problem: Write a Byte Pain Encoder in Python

### Problem Statement
Implement a **Transformer model** in PyTorch by completing the required sections. The model should consist of an embedding layer, a Transformer encoder, and an output layer for sequence processing and prediction.

### Requirements
1. **Define the Transformer Model Architecture**:
   - **Embedding Layer**:
     - Implement a layer to transform input data into a higher-dimensional space.
     - Use a `torch.nn.Linear` or `torch.nn.Embedding` layer to create embeddings from the input.
   - **Transformer Encoder**:
     - Use `torch.nn.TransformerEncoder` or `torch.nn.Transformer` to process sequences with attention.
     - Configure parameters such as the number of attention heads and encoder layers.
   - **Output Layer**:
     - Add a fully connected (linear) layer to reduce the transformer's sequence output into the desired output dimension.

2. **Implement the Forward Method**:
   - Map the input to the higher-dimensional space using the embedding layer.
   - Pass the transformed input through the Transformer encoder.
   - Use the output layer to convert the encoded sequence into predictions.

### Constraints
- Handle input padding correctly for variable-length sequences.
- Ensure compatibility with batch processing by correctly shaping input and output tensors.


In [107]:
sorted(tuple(list('asdfg'))[1:-1])
get_all_n_grams('ashutosh',n_min=2,n_max=2)

['as', 'sh', 'hu', 'ut', 'to', 'os', 'sh']

In [121]:
vocab

{('l', 'o', 'w', '</w>'): 1,
 ('l', 'o', 'w', 'e', 'r', '</w>'): 1,
 ('l', 'o', 'w', 'e', 's', 't', '</w>'): 1,
 ('n', 'e', 'w', 'e', 'r', '</w>'): 1,
 ('w', 'i', 'd', 'e', 'r', '</w>'): 1,
 'lo': 3}

In [147]:
list(pair.items())[0]

(('l', 'o'), 3)

In [174]:
corpus

['low', 'lower', 'lowest', 'newer', 'wider']

In [173]:
vocab = get_vocab(corpus)
pair = get_stats(vocab)
merge_vocab(pair, vocab)
byte_pair_encoding(corpus=corpus)

[('l', 'o'), ('o', 'w'), ('w', '</w>')]
[('l', 'o'), ('o', 'w'), ('w', 'e'), ('e', 'r'), ('r', '</w>')]
[('l', 'o'), ('o', 'w'), ('w', 'e'), ('e', 's'), ('s', 't'), ('t', '</w>')]
[('n', 'e'), ('e', 'w'), ('w', 'e'), ('e', 'r'), ('r', '</w>')]
[('w', 'i'), ('i', 'd'), ('d', 'e'), ('e', 'r'), ('r', '</w>')]
most_freq_pair_chars = lo | most_freq_pair = (('l', 'o'), 3)
[('l', 'o'), ('o', 'w'), ('w', '</w>')]
[('l', 'o'), ('o', 'w'), ('w', 'e'), ('e', 'r'), ('r', '</w>')]
[('l', 'o'), ('o', 'w'), ('w', 'e'), ('e', 's'), ('s', 't'), ('t', '</w>')]
[('n', 'e'), ('e', 'w'), ('w', 'e'), ('e', 'r'), ('r', '</w>')]
[('w', 'i'), ('i', 'd'), ('d', 'e'), ('e', 'r'), ('r', '</w>')]
stats = {('l', 'o'): 3, ('o', 'w'): 3, ('w', '</w>'): 1, ('w', 'e'): 3, ('e', 'r'): 3, ('r', '</w>'): 3, ('e', 's'): 1, ('s', 't'): 1, ('t', '</w>'): 1, ('n', 'e'): 1, ('e', 'w'): 1, ('w', 'i'): 1, ('i', 'd'): 1, ('d', 'e'): 1}
most_freq_pair_chars = lo | most_freq_pair = (('l', 'o'), 3)
vocab = {('lo', 'w', '</w>'): 1, (

IndexError: list index out of range

In [172]:
from collections import defaultdict, Counter

def get_vocab(corpus, cased=False):
    """Creates a vocabulary with words split into characters and a special end-of-word token."""
    eow_char = '</w>'
    from collections import Counter
    if cased == True:
        corpus = [lower(word.strip()) for word in corpus]
    word_frequency = Counter(corpus)
    vocab = {}
    for word in word_frequency:
        chars = list(word)
        chars.append(eow_char)
        chars = tuple(chars)
        vocab[chars] = word_frequency.get(word, 1)
    return vocab

# def get_all_n_grams(word):
#     n_grams = []
#     for n_gram in range(len(word)):
#         for i in range(len(word)-n_gram):
#             n_grams.append(word[i:i+n_gram+1])
#     return n_grams


def get_all_n_grams(word, n_min=None, n_max=None):
    n_grams = []
    if n_max == None:
        n_max = len(word)
    if n_min == None:
        n_min = 1
    for n_gram in range(n_min, n_max+1):
        for i in range(len(word)-(n_gram-1)):
            n_grams.append(word[i:i+n_gram])
    print(n_grams)
    return n_grams

def get_stats(vocab):
    """Counts frequency of adjacent symbol pairs."""
    stats = {}
    for characters, word_count in vocab.items():
        # word = ''.join(characters[:-1])
        # word_ngrams = get_all_n_grams(word)
        word_ngrams = get_all_n_grams(characters, n_min=2, n_max=2)
        for pair in word_ngrams:
            stats[pair] = stats.get(pair, 0) + word_count
    return stats

def get_merged_key(key, most_freq_pair):
    """
    1. abcde, bc
    2. abcde, fg
    3. abbcdd, bb
    """
    first_elem = most_freq_pair[0]
    second_elem = most_freq_pair[1]
    first_elem_already_match = False
    new_key = []
    for char in key:
        if char == second_elem:
            if first_elem_already_match:
                # new_key = new_key[:-1] + [first_elem+second_elem]
                new_key[-1] = first_elem+second_elem
        elif char == first_elem:
            first_elem_already_match = True
            new_key.append(char)
        else:
            new_key.append(char)

        if first_elem_already_match == True:
            first_elem_already_match = False
        if char == first_elem:
            first_elem_already_match = True
    return tuple(new_key)


def merge_vocab(pair, vocab):
    """Merges the most frequent pair into a single symbol."""
    sorted_pair_copy = sorted(pair.items(), key=lambda x: x[1], reverse=True)
    most_freq_pair = sorted_pair_copy[0]
    most_freq_pair_chars = ''.join(most_freq_pair[0])
    print(f"most_freq_pair_chars = {most_freq_pair_chars} | most_freq_pair = {most_freq_pair}")
    new_vocab = {}
    for key in vocab:
        val = vocab.get(key)
        new_key = get_merged_key(key, most_freq_pair[0])
        new_vocab[new_key] = val
    # new_vocab.update(vocab)
    return new_vocab

def byte_pair_encoding(corpus, num_merges=10):
    """Performs BPE on a corpus."""
    vocab = get_vocab(corpus)
    for i in range(num_merges):
        stats = get_stats(vocab)
        print(f"stats = {stats}")
        vocab = merge_vocab(stats, vocab)
        print(f"vocab = {vocab}")
        print("="*100)
    return vocab

# Example usage
corpus = ["low", "lower", "lowest", "newer", "wider"]
# get_vocab(corpus)
sorted(get_stats(get_vocab(corpus)).items(), key=lambda x: x[1], reverse=True)
# get_stats(get_vocab(corpus))
# final_vocab, merge_operations = byte_pair_encoding(corpus, num_merges=10)

# print("\nFinal Vocabulary:")
# for word in final_vocab:
#     print(' '.join(word), ":", final_vocab[word])


[('l', 'o'), ('o', 'w'), ('w', '</w>')]
[('l', 'o'), ('o', 'w'), ('w', 'e'), ('e', 'r'), ('r', '</w>')]
[('l', 'o'), ('o', 'w'), ('w', 'e'), ('e', 's'), ('s', 't'), ('t', '</w>')]
[('n', 'e'), ('e', 'w'), ('w', 'e'), ('e', 'r'), ('r', '</w>')]
[('w', 'i'), ('i', 'd'), ('d', 'e'), ('e', 'r'), ('r', '</w>')]


[(('l', 'o'), 3),
 (('o', 'w'), 3),
 (('w', 'e'), 3),
 (('e', 'r'), 3),
 (('r', '</w>'), 3),
 (('w', '</w>'), 1),
 (('e', 's'), 1),
 (('s', 't'), 1),
 (('t', '</w>'), 1),
 (('n', 'e'), 1),
 (('e', 'w'), 1),
 (('w', 'i'), 1),
 (('i', 'd'), 1),
 (('d', 'e'), 1)]

In [112]:
{('t', 'e', 's', 't', '</w>'): 1}

{('t', 'e', 's', 't', '</w>'): 1}

In [151]:
def test_get_vocab():
    corpus = ["test"]
    vocab = get_vocab(corpus)
    print(vocab)
    assert vocab == {('t', 'e', 's', 't', '</w>'): 1}
    print("✓ test_get_vocab passed")

def test_get_stats():
    vocab = {('t', 'e', 's', 't', '</w>'): 1}
    stats = get_stats(vocab)
    expected = {
        ('t', 'e'): 1,
        ('e', 's'): 1,
        ('s', 't'): 1,
        ('t', '</w>'): 1
    }
    assert stats == expected
    print("✓ test_get_stats passed")

def test_merge_vocab():
    vocab = {('t', 'e', 's', 't', '</w>'): 1}
    stats = {
        ('t', 'e'): 1,
        ('e', 's'): 2,
        ('s', 't'): 1,
        ('t', '</w>'): 1
    }
    merged = merge_vocab(stats, vocab)
    # merged = merge_vocab(('e', 's'), vocab)
    print(merged)
    expected = {('t', 'es', 't', '</w>'): 1}
    assert merged == expected
    print("✓ test_merge_vocab passed")

def test_bpe_sequence():
    corpus = ["low", "lower", "newest", "widest"]
    final_vocab, merges = byte_pair_encoding(corpus, num_merges=5)
    assert isinstance(final_vocab, dict)
    assert all(isinstance(pair, tuple) for pair in merges)
    assert len(merges) == 5
    print("✓ test_bpe_sequence passed")

# Run all tests
test_get_vocab()
test_get_stats()
test_merge_vocab()

{('t', 'e', 's', 't', '</w>'): 1}
✓ test_get_vocab passed
[('t', 'e'), ('e', 's'), ('s', 't'), ('t', '</w>')]
✓ test_get_stats passed
most_freq_pair_chars = es | most_freq_pair = (('e', 's'), 2)
{('t', 'es', 't', '</w>'): 1}
✓ test_merge_vocab passed
