<a href="https://colab.research.google.com/github/BragdonD/text-classification-NLP-tf/blob/main/Multiclass_text_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Import the depencies

In [None]:
%pip install tensorflow_addons
%pip install tensorflow-model-analysis

In [51]:
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Embedding, LSTM, SpatialDropout1D, Dropout, Bidirectional, GRU, Conv1D, GlobalMaxPooling1D,BatchNormalization
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.data import Dataset
from tensorflow.keras import regularizers
#import tensorflow_addons as tfa
#import tensorflow_model_analysis as tfma
import tensorflow as tf
from keras.utils.np_utils import to_categorical
from keras import backend as K
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, LabelBinarizer
from sklearn.utils import class_weight

import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import re
import nltk

nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')
nltk.download('averaged_perceptron_tagger')
nltk.download('words')

from nltk.corpus import stopwords
from nltk import word_tokenize
from nltk.stem import WordNetLemmatizer
from nltk.corpus import wordnet

STOPWORDS = set(stopwords.words('english'))

from bs4 import BeautifulSoup
from collections import Counter

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package words to /root/nltk_data...
[nltk_data]   Package words is already up-to-date!


Mouting the drive and defining the path toward the dataset file

In [5]:
from google.colab import drive
import json

drive.mount("/content/drive/")

JSON_FILE_PATH = "drive/MyDrive/NLP/"
JSON_FILE_NAME = "News_Category_Dataset_v3_balanced.json";

Mounted at /content/drive/


Extract the data from the Json file and turn it into a panda DataFrame

In [6]:
def extractJsonData(jsonData):
  return pd.read_json(jsonData, lines=True);

In [7]:
jsonFile = open(JSON_FILE_PATH + JSON_FILE_NAME);
df = extractJsonData(jsonFile);

In [8]:
df.head()

Unnamed: 0,link,headline,category,short_description,authors,date
0,https://www.huffpost.com/entry/covid-boosters-...,Over 4 Million Americans Roll Up Sleeves For O...,U.S. NEWS,Health experts said it is too early to predict...,"Carla K. Johnson, AP",2022-09-23
1,https://www.huffpost.com/entry/american-airlin...,"American Airlines Flyer Charged, Banned For Li...",U.S. NEWS,He was subdued by passengers and crew when he ...,Mary Papenfuss,2022-09-23
2,https://www.huffpost.com/entry/funniest-tweets...,23 Of The Funniest Tweets About Cats And Dogs ...,COMEDY,"""Until you have a dog you don't understand wha...",Elyse Wanshel,2022-09-23
3,https://www.huffpost.com/entry/funniest-parent...,The Funniest Tweets From Parents This Week (Se...,PARENTING,"""Accidentally put grown-up toothpaste on my to...",Caroline Bologna,2022-09-23
4,https://www.huffpost.com/entry/amy-cooper-lose...,Woman Who Called Cops On Black Bird-Watcher Lo...,U.S. NEWS,Amy Cooper accused investment firm Franklin Te...,Nina Golgowski,2022-09-22


In [9]:
df = df.drop('link', axis=1)
df = df.drop('authors', axis=1)
df = df.drop('date', axis=1)
df = df.drop('headline', axis=1)

Text Preprocessing

In [10]:
def lower_text(dataframe):
  return dataframe["short_description"].str.lower()

In [11]:
import string
punctuation = string.punctuation
def remove_punctuation(text):
  return text.translate(str.maketrans('', '', punctuation))

def remove_punctuation_text(dataframe):
  return dataframe["short_description_pre_process"].apply(lambda text: remove_punctuation(text));

In [12]:
def remove_stopwords(text):
  return " ".join([word for word in str(text).split() if word not in STOPWORDS])

def remove_stopwords_text(dataframe):
  return dataframe["short_description_pre_process"].apply(lambda text: remove_stopwords(text));

In [13]:
cnt = Counter()
def count_word_occurence_df(dataframe):
  for text in dataframe["short_description_pre_process"].values:
    for word in text.split():
      cnt[word] += 1;
  return cnt;

def remove_frequentWord(text, freqwords):
  return " ".join([word for word in str(text).split() if word not in freqwords]);

def remove_frequent_word_text(dataframe):
  freqwords = set([w for (w, wc) in cnt.most_common(20)])
  return dataframe["short_description_pre_process"].apply(lambda text: remove_frequentWord(text,freqwords))



In [14]:
n_rare_words = 0
def count_rare_word_df(dataframe):
  rare_words = 0
  for (w, wc) in cnt.most_common():
    if wc == 1:
      rare_words += 1;
  return rare_words;

def remove_rareWord(text, rare_words):
  return " ".join([word for word in str(text).split() if word not in rare_words]);

def remove_rare_word_text(dataframe):
  rare_words = set([w for (w, wc) in cnt.most_common()[:-n_rare_words-1:-1]])
  return dataframe["short_description_pre_process"].apply(lambda text: remove_rareWord(text, rare_words))

In [15]:
def lemmatize(dataframe):
  lemmatizer = WordNetLemmatizer()
  wordnet_map = {"N":wordnet.NOUN, "V":wordnet.VERB, "J":wordnet.ADJ, "R":wordnet.ADV}

  def lemmatize_words(text):
    pos_tagged_text = nltk.pos_tag(text.split())
    return " ".join([lemmatizer.lemmatize(word, wordnet_map.get(pos[0], wordnet.NOUN)) for word, pos in pos_tagged_text])
  
  return dataframe["short_description_pre_process"].apply(lambda text: lemmatize_words(text))

In [16]:
def remove_emoji(string):
  emoji_pattern = re.compile("["
    u"\U0001F600-\U0001F64F"  # emoticons
    u"\U0001F300-\U0001F5FF"  # symbols & pictographs
    u"\U0001F680-\U0001F6FF"  # transport & map symbols
    u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
    u"\U00002702-\U000027B0"
    u"\U000024C2-\U0001F251"
    "]+", flags=re.UNICODE)
  return emoji_pattern.sub(r'', string)

def remove_emoji_text(dataframe):
  return dataframe["short_description_pre_process"].apply(lambda text: remove_emoji(text))

In [17]:
REPLACE_BY_SPACE_RE = re.compile('[/(){}\[\]\|@,;]')
BAD_SYMBOLS_RE = re.compile('[^0-9a-z #+_]')
def symbols_cleaning_text(dataframe):
  return dataframe["short_description_pre_process"].apply(lambda text: REPLACE_BY_SPACE_RE.sub(' ', text))

def numbers_cleaning_text(dataframe):
  return dataframe["short_description_pre_process"].apply(lambda text: BAD_SYMBOLS_RE.sub('', text))

In [18]:
def preprocess_text(dataframe):
  dataframe["short_description_pre_process"] = lower_text(dataframe)
  dataframe["short_description_pre_process"] = remove_punctuation_text(dataframe)
  dataframe["short_description_pre_process"] = remove_stopwords_text(dataframe)
  dataframe["short_description_pre_process"] = remove_frequent_word_text(dataframe)
  dataframe["short_description_pre_process"] = remove_rare_word_text(dataframe)
  dataframe["short_description_pre_process"] = remove_emoji_text(dataframe)
  dataframe["short_description_pre_process"] = symbols_cleaning_text(dataframe)
  dataframe["short_description_pre_process"] = numbers_cleaning_text(dataframe)
  #dataframe["short_description_pre_process"] = lemmatize(dataframe)
  return dataframe["short_description_pre_process"]

df["short_description_pre_process"] = preprocess_text(df)

In [19]:
df = df.drop(df[df.short_description_pre_process == ""].index)

In [20]:
df["short_description_pre_process"].apply(lambda x: len(x.split(" "))).describe()


count    189492.000000
mean         12.168894
std           6.955057
min           1.000000
25%           7.000000
50%          11.000000
75%          15.000000
max         140.000000
Name: short_description_pre_process, dtype: float64

In [21]:
df = df.drop('short_description', axis=1)

In [22]:
class_weights = class_weight.compute_class_weight(class_weight = 'balanced',
                                                  classes = np.unique(df.category.values),
                                                  y = df.category.values)

In [23]:
class_weights = dict(zip(np.unique([0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26]), class_weights))
class_weights

{0: 2.1508495930806686,
 1: 1.0196458262708443,
 2: 1.5177816224529028,
 3: 2.479061187644727,
 4: 2.0485178698838946,
 5: 3.854048447129172,
 6: 0.4771055215650729,
 7: 2.01210499490316,
 8: 0.8495608548870865,
 9: 0.65116183171481,
 10: 1.6264709669112913,
 11: 2.274213293007849,
 12: 2.9266981744045966,
 13: 3.3839065680917177,
 14: 0.5689220348753423,
 15: 0.21683934444238467,
 16: 3.7410566216536365,
 17: 1.7995441595441595,
 18: 1.5946880759423363,
 19: 0.6192184773444699,
 20: 0.7452715538093047,
 21: 5.096748164286291,
 22: 1.9212215226450102,
 23: 3.042142272311323,
 24: 0.30253565920433756,
 25: 2.2139502278303542,
 26: 0.8559851472401784}

In [24]:
encoder = LabelBinarizer()
df_cat_enc = encoder.fit_transform(df.category.values);
df["category"] = df_cat_enc.tolist()

In [25]:
# df = df.drop('new_category', axis=1)

In [26]:
df.head()

Unnamed: 0,category,short_description_pre_process
0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",health experts said early predict whether dema...
1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",subdued passengers crew fled back aircraft con...
2,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",dog dont understand could eaten
3,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...",accidentally put grownup toothpaste toddlers t...
4,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",amy cooper accused investment firm franklin te...


In [27]:
train, test = train_test_split(df, test_size=0.2, shuffle=True, stratify=df.category)
train, validation = train_test_split(train, test_size=0.4, shuffle=True, stratify=train.category)

In [28]:
print(train["category"].value_counts(normalize=True))

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    0.170810
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]    0.122423
[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    0.077632
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    0.065098
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]    0.059810
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    0.056874
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]    0.049695
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    0.043593
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]    0.043263
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    0.036326
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [29]:
print(test["category"].value_counts(normalize=True))

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    0.170796
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]    0.122431
[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    0.077627
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    0.065094
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]    0.059817
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    0.056888
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]    0.049685
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    0.043590
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]    0.043273
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    0.036333
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [30]:
print(validation["category"].value_counts(normalize=True))

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    0.170800
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]    0.122415
[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    0.077625
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    0.065108
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0]    0.059814
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    0.056879
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]    0.049705
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    0.043603
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]    0.043273
[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]    0.036314
[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [31]:
train_ds = Dataset.from_tensor_slices((train.short_description_pre_process, train.category.tolist()));
validation_ds = Dataset.from_tensor_slices((validation.short_description_pre_process, validation.category.tolist()));
test_ds = Dataset.from_tensor_slices((test.short_description_pre_process, test.category.tolist()));

In [32]:
BATCH_SIZE = 32

train_ds = train_ds.shuffle(100000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
validation_ds = validation_ds.shuffle(70000).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
test_ds = test_ds.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)

In [33]:
encoder.classes_

array(['ARTS & CULTURE', 'BUSINESS & FINANCES', 'COMEDY', 'CRIME',
       'DIVORCE', 'EDUCATION', 'ENTERTAINMENT', 'ENVIRONMENT',
       'FOOD & DRINK', 'GROUPS VOICES', 'HOME & LIVING', 'IMPACT',
       'MEDIA', 'MISCELLANEOUS', 'PARENTING', 'POLITICS', 'RELIGION',
       'SCIENCE & TECH', 'SPORTS', 'STYLE & BEAUTY', 'TRAVEL',
       'U.S. NEWS', 'WEDDINGS', 'WEIRD NEWS', 'WELLNESS', 'WOMEN',
       'WORLD NEWS'], dtype='<U19')

In [34]:
def invert_multi_hot(encoded_labels):
    """Reverse a single multi-hot encoded label to a tuple of vocab terms."""
    hot_indices = [i for i, elem in enumerate(encoded_labels) if elem == 1]
    return np.take(encoder.classes_, hot_indices)[0]

In [35]:
for example, label in train_ds.take(1):
  print('text: ', example.numpy()[0])
  print('label: ', invert_multi_hot(label.numpy()[0]))


text:  b'maybe shouldnt blame woman baby broken hazed rebuilt form mother thinskinned sometimes sanctimonious desperately insecure'
label:  PARENTING


In [36]:
for example, label in train_ds.take(1):
  print('text: ', example.numpy())
  print('label: ', label.numpy())

text:  [b'needs stop ben jacobs said sentencing greg gianforte'
 b'modern life utterly dependent electricity supply becomes erratic read mother nature'
 b'one fundamental arguments knocked'
 b'this year least 46 transgender individuals country  hundreds around world  killed horrifying acts violence biden said'
 b'    '
 b'since birds grow slowly move around meat isnt soft finegrained one reasons dry white'
 b'infectious beats rhythms key good exercise music help forget exercising get caught vibe often lyrics excellent distraction'
 b'cliven bundy controversial nevada rancher center armed standoff federal officials 2014 arrested'
 b'year dont beat dont get around sending holiday card instead hug children tight focus less unearned trophies missed soccer goals family squabbles heres boastfulfree holiday'
 b'like read sign huffpost hill get cheeky dose political news every evening trump transition'
 b'triggers incredible ally triggers lead needs healing instead ashamed triggered get excite

In [37]:
MAX_WORDS = 20000
MAX_LEN = 11 # 50% of the dataset have an average length of 11 words
vector = tf.keras.layers.TextVectorization(
    max_tokens=MAX_WORDS,
    standardize=None,
    split="whitespace",
    ngrams=2
  )
vector.adapt(train_ds.map(lambda text, label: text))

In [41]:
embeddings_index = dict()
path_to_glove_embed = "drive/MyDrive/NLP/glove.6B.100d.txt"
def get_embeddings_from_glove(glove_path):
  embeddings_index = dict();
  for line in open(glove_path, encoding="utf8"):
    values = line.split()
    word = values[0]
    coefs = np.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs
  return embeddings_index;

embeddings_index = get_embeddings_from_glove(path_to_glove_embed);

In [48]:
def get_embedding_matrix():
  embedding_matrix = np.zeros((len(vector.get_vocabulary()), 100))
  for word, i in zip(vector.get_vocabulary(),range(len(vector.get_vocabulary()))):
      embedding_vector = embeddings_index.get(word)
      if embedding_vector is not None:
          embedding_matrix[i] = embedding_vector
  return embedding_matrix

embedding_matrix = get_embedding_matrix()

In [49]:
vocab = np.array(vector.get_vocabulary())
print(len(vocab))
vocab[:20]

20000


array(['', '[UNK]', 'one', 'new', 'us', 'people', 'time', 'like', 'said',
       'get', 'day', 'life', 'would', 'many', 'make', 'dont', 'years',
       'first', 'want', 'know'], dtype='<U35')

In [50]:
# sample_weights = class_weight.compute_sample_weight( 'balanced', Y_train )

Splitting the dataset between train, validation and test

In [162]:
# model = Sequential([
#     Dense(524, activation='relu'),
#     Dense(262, activation='relu'),
#     Dense(Y.shape[1], activation='sigmoid')
# ])

In [163]:
# model = Sequential([
#     vector,
#     Embedding(len(vocab), EMBEDDING_DIM, mask_zero=True),
#     GRU(EMBEDDING_DIM, return_sequences=True),
#     GRU(int(EMBEDDING_DIM/2)),
#     Dense(int(EMBEDDING_DIM/2), activation='relu'),
#     Dropout(0.3),
#     Dense(len(encoder.classes_), activation='sigmoid')
# ])


In [164]:
# model = Sequential([
#     vector,
#     Embedding(MAX_NB_WORDS, EMBEDDING_DIM, mask_zero=True),
#     Conv1D(128, 5, activation='relu'),
#     GlobalMaxPooling1D(),
#     Dense(28, activation='relu'),
#     Dropout(0.2),
#     Dense(len(encoder.classes_))
# ])


In [57]:
model = Sequential([
    vector,
    Embedding( 
        input_dim=len(vocab), 
        output_dim=100,
        weights=[embedding_matrix],
        input_length=MAX_LEN,
        trainable=False
    ),  
    SpatialDropout1D(0.2),
    GRU(64, dropout=0.2, return_sequences=True),  
    BatchNormalization(),
    GRU(64, dropout=0.2),
    BatchNormalization(),
    Dense(len(encoder.classes_), activation="sigmoid")
])


In [58]:
print([layer.supports_masking for layer in model.layers])


[False, False, True, True, True, True, True, True]


In [59]:
sample_text = ('The movie was cool. The animation and the graphics '
               'were out of this world. I would recommend this movie.')
predictions = model.predict(np.array([sample_text]))
print(predictions[0])

[0.5000374  0.499967   0.49995926 0.49996588 0.50003153 0.499988
 0.50004715 0.5000145  0.49997455 0.5000109  0.4999949  0.5000078
 0.50006384 0.5000598  0.50004166 0.5000127  0.49991375 0.50006753
 0.50008434 0.499963   0.49997386 0.5000147  0.4999967  0.50000495
 0.49997422 0.5000115  0.500022  ]


In [60]:
padding = "the " * 2000
predictions = model.predict(np.array([sample_text, padding]))
print(predictions[0])

[0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5
 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5 0.5]


In [64]:
model.compile(  loss='binary_crossentropy', 
                optimizer='adam', 
                weighted_metrics=['accuracy'])

In [65]:
model.summary()

Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 text_vectorization (TextVec  (None, None)             0         
 torization)                                                     
                                                                 
 embedding_3 (Embedding)     (None, None, 100)         2000000   
                                                                 
 spatial_dropout1d_3 (Spatia  (None, None, 100)        0         
 lDropout1D)                                                     
                                                                 
 gru_6 (GRU)                 (None, None, 64)          31872     
                                                                 
 batch_normalization_6 (Batc  (None, None, 64)         256       
 hNormalization)                                                 
                                                      

In [None]:
epochs = 20

history = model.fit(
    train_ds, 
    epochs=epochs,
    validation_data=validation_ds,
    class_weight=class_weights,
    callbacks=[
        EarlyStopping(monitor='val_loss', patience=3, min_delta=0.0001)
    ]
)

In [89]:
accr = model.evaluate(test_ds)
print('Test set\n  Loss: {:0.3f}\n  Accuracy: {:0.3f}'.format(accr[0],accr[1]))

Test set
  Loss: 0.109
  Accuracy: 0.448


In [None]:
plt.title('Loss')
plt.plot(history.history['loss'], label='train')
plt.plot(history.history['val_loss'], label='test')
plt.legend()
plt.show();

In [None]:
plt.title('Accuracy')
plt.plot(history.history['categorical_accuracy'], label='train')
plt.plot(history.history['val_categorical_accuracy'], label='test')
plt.legend()
plt.show();

In [None]:
checkpoint_path = "drive/MyDrive/NLP/training_2/cp-{epoch:04d}.ckpt"
model.save_weights(checkpoint_path.format(epoch=0))