In [None]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, Dropout
from sklearn.model_selection import train_test_split

In [3]:
import torch
from transformers import DistilBertTokenizer, DistilBertModel
import re
from bs4 import BeautifulSoup

In [5]:
#load tokenizer and pretrained model
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
bert_model = DistilBertModel.from_pretrained("distilbert-base-uncased")

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

In [18]:
movie_reviews = pd.read_csv('Data/IMDB Dataset.csv')
movie_reviews.head()

Unnamed: 0,review,sentiment
0,One of the other reviewers has mentioned that ...,positive
1,A wonderful little production. <br /><br />The...,positive
2,I thought this was a wonderful way to spend ti...,positive
3,Basically there's a family where a little boy ...,negative
4,"Petter Mattei's ""Love in the Time of Money"" is...",positive


In [19]:
movie_reviews.describe()

Unnamed: 0,review,sentiment
count,50000,50000
unique,49582,2
top,Loved today's show!!! It was a variety and not...,positive
freq,5,25000


In [20]:
def text_noise_remover(text):
    """
    remove noise in the text

    Args:
        text (str): text

    Returns:
        str: cleaned text from links
    """
    soup = BeautifulSoup(text, "html.parser")
    text = soup.get_text()
    text = re.sub("\[[^]]*\]", "", text)
    return text

#applying noise remover function on review column
movie_reviews["review"] = movie_reviews["review"].apply(text_noise_remover)

  soup = BeautifulSoup(text, "html.parser")


In [23]:
# Remove long reviews
def number_text_token(text):
  """
    compute the number of tokens in a text
  Args:
      text (str): input text

  Returns:
      int: length of the text
  """
  return len(text.split())

movie_reviews = movie_reviews[movie_reviews['review'].apply(number_text_token)<100]

In [25]:
# encode the target column into 0 or 1
movie_reviews['sentiment'] = movie_reviews['sentiment'].apply(lambda x: 1 if x=='positive' else 0)

In [26]:
def text_embedding(text):
  """
  create text embeddings

  Args:
      text (str): input text

  Returns:
      numpy array: numpy array of 768 dimension
  """
  # tokenize the text
  inputs = tokenizer(text, return_tensors="pt")
  # encode the tokenazition
  with torch.no_grad():
      encoded_layers = bert_model(**inputs)
  # Remove dimension 1, the "batches".
  token_embeddings = torch.squeeze(encoded_layers.last_hidden_state, dim=0)
  # Calculate the average of all token vectors.
  sentence_embedding = torch.mean(token_embeddings, dim=0)
  return sentence_embedding.numpy()


# create an embedding column for the texts
movie_reviews['embeddings'] = movie_reviews['review'].apply(text_embedding)

In [30]:
# compute the number of labels
num_labels = len(np.unique(movie_reviews['sentiment']))
# network parameters
batch_size = 128
hidden_units = 256
dropout = 0.45
input_size = 768

In [47]:
# split the data into train and test
X_train, X_test, y_train, y_test = train_test_split(movie_reviews['embeddings'], movie_reviews['sentiment'], test_size=0.2, random_state=42)
X_train = np.array([np.array(val) for val in X_train])
X_test = np.array([np.array(val) for val in X_test])
y_train = np.array([np.array(val) for val in y_train])
y_test = np.array([np.array(val) for val in y_test])

In [48]:
# model is a 3-layer MLP with ReLU and dropout after each layer
model = Sequential()
model.add(Dense(hidden_units, input_dim=input_size))
model.add(Activation('relu'))
model.add(Dropout(dropout))
model.add(Dense(hidden_units))
model.add(Activation('relu'))
model.add(Dropout(dropout))
model.add(Dense(1, activation='sigmoid'))

In [49]:
# use of adam optimizer
# accuracy is good metric for classification tasks
model.compile(loss='binary_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])
# train the network
model.fit(X_train, y_train, epochs=10, batch_size=batch_size)


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.src.callbacks.History at 0x7b12bd9001c0>

In [50]:
# validate the model on test dataset to determine generalization
_, acc = model.evaluate(X_test,
                        y_test,
                        batch_size=batch_size,
                        verbose=0)
print("\nTest accuracy: %.1f%%" % (100.0 * acc))



Test accuracy: 88.4%


In [51]:
# save the model
model.save('tensorflow_model')

In [52]:
# load the model 
loaded_model = tf.keras.models.load_model('tensorflow_model')

In [55]:
# test example
sample = [[ 5.20340651e-02,  2.36338988e-01,  1.78531513e-01,  3.71464007e-02,
        3.22910398e-01, -3.91929358e-01,  2.98932623e-02,  3.02443326e-01,
        5.20332158e-02, -4.74965782e-04,  8.45000893e-02, -2.96129942e-01,
       -1.15338206e-01,  3.29076529e-01,  2.10027490e-02,  6.30776167e-01,
       -2.90733632e-02, -3.05735052e-01, -2.46759411e-03,  2.03132167e-01,
        3.30812305e-01, -6.62997067e-02, -6.23699985e-02,  5.33742487e-01,
        1.62404373e-01,  2.03644469e-01, -2.35587910e-01,  4.11016494e-02,
       -2.88673341e-01, -4.83975038e-02,  4.99878347e-01, -3.56269516e-02,
       -1.71774074e-01, -3.71305674e-01, -2.25286573e-01,  1.34844864e-02,
       -2.13677343e-02,  1.26411011e-02,  1.92190241e-02, -3.38497981e-02,
       -3.49103570e-01, -1.32836685e-01,  2.25351170e-01,  4.38235588e-02,
       -3.42169464e-01, -2.06382707e-01,  2.41818875e-01,  6.42225668e-02,
        5.91336526e-02, -1.02593780e-01, -1.08172618e-01,  1.55878097e-01,
       -7.45715499e-02, -2.00651348e-01,  2.78436206e-02,  4.01206493e-01,
        7.94268399e-03, -3.31592232e-01, -3.97013038e-01,  6.68925117e-04,
       -2.83438861e-02,  5.76282367e-02,  7.72598162e-02, -5.12658536e-01,
        1.86074108e-01,  1.96199492e-01, -8.84490609e-02,  5.29277146e-01,
       -2.80148447e-01, -1.83779076e-02, -2.64606029e-01, -4.30156551e-02,
       -9.09992903e-02, -1.32018417e-01, -2.51941293e-01, -1.72699898e-01,
        2.07111180e-01,  8.01157355e-02,  9.01978984e-02,  1.01144299e-01,
        5.72853722e-02,  2.49400690e-01, -1.02699652e-01,  3.22199881e-01,
        5.59882596e-02,  4.23982777e-02, -3.64157408e-02,  1.47372380e-01,
       -4.86633658e-01,  5.60137928e-01, -2.05407590e-01, -3.78641486e-01,
        1.94350109e-01,  4.01857533e-02,  1.03131540e-01, -2.09923655e-01,
        7.93311521e-02, -5.36147095e-02, -1.11851968e-01,  2.34935045e-01,
        1.04730085e-01, -4.85622078e-01, -4.06456739e-02,  1.75663203e-01,
       -1.34285644e-01,  1.94162458e-01,  4.38677937e-01, -1.10022938e-02,
       -2.06238493e-01,  2.65419096e-01,  3.81478965e-01, -7.94875398e-02,
       -1.68639854e-01, -2.50249058e-01,  1.61789898e-02, -8.18261206e-02,
       -4.25786301e-02, -1.76802561e-01, -2.87414566e-02,  1.05392896e-01,
        8.54423852e-04, -2.21329033e-01,  1.13644764e-01,  6.52670562e-01,
       -1.46019951e-01,  1.69148520e-02, -1.16397545e-01,  2.82865614e-01,
        2.48409212e-02, -2.71206170e-01,  4.12739441e-02,  3.79798949e-01,
        2.80598909e-01, -4.02462929e-01, -1.68586388e-01,  1.44552141e-01,
        1.65173858e-01, -1.00494526e-01, -3.60992491e-01,  2.39878520e-01,
        1.89764172e-01, -4.05451143e-03,  1.36343211e-01,  1.18460216e-01,
        1.05589405e-01,  1.40838055e-02, -8.74524713e-02,  1.22762889e-01,
        1.66362897e-01,  2.23484766e-02,  2.06903413e-01,  9.09998044e-02,
       -2.14673787e-01, -3.92653942e-01, -1.41439870e-01, -2.20272373e-02,
       -3.18846315e-01, -7.22564682e-02,  1.46475300e-01, -2.31267378e-01,
        1.63046941e-01, -1.19746044e-01, -4.27318215e-02,  2.80332208e-01,
       -6.21915869e-02, -6.35261685e-02,  7.64021068e-04,  2.00260028e-01,
       -3.51641506e-01,  1.80350363e-01, -1.09108992e-01,  8.63686204e-02,
        3.47230434e-01,  1.64044425e-01, -1.26849473e-01,  1.75439678e-02,
        4.52344328e-01,  9.13501531e-02,  1.85553759e-01,  1.40101120e-01,
       -7.36575186e-01,  1.90364867e-01,  1.21073574e-01, -2.10696563e-01,
        1.89206496e-01, -3.58273268e-01,  2.90292919e-01, -7.79214874e-02,
        1.40401170e-01, -1.62011728e-01,  7.95739368e-02, -9.49750617e-02,
       -1.52936056e-01, -2.69863814e-01, -4.19851281e-02, -1.20147215e-02,
       -1.65208820e-02, -5.51631972e-02, -7.82638639e-02, -7.89317489e-02,
        7.20919743e-02,  2.67958734e-02, -3.00837066e-02, -4.01755236e-03,
       -2.27427080e-01, -2.20951632e-01,  2.29943171e-02, -2.54333228e-01,
       -1.60772458e-01, -6.72870353e-02, -2.11541206e-01,  2.47903079e-01,
       -9.94343907e-02, -2.74340082e-02, -1.94017544e-01,  1.49756089e-01,
        2.86101527e-03, -1.13621049e-01, -1.04791246e-01, -8.23615715e-02,
        5.50294966e-02,  2.81730920e-01, -3.09207320e-01,  1.57492191e-01,
       -1.52109101e-01,  7.86923587e-01, -3.00865527e-02, -5.20248652e-01,
        2.16398150e-01,  4.27326471e-01,  1.11449845e-01, -2.86397457e-01,
        3.99279982e-01, -2.41829455e-01, -1.74326748e-01,  1.29643932e-01,
       -3.39393556e-01, -1.09447494e-01,  2.32514247e-01, -2.47932345e-01,
       -1.82949603e-01,  1.86847135e-01,  1.61987022e-01,  9.23054889e-02,
        5.44477738e-02, -2.06508934e-01, -1.22943342e-01,  3.29603478e-02,
       -2.57624984e-01, -1.54787362e-01, -5.51480591e-01, -2.69271106e-01,
        5.01275584e-02, -4.36392903e-01, -5.95956892e-02, -9.28212926e-02,
       -2.20229596e-01, -4.03322726e-01,  4.11339626e-02,  1.23960413e-01,
        3.72384280e-01, -3.59070897e-02,  1.01030417e-01, -2.97426619e-02,
       -2.95806110e-01, -4.42186177e-01,  7.56530687e-02,  2.05590412e-01,
        4.00734752e-01,  2.68297821e-01, -1.43811241e-01,  3.41112167e-01,
        2.63209879e-01,  7.89335370e-01, -7.45306909e-02, -1.86214581e-01,
        9.76575166e-02,  1.61164775e-02,  2.14398801e-01, -7.94535726e-02,
        2.41461784e-01,  4.95845616e-01, -4.72704738e-01, -1.40755221e-01,
       -5.38746595e-01, -9.52907745e-03,  1.56719714e-01,  7.00369924e-02,
       -1.26326755e-01, -2.43016586e-01, -1.19045489e-01, -4.21539210e-02,
       -4.07776117e-01, -8.45694691e-02,  2.87962675e-01, -1.58244684e-01,
        2.39949137e-01, -5.24349660e-02, -3.84513110e-01, -1.88281685e-01,
        5.54865934e-02, -8.17316771e-02,  6.00040704e-02,  2.28392836e-02,
       -6.17721491e-02,  1.89939469e-01,  1.77468620e-02, -5.66343129e-01,
       -4.13943768e+00,  1.09585658e-01,  1.42820641e-01, -1.70431003e-01,
        2.66199917e-01, -1.86178591e-02, -8.25976804e-02, -3.09305012e-01,
       -4.22797263e-01,  2.30223686e-01, -2.47025311e-01, -1.97055638e-01,
        1.57703638e-01,  2.57100463e-01,  3.06500018e-01, -3.84250097e-02,
       -2.05150187e-01, -2.05634654e-01, -1.37036473e-01,  4.96277750e-01,
       -1.16853066e-01, -5.06655633e-01,  3.49799305e-01, -1.85154885e-01,
        1.23907052e-01,  4.14550602e-01, -4.18117166e-01, -1.80098843e-02,
       -1.88915595e-01, -1.59217089e-01, -2.04115584e-01, -2.32774600e-01,
       -7.26526678e-02,  3.32365967e-02,  2.06136674e-01,  4.62973379e-02,
        1.05287507e-01, -2.81580329e-01, -2.43881140e-02, -3.42152029e-01,
        5.85510656e-02, -4.86458302e-01, -1.17712154e-03, -1.87248200e-01,
        5.94154954e-01, -1.98881194e-01,  1.11001715e-01, -4.58162397e-01,
        6.58564866e-02, -3.07629462e-02,  8.31438303e-02,  4.48260531e-02,
       -9.27313864e-02, -2.32178941e-01, -1.90088719e-01, -3.10660571e-01,
        4.06713337e-01,  4.05026585e-01, -2.91027874e-01, -2.98203200e-01,
        3.52568030e-01, -9.84086320e-02, -3.03286403e-01,  4.16820869e-02,
        2.79600061e-02, -1.40333429e-01, -4.39803809e-01, -2.57664829e-01,
        4.15826850e-02,  2.05407649e-01,  4.84345518e-02,  3.13899428e-01,
       -3.49204600e-01, -5.82581520e-01, -1.85782760e-01, -8.17709118e-02,
        8.72678086e-02,  1.29466102e-01,  3.51160467e-02,  7.42424205e-02,
       -2.00928822e-01, -2.90355682e-01, -4.94455993e-02,  6.22397661e-03,
       -1.52563408e-01, -3.91677588e-01, -2.09093764e-01, -1.14558991e-02,
       -2.85773724e-01, -4.43756670e-01,  3.71447206e-01,  2.24346980e-01,
       -2.20691055e-01,  2.65006304e-01, -1.55755514e-02,  4.87195142e-03,
        1.91860050e-01,  1.20959841e-01,  3.34893078e-01, -1.05380826e-01,
       -2.66472716e-02, -8.85563344e-02,  6.37625873e-01, -2.66657144e-01,
       -1.18155330e-01, -1.42098710e-01, -2.69173145e-01, -3.71462777e-02,
        4.83392254e-02,  5.47190942e-02,  9.97357592e-02, -1.12278402e-01,
        1.83545321e-01, -2.58381277e-01,  1.51072949e-01, -2.97399640e-01,
        1.59399450e-01,  3.20883751e-01, -9.40045118e-02, -5.45633547e-02,
       -8.31284299e-02,  4.41325396e-01, -1.55379698e-01,  3.79332811e-01,
       -1.40745401e-01,  1.77665323e-01, -9.75979567e-02,  6.95721209e-02,
       -5.64450920e-02, -1.67214766e-01, -3.41909051e-01, -2.83418804e-01,
        1.41442046e-01,  1.48409247e-01,  3.43214214e-01,  6.48462474e-02,
        1.95699260e-02, -3.68682384e-01, -1.76143631e-01,  5.31481355e-02,
        1.92670792e-01,  3.46076906e-01,  2.22408473e-01,  1.72704503e-01,
        1.13365918e-01,  4.17984605e-01,  2.28447877e-02,  5.99751696e-02,
       -1.24198027e-01, -4.32108194e-02, -3.53497148e-01,  1.17510352e-02,
       -1.23593971e-01, -5.74785709e-01,  7.62676150e-02,  2.73524642e-01,
        6.00279421e-02,  1.77315734e-02,  1.38185054e-01, -3.05946946e-01,
        5.07561088e-01,  1.41227543e-01,  3.57851535e-01,  1.02521025e-01,
       -7.60204196e-02,  1.32517412e-01,  4.02053893e-02, -3.11969727e-01,
       -2.63988107e-01, -7.03858510e-02, -1.50089502e-01,  8.63525867e-02,
        1.24932542e-01, -4.68670458e-01, -1.40594408e-01,  4.76862192e-01,
       -3.84178162e-02, -1.92818880e-01,  1.87226295e-01,  4.19289559e-01,
       -1.20883651e-01, -5.43239564e-02,  1.07319402e-02,  3.50385252e-03,
        3.43751311e-01,  7.58460090e-02, -6.99715838e-02, -1.67494759e-01,
       -9.42447931e-02,  1.06128402e-01, -1.62058305e-02,  4.32305306e-01,
        7.54823908e-02, -3.13513130e-01, -3.62764508e-01, -3.58962357e-01,
        2.07204700e-01,  4.66701426e-02,  6.44742250e-02,  4.07988206e-02,
        2.16780171e-01, -6.00378821e-03, -1.37872860e-01,  3.50705683e-01,
        4.17392582e-01, -1.68873161e-01,  5.30223511e-02,  1.13812245e-01,
       -9.14959759e-02,  2.06823602e-01, -2.86784708e-01, -2.95684487e-01,
       -1.48443669e-01, -2.76947953e-02,  1.64330021e-01, -1.43663645e-01,
        1.70178153e-02, -2.30404779e-01, -2.28912264e-01,  1.75594747e-01,
       -3.03562909e-01,  6.60676211e-02,  7.70484880e-02,  1.25259265e-01,
       -3.67534995e-01, -2.33395815e-01, -1.72100458e-02,  3.35053146e-01,
       -3.76563132e-01, -9.68274921e-02, -2.17907168e-02, -7.80378699e-01,
        1.53753355e-01,  1.01778232e-01, -1.96868271e-01,  4.29562144e-02,
       -1.41489744e-01,  2.04634443e-02,  1.39570817e-01, -1.79884270e-01,
        3.06650102e-02, -6.32397085e-02,  1.60029545e-01,  6.42561391e-02,
        2.68880576e-01, -1.25637382e-01,  6.67060986e-02,  5.66682220e-01,
       -1.79461449e-01,  8.65146443e-02, -1.23139195e-01,  9.39582363e-02,
       -3.15414697e-01, -5.22285640e-01,  1.54819950e-01, -5.12535393e-01,
       -1.15436181e-01, -1.98537949e-02,  2.92577371e-02, -1.33642405e-02,
       -5.68517186e-02,  3.81593034e-02, -1.96998045e-01, -6.62248507e-02,
       -4.47144061e-02, -5.21122813e-02, -1.13417173e-03, -5.28462455e-02,
        1.89160600e-01, -2.55561292e-01,  8.36480334e-02, -1.34348497e-01,
        5.53108416e-02, -1.33286878e-01,  1.36912704e-01,  2.01110393e-01,
       -1.03323691e-01, -3.31513286e-01,  9.99042615e-02, -5.27536094e-01,
        7.37991706e-02,  2.91844010e-01, -4.45236899e-02, -2.09454715e-01,
        2.23787352e-01, -2.28347227e-01, -2.70839483e-01,  3.02473098e-01,
        1.11593306e-01,  2.08858490e-01,  2.29601026e-01,  1.86192274e-01,
        9.81634185e-02,  2.43903995e-01, -1.75136030e-01, -7.07518607e-02,
        5.18429816e-01,  2.61111796e-01, -5.54995947e-02,  9.59509760e-02,
        1.36965349e-01, -4.52254936e-02,  1.08696021e-01,  2.75711659e-02,
       -2.35339239e-01, -2.45896019e-02,  3.19933563e-01, -3.91070485e-01,
       -5.47538027e-02,  2.75256544e-01, -2.38335237e-01, -3.11554193e-01,
        2.84444749e-01,  4.41088617e-01, -3.82276863e-01, -1.13917686e-01,
        2.88921714e-01, -1.11010969e-01,  1.40502341e-02, -1.46340700e-02,
        1.42324671e-01, -1.60856441e-01,  2.57794678e-01, -9.76497233e-02,
        4.43745136e-01,  2.65695453e-01, -1.45535231e-01, -3.76501769e-01,
        2.01196805e-01,  2.63625085e-01,  1.41210973e-01,  6.46361172e-01,
       -2.71090150e-01,  2.22757161e-01, -2.58495361e-01, -1.62668556e-01,
        7.28517398e-02,  4.49503511e-01,  2.59264141e-01,  3.12427133e-02,
       -1.29327893e-01,  7.77065903e-02,  3.14742237e-01,  6.53949022e-01,
        1.00652672e-01,  1.95482418e-01,  1.70837492e-01,  1.35074958e-01,
        5.19774258e-01,  5.04844636e-02,  2.04156637e-01,  1.62802756e-01,
       -3.12241409e-02, -2.34825835e-01,  7.18507886e-01, -6.01789728e-02,
        8.68651941e-02,  1.97697014e-01,  1.24174237e-01,  5.95605731e-01,
        1.72425985e-01,  1.72548845e-01,  4.16919380e-01, -3.49776179e-01,
       -9.82857943e-02,  1.65822491e-01,  3.58925343e-01, -1.43119663e-01,
       -1.55700415e-01,  5.89764304e-02,  1.05678231e-01, -1.02054290e-02,
       -2.77458221e-01, -2.78051883e-01, -1.01242848e-01,  4.00244370e-02,
       -1.75566301e-02, -9.33056921e-02, -1.10234700e-01,  4.04566973e-01,
        2.56548464e-01, -1.57683324e-02, -6.89859688e-01, -1.25983611e-01,
        9.64178145e-03, -2.61619300e-01, -2.02679008e-01, -1.96608037e-01,
       -1.06095828e-01, -4.94363427e-01, -8.18516538e-02, -2.34242350e-01,
       -5.84817678e-02, -1.65549532e-01, -2.41450444e-02, -3.88003200e-01,
        3.63868207e-01,  9.43293944e-02,  2.20769510e-01, -2.54503340e-02,
        2.58350551e-01, -1.48419783e-01, -3.86305243e-01,  1.53318271e-01,
        2.80947089e-01,  2.77289916e-02, -1.47237197e-01,  1.78571150e-03,
       -1.50650367e-01,  3.62522244e-01, -2.86555558e-01,  3.70166451e-01,
       -4.77044970e-01,  1.55587822e-01, -1.42645538e-01,  9.11027789e-02,
        1.62747353e-01, -1.02501146e-01, -1.46893904e-01, -8.49254876e-02,
       -1.37248635e-01, -1.60076737e-01, -2.54166629e-02,  6.79814955e-04,
       -1.97389364e-01,  1.56347319e-01,  2.26483598e-01,  3.01064223e-01,
       -1.02658585e-01,  9.31802988e-02,  8.32294766e-03,  9.51032937e-02,
       -9.50170904e-02, -2.90117025e-01,  4.00186181e-02, -4.34238374e-01,
       -1.87085763e-01,  6.09676838e-02, -3.73848587e-01, -1.44035921e-01,
        7.65906051e-02,  3.47925514e-01,  1.12025574e-01, -1.62462685e-02,
        4.97680008e-01, -1.66686103e-01, -2.01417416e-01,  1.01609893e-01,
       -1.11325160e-01, -1.37727469e-01, -9.68218073e-02,  4.61517051e-02,
       -1.89185321e-01, -2.01947346e-01,  1.13105364e-02, -4.74422276e-02,
        4.64750901e-02,  1.16737276e-01,  3.00498337e-01, -2.70455092e-01]]

In [65]:
loaded_model.predict(sample)[0][0]



0.9974244