In [4]:
import numpy as np
from keras.datasets import mnist
from keras.models import Model
from keras.layers import Dense, Input
import matplotlib.pyplot as plt
import keras

# DATAPATH = "../data/"
DATAPATH = "F:/data/"
encoding_dim = 20 # 压缩后的维度

In [5]:
country_feature = np.load(DATAPATH + "features/country_feature.npy")
genre_feature = np.load(DATAPATH + "features/genre_feature.npy")
audio_feature = np.load(DATAPATH + "features/audio_feature.npy")
video_feature = np.load(DATAPATH + "features/video_feature.npy")

# 图像特征归一化
max = video_feature.max()
min = video_feature.min()
video_feature = (video_feature-min)/(max-min)

In [6]:
# 将nan填充为0
audio_feature[np.isnan(audio_feature)] = 0.

In [7]:
country_input = Input(shape=(country_feature.shape[1],))
genre_input = Input(shape=(genre_feature.shape[1],))
audio_input = Input(shape=(audio_feature.shape[1],))
video_input = Input(shape=(video_feature.shape[1],))

In [8]:
input_layer = keras.layers.concatenate([country_input, genre_input, audio_input, video_input])

# encoder layers
encoded = Dense(128, activation='relu')(input_layer)
encoded = Dense(64, activation='relu')(encoded)
encoded = Dense(10, activation='relu')(encoded)
encoder_output = Dense(encoding_dim)(encoded)

# decoder layers
decoded = Dense(10, activation='relu')(encoder_output)
decoded = Dense(64, activation='relu')(decoded)
decoded = Dense(128, activation='relu')(decoded)
decoded = Dense(784, activation='relu')(decoded)

In [9]:
# output layers
'''
softmax [0,1]，每一行的和为1
sigmoid [0,1]，每一行的元素大小相互独立
tanh [-1,1]，每一行的元素大小互相独立
'''
country_output = Dense(country_feature.shape[1], activation='softmax')(decoded)
genre_output = Dense(genre_feature.shape[1], activation='sigmoid')(decoded)
audio_output = Dense(audio_feature.shape[1], activation='tanh')(decoded)
video_output = Dense(video_feature.shape[1], activation='sigmoid')(decoded)

In [10]:
# construct the autoencoder model
autoencoder = Model(input=[country_input, genre_input, audio_input, video_input],
                    output=[country_output, genre_output, audio_output, video_output])

# construct the encoder model for plotting
encoder = Model(input=[country_input, genre_input, audio_input, video_input], output=encoder_output)

# compile autoencoder
autoencoder.compile(optimizer='adam', loss='mse')

# training
autoencoder.fit([country_feature, genre_feature, audio_feature, video_feature],
                [country_feature, genre_feature, audio_feature, video_feature],
                epochs=5, batch_size=256, shuffle=True, verbose=1)

  after removing the cwd from sys.path.
  import sys


Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.callbacks.History at 0x21d651a1eb8>

In [11]:
test = autoencoder.predict([country_feature, genre_feature, audio_feature, video_feature])
test

[array([[6.92525282e-06, 1.10003266e-05, 1.51139795e-06, ...,
         9.92160022e-01, 4.15383056e-06, 2.90427192e-06],
        [6.23453502e-07, 1.09210578e-06, 1.00272274e-07, ...,
         9.97558594e-01, 3.36583355e-07, 2.20533494e-07],
        [3.03126879e-09, 6.52572840e-09, 2.45930332e-10, ...,
         9.99797761e-01, 1.30727829e-09, 7.30020377e-10],
        ...,
        [1.51467987e-07, 2.80833262e-07, 2.03698001e-08, ...,
         9.98754144e-01, 7.66793136e-08, 4.84691753e-08],
        [1.54409852e-09, 3.41422535e-09, 1.14895642e-10, ...,
         9.99851584e-01, 6.48472054e-10, 3.54346580e-10],
        [2.72372302e-09, 5.88447069e-09, 2.17649121e-10, ...,
         9.99807060e-01, 1.17244892e-09, 6.50444920e-10]], dtype=float32),
 array([[3.6567450e-05, 3.9815903e-05, 2.2053719e-06, ..., 4.0787458e-04,
         4.6491623e-05, 1.7309189e-04],
        [4.6491623e-06, 5.1558018e-06, 1.7881393e-07, ..., 8.4042549e-05,
         6.1690807e-06, 2.9861927e-05],
        [2.9802322e-08

In [12]:
test_encoder = encoder.predict([country_feature, genre_feature, audio_feature, video_feature])
test_encoder


array([[ -4.55892  ,  -7.8308864,  -2.1602757, ...,   1.6574103,
         -1.1895219,  -4.938837 ],
       [ -5.464244 ,  -9.463632 ,  -2.5906794, ...,   1.9835052,
         -1.3631779,  -5.997165 ],
       [ -7.5566506, -13.035078 ,  -3.5814118, ...,   2.7458096,
         -1.9304525,  -8.251015 ],
       ...,
       [ -6.02774  , -10.411155 ,  -2.8454287, ...,   2.1982532,
         -1.5104538,  -6.6507497],
       [ -7.8010726, -13.4939785,  -3.7082815, ...,   2.8311756,
         -1.9714314,  -8.511859 ],
       [ -7.576176 , -13.113962 ,  -3.598633 , ...,   2.7373507,
         -1.9174386,  -8.264103 ]], dtype=float32)