In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))


os.environ["KERAS_BACKEND"] = "tensorflow"

/kaggle/input/taylorswiftlyrics/taylor_swift_lyrics.csv
/kaggle/input/gpt2/keras/gpt2_base_en/2/config.json
/kaggle/input/gpt2/keras/gpt2_base_en/2/tokenizer.json
/kaggle/input/gpt2/keras/gpt2_base_en/2/metadata.json
/kaggle/input/gpt2/keras/gpt2_base_en/2/model.weights.h5
/kaggle/input/gpt2/keras/gpt2_base_en/2/assets/tokenizer/merges.txt
/kaggle/input/gpt2/keras/gpt2_base_en/2/assets/tokenizer/vocabulary.json


In [2]:
from tensorflow.data import Dataset, AUTOTUNE

from keras_hub.models import GPT2CausalLMPreprocessor, GPT2CausalLM
from keras_hub.samplers import TopPSampler
from keras.optimizers import Adam
from keras.optimizers.schedules import PolynomialDecay
from keras.losses import SparseCategoricalCrossentropy

In [3]:
Preprocessor = GPT2CausalLMPreprocessor.from_preset(
    "gpt2_base_en",
    sequence_length = 384, # 512 is too long and can cause repetition
)

In [4]:
Sampler = TopPSampler(
    p = 0.70,
    k = None,
    seed = None,
    temperature = 1.09,
)

In [5]:
GPT2Model = GPT2CausalLM.from_preset(
    "gpt2_base_en",
    preprocessor=Preprocessor
)

In [6]:
lyrics = pd.read_csv('/kaggle/input/taylorswiftlyrics/taylor_swift_lyrics.csv').dropna()

In [7]:
print(f"There are {len(lyrics)} examples.")

There are 198 examples.


In [8]:
dataset = Dataset.from_tensor_slices(lyrics["Lyrics"].values)
dataset = dataset.batch(9).cache().prefetch(AUTOTUNE)

In [9]:
EPOCHS = 22

In [10]:
scheduler = PolynomialDecay(
    5e-5,
    decay_steps=dataset.cardinality() * EPOCHS,
    end_learning_rate=0.0,
)

In [11]:
loss = SparseCategoricalCrossentropy(from_logits=True)

In [12]:
GPT2Model.compile(
    optimizer=Adam(scheduler),
    loss=loss,
    weighted_metrics=["accuracy"],
    sampler=Sampler,
)

GPT2Model.summary()

In [13]:
GPT2Model.fit(dataset, epochs=EPOCHS + 11)

Epoch 1/33
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m101s[0m 883ms/step - accuracy: 0.4927 - loss: 2.6895
Epoch 2/33
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 949ms/step - accuracy: 0.5362 - loss: 2.3573
Epoch 3/33
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 1s/step - accuracy: 0.5512 - loss: 2.2400
Epoch 4/33
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 1s/step - accuracy: 0.5621 - loss: 2.1556
Epoch 5/33
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 1s/step - accuracy: 0.5700 - loss: 2.0865
Epoch 6/33
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 1s/step - accuracy: 0.5798 - loss: 2.0266
Epoch 7/33
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 1s/step - accuracy: 0.5858 - loss: 1.9768
Epoch 8/33
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 1s/step - accuracy: 0.5928 - loss: 1.9214
Epoch 9/33
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━

<keras.src.callbacks.history.History at 0x7a8f8f0ec610>

In [14]:
test = GPT2Model.generate(
    "And I said, Romeo take me somewhere we can",
    max_length = 324,
)

print(test)

And I said, Romeo take me somewhere we can go again
And it was a quiet night in New York
And I've never had to look at you this way
But I knew it would end up like this
'Cause I'm so scared of dying
And you are the only thing that keeps me up
And I said, Romeo take me somewhere we can go again
And it was a quiet night in New York
And I've never had to look at you this way
But I know I'm gonna keep you
