In [None]:
pip install -q keras-nlp

In [None]:
import keras_nlp
import tensorflow as tf
from tensorflow import keras
import time

In [None]:
# To speed up training and generation, we use preprocessor of length 128
# instead of full length 1024.
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset(
    "gpt2_base_en",
    sequence_length=128,
)
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset(
    "gpt2_base_en", preprocessor=preprocessor
)

In [None]:
pip install datasets

In [None]:
from datasets import load_dataset

dataset = load_dataset("merve/poetry")

In [None]:
reddit_ds = []

import os
for filename in os.listdir(os.getcwd()+"/dataset"):
   with open(os.path.join(os.getcwd()+"/dataset", filename), 'r') as f: # open in readonly mode
      reddit_ds.append(''.join(f.readlines()))

dataset = dataset.filter(lambda x: x['age'] == 'Love')

dataset.map(lambda x: reddit_ds.append(x['content']))

reddit_ds

In [None]:
train_ds = (
    tf.data.Dataset.from_tensor_slices(reddit_ds)
    .batch(32)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)

In [None]:
num_epochs = 150

# Linearly decaying learning rate.
learning_rate = keras.optimizers.schedules.PolynomialDecay(
    5e-5,
    decay_steps=train_ds.cardinality() * num_epochs,
    end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
gpt2_lm.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=loss,
    weighted_metrics=["accuracy"],
)

gpt2_lm.fit(train_ds, epochs=num_epochs)

In [None]:
start = time.time()

output = gpt2_lm.generate("Take me back to the night we met", max_length=200)
print("\nGPT-2 output:")
print(output)

end = time.time()
print(f"TOTAL TIME ELAPSED: {end - start:.2f}s")

In [None]:
import pickle

# save the model to disk
filename = 'english_love_poems.sav'
pickle.dump(gpt2_lm, open(filename, 'wb'))
# gpt2_lm.save("english_love_poems.h5")

In [None]:
loaded_model = pickle.load(open(filename, 'rb'))
output = loaded_model.generate("I love you", max_length=200)
print("\nGPT-2 output:")
print(output)