In [1]:
with open('input.txt', 'r') as f:
    text = f.read()
data = text[:1000] # first 1,000 characters
print(data[:200])


First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


In [2]:
print(len(text))

1115394


In [3]:
print(len(data))

1000


In [4]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

In [5]:
print(vocab_size)

65


In [6]:
print(''.join(chars))


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


# Tokenization

## Character Tokeninzation

In [7]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string


<pre>
<b>stoi</b> = “string to integer”  
→ A dictionary that maps each character in the list <i>chars</i> to a unique integer ID.

<b>itos</b> = “integer to string”  
→ A reverse dictionary that maps integers back to their corresponding characters.

<b>Example:</b>
    chars = ['a', 'b', 'c']

    stoi = {'a': 0, 'b': 1, 'c': 2}
    itos = {0: 'a', 1: 'b', 2: 'c'}

<b>{ch: i for i, ch in enumerate(chars)}</b>  
→ A dictionary comprehension.

<b>Meaning:</b>
    For each (i, ch) pair from enumerate(chars),
    Create a dictionary entry where:
        - key = ch
        - value = i

<b>Result:</b>
    {'a': 0, 'b': 1, 'c': 2}

This is called a "lookup table".
In this case, it maps characters to integers.

--------------------------------------------------

<b>List Comprehension Syntax:</b>
    [new_element for variable in iterable]

<b>Equivalent to:</b>
    result = []
    for variable in iterable:
        result.append(new_element)

<b>Example:</b>
    numbers = [1, 2, 3, 4]

    # Using a loop:
    squares = []
    for num in numbers:
        squares.append(num ** 2)

    # Using list comprehension:
    squares = [num ** 2 for num in numbers]

--------------------------------------------------

<b>Dictionary Comprehension Syntax:</b>
    {key_expr: value_expr for item in iterable}

<b>Example:</b>
    squares_dict = {num: num ** 2 for num in range(5)}
    # Output: {0: 0, 1: 1, 2: 4, 3: 9, 4: 16}

--------------------------------------------------

<b>Lambda Function Template:</b>
    lambda arguments: expression

<b>Example:</b>
    add = lambda x, y: x + y
    add(3, 5) => 8

<b>encode = lambda s: [stoi[c] for c in s]</b>

<b>Explanation:</b>
    - Takes one input s (a string)
    - Uses a list comprehension to:    
        • Iterate over each character c in s  
        • Look up its integer value from the 'stoi' dictionary  
        • Return the list of integers (encoded form of the string)

--------------------------------------------------

<b>[stoi[c] for c in s]</b>
    → A list comprehension.

<b>General Template:</b>
    [new_element for variable in iterable]

<b>Equivalent to:</b>
    result = []
    for variable in iterable:
        result.append(new_element)

<b>So,</b>
    [stoi[c] for c in s]

<b>Is the same as:</b>
    result = []
    for c in s:
        result.append(stoi[c])
</pre>


**How enumerate works** <br>
*enumerate(chars) gives you pairs of (index, value) as you loop through the list.*


In [8]:
for i, ch in enumerate(chars[:20]):
    print(i, ch)

0 

1  
2 !
3 $
4 &
5 '
6 ,
7 -
8 .
9 3
10 :
11 ;
12 ?
13 A
14 B
15 C
16 D
17 E
18 F
19 G


In [9]:
print(encode("Hii !! How are you"))

[20, 47, 47, 1, 2, 2, 1, 20, 53, 61, 1, 39, 56, 43, 1, 63, 53, 59]


In [10]:
print(decode([20, 47, 47, 1, 2, 2, 1, 20, 53, 61, 1, 39, 56, 43, 1, 63, 53, 59]))

Hii !! How are you


In [11]:
import torch
print(torch.__version__)

2.6.0


In [12]:
import torch
import torch.nn as nn
from torch.nn import functional as F

## Train test split

In [13]:

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [14]:
print(data.shape,data.dtype)

torch.Size([1115394]) torch.int64


In [15]:
print(data[:1000])

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
        47, 59, 57,  1, 47, 57,  1, 41, 

## Utf encoding

**Only have 256 token size for utf8. And for utf 16 and 32 there is a lot of waste space**

In [16]:
ord('P') # ord() = “ordinal” is used to convert a character into its Unicode (or ASCII) integer code point.

80

In [17]:
chr(80) # chr is used to convert it back

'P'

In [18]:
[ord(x) for x in "Hello how are you" ]

[72,
 101,
 108,
 108,
 111,
 32,
 104,
 111,
 119,
 32,
 97,
 114,
 101,
 32,
 121,
 111,
 117]

In [19]:
list("Hello how are you".encode("utf-8"))

[72,
 101,
 108,
 108,
 111,
 32,
 104,
 111,
 119,
 32,
 97,
 114,
 101,
 32,
 121,
 111,
 117]

In [20]:
list("Hello how are you".encode("utf-16"))

[255,
 254,
 72,
 0,
 101,
 0,
 108,
 0,
 108,
 0,
 111,
 0,
 32,
 0,
 104,
 0,
 111,
 0,
 119,
 0,
 32,
 0,
 97,
 0,
 114,
 0,
 101,
 0,
 32,
 0,
 121,
 0,
 111,
 0,
 117,
 0]

In [21]:
list("Hello how are you".encode("utf-32"))

[255,
 254,
 0,
 0,
 72,
 0,
 0,
 0,
 101,
 0,
 0,
 0,
 108,
 0,
 0,
 0,
 108,
 0,
 0,
 0,
 111,
 0,
 0,
 0,
 32,
 0,
 0,
 0,
 104,
 0,
 0,
 0,
 111,
 0,
 0,
 0,
 119,
 0,
 0,
 0,
 32,
 0,
 0,
 0,
 97,
 0,
 0,
 0,
 114,
 0,
 0,
 0,
 101,
 0,
 0,
 0,
 32,
 0,
 0,
 0,
 121,
 0,
 0,
 0,
 111,
 0,
 0,
 0,
 117,
 0,
 0,
 0]

In [22]:
print(data)

tensor([18, 47, 56,  ..., 45,  8,  0])


In [23]:
data = text[:1000] # first 1,000 characters
print(data[:500])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor


In [24]:
print(data)

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [25]:
token = text.encode("utf-8")

In [26]:
text = """On the night of 5 June, HMS Cardiff was stationed to the east of the islands to provide gunfire support to the land forces and intercept enemy aircraft. At around 02:00 a radar contact was detected; a British Army Air Corps Westland Gazelle helicopter was making a routine delivery of personnel and equipment to a radio rebroadcast station on East Falkland. From the contact's speed and course, Cardiff's operations room crew assumed it to be hostile. One Sea Dart missile was fired, missing the target. A second destroyed it.[1] The Gazelle's wreckage and crew were discovered the next morning, and the loss was attributed to enemy fire. Although Cardiff was suspected, later scientific tests on the wreckage proved inconclusive. No formal inquiry was held until four years later. Defending their claim that the helicopter had been lost in action, the United Kingdom's Ministry of Defence (MoD) stated that they had not wanted to "cause further anguish to relatives" while they were still trying to ascertain how the Gazelle had been shot down. The board of inquiry finally confirmed that the soldiers died due to friendly fire. It recommended that "neither negligence nor blame should be attributed to any individual",[2] but identified several factors. A lack of communication between the army and the navy meant that 5th Infantry Brigade had not notified anyone of the helicopter's flight. The navy had not informed the land forces that Cardiff had changed position to set up an ambush for Argentine aircraft travelling over the area. The helicopter's identification friend or foe (IFF) transmitter was turned off, because it caused interference with the army's Rapier anti-aircraft missile system. The board of inquiry's findings prompted criticism of the MoD's initial response to the incident."""
token  = text.encode("utf-8")
token = list(token)

**Converting text to Unicode bytes (i.e., byte encoding), in order to tokenize it at the byte level for Byte Pair Encoding (BPE).**

In [27]:
print(len(token))
print(token[:50]) 

1800
[79, 110, 32, 116, 104, 101, 32, 110, 105, 103, 104, 116, 32, 111, 102, 32, 53, 32, 74, 117, 110, 101, 44, 32, 72, 77, 83, 32, 67, 97, 114, 100, 105, 102, 102, 32, 119, 97, 115, 32, 115, 116, 97, 116, 105, 111, 110, 101, 100, 32]


In [28]:
def get_stats(ids):
    counts = {} # Creates an empty Dictionary
    for pair in zip(ids,ids[1:]):
        counts[pair] = counts.get(pair ,0)+1
    return counts
    

**How zip works**
<pre>
```
a = [10, 20, 30, 40]
b = [20, 30, 40]
z = zip(a, b)
z = zip(a, b) → [
    (10, 20),  # a[0], b[0]
    (20, 30),  # a[1], b[1]
    (30, 40)   # a[2], b[2]
    # a[3] has no match in b → ignored
]

a = [10, 20, 30, 40]
     ↓   ↓   ↓
b = [20, 30, 40]

Result:
[
  (10, 20),
  (20, 30),
  (30, 40)
]
```
</pre>


**How counts.get(pair, 0) works**
<pre>
counts.get(pair, 0)
	•	Looks for pair in the counts dictionary.
	•	If the pair already exists, it returns its current count.
	•	If the pair doesn’t exist, it returns 0 by default (thanks to the get(..., 0) part).
</pre>

In [29]:
stats = get_stats(token)
print(stats)

{(79, 110): 2, (110, 32): 13, (32, 116): 49, (116, 104): 35, (104, 101): 36, (101, 32): 53, (32, 110): 10, (110, 105): 6, (105, 103): 4, (103, 104): 3, (104, 116): 2, (116, 32): 30, (32, 111): 15, (111, 102): 10, (102, 32): 12, (32, 53): 2, (53, 32): 1, (32, 74): 1, (74, 117): 1, (117, 110): 5, (110, 101): 13, (101, 44): 2, (44, 32): 7, (32, 72): 1, (72, 77): 1, (77, 83): 1, (83, 32): 1, (32, 67): 5, (67, 97): 4, (97, 114): 14, (114, 100): 6, (100, 105): 11, (105, 102): 8, (102, 102): 5, (32, 119): 15, (119, 97): 9, (97, 115): 13, (115, 32): 27, (32, 115): 15, (115, 116): 14, (116, 97): 7, (97, 116): 17, (116, 105): 21, (105, 111): 8, (111, 110): 17, (101, 100): 24, (100, 32): 43, (116, 111): 14, (111, 32): 15, (32, 101): 4, (101, 97): 5, (32, 105): 17, (105, 115): 9, (115, 108): 1, (108, 97): 11, (97, 110): 21, (110, 100): 19, (100, 115): 1, (32, 112): 5, (112, 114): 3, (114, 111): 9, (111, 118): 4, (118, 105): 2, (105, 100): 5, (100, 101): 9, (32, 103): 1, (103, 117): 2, (110, 102): 

In [30]:
print(sorted(((v, k) for k, v in stats.items()), reverse=True))

[(53, (101, 32)), (49, (32, 116)), (43, (100, 32)), (36, (104, 101)), (35, (116, 104)), (30, (116, 32)), (27, (115, 32)), (27, (32, 97)), (25, (105, 110)), (24, (101, 100)), (23, (101, 114)), (21, (116, 105)), (21, (97, 110)), (20, (116, 101)), (19, (110, 100)), (18, (101, 110)), (17, (114, 101)), (17, (111, 110)), (17, (97, 116)), (17, (32, 105)), (16, (121, 32)), (16, (32, 102)), (15, (114, 32)), (15, (111, 32)), (15, (110, 116)), (15, (32, 119)), (15, (32, 115)), (15, (32, 111)), (14, (116, 111)), (14, (115, 116)), (14, (97, 114)), (13, (110, 101)), (13, (110, 32)), (13, (105, 114)), (13, (104, 97)), (13, (99, 111)), (13, (97, 115)), (13, (46, 32)), (13, (32, 104)), (12, (102, 32)), (12, (101, 108)), (11, (111, 114)), (11, (108, 97)), (11, (105, 116)), (11, (102, 105)), (11, (100, 105)), (11, (32, 99)), (10, (111, 102)), (10, (110, 103)), (10, (97, 100)), (10, (32, 110)), (10, (32, 98)), (9, (119, 97)), (9, (114, 111)), (9, (114, 97)), (9, (105, 115)), (9, (100, 101)), (9, (39, 115)

In [31]:
top_pair = max(stats,key=stats.get)
top_pair

(101, 32)

**Decoding using chr()**

In [32]:
# Get top 5 most frequent pairs using stats.get as the key
top_5_pairs = sorted(stats, key=stats.get, reverse=True)[:5]

# Print results using chr()
for i, pair in enumerate(top_5_pairs, 1):
    print(f"\nTop {i}: Pair {pair} (count={stats[pair]})")
    print("Characters:")
    for b in pair:
        print(f"'{chr(b)}'")


Top 1: Pair (101, 32) (count=53)
Characters:
'e'
' '

Top 2: Pair (32, 116) (count=49)
Characters:
' '
't'

Top 3: Pair (100, 32) (count=43)
Characters:
'd'
' '

Top 4: Pair (104, 101) (count=36)
Characters:
'h'
'e'

Top 5: Pair (116, 104) (count=35)
Characters:
't'
'h'


## Merge Function

In [33]:
def merge(ids, pair, idx):
  # in the list of ints (ids), replace all consecutive occurences of pair with the new token idx
  newids = []
  i = 0
  while i < len(ids):
    # if we are not at the very last position AND the pair matches, replace it
    if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids


## Merge Function for Byte Pair Encoding (BPE)


```python
def merge(ids, pair, idx):
```
This function replaces **every** occurrence of a specific **pair** of integers in a **list (ids)** with a new single integer token **(idx).**

Exapmple -
	•	ids: A list of integers, each representing a character/token (e.g., [104, 101, 108, 108, 111] for "hello").
	•	pair: A tuple like (108, 108), meaning we want to merge this adjacent pair ('l', 'l').
	•	idx: A new integer token to represent the merged pair, like 999.

---
```python
newids = []
i = 0
```
•	newids will store the updated list after merging.<br>
•	i is your index to move through ids.

---

``` python
if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
```

This line checks whether the current and next elements in the list ids form the target pair that we want to merge.

---
 ``` python 
     i < len(ids) - 1 
 ```
•	Why: To make sure that we don’t go out of bounds when accessing ids[i+1].
•	You’re only allowed to look ahead one step if i is not the last index.
•	For example, if len(ids) == 5, valid i values to check ids[i+1] are 0, 1, 2, 3 (not 4).

---
 ``` python       
ids[i] == pair[0]
```
•	Checks whether the current token matches the first token of the pair we want to merge.
 ``` python 
ids[i+1] == pair[1]
```
•	Checks whether the next token matches the second token of the pair.

---
        
 ``` python 
    newids.append(idx)
    i += 2
```
Replace the entire pair with the new token idx.
	•	Skip both tokens (i += 2) since you merged them.

``` python
else:
    newids.append(ids[i])
    i += 1
i += 2
```
If it’s not the target pair, just copy the current token to newids and move to the next one.



In [34]:
print(merge([5, 6, 6, 7, 9, 1], (6, 7), 99))


[5, 6, 99, 9, 1]


In [35]:
tokens2 = merge(token, top_pair, 256)
print(tokens2[:50])
print("length:", len(tokens2))

[79, 110, 32, 116, 104, 256, 110, 105, 103, 104, 116, 32, 111, 102, 32, 53, 32, 74, 117, 110, 101, 44, 32, 72, 77, 83, 32, 67, 97, 114, 100, 105, 102, 102, 32, 119, 97, 115, 32, 115, 116, 97, 116, 105, 111, 110, 101, 100, 32, 116]
length: 1747


In [36]:
vocab_size = 276 # the desired final vocabulary size
num_merges = vocab_size - 256
ids = list(token) # copy so we don't destroy the original list

merges = {} # (int, int) -> int
for i in range(num_merges):
  stats = get_stats(ids)
  pair = max(stats, key=stats.get)
  idx = 256 + i
  print(f"merging {pair} into a new token {idx}")
  ids = merge(ids, pair, idx)
  merges[pair] = idx

merging (101, 32) into a new token 256
merging (32, 116) into a new token 257
merging (100, 32) into a new token 258
merging (257, 104) into a new token 259
merging (105, 110) into a new token 260
merging (115, 32) into a new token 261
merging (101, 114) into a new token 262
merging (116, 32) into a new token 263
merging (259, 256) into a new token 264
merging (97, 110) into a new token 265
merging (116, 105) into a new token 266
merging (101, 110) into a new token 267
merging (111, 110) into a new token 268
merging (121, 32) into a new token 269
merging (97, 114) into a new token 270
merging (105, 114) into a new token 271
merging (46, 32) into a new token 272
merging (101, 100) into a new token 273
merging (111, 32) into a new token 274
merging (101, 108) into a new token 275


In [37]:
print("tokens length:", len(token))
print("ids length:", len(ids))
print(f"compression ratio: {len(token) / len(ids):.2f}X")

tokens length: 1800
ids length: 1359
compression ratio: 1.32X


In [38]:
print(token[:50])
print("")
print(ids[:50])
print("")
print("length:", len(ids))

[79, 110, 32, 116, 104, 101, 32, 110, 105, 103, 104, 116, 32, 111, 102, 32, 53, 32, 74, 117, 110, 101, 44, 32, 72, 77, 83, 32, 67, 97, 114, 100, 105, 102, 102, 32, 119, 97, 115, 32, 115, 116, 97, 116, 105, 111, 110, 101, 100, 32]

[79, 110, 264, 110, 105, 103, 104, 263, 111, 102, 32, 53, 32, 74, 117, 110, 101, 44, 32, 72, 77, 83, 32, 67, 270, 100, 105, 102, 102, 32, 119, 97, 261, 115, 116, 97, 266, 268, 273, 257, 111, 264, 101, 97, 115, 263, 111, 102, 264, 105]

length: 1359


In [39]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]

def decode(ids):
  # given ids (list of integers), return Python string
  tokens = b"".join(vocab[idx] for idx in ids)
  text = tokens.decode("utf-8", errors="replace")
  return text

print(decode([122]))

z


In [40]:
merges

{(101, 32): 256,
 (32, 116): 257,
 (100, 32): 258,
 (257, 104): 259,
 (105, 110): 260,
 (115, 32): 261,
 (101, 114): 262,
 (116, 32): 263,
 (259, 256): 264,
 (97, 110): 265,
 (116, 105): 266,
 (101, 110): 267,
 (111, 110): 268,
 (121, 32): 269,
 (97, 114): 270,
 (105, 114): 271,
 (46, 32): 272,
 (101, 100): 273,
 (111, 32): 274,
 (101, 108): 275}

In [41]:
def encode(text):
  # given a string, return list of integers (the tokens)
  tokens = list(text.encode("utf-8"))
  while len(tokens) >= 2:
    stats = get_stats(tokens)
    pair = min(stats, key=lambda p: merges.get(p, float("inf")))
    if pair not in merges:
      break # nothing else can be merged
    idx = merges[pair]
    tokens = merge(tokens, pair, idx)
  return tokens

print(encode("hello world"))

[104, 275, 108, 274, 119, 111, 114, 108, 100]


In [42]:
print(decode(encode("hello world")))

hello world


In [43]:
import regex as re
gpt2pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

print(re.findall(gpt2pat, "Hello've world123 how's   124324235235532are              you!!!?          "))

['Hello', "'ve", ' world', '123', ' how', "'s", '  ', ' 124324235235532', 'are', '             ', ' you', '!!!?', '          ']


cl100k_base(gpt4)

In [44]:
gpt2pat = re.compile(r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s""")

print(re.findall(gpt2pat, "Hello've world123 how's 1132423134121214are          you!!!?          "))


['Hello', "'ve", ' world', '123', ' how', "'s", ' ', '113', '242', '313', '412', '121', '4', 'are', '         ', ' you', '!!!?', '          ']
