[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1G5PJEFHEiYaJQ29pk0lfnTKk1yLxFBWT?usp=sharing)

# Causal Language Models

Causal Language Models (also called autoregressive or decorder models) are transformers designed to generate text sequentially from left to right, where each token can only attend to previous tokens.

Here are some causal models that we can use from the TF/Keras Hub (2025/10 version)

| Model | Size | Description |
|-------|------|-------------|
| bloom_560m_multi | 559.21M | 24-layer Bloom model with hidden dimension of 1024. Trained on 45 natural languages and 12 programming languages. |
| bloomz_560m_multi | 559.21M | 24-layer Bloom model with hidden dimension of 1024. Finetuned on crosslingual task mixture (xP3) dataset. |
| bloom_1.1b_multi | 1.07B | 24-layer Bloom model with hidden dimension of 1536. Trained on 45 natural languages and 12 programming languages. |
| bloomz_1.1b_multi | 1.07B | 24-layer Bloom model with hidden dimension of 1536. Finetuned on crosslingual task mixture (xP3) dataset. |
| bloom_1.7b_multi | 1.72B | 24-layer Bloom model with hidden dimension of 2048. Trained on 45 natural languages and 12 programming languages. |
| bloomz_1.7b_multi | 1.72B | 24-layer Bloom model with hidden dimension of 2048. Finetuned on crosslingual task mixture (xP3) dataset. |
| bloom_3b_multi | 3.00B | 30-layer Bloom model with hidden dimension of 2560. Trained on 45 natural languages and 12 programming languages. |
| bloomz_3b_multi | 3.00B | 30-layer Bloom model with hidden dimension of 2560. Finetuned on crosslingual task mixture (xP3) dataset. |
| falcon_refinedweb_1b_en | 1.31B | 24-layer Falcon model (Falcon with 1B parameters), trained on 350B tokens of RefinedWeb dataset. |
| vault_gemma_1b_en | 1.04B | 1 billion parameter, 26-layer, VaultGemma model. |
| gemma_2b_en | 2.51B | 2 billion parameter, 18-layer, base Gemma model. |
| gemma_instruct_2b_en | 2.51B | 2 billion parameter, 18-layer, instruction tuned Gemma model. |
| gemma_1.1_instruct_2b_en | 2.51B | 2 billion parameter, 18-layer, instruction tuned Gemma model. The 1.1 update improves model quality. |
| code_gemma_1.1_2b_en | 2.51B | 2 billion parameter, 18-layer, CodeGemma model. This model has been trained on a fill-in-the-middle (FIM) task for code completion. The 1.1 update improves model quality. |
| code_gemma_2b_en | 2.51B | 2 billion parameter, 18-layer, CodeGemma model. This model has been trained on a fill-in-the-middle (FIM) task for code completion. |
| gemma2_2b_en | 2.61B | 2 billion parameter, 26-layer, base Gemma model. |
| gemma2_instruct_2b_en | 2.61B | 2 billion parameter, 26-layer, instruction tuned Gemma model. |
| shieldgemma_2b_en | 2.61B | 2 billion parameter, 26-layer, ShieldGemma model. |
| gemma_7b_en | 8.54B | 7 billion parameter, 28-layer, base Gemma model. |
| gemma_instruct_7b_en | 8.54B | 7 billion parameter, 28-layer, instruction tuned Gemma model. |
| gemma_1.1_instruct_7b_en | 8.54B | 7 billion parameter, 28-layer, instruction tuned Gemma model. The 1.1 update improves model quality. |
| code_gemma_7b_en | 8.54B | 7 billion parameter, 28-layer, CodeGemma model. This model has been trained on a fill-in-the-middle (FIM) task for code completion. |
| code_gemma_instruct_7b_en | 8.54B | 7 billion parameter, 28-layer, instruction tuned CodeGemma model. This model has been trained for chat use cases related to code. |
| code_gemma_1.1_instruct_7b_en | 8.54B | 7 billion parameter, 28-layer, instruction tuned CodeGemma model. This model has been trained for chat use cases related to code. The 1.1 update improves model quality. |
| gemma2_9b_en | 9.24B | 9 billion parameter, 42-layer, base Gemma model. |
| gemma2_instruct_9b_en | 9.24B | 9 billion parameter, 42-layer, instruction tuned Gemma model. |
| shieldgemma_9b_en | 9.24B | 9 billion parameter, 42-layer, ShieldGemma model. |
| gemma2_27b_en | 27.23B | 27 billion parameter, 42-layer, base Gemma model. |
| gemma2_instruct_27b_en | 27.23B | 27 billion parameter, 42-layer, instruction tuned Gemma model. |
| shieldgemma_27b_en | 27.23B | 27 billion parameter, 42-layer, ShieldGemma model. |
| gemma3_270m | 268.10M | 270-million parameter (170m embedding, 100m transformer params) model, 18-layer, text-only designed for hyper-efficient AI, particularly for task-specific fine-tuning. |
| gemma3_instruct_270m | 268.10M | 270-million parameter (170m embedding, 100m transformer params) model, 18-layer, text-only, instruction-tuned model designed for hyper-efficient AI, particularly for task-specific fine-tuning. |
| gemma3_1b | 999.89M | 1 billion parameter, 26-layer, text-only pretrained Gemma3 model. |
| gemma3_instruct_1b | 999.89M | 1 billion parameter, 26-layer, text-only instruction-tuned Gemma3 model. |
| gemma3_4b_text | 3.88B | 4 billion parameter, 34-layer, text-only pretrained Gemma3 model. |
| gemma3_instruct_4b_text | 3.88B | 4 billion parameter, 34-layer, text-only instruction-tuned Gemma3 model. |
| gemma3_4b | 4.30B | 4 billion parameter, 34-layer, vision+text pretrained Gemma3 model. |
| gemma3_instruct_4b | 4.30B | 4 billion parameter, 34-layer, vision+text instruction-tuned Gemma3 model. |
| gemma3_12b_text | 11.77B | 12 billion parameter, 48-layer, text-only pretrained Gemma3 model. |
| gemma3_instruct_12b_text | 11.77B | 12 billion parameter, 48-layer, text-only instruction-tuned Gemma3 model. |
| gemma3_12b | 12.19B | 12 billion parameter, 48-layer, vision+text pretrained Gemma3 model. |
| gemma3_instruct_12b | 12.19B | 12 billion parameter, 48-layer, vision+text instruction-tuned Gemma3 model. |
| gemma3_27b_text | 27.01B | 27 billion parameter, 62-layer, text-only pretrained Gemma3 model. |
| gemma3_instruct_27b_text | 27.01B | 27 billion parameter, 62-layer, text-only instruction-tuned Gemma3 model. |
| gemma3_27b | 27.43B | 27 billion parameter, 62-layer, vision+text pretrained Gemma3 model. |
| gemma3_instruct_27b | 27.43B | 27 billion parameter, 62-layer, vision+text instruction-tuned Gemma3 model. |
| gpt2_base_en | 124.44M | 12-layer GPT-2 model where case is maintained. Trained on WebText. |
| gpt2_base_en_cnn_dailymail | 124.44M | 12-layer GPT-2 model where case is maintained. Finetuned on the CNN/DailyMail summarization dataset. |
| gpt2_medium_en | 354.82M | 24-layer GPT-2 model where case is maintained. Trained on WebText. |
| gpt2_large_en | 774.03M | 36-layer GPT-2 model where case is maintained. Trained on WebText. |
| gpt2_extra_large_en | 1.56B | 48-layer GPT-2 model where case is maintained. Trained on WebText. |
| llama2_7b_en | 6.74B | 7 billion parameter, 32-layer, base LLaMA 2 model. |
| llama2_instruct_7b_en | 6.74B | 7 billion parameter, 32-layer, instruction tuned LLaMA 2 model. |
| vicuna_1.5_7b_en | 6.74B | 7 billion parameter, 32-layer, instruction tuned Vicuna v1.5 model. |
| llama2_7b_en_int8 | 6.74B | 7 billion parameter, 32-layer, base LLaMA 2 model with activation and weights quantized to int8. |
| llama2_instruct_7b_en_int8 | 6.74B | 7 billion parameter, 32-layer, instruction tuned LLaMA 2 model with activation and weights quantized to int8. |
| llama3.2_1b | 1.50B | 1 billion parameter, 16-layer, based LLaMA 3.2 model. |
| llama3.2_instruct_1b | 1.50B | 1 billion parameter, 16-layer, instruction tuned LLaMA 3.2. |
| llama3.2_guard_1b | 1.50B | 1 billion parameter, 16-layer, based LLaMA 3.2 model fine-tuned for consent safety classification. |
| llama3.2_3b | 3.61B | 3 billion parameter, 26-layer, based LLaMA 3.2 model. |
| llama3.2_instruct_28 b_en | 3.61B | 3 billion parameter, 28-layer, instruction tuned LLaMA 3.2. |
| llama3_8b_en | 8.03B | 8 billion parameter, 32-layer, base LLaMA 3 model. |
| llama3_instruct_8b_en | 8.03B | 8 billion parameter, 32-layer, instruction tuned LLaMA 3 model. |
| llama3.1_8b | 8.03B | 8 billion parameter, 32-layer, based LLaMA 3.1 model. |
| llama3.1_instruct_8b | 8.03B | 8 billion parameter, 32-layer, instruction tuned LLaMA 3.1. |
| llama3.1_guard_8b | 8.03B | 8 billion parameter, 32-layer, LLaMA 3.1 fine-tuned for consent safety classification. |
| llama3_8b_en_int8 | 8.03B | 8 billion parameter, 32-layer, base LLaMA 3 model with activation and weights quantized to int8. |
| mistral_7b_en | 7.24B | Mistral 7B base model. |
| mistral_instruct_7b_en | 7.24B | Mistral 7B instruct model. |
| mistral_0.2_instruct_7b_en | 7.24B | Mistral 7B instruct version 0.2 model. |
| mistral_0.3_7b_en | 7.25B | Mistral 7B base version 0.3 model. |
| mistral_0.3_instruct_7b_en | 7.25B | Mistral 7B instruct version 0.3 model. |
| mixtral_8_7b_en | 46.70B | 32-layer Mixtral MoE model with 7 billion active parameters and 8 experts per MoE layer. |
| mixtral_8_instruct_7b_en | 46.70B | Instruction fine-tuned 32-layer Mixtral MoE model with 7 billion active parameters and 8 experts per MoE layer. |
| moonshine_tiny_en | 27.09M | Moonshine tiny model for English speech recognition. Developed by Useful Sensors for real-time transcription. |
| moonshine_base_en | 61.51M | Moonshine base model for English speech recognition. Developed by Useful Sensors for real-time transcription. |
| opt_125m_en | 125.24M | 12-layer OPT model where case is maintained. Trained on BookCorpus, CommonCrawl, Pile, and PushShift.io corpora. |
| opt_1.3b_en | 1.32B | 24-layer OPT model where case is maintained. Trained on BookCorpus, CommonCrawl, Pile, and PushShift.io corpora. |
| opt_2.7b_en | 2.70B | 32-layer OPT model where case is maintained. Trained on BookCorpus, CommonCrawl, Pile, and PushShift.io corpora. |
| opt_6.7b_en | 6.70B | 32-layer OPT model where case is maintained. Trained on BookCorpus, CommonCrawl, Pile, and PushShift.io corpora. |
| pali_gemma_3b_mix_224 | 2.92B | image size 224, mix fine tuned, text sequence length is 256. |
| pali_gemma_3b_224 | 2.92B | image size 224, pre-trained, text sequence length is 128. |
| pali_gemma_3b_mix_448 | 2.92B | image size 448, mix fine tuned, text sequence length is 512. |
| pali_gemma_3b_448 | 2.92B | image size 448, pre-trained, text sequence length is 512. |
| pali_gemma_3b_896 | 2.93B | image size 896, pre-trained, text sequence length is 512. |
| pali_gemma2_mix_3b_224 | 3.03B | 3 billion parameter, image size 224, 27-layer for SigLIP-So400m vision encoder and 26-layer Gemma2 2B language model. This model has been fine-tuned on a wide range of vision-language tasks and domains. |
| pali_gemma2_pt_3b_224 | 3.03B | 3 billion parameter, image size 224, 27-layer for SigLIP-So400m vision encoder and 26-layer Gemma2 2B language model. This model has been pre-trained on a mixture of datasets. |
| pali_gemma_2_ft_docci_3b_448 | 3.03B | 3 billion parameter, image size 448, 27-layer for SigLIP-So400m vision encoder and 26-layer Gemma2 2B language model. This model has been fine-tuned on the DOCCI dataset for improved descriptions with fine-grained details. |
| pali_gemma2_mix_3b_448 | 3.03B | 3 billion parameter, image size 448, 27-layer for SigLIP-So400m vision encoder and 26-layer Gemma2 2B language model. This model has been fine-tuned on a wide range of vision-language tasks and domains. |
| pali_gemma2_pt_3b_448 | 3.03B | 3 billion parameter, image size 448, 27-layer for SigLIP-So400m vision encoder and 26-layer Gemma2 2B language model. This model has been pre-trained on a mixture of datasets. |
| pali_gemma2_pt_3b_896 | 3.04B | 3 billion parameter, image size 896, 27-layer for SigLIP-So400m vision encoder and 26-layer Gemma2 2B language model. This model has been pre-trained on a mixture of datasets. |
| pali_gemma2_mix_10b_224 | 9.66B | 10 billion parameter, image size 224, 27-layer for SigLIP-So400m vision encoder and 42-layer Gemma2 9B language model. This model has been fine-tuned on a wide range of vision-language tasks and domains. |
| pali_gemma2_pt_10b_224 | 9.66B | 10 billion parameter, image size 224, 27-layer for SigLIP-So400m vision encoder and 42-layer Gemma2 9B language model. This model has been pre-trained on a mixture of datasets. |
| pali_gemma2_ft_docci_10b_448 | 9.66B | 10 billion parameter, 27-layer for SigLIP-So400m vision encoder and 42-layer Gemma2 9B language model. This model has been fine-tuned on the DOCCI dataset for improved descriptions with fine-grained details. |
| pali_gemma2_mix_10b_448 | 9.66B | 10 billion parameter, image size 448, 27-layer for SigLIP-So400m vision encoder and 42-layer Gemma2 9B language model. This model has been fine-tuned on a wide range of vision-language tasks and domains. |
| pali_gemma2_pt_10b_448 | 9.66B | 10 billion parameter, image size 448, 27-layer for SigLIP-So400m vision encoder and 42-layer Gemma2 9B language model. This model has been pre-trained on a mixture of datasets. |
| pali_gemma2_pt_10b_896 | 9.67B | 10 billion parameter, image size 896, 27-layer for SigLIP-So400m vision encoder and 42-layer Gemma2 9B language model. This model has been pre-trained on a mixture of datasets. |
| phi3_mini_4k_instruct_en | 3.82B | 3.8 billion parameters, 32 layers, 4k context length, Phi-3 model. The model was trained using the Phi-3 datasets which include both synthetic data and filtered publicly available website data, with an emphasis on high-quality and reasoning-dense properties. |
| phi3_mini_128k_instruct_en | 3.82B | 3.8 billion parameters, 32 layers, 128k context length, Phi-3 model. The model was trained using the Phi-3 datasets. |
| qwen2.5_0.5b_en | 494.03M | 24-layer Qwen model with 0.5 billion parameters. |
| qwen2.5_instruct_0.5b_en | 494.03M | Instruction fine-tuned 24-layer Qwen model with 0.5 billion parameters. |
| qwen2.5_3b_en | 3.09B | 36-layer Qwen model with 3.1 billion parameters. |
| qwen2.5_7b_en | 6.99B | 48-layer Qwen model with 7 billion parameters. |
| qwen2.5_instruct_32b_en | 32.76B | Instruction fine-tuned 64-layer Qwen model with 32 billion parameters. |
| qwen2.5_instruct_72b_en | 72.71B | Instruction fine-tuned 80-layer Qwen model with 72 billion parameters. |
| qwen3_0.6b_en | 596.05M | 28-layer Qwen3 model with 596 M parameters, optimized for efficiency and fast inference on resource-constrained devices. |
| qwen3_1.7b_en | 1.72B | 28-layer Qwen3 model with 1.72 billion parameters, offering a good balance between performance and resource usage. |
| qwen3_4b_en | 4.02B | 36-layer Qwen3 model with 4.02 billion parameters, offering improved reasoning capabilities and better performance than smaller variants. |
| qwen3_8b_en | 8.19B | 36-layer Qwen3 model with 8.19 billion parameters, featuring enhanced reasoning, coding, and instruction-following capabilities. |
| qwen3_14b_en | 14.77B | 40-layer Qwen3 model with 14.77 billion parameters, featuring advanced reasoning, coding, and multilingual capabilities. |
| qwen3_32b_en | 32.76B | 64-layer Qwen3 model with 32.76 billion parameters, featuring state-of-the-art performance across reasoning, coding, and general language tasks. |
| qwen1.5_moe_2.7b_en | 14.32B | 24-layer Qwen MoE model with 2.7 billion active parameters and 8 experts per MoE layer. |


In [1]:
import keras
import keras_hub
from keras_hub import samplers

import tensorflow as tf

In [2]:
pretrained_model_name = "gpt2_base_en"
causal_lm = keras_hub.models.CausalLM.from_preset(pretrained_model_name)
causal_lm.summary()

Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/3/download/config.json...


100%|██████████| 431/431 [00:00<00:00, 851kB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/3/download/model.weights.h5...


100%|██████████| 475M/475M [00:05<00:00, 90.8MB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/3/download/tokenizer.json...


100%|██████████| 618/618 [00:00<00:00, 1.27MB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/3/download/assets/tokenizer/vocabulary.json...


100%|██████████| 0.99M/0.99M [00:00<00:00, 5.05MB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/gpt2/keras/gpt2_base_en/3/download/assets/tokenizer/merges.txt...


100%|██████████| 446k/446k [00:00<00:00, 2.98MB/s]


## Generate text

Even when our model is already trained, we have several techniques to generate text that could affect to the quality of the generations.

In [3]:
prompt = "Once upon a time in a distant galaxy, the explorers discovered"

def generate_text(causal_lm, prompt, max_length=50, temperature=1.0, top_k=None, top_p=None):
    # Decide which sampler to use
    if top_p is not None:
        sampler = samplers.TopPSampler(p=top_p, temperature=temperature)
    elif top_k is not None:
        sampler = samplers.TopKSampler(k=top_k, temperature=temperature)
    else:
        sampler = samplers.GreedySampler(temperature=temperature)

    # Compile model with the sampler
    causal_lm.compile(sampler=sampler)

    # Generate text
    output = causal_lm.generate(prompt, max_length=max_length, strip_prompt=True)
    return output


### Temperature experiment

Controls the randomness of predictions by scaling the probability distribution before sampling.

In [10]:
print("=== Temperature sweep ===")
for temp in [0.2, 0.5, 1.0, 1.5]:
    text = generate_text(causal_lm, prompt, temperature=temp, top_k=50, top_p=None)
    print(f"Temp={temp} → {text}\n")

=== Temperature sweep ===
Temp=0.2 →  a planet called Teth. The planet was inhabited by a race of sentient beings called the Tethians. The Tethians were the descendants of the Tethians who had been

Temp=0.5 →  a strange planet. The planet was a planet in the midst of an interstellar war. The invaders had discovered a planet with a unique technology, and sought to destroy it. The explorers were

Temp=1.0 →  a group of alien races, of unknown origin. But how did they get here? The story of this enigmatic civilization turns to the face of the unknown, a tale of betrayal, deception

Temp=1.5 →  hundreds of small islands out-world. Today they're looking for what life on these distant islands could look like.


Inhabited by space colony types known as HOD,



### Top k experiment

Only consider the top k most likely tokens for sampling, then redistribute probability among them.

In [11]:
print("=== Top-k sweep ===")
for k in [1, 10, 50, 200]:
    text = generate_text(causal_lm, prompt, top_k=k, top_p=None)
    print(f"Top_k={k} → {text}\n")

=== Top-k sweep ===




Top_k=1 →  a planet called the "Planet of the Gods." The planet was a planet of great power and great wealth. The planet was a planet of great wealth and great power. The planet was





Top_k=10 →  the planet of the future. The young, beautiful world of Kain and its inhabitants are a paradise to be seen by the young explorers. The explorers have come to realize they have a

Top_k=50 →  hundreds of billions of years ago that Mars wasn't built in 20,000 years.

Today, even if our planet, our home planet, is 5,000 times more huge

Top_k=200 →  a civilization inhabited on a race of moons. Now the mighty civilizations upon the moon are all peaceful. To face off against what could be considered conquest from this alien race, the Orion System



### Top p experiment

Keep the smallest set of tokens whose cumulative probability ≥ p, then sample from that set.

In [12]:
print("=== Top-p sweep ===")
for p in [0.3, 0.6, 0.9, 0.99]:
    text = generate_text(causal_lm, prompt, temperature=1.0, top_k=None, top_p=p)
    print(f"Top_p={p} → {text}\n")

=== Top-p sweep ===
Top_p=0.3 →  a strange planet. The inhabitants of the planet were determined to discover the secrets of the planet. The planet was called Zilith, and it was a very beautiful planet. It was

Top_p=0.6 →  a world of extraordinary life, and its people were a force to be reckoned with. But a small group of brave and clever men led by a mysterious figure came into contact with the larger

Top_p=0.9 →  a fascinating and absolutely sublime location on Mars where the time has come to prevent humanity from stopping their civilization from having a nuclear meltdown that will ultimately destroy mankind. Even though this disaster, the

Top_p=0.99 →  a world that was traveling through time. And yet – dark for a race with no clear understanding of time – it went awry.

The halves of the largest of the stars



**Next steps**

To avoid toxic language, control certain topics that we should no answer, include certain words, we can improve the `generate_text` function

## Finetune causal language model

source: https://keras.io/examples/generative/gpt2_text_generation_with_keras_hub/

loss function: https://keras.io/api/optimizers/learning_rate_schedules/polynomial_decay/

In [4]:
!# Load chinese poetry dataset.
!git clone https://github.com/chinese-poetry/chinese-poetry.git

Cloning into 'chinese-poetry'...
remote: Enumerating objects: 7341, done.[K
remote: Counting objects: 100% (11/11), done.[K
remote: Compressing objects: 100% (8/8), done.[K
remote: Total 7341 (delta 7), reused 3 (delta 3), pack-reused 7330 (from 2)[K
Receiving objects: 100% (7341/7341), 236.99 MiB | 35.06 MiB/s, done.
Resolving deltas: 100% (5012/5012), done.
Updating files: 100% (2285/2285), done.


In [5]:
import os
import json

poem_collection = []
for file in os.listdir("chinese-poetry/全唐诗"):
    if ".json" not in file or "poet" not in file:
        continue
    full_filename = "%s/%s" % ("chinese-poetry/全唐诗", file)
    with open(full_filename, "r") as f:
        content = json.load(f)
        poem_collection.extend(content)

paragraphs = ["".join(data["paragraphs"]) for data in poem_collection]

In [6]:
print(paragraphs[0])

樂全老子如星日，真一仙人似鳳鸞。早歲光明均照耀，異時文彩避高寒。低回氣類追千劫，邂逅風流得二難。環堵蕭然清徹骨，却疑深雪卧袁安。


In [7]:
train_ds = (
    tf.data.Dataset.from_tensor_slices(paragraphs)
    .batch(16)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)

# Running through the whole dataset takes long, only take `500` and run 1
# epochs for demo purposes.
train_ds = train_ds.take(500)
num_epochs = 1

learning_rate = keras.optimizers.schedules.PolynomialDecay(
    5e-4,
    decay_steps=train_ds.cardinality() * num_epochs,
    end_learning_rate=0.0,
)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [8]:
%%time

causal_lm.compile(
    optimizer=keras.optimizers.Adam(learning_rate),
    loss=loss,
    weighted_metrics=["accuracy"],
)

causal_lm.fit(train_ds, epochs=num_epochs)

[1m500/500[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m265s[0m 356ms/step - accuracy: 0.2486 - loss: 0.4986
CPU times: user 3min 38s, sys: 6.81 s, total: 3min 45s
Wall time: 4min 27s


<keras.src.callbacks.history.History at 0x7fe5270d5970>

In [9]:
output = generate_text(causal_lm, "昨夜雨疏风骤", top_k=None, top_p=0.3)
print(output)

，暮推林林書爲。江南曾翁
