In [0]:
# Install required libraries (if not installed on the cluster yet)
%pip install transformers torch

In [0]:
from transformers import pipeline

# Load a pre-trained GPT-2 model
generator = pipeline("text-generation", model="gpt2")

In [0]:
def generate_text(prompt, max_new_tokens=250):
    # Text generation with truncation and explicit max_new_tokens
    result = generator(
        prompt,
        max_new_tokens=max_new_tokens,  # Limit the generated text length
        num_return_sequences=1,        # Return only 1 output
        truncation=True                # Explicitly enable truncation
    )
    return result[0]["generated_text"]

# Example usage: Test text generation
prompt = "In the world of AI,"
output = generate_text(prompt)
print(f"Generated Text: {output}")

In [0]:
%pip install langchain

In [0]:
from langchain.prompts import PromptTemplate
from langchain.schema.runnable import RunnableMap
from transformers import pipeline

# Load the GPT-2 model using Hugging Face (transformers pipeline)
generator = pipeline('text-generation', model='gpt2')

# Define a custom wrapper for Hugging Face model compatibility with LangChain
class DatabricksLLM:
    """Custom Runnable-compatible wrapper for Hugging Face models."""

    def __init__(self, generator):
        self.generator = generator

    def __call__(self, prompt) -> str:
        # Extract string if prompt is a StringPromptValue
        if hasattr(prompt, 'to_string'):
            prompt = prompt.to_string()
        result = self.generator(
            prompt,
            max_new_tokens=200,
            num_return_sequences=1
        )
        return result[0]["generated_text"]

# Create an instance of the custom DatabricksLLM
databricks_llm = DatabricksLLM(generator)

# Define a prompt template
prompt = PromptTemplate(
    input_variables=["topic"],
    template="Explain {topic} in detail."
)

# Chain the prompt and DatabricksLLM using the new `Runnable` API
chain = prompt | databricks_llm

# Example usage
response = chain.invoke({"topic": "gravity"})
print("Generated Response:\n", response)