In [None]:
from collections import Counter
from sklearn.model_selection import train_test_split
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.models import Sequential
import csv, cv2, matplotlib.pyplot as plt, numpy as np, tensorflow as tf

In [None]:
filename = ".\\data.csv"

post_ids = []
labels = []

with open(filename, 'r') as fp:
	reader = csv.reader(fp)

	headings = next(reader)

	for row in reader:
		match row[2]:
			case "Safe":
				if np.random.random() < 0.08:
					post_ids.append(row[0])
					labels.append(row[2])
			case "Explicit":
				post_ids.append(row[0])
				labels.append(row[2])

In [None]:
all_imgs = np.array([cv2.imread(f"images\\{post_id}.jpg") for post_id in post_ids], dtype=np.float64) / 255
d = {"Safe": 0, "Questionable": 2, "Explicit": 1}
labels = np.array([d[label] for label in labels])
print(Counter(labels))

In [None]:
print(all_imgs.shape)
print(labels.shape)
print(all_imgs[:2, :2, :2])

In [None]:
training_imgs, test_imgs, training_labels, test_labels = train_test_split(all_imgs, labels, train_size=0.8)

In [None]:
model = Sequential()
model.add(tf.keras.Input(shape=(128, 128, 3)))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation="relu"))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(128, (3, 3), activation="relu"))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(512, activation="relu"))
model.add(Dropout(0.5))
model.add(Dense(3, activation="softmax"))

model.compile(optimizer="adam",
			  loss="sparse_categorical_crossentropy",
			  metrics=["accuracy"])

model.summary()

In [None]:
epochs = 5
batch_size = 32
history = model.fit(training_imgs, training_labels, epochs=epochs, batch_size=batch_size,
					validation_data=(test_imgs, test_labels)).history
# history = model.fit(training_imgs, training_labels, epochs=epochs, batch_size=batch_size, verbose=1).history

In [None]:
plt.plot(range(1, epochs + 1), history["accuracy"], label="training")
plt.plot(range(1, epochs + 1), history["val_accuracy"], label="test")
plt.title(f"{labels.size} Images")
plt.xlabel("epoch")
plt.ylabel("accuracy")
plt.xlim((1, epochs))
plt.ylim(0, 1)
plt.legend()
plt.show()