Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

supporting different architectures - vggnet, resnet #297

Merged
merged 14 commits into from
Jul 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions libra/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,8 @@ def convolutional_query(self,
image_column=None,
test_size=0.2,
augmentation=True,
custom_arch=None,
pretrained=None,
epochs=10,
height=None,
width=None):
Expand Down Expand Up @@ -738,6 +740,8 @@ def convolutional_query(self,
image_column=image_column,
training_ratio=1 - test_size,
augmentation=augmentation,
custom_arch=custom_arch,
pretrained=pretrained,
epochs=epochs,
height=height,
width=width)
Expand Down
179 changes: 120 additions & 59 deletions libra/query/feedforward_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
set_distinguisher,
already_processed)
from libra.preprocessing.data_reader import DataReader
from keras.models import Sequential
from keras.layers import (Dense, Conv2D, Flatten, MaxPooling2D, Dropout)
from keras import Model
from keras.models import Sequential, model_from_json
from keras.layers import (Dense, Conv2D, Flatten, MaxPooling2D, Dropout, GlobalAveragePooling2D)
from keras.applications import VGG16, VGG19, ResNet50, ResNet101, ResNet152

import pandas as pd
import json
from libra.query.supplementaries import save, generate_id
from keras.preprocessing.image import ImageDataGenerator
from sklearn.preprocessing import OneHotEncoder
Expand Down Expand Up @@ -180,8 +184,8 @@ def regression_ann(
print((" " * 2 * counter) + "| " + ("".join(word.ljust(col_width)
for word in row)) + " |")
datax = []
#while all(x > y for x, y in zip(losses, losses[1:])):
while (len(losses)<=2 or losses[len(losses)-1] < losses[len(losses)-2]):
# while all(x > y for x, y in zip(losses, losses[1:])):
while (len(losses) <= 2 or losses[len(losses) - 1] < losses[len(losses) - 2]):
model = get_keras_model_reg(data, i)
history = model.fit(
X_train,
Expand Down Expand Up @@ -290,7 +294,7 @@ def classification_ann(instruction,

X_train = data['train']
X_test = data['test']

if num_classes > 2:
# ANN needs target one hot encoded for classification
one_hot_encoder = OneHotEncoder()
Expand All @@ -299,8 +303,8 @@ def classification_ann(instruction,
np.reshape(
y.values,
(-1,
1))).toarray(),
columns = one_hot_encoder.get_feature_names())
1))).toarray(),
columns=one_hot_encoder.get_feature_names())

y_train = y.iloc[:len(X_train)]
y_test = y.iloc[len(X_train):]
Expand Down Expand Up @@ -360,7 +364,7 @@ def classification_ann(instruction,
losses.append(history.history[maximizer]
[len(history.history[maximizer]) - 1])
accuracies.append(history.history['val_accuracy']
[len(history.history['val_accuracy']) - 1])
[len(history.history['val_accuracy']) - 1])
# keeps running model and fit functions until the validation loss stops
# decreasing

Expand All @@ -373,8 +377,8 @@ def classification_ann(instruction,
print((" " * 2 * counter) + "| " + ("".join(word.ljust(col_width)
for word in row)) + " |")
datax = []
#while all(x < y for x, y in zip(accuracies, accuracies[1:])):
while (len(accuracies)<=2 or accuracies[len(accuracies)-1] > accuracies[len(accuracies)-2]):
# while all(x < y for x, y in zip(accuracies, accuracies[1:])):
while (len(accuracies) <= 2 or accuracies[len(accuracies) - 1] > accuracies[len(accuracies) - 2]):
model = get_keras_model_class(data, i, num_classes)
history = model.fit(
X_train,
Expand Down Expand Up @@ -418,7 +422,7 @@ def classification_ann(instruction,
logger('->', "Training Accuracy: " + str(final_hist.history['accuracy']
[len(final_hist.history['val_accuracy']) - 1]))
logger('->', "Test Accuracy: " + str(final_hist.history['val_accuracy'][
len(final_hist.history['val_accuracy']) - 1]))
len(final_hist.history['val_accuracy']) - 1]))

# genreates appropriate classification plots by feeding all information
plots = {}
Expand Down Expand Up @@ -459,6 +463,8 @@ def convolutional(instruction=None,
image_column=None,
training_ratio=0.8,
augmentation=True,
custom_arch=None,
pretrained=None,
epochs=10,
height=None,
width=None):
Expand All @@ -471,7 +477,18 @@ def convolutional(instruction=None,

logger("Generating datasets for classes")

if pretrained:
if not height:
height = 224
if not width:
width = 224
if height != 224 or width != 224:
raise ValueError("For pretrained models, height must be 224 and width must be 224.")

if preprocess:
if custom_arch:
raise ValueError("If custom_arch is not None, preprocess must be set to false.")

read_mode_info = set_distinguisher(data_path, read_mode)
read_mode = read_mode_info["read_mode"]

Expand Down Expand Up @@ -525,54 +542,97 @@ def convolutional(instruction=None,
elif num_classes == 2:
loss_func = "binary_crossentropy"

logger("Creating convolutional neural network dynamically")
logger("Creating convolutional neural netwwork dynamically")

# Convolutional Neural Network
model = Sequential()
# model.add(
# Conv2D(
# 64,
# kernel_size=3,
# activation="relu",
# input_shape=input_shape))
# model.add(MaxPooling2D(pool_size=(2, 2)))
# model.add(Conv2D(64, kernel_size=3, activation="relu"))
# model.add(MaxPooling2D(pool_size=(2, 2)))
# model.add(Flatten())
# model.add(Dense(num_classes, activation="softmax"))
# model.compile(
# optimizer="adam",
# loss=loss_func,
# metrics=['accuracy'])
model.add(Conv2D(
filters=64,
kernel_size=5,
activation="relu",
input_shape=input_shape))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Conv2D(
filters=64,
kernel_size=3,
activation="relu"))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.25))
model.add(Conv2D(
filters=64,
kernel_size=3,
activation="relu"))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Flatten())
model.add(Dense(
units=256,
activation="relu"))
model.add(Dropout(0.25))
model.add(Dense(
units=num_classes,
activation="softmax"
))
# Build model based on custom_arch configuration if given
if custom_arch:
with open(custom_arch, "r") as f:
custom_arch_dict = json.load(f)
custom_arch_json_string = json.dumps(custom_arch_dict)
model = model_from_json(custom_arch_json_string)

# Build an existing state-of-the-art model
elif pretrained:

arch_lower = pretrained.get('arch').lower()

# If user specifies value of pretrained['weights'] as 'imagenet', weights pretrained on ImageNet will be used
if 'weights' in pretrained and pretrained.get('weights') == 'imagenet':
# Load ImageNet pretrained weights
if arch_lower == "vggnet16":
base_model = VGG16(include_top=False, weights='imagenet', input_shape=input_shape)
x = Flatten()(base_model.output)
x = Dense(4096)(x)
x = Dropout(0.5)(x)
x = Dense(4096)(x)
x = Dropout(0.5)(x)
pred = Dense(num_classes, activation='softmax')(x)
model = Model(base_model.input, pred)
elif arch_lower == "vggnet19":
base_model = VGG19(include_top=False, weights='imagenet', input_shape=input_shape)
x = Flatten()(base_model.output)
x = Dense(4096)(x)
x = Dropout(0.5)(x)
x = Dense(4096)(x)
x = Dropout(0.5)(x)
pred = Dense(num_classes, activation='softmax')(x)
model = Model(base_model.input, pred)
elif arch_lower == "resnet50":
base_model = ResNet50(include_top=False, weights='imagenet', input_shape=input_shape)
x = Flatten()(base_model.output)
x = GlobalAveragePooling2D()(base_model.output)
x = Dropout(0.5)(x)
pred = Dense(num_classes, activation='softmax')(x)
model = Model(base_model.input, pred)
elif arch_lower == "resnet101":
base_model = ResNet101(include_top=False, weights='imagenet', input_shape=input_shape)
x = GlobalAveragePooling2D()(base_model.output)
x = Dropout(0.5)(x)
pred = Dense(num_classes, activation='softmax')(x)
model = Model(base_model.input, pred)
elif arch_lower == "resnet152":
base_model = ResNet152(include_top=False, weights='imagenet', input_shape=input_shape)
x = GlobalAveragePooling2D()(base_model.output)
x = Dropout(0.5)(x)
pred = Dense(num_classes, activation='softmax')(x)
model = Model(base_model.input, pred)
else:
raise ModuleNotFoundError("arch \'" + pretrained.get('arch') + "\' not supported.")

else:
# Randomly initialized weights
if arch_lower == "vggnet16":
model = VGG16(include_top=True, weights=None, classes=num_classes)
elif arch_lower == "vggnet19":
model = VGG19(include_top=True, weights=None, classes=num_classes)
elif arch_lower == "resnet50":
model = ResNet50(include_top=True, weights=None, classes=num_classes)
elif arch_lower == "resnet101":
model = ResNet101(include_top=True, weights=None, classes=num_classes)
elif arch_lower == "resnet152":
model = ResNet152(include_top=True, weights=None, classes=num_classes)
else:
raise ModuleNotFoundError("arch \'" + pretrained.get('arch') + "\' not supported.")
else:
model = Sequential()
model.add(
Conv2D(
64,
kernel_size=3,
activation="relu",
input_shape=input_shape))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, kernel_size=3, activation="relu"))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(num_classes, activation="softmax"))

model.compile(
optimizer="adam",
loss=loss_func,
metrics=['accuracy'])

logger("Located image data")

if augmentation:
Expand All @@ -599,22 +659,23 @@ def convolutional(instruction=None,
batch_size=(32 if processInfo["test_size"] >= 32 else 1),
class_mode=loss_func[:loss_func.find("_")])


if epochs < 0:
raise BaseException("Number of epochs has to be greater than 0.")
logger('Training image model')
history = model.fit_generator(
X_train,
steps_per_epoch=X_train.n //
X_train.batch_size,
X_train.batch_size,
validation_data=X_test,
validation_steps=X_test.n //
X_test.batch_size,
X_test.batch_size,
epochs=epochs,
verbose=verbose)

logger('->', 'Final training accuracy: {}'.format(history.history['accuracy'][len(history.history['accuracy']) - 1]))
logger('->', 'Final validation accuracy: {}'.format(history.history['val_accuracy'][len(history.history['val_accuracy']) - 1]))
logger('->',
'Final training accuracy: {}'.format(history.history['accuracy'][len(history.history['accuracy']) - 1]))
logger('->', 'Final validation accuracy: {}'.format(
history.history['val_accuracy'][len(history.history['val_accuracy']) - 1]))
# storing values the model dictionary

logger("Stored model under 'convolutional_NN' key")
Expand Down
1 change: 0 additions & 1 deletion tests/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def compare(a, b):

return ordered, compare


ordered, compare = make_orderer()
unittest.defaultTestLoader.sortTestMethodsUsing = compare

Expand Down