In [None]:
import time
from IPython.display import display, Latex, clear_output
import tensorflow as tf
from tensorflow.keras import Sequential, losses, optimizers, layers, Model, mixed_precision
from tensorflow.keras.layers import Layer
import numpy as np
import tiktoken

In [None]:
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)
tokenizer = tiktoken.get_encoding(Config.tokenizer)

In [None]:
class Embedding(Layer):
    def __init__(self, config):
        super().__init__()
        self.embed = layers.Embedding(config.vocab_size, config.d_model)
    
    def call(self, inputs):
        return self.embed(inputs)
    
def Get_Position(context_length, d_model, n=10000):
    P = np.zeros((context_length ,d_model))
    for k in range(context_length):
        for i in np.arange(int(d_model/2)):
            denominator = np.power(n, 2*i/d_model)
            P[k, 2*i] = np.sin(k/denominator)
            P[k, 2*i + 1] = np.cos(k/denominator)
    
    return P

class Blocks(Layer):
    def __init__(self, config):
        super().__init__()
        self.mha = layers.MultiHeadAttention(num_heads=config.num_heads,
                                             key_dim = config.d_model)
        self.ffn = Sequential([
            layers.Dense(config.hidden_unit, activation = 'gelu'),
            layers.Dropout(0.1),
            layers.Dense(config.d_model),
            layers.Dropout(0.1)
        ])
        self.layernorm = layers.LayerNormalization()
        self.add = layers.Add()


    def call(self, inputs):

        ##Multi-head Attention
        attention_output = self.mha(
            query = inputs,
            key = inputs,
            value = inputs,
            use_causal_mask = True,
        )
        x = self.add([inputs, attention_output])
        inputs2 = self.layernorm(x)

        ##Feed Forward
        x = self.ffn(inputs2)
        x = self.add([x, inputs2])
        return self.layernorm(x)
    
class Linear(Layer):
    def __init__(self, config):
        super().__init__()
        self.linear = layers.Dense(config.vocab_size)
  
    def call(self, inputs):
        return self.linear(inputs)
    
def create_model(config = Config()):   

    return Sequential([

        tf.keras.Input(shape=(None,)),

        Embedding(config),

        Sequential([
            Blocks(config)
                for _ in range(config.block_count)
            ]),

        Linear(config),
    ])

In [None]:
model = create_model()
model_path = Config.model_path
model.load_weights(model_path)
model.summary()

In [None]:
def fix_prompt(prompt):
    '''
    A prompt must end with a dot.
    A prompt should start with upper letter
    '''
    prompt = prompt.capitalize()

    list = ['.', '?', '!']
    if (prompt[-1] not in list):
        prompt = prompt+'.'

    return prompt

def init_inputs(human):

    encoding = tokenizer.encode(human)
    current_length = len(encoding)
    try:
        assert (current_length < Config.context_length)
        gen_array = []
        
        for tokenIdx in range(current_length):
            gen_array.append(encoding[tokenIdx])
    
        return gen_array, current_length

    except AssertionError:
        print(f"({current_length})")
        print(" >> Exceeds limit, chat ended.")
        return None, None


def make_predict(x):
    '''
    Here we want to make different prediction every time
    '''
    _, cl = x.shape
    
    next = model.layers[0](x) + Get_Position(context_length=cl, 
                                             d_model=Config.d_model)

    attention = model.layers[1](next)

    logits = model.layers[-1](attention[:, -1, :])[0]

    choice = np.argmax(logits)

    return choice


def check_prediction(gen_array, prediction):

    legit = True
    gen_array.append(prediction)
    fed_in_array = np.array(gen_array).reshape(1, len(gen_array))
    nxt_token = make_predict(fed_in_array)

    if (nxt_token == 25):
        legit = False
        gen_array.pop()
    
    return legit


def clear_and_reload(prompt, content):
    clear_output()
    display(Latex(f"\n*** Your Question:\n\n{prompt}\n"))
    display(Latex(content))
    

########################################################################

def generate(fix=False, reload_scaler=3):

    prompt = str(input("\n*** Your Question:\n\n"))

    if (fix):
        prompt = fix_prompt(prompt)

    #This is the actual part given to the model
    generated = f"\n\nHuman: {prompt}\n\nAssistant: "

    gen_array, current_length = init_inputs(generated)
    
    if gen_array == None:
        return

    ##Pre-process done, now generating ...
    print('\n ***Generated Answer:\n')

    #here I reuse prompt so it will be printed as LaTex in cell
    reload = f"*** Generated Answer:\n\n"

    while (current_length < Config.context_length):
        
        fed_in_array = np.array(gen_array).reshape(1, current_length)


        prediction = make_predict(fed_in_array)

        if (prediction == 50256):
            #endoftext token
            break

        if (prediction in [20490, 48902]):
            try:
                legit = check_prediction(gen_array, prediction)
                assert legit == True
            except AssertionError:
                break

        else:
            gen_array.append(prediction)
        
        nxt_token_ = tokenizer.decode([prediction])
        '''
        generated is for model to generate
        reload is the contents we will see
        '''
        generated += nxt_token_
        reload += nxt_token_
        current_length += 1

        if (current_length % reload_scaler == 0):
            clear_and_reload(prompt, reload)

    clear_and_reload(prompt, reload)

In [None]:
generate(reload_scaler=5)