In [1]:
from ipywidgets import widgets
import tensorflow.data as tf_data
import keras
from keras import layers
import os
import numpy as np
import pathlib

In [2]:
data_path = keras.utils.get_file(
    "news20.tar.gz",
    "http://www.cs.cmu.edu/afs/cs.cmu.edu/project/theo-20/www/data/news20.tar.gz",
    untar=True,
)

In [3]:
data_dir = pathlib.Path(data_path).parent / "20_newsgroup"
dirnames = os.listdir(data_dir)
# print("Number of directories:", len(dirnames))
# print("Directory names:", dirnames)

fnames = os.listdir(data_dir / "comp.graphics")
# print("Number of files in comp.graphics:", len(fnames))
# print("Some example filenames:", fnames[:5])

In [4]:
samples = []
labels = []
class_names = []
class_index = 0
for dirname in sorted(os.listdir(data_dir)):
    class_names.append(dirname)
    dirpath = data_dir / dirname
    fnames = os.listdir(dirpath)
    print("Processing %s, %d files found" % (dirname, len(fnames)))
    for fname in fnames:
        fpath = dirpath / fname
        f = open(fpath, encoding="latin-1")
        content = f.read()
        lines = content.split("\n")
        lines = lines[10:]
        content = "\n".join(lines)
        samples.append(content)
        labels.append(class_index)
    class_index += 1

Processing alt.atheism, 1000 files found
Processing comp.graphics, 1000 files found
Processing comp.os.ms-windows.misc, 1000 files found
Processing comp.sys.ibm.pc.hardware, 1000 files found
Processing comp.sys.mac.hardware, 1000 files found
Processing comp.windows.x, 1000 files found
Processing misc.forsale, 1000 files found
Processing rec.autos, 1000 files found
Processing rec.motorcycles, 1000 files found
Processing rec.sport.baseball, 1000 files found
Processing rec.sport.hockey, 1000 files found
Processing sci.crypt, 1000 files found
Processing sci.electronics, 1000 files found
Processing sci.med, 1000 files found
Processing sci.space, 1000 files found
Processing soc.religion.christian, 997 files found
Processing talk.politics.guns, 1000 files found
Processing talk.politics.mideast, 1000 files found
Processing talk.politics.misc, 1000 files found
Processing talk.religion.misc, 1000 files found


In [5]:
# Shuffle the data
seed = 1337
rng = np.random.RandomState(seed)
rng.shuffle(samples)
rng = np.random.RandomState(seed)
rng.shuffle(labels)

# Extract a training & validation split
validation_split = 0.2
num_validation_samples = int(validation_split * len(samples))
train_samples = samples[:-num_validation_samples]

In [6]:
vectorizer = layers.TextVectorization(max_tokens=20000, output_sequence_length=200)
text_ds = tf_data.Dataset.from_tensor_slices(train_samples).batch(128)
vectorizer.adapt(text_ds)

In [7]:
model = keras.models.load_model("glove-newsgroups")

In [8]:
string_input = keras.Input(shape=(1,), dtype="string")
x = vectorizer(string_input)
preds = model(x)
end_to_end_model = keras.Model(string_input, preds)
# probabilities = end_to_end_model.predict(
#     [["this message is about computer graphics and 3D modeling"]]
# )

In [9]:
output = widgets.Output()

@output.capture()
def predict_article(button):
    probabilities = end_to_end_model.predict([[article_text.value]])
    print(class_names[np.argmax(probabilities[0])])

In [10]:
lbl= widgets.Label("Enter an article")
display(lbl)
article_text=widgets.Text()
display(article_text)
btn=widgets.Button(description="Categorize")
display(btn)


btn.on_click(predict_article)
display(output)

Label(value='Enter an article')

Text(value='')

Button(description='Categorize', style=ButtonStyle())

Output()