In [1]:
import os, shutil
import numpy as np
import pandas as pd
import tensorflow as tf

from PIL import Image
from tensorflow.keras import applications
from tensorflow.keras import optimizers
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.models import model_from_json
from tensorflow.keras.layers import Input, Dense, BatchNormalization, Dropout, Flatten, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
TEST_DIR = './../../DataMining/AllData/filtered/images/'
MODEL_DIR = '../model'
OUTPUT_DIR = 'barplots'

In [4]:
testdf = pd.read_csv(MODEL_DIR + '/test.csv')

test_datagen = ImageDataGenerator(rescale=1./255.)

test_generator = test_datagen.flow_from_dataframe(
    dataframe = testdf,
    directory = TEST_DIR,
    x_col = "chart",
    y_col = None,
    batch_size = 1,
    seed = 42,
    shuffle = False,
    class_mode = None,
    target_size = (224, 224))

Found 9709 validated image filenames.


In [5]:
json_file = open(MODEL_DIR + '/model.json', 'r')
loaded_model_json = json_file.read()
json_file.close()

loaded_model = model_from_json(loaded_model_json)

# load weights into new model
loaded_model.load_weights(MODEL_DIR + '/model.h5')
print("Loaded model from disk")

Loaded model from disk


In [None]:
test_generator.reset()
STEP_SIZE_TEST=test_generator.n // test_generator.batch_size

pred = loaded_model.predict(test_generator,
                            steps=STEP_SIZE_TEST,
                            verbose=1)

predicted_class_indices = np.argmax(pred, axis=1)



In [None]:
class_indices = {'AreaGraph': 0,
 'BarGraph': 1,
 'BoxPlot': 2,
 'BubbleChart': 3,
 'FlowChart': 4,
 'LineGraph': 5,
 'Map': 6,
 'NetworkDiagram': 7,
 'ParetoChart': 8,
 'PieChart': 9,
 'ScatterGraph': 10,
 'TreeDiagram': 11,
 'VennDiagram': 12}

labels = class_indices
labels = dict((v,k) for k,v in labels.items())
predictions = [labels[k] for k in predicted_class_indices]

filenames = test_generator.filenames
results = pd.DataFrame({"chart": filenames, "type": predictions})

In [None]:
barplots = results.loc[results['type'] == 'BarGraph']
barplotlist = barplots['chart'].tolist()

In [None]:
if not os.path.exists(OUTPUT_DIR):
    os.mkdir(OUTPUT_DIR)
    
for file in os.listdir(TEST_DIR):
    if file in barplotlist:
        shutil.copy(os.path.join(TEST_DIR, file), OUTPUT_DIR)