In [6]:
import numpy as np
import matplotlib.pyplot as plt
from progressbar import ProgressBar
import matplotlib.image as mpimg
import tensorflow as tf
from tensorflow.keras.preprocessing.image import img_to_array, load_img


In [7]:


def get_image_arrays(image_column, image_path):
    progressBar = ProgressBar()
    X = []

    for image_id in progressBar(image_column.values):
        image = load_img(image_path + image_id, target_size=(224, 224))
        image_array = img_to_array(image)

        X.append(image_array)

    X_array = np.asarray(X, dtype='float32')
    X_array /= 255.

    return X_array


def get_image_predictions(image_array, model_path):
    # Load the TFLite model and allocate tensors.
    interpreter = tf.lite.Interpreter(model_path=model_path)
    interpreter.allocate_tensors()

    # Get input and output tensors.
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    # Test the model on random input data.
    input_shape = input_details[0]['shape']
    input_data = image_array
    interpreter.set_tensor(input_details[0]['index'], input_data)

    interpreter.invoke()

    # The function `get_tensor()` returns a copy of the tensor data.
    # Use `tensor()` in order to get a pointer to the tensor.
    output_data = interpreter.get_tensor(output_details[0]['index'])

    return output_data


def show_image(image_id, image_path):
    image_id_dict = dict(image_id).values()
    image_id_string = list(image_id_dict)[0]
    img = mpimg.imread(image_path + image_id_string)
    plt.imshow(img, interpolation='nearest', aspect='auto')
    plt.show()


In [10]:

import pandas as pd
import numpy as np
import pickle


In [11]:
data_dir = r'E:\datasets\MADE\3_graduation\parthplc\archive\data\\'

train_path = data_dir + 'train.jsonl'
dev_path = data_dir + 'dev.jsonl'


train_data = pd.read_json(train_path, lines=True)
test_data = pd.read_json(dev_path, lines=True)
demo_data = pd.read_json(dev_path, lines=True)

In [13]:
demo_data.head(3)

Unnamed: 0,id,img,label,text
144,28951,img/28951.png,0,if the brim of your hat is flat and has a stic...


In [15]:
TFLITE_FILE_PATH = 'image_model.tflite'

demo_data = demo_data.sample(1)
y_true = demo_data['label']
image_id = demo_data['img']
text = demo_data['text']

image_id_dict = dict(image_id).values()
image_id_string = list(image_id_dict)[0]

In [18]:

image_array = get_image_arrays(image_id, data_dir)
image_prediction = get_image_predictions(image_array, TFLITE_FILE_PATH)
y_pred_image = np.argmax(image_prediction, axis=1)
print('Image Prediction Probabilities:')
print(image_prediction)

100% |########################################################################|

Image Prediction Probabilities:
[[0.5786022 0.4380305]]





In [21]:
# TFIDF Model
model = 'tfidf_model.pickle'
vectorizer = 'tfidf_vectorizer.pickle'
tfidf_model = pickle.load(open(model, 'rb'))
tfidf_vectorizer = pickle.load(open(vectorizer, 'rb'))
transformed_text = tfidf_vectorizer.transform(text)
text_prediction = tfidf_model.predict_proba(transformed_text)
y_pred_text = np.argmax(text_prediction, axis=1)
print('Text Prediction Probabilities:')
print(text_prediction)


Text Prediction Probabilities:
[[0.67821008 0.32178992]]


https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
  tfidf_vectorizer = pickle.load(open(vectorizer, 'rb'))
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [22]:
# Ensemble Probabilities
ensemble_prediction = np.mean(np.array([image_prediction, text_prediction]), axis=0)
y_pred_ensemble = np.argmax(ensemble_prediction, axis=1)
print(ensemble_prediction)

[[0.62840614 0.37991022]]


In [23]:

true_label = list(dict(y_true).values())[0]
predicted_label = y_pred_ensemble[0]

if true_label == 0:
    print('True Label: non-hateful')
if true_label == 1:
    print('True Label: hateful')

if predicted_label == 0:
    print('Predicted Label: non-hateful')
if predicted_label == 1:
    print('Predicted Label: hateful')


True Label: non-hateful
Predicted Label: non-hateful
