In [1]:
import os
import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

  _torch_pytree._register_pytree_node(


In [3]:
def load_model(model_dir):
    """Load model and tokenizer from the given directory."""
    print(f"Loading model from {model_dir}...")
    
    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    
    # Load the model
    model = AutoModelForCausalLM.from_pretrained(
        model_dir,
        device_map="auto",  # Use available GPU or CPU
        torch_dtype=torch.float16  # Use half precision to reduce memory usage
    )
    
    print("Model loaded successfully!")
    return model, tokenizer

In [4]:
def generate_text(model, tokenizer, prompt, max_length=100):
    """Generate text based on the prompt."""
    print(f"Generating text for prompt: '{prompt}'")
    
    # Encode the prompt
    inputs = tokenizer(prompt, return_tensors="pt")
    
    # Move inputs to the same device as the model
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Generate text
    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_length=max_length,
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id
        )
    
    # Decode the generated text
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text

In [5]:
def main():
    # Get the model directory
    model_dir = os.path.dirname(os.path.abspath(__file__))
    
    # Load model and tokenizer
    model, tokenizer = load_model(model_dir)
    
    print("\n===== MG MOTOR AI TEXT GENERATOR =====")
    print("Type 'exit' to quit the program")
    
    # Interactive loop
    while True:
        user_prompt = input("\nEnter your prompt: ")
        
        if user_prompt.lower() == 'exit':
            print("Exiting program...")
            break
        
        # Generate and display text
        generated_text = generate_text(model, tokenizer, user_prompt)
        print("\n--- Generated Text ---")
        print(generated_text)
        print("---------------------")

In [6]:
if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        print(f"An error occurred: {e}")

An error occurred: name '__file__' is not defined
