In [298]:
from keras.preprocessing.sequence import pad_sequences
from keras.models import load_model
from keras.datasets import imdb
import gradio as gr
import numpy as np

In [299]:
word_to_index = imdb.get_word_index()
index_to_word = {index: word for word, index in word_to_index.items()}
model = load_model("../lecture11/imdb_model.h5")

In [305]:
(x_train, y_train), (x_test, y_test) = imdb.load_data()

In [300]:
def encode_text(text: str, start_index: int = 3):
    encoded = [1]
    for word in text.split():
        if word in word_to_index:
            encoded.append(word_to_index[word.lower()] + start_index)
        else:
            encoded.append(2)

    return encoded

def pad_text(encoded_text: list, pad_length: int = 256):
    padded = pad_sequences([encoded_text], maxlen=pad_length, value=0, padding="post")

    return padded

def preprocess_text(text: str):
    encoded = encode_text(text)
    padded = pad_text(encoded)
    padded = np.array(padded)

    return padded

In [301]:

def predict(input_text: str):
    preprocessed_text = preprocess_text(input_text)
    print(preprocessed_text)
    prediction = model.predict(preprocessed_text)[0][0]

    return prediction

def parse_prediction(prediction: float, threshold: float = 0.5):
    return "Positive" if prediction > threshold else "Negative"

def classify(input_text: str):
    prediction = predict(input_text)

    return parse_prediction(prediction)

In [302]:
def start_ui():
    ui = gr.Interface(fn=classify, inputs="text", outputs="text")
    ui.launch(show_api=False, share=True)

In [304]:
start_ui()

Running on local URL:  http://127.0.0.1:7873
Running on public URL: https://92a7a345f2a34d0a96.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)
