Load the data

In [4]:
import pandas as pd
import pickle
import config as c

df = pd.read_pickle(c.PATHS.OBJECTS_FOLDER / "transformed_data_popular.pickle")
with open(c.PATHS.OBJECTS_FOLDER / "embeddings_popular.pkl", "rb") as f:
    embeddings = pickle.load(f)

In [24]:
# drop all columns with "city_*" in the name
df = df.drop(columns=[col for col in df.columns if "city_" in col])

In [None]:
df.shape, embeddings.shape

((223152, 12), (223152, 768), (203012, 12), (203012, 768))

In [26]:
# store the embeddings (numpy array) in the dataframe
df["embeddings"] = embeddings.tolist()

Create train and test sets

In [27]:
from sklearn.model_selection import train_test_split

df = pd.concat([df_popular, df_unpopular])
X = df.drop(columns=["is_popular"])
y = df["is_popular"]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

Build keras neural network model to predict binary popularity

In [28]:
X_train_embeddings = X_train["embeddings"].values.tolist()
X_test_embeddings = X_test["embeddings"].values.tolist()

X_train = X_train.drop(columns=["embeddings"])
X_test = X_test.drop(columns=["embeddings"])

In [29]:
from keras.layers import Input, Dense, Flatten, concatenate
from keras.models import Model

# Define input layers for text data and embeddings
input_text = Input(shape=(X_train.shape[1],))
input_embed = Input(shape=(len(X_train_embeddings[0]),))

# Define dense layers for text data and embeddings
dense_text = Dense(64, activation='relu')(input_text)
dense_embed = Dense(64, activation='relu')(input_embed)

# Concatenate the dense layers
concat = concatenate([dense_text, dense_embed])

# Define output layer
output = Dense(1, activation='sigmoid')(concat)

# Define the model
model = Model(inputs=[input_text, input_embed], outputs=output)

# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Fit the model
model.fit([X_train, X_train_embeddings], y_train, epochs=10, batch_size=32, validation_data=([X_test, X_test_embeddings], y_test))
