In [None]:
import keras 
keras.__version__

In [None]:
from keras.datasets import cifar10
(X_train, y_train), (X_test, y_test) = cifar10.load_data()

In [None]:
from keras.utils import to_categorical

y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
print(y_train.shape,  y_test.shape)

In [None]:
from sklearn.model_selection import train_test_split

X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.3, random_state=2045)

print(X_train.shape, y_train.shape, '/', X_valid.shape, y_valid.shape, '/', X_test.shape, y_test.shape)

In [None]:
from keras.applications import VGG16

conv_base = VGG16(weights='imagenet', include_top = False, input_shape=(32,32, 3))   ######### input_shape #########
conv_base.summary()

In [None]:
from keras.preprocessing.image import ImageDataGenerator
import numpy as np

datagen = ImageDataGenerator(rescale = 1./255)
batch_size =32  ########## 32,32,3 shape에서의 32

def extract_features(x, y, sample_count):
  features = np.zeros(shape = (sample_count, 1, 1, 512)) ################ summary 마지막 노드 확인####################
  labels = np.zeros(shape = (sample_count, 10))              ################ binary or categorical  #####

  generator = datagen.flow(x, y)  ####### binary or categorical ###########

  i = 0
  for input_batch, label_batch in generator:
    features_batch = conv_base.predict(input_batch)
    features[ i*batch_size : (i+1)*batch_size ] = features_batch
    labels [ i*batch_size : (i+1)*batch_size ] = label_batch
    i += 1
    if i * batch_size >= sample_count:
      break
  return features, labels

X_train, y_train = extract_features(X_train, y_train, 35000)
X_valid, y_valid = extract_features(X_valid, y_valid, 15000)
X_test, y_test = extract_features(X_test, y_test, 10000)

print(X_train.shape, y_train.shape)
print(X_valid.shape, y_valid.shape)
print(X_test.shape, y_test.shape)

In [None]:
# def extract_features(x, y, sample_count):
#   features = np.zeros(shape =(sample_count, 1, 1, 512))
#   labels = np.zeros(shape =(sample_count, 10))

#   features_batch = conv_base.predict(x)
#   features_batch = features
#   labels = y
#   return features, labels



# X_train, y_train = extract_features(X_train, y_train, 35000)
# X_valid, y_valid = extract_features(X_valid, y_valid, 15000)
# X_test, y_test = extract_features(X_test, y_test, 10000)

# print(X_train.shape, y_train.shape)
# print(X_valid.shape, y_valid.shape)
# print(X_test.shape, y_test.shape)

In [None]:
X_train = np.reshape(X_train, (35000, 1*1*512))       ########## 데이터 크기 바꿔주기
X_valid = np.reshape(X_valid, (15000, 1*1*512))
X_test = np.reshape(X_test, (10000, 1*1*512))

X_train.shape, X_valid.shape, X_test.shape

In [None]:
from keras import models, layers

model = models.Sequential()
model.add(layers.Dense(256, activation ='relu', input_dim=(1*1*512)))  ######### input_dim #################
model.add(layers.Dropout(0.5))
model.add(layers.Dense(10, activation='softmax')) ######### sigmoid or softmax ############

model.summary()

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])  ############# binary or categorical ########

In [None]:
%%time
Hist = model.fit(X_train, y_train, epochs=100, batch_size=20, validation_data=(X_valid, y_valid))

In [None]:
loss, accuracy = model.evaluate(X_test, y_test)

print('Loss = {:.5f}'.format(loss))
print('Accuracy = {:.5f}'.format(accuracy))

In [None]:
import matplotlib.pyplot as plt

epochs = range(1, len(Hist.history['loss']) + 1)
plt.figure(figsize = (9, 6))
plt.plot(epochs, Hist.history['loss'])
plt.plot(epochs, Hist.history['val_loss'])
plt.title('Training & Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(['Training Loss', 'Validation Loss'])
plt.grid()
plt.show()

epochs = range(1, len(Hist.history['loss']) + 1)
plt.figure(figsize = (9, 6))
plt.plot(epochs, Hist.history['accuracy'])
plt.plot(epochs, Hist.history['val_accuracy'])
plt.title('Training & Validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(['Training accuracy', 'Validation accuracy'])
plt.grid()
plt.show()