In [1]:
pip install transformers torch

Note: you may need to restart the kernel to use updated packages.


In [None]:
import torch
from transformers import pipeline, set_seed

class TextGenerator:
    """
    A simple text generation tool using Hugging Face's transformers library.
    """
    def __init__(self, model_name="gpt2"):
        """
        Initializes the text generator with a specified model.

        Args:
            model_name (str): The name of the pre-trained model to use from Hugging Face.
                              Common choices: "gpt2", "distilgpt2", "EleutherAI/gpt-neo-125M",
                                              "EleutherAI/gpt-neo-1.3B" (larger).
        """
        print(f"Loading text generation pipeline with model: {model_name}...")
        try:
            self.generator = pipeline('text-generation', model=model_name)
            print("Model loaded successfully!")
        except Exception as e:
            print(f"Error loading model {model_name}: {e}")
            print("Please check if the model name is correct and you have an internet connection.")
            print("Falling back to 'gpt2' model (if it failed).")
            self.generator = pipeline('text-generation', model="gpt2")

        set_seed(42) # For reproducibility of results

    def generate_text(self, prompt, max_length=100, num_return_sequences=1,
                      do_sample=True, temperature=0.7, top_k=50, top_p=0.95):
        """
        Generates text based on a given prompt.

        Args:
            prompt (str): The initial text prompt to base the generation on.
            max_length (int): The maximum total length of the generated text
                              (including the prompt).
            num_return_sequences (int): The number of independent sequences to generate.
            do_sample (bool): If True, uses sampling for generation; otherwise, uses greedy decoding.
                              Sampling is generally preferred for creative text.
            temperature (float): Controls the randomness of predictions. Lower values
                                 make the text more deterministic, higher values more random.
                                 (Only applies if do_sample is True).
            top_k (int): Filters out the k most likely next words.
                         (Only applies if do_sample is True).
            top_p (float): Filters out the smallest set of most likely next words
                           whose cumulative probability exceeds p.
                           (Only applies if do_sample is True).

        Returns:
            list: A list of dictionaries, where each dictionary contains the generated text
                  under the key 'generated_text'.
        """
        if not prompt:
            print("Warning: Empty prompt provided. Generating without a specific context.")

        print(f"\nGenerating text for prompt: '{prompt}'...")
        try:
            generated_sequences = self.generator(
                prompt,
                max_length=max_length,
                num_return_sequences=num_return_sequences,
                do_sample=do_sample,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p
            )
            print("Generation complete!")
            return generated_sequences
        except Exception as e:
            print(f"Error during text generation: {e}")
            return []

# --- How to Use the Text Generator ---
if __name__ == "__main__":
    # Initialize the generator
    # You can change "gpt2" to other models like "distilgpt2" for faster generation
    # or "EleutherAI/gpt-neo-1.3B" for potentially better quality (requires more RAM).
    generator = TextGenerator(model_name="gpt2")

    # --- Interactive Mode (Loop) ---
    print("\n--- Interactive Text Generation Mode ---")
    print("Type your prompt and press Enter. Type 'quit' to exit.")

    while True:
        user_prompt = input("\nEnter your prompt: ")
        if user_prompt.lower() == 'quit':
            break

        # Get generation parameters from user or use defaults
        try:
            max_len_input = input(f"Max length (default: {100}): ")
            max_length = int(max_len_input) if max_len_input else 100

            num_seq_input = input(f"Number of sequences (default: {1}): ")
            num_return_sequences = int(num_seq_input) if num_seq_input else 1

            do_sample_input = input(f"Use sampling (y/n, default: y): ").lower()
            do_sample = True if do_sample_input in ['y', 'yes', ''] else False

            temperature = 0.7
            top_k = 50
            top_p = 0.95

            if do_sample:
                temp_input = input(f"Temperature (0.1-2.0, default: {0.7}): ")
                temperature = float(temp_input) if temp_input else 0.7
                top_k_input = input(f"Top-K (default: {50}): ")
                top_k = int(top_k_input) if top_k_input else 50
                top_p_input = input(f"Top-P (default: {0.95}): ")
                top_p = float(top_p_input) if top_p_input else 0.95

        except ValueError:
            print("Invalid input for generation parameters. Using defaults.")
            max_length = 100
            num_return_sequences = 1
            do_sample = True
            temperature = 0.7
            top_k = 50
            top_p = 0.95

        results = generator.generate_text(
            user_prompt,
            max_length=max_length,
            num_return_sequences=num_return_sequences,
            do_sample=do_sample,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p
        )

        for i, result in enumerate(results):
            print(f"\n--- Generated Text {i+1} ---")
            print(result['generated_text'])
            print("-" * (len("--- Generated Text ---") + 4))

    print("\nExiting text generator. Goodbye!")

Loading text generation pipeline with model: gpt2...



Device set to use cpu


Model loaded successfully!

--- Interactive Text Generation Mode ---
Type your prompt and press Enter. Type 'quit' to exit.


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Both `max_new_tokens` (=256) and `max_length`(=100) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)



Generating text for prompt: 'artificial intellligence in healthcare'...
Generation complete!

--- Generated Text 1 ---
artificial intellligence in healthcare: What does this mean for you?

The study is published in the journal PLOS ONE.

The authors say this is the first time they've looked at the risk of cancer in patients who are taking a pill. The study looked at patients with known or suspected cancer. The study looked at a range of health care providers from the US to the UK. The study looked at patients with known or suspected cancer.

There are other indications that a large proportion of people who are taking a pill are taking a drug that is not properly regulated by the FDA.

This could make it hard for patients to get the treatment they need.

This is why the FDA says it has to develop a system that protects against side effects, which could mean taking a drug that's not regulated by the FDA.

The FDA says that if a drug is not regulated by the FDA, it can be sold for as lit