# AlmondGPT

#### Importing the libraries

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import os

#### Loaded the dataset

In [5]:
# loaded the dataset
path_file = r'../data/raw/input.txt'
try:
    with open(path_file, 'r', encoding='utf-8') as f:
        text = f.read()
        print('Dataset was succesfully loaded!!!')
except Exception as e:
    raise FileNotFoundError(f'This path:{path_file} not found in directory')

Dataset was succesfully loaded!!!


#### Overview

In [6]:
# dataset info
print(f"Length of dataset: {len(text)}")
print(f"Preview dataset:\n{text[:1000]}")

Length of dataset: 6701920
Preview dataset:
User: When did Virgin Australia start operating?
Almond: Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.

User: Which is a species of fish? Tope or Rope
Almond: Tope

User: Why can camels survive for long without water?
Almond: Camels use the fat in their humps to keep them filled with energy and hydration for long periods of time.

User: Alice's parents have three daughters: Amy, Jessy, and whatâ€™s the name of the third daughter?
Almond: The name of the third daughter is Alice

User: When was Tomoaki Komorida born?
Almond: Tomoaki Komorida was born on July 10,1981.

User: If I have more pieces at the time of stalemate, have I won?
Almond: No. 
Stalemate is a drawn position. It doesn't matter who has captured more pieces or is in a winning position

User: Given a reference text about Lollapalooza, where does it take place, who started it and what is it?
Almond: Lollapalooze is an ann

In [9]:
# convert to tokens
tokens = text.encode('utf-8')
tokens = list(map(int, tokens))
print(f"Length of tokens: {len(tokens)}")
print(f"Preview tokens:\n{tokens[:1000]}")

Length of tokens: 6711368
Preview tokens:
[85, 115, 101, 114, 58, 32, 87, 104, 101, 110, 32, 100, 105, 100, 32, 86, 105, 114, 103, 105, 110, 32, 65, 117, 115, 116, 114, 97, 108, 105, 97, 32, 115, 116, 97, 114, 116, 32, 111, 112, 101, 114, 97, 116, 105, 110, 103, 63, 10, 65, 108, 109, 111, 110, 100, 58, 32, 86, 105, 114, 103, 105, 110, 32, 65, 117, 115, 116, 114, 97, 108, 105, 97, 32, 99, 111, 109, 109, 101, 110, 99, 101, 100, 32, 115, 101, 114, 118, 105, 99, 101, 115, 32, 111, 110, 32, 51, 49, 32, 65, 117, 103, 117, 115, 116, 32, 50, 48, 48, 48, 32, 97, 115, 32, 86, 105, 114, 103, 105, 110, 32, 66, 108, 117, 101, 44, 32, 119, 105, 116, 104, 32, 116, 119, 111, 32, 97, 105, 114, 99, 114, 97, 102, 116, 32, 111, 110, 32, 97, 32, 115, 105, 110, 103, 108, 101, 32, 114, 111, 117, 116, 101, 46, 10, 10, 85, 115, 101, 114, 58, 32, 87, 104, 105, 99, 104, 32, 105, 115, 32, 97, 32, 115, 112, 101, 99, 105, 101, 115, 32, 111, 102, 32, 102, 105, 115, 104, 63, 32, 84, 111, 112, 101, 32, 111, 114, 32, 8

you can see, tokens on above need to known merges and vocab size

#### Tokenization BPE (Byte Pair Encoding)

In [10]:
def get_stats(ids):
    '''This function is defined to find the number of words in bigram form.'''
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
    '''This function is defined to combine 2 frequently occurring words (bigram) into new tokens.'''
    newids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

# ------- train -----------
# We will get mapping word (vocab) and result of merges
vocab_size = 768 # size vocab 
num_merges = vocab_size - 256 # default byte 0-255
ids = list(tokens) # duplicate to save real tokens
merges = {}

for i in range(num_merges):
    stats = get_stats(ids)
    pair = max(stats, key=stats.get)
    idx = 256 + i
    print(f'merging {pair} into a new token {idx}')
    ids = merge(ids, pair, idx)
    merges[pair] = idx

merging (101, 32) into a new token 256
merging (115, 32) into a new token 257
merging (116, 104) into a new token 258
merging (105, 110) into a new token 259
merging (101, 114) into a new token 260
merging (97, 110) into a new token 261
merging (100, 32) into a new token 262
merging (116, 32) into a new token 263
merging (111, 110) into a new token 264
merging (44, 32) into a new token 265
merging (258, 256) into a new token 266
merging (97, 114) into a new token 267
merging (111, 114) into a new token 268
merging (101, 110) into a new token 269
merging (121, 32) into a new token 270
merging (97, 108) into a new token 271
merging (111, 32) into a new token 272
merging (111, 117) into a new token 273
merging (116, 105) into a new token 274
merging (58, 32) into a new token 275
merging (97, 32) into a new token 276
merging (259, 103) into a new token 277
merging (114, 101) into a new token 278
merging (102, 32) into a new token 279
merging (46, 32) into a new token 280
merging (261, 262)

In [16]:
# Get vocab dictionary
vocab = {idx: bytes([idx]) for idx in range(256)}
# Merges tree into vocab
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]
print(f"Vocab dictionary succesfully created.")
print(f"Length of vocab: {len(vocab)}")

Vocab dictionary succesfully created.
Length of vocab: 768


In [18]:
def encode(text):
    '''Encode text into list of integers'''
    ids = list(text.encode('utf-8'))
    while len(ids) >= 2:
        stats = get_stats(ids)
        pair = min(stats, key=lambda p: merges.get(p, float('inf')))
        if pair not in merges:
            break
        ids = merge(ids, pair, merges[pair])
    return ids

def decode(ids):
    '''Decode list of integers into string'''
    tokens = b''.join([vocab[i] for i in ids])
    text = tokens.decode('utf-8', errors='replace')
    return text

In [21]:
# Validate check
raw = 'Almond What'
raw2 = decode(encode(raw))
print(raw == raw2)

True
