In [1]:
from flask import Flask, request, jsonify
from transformers import GPT2Tokenizer, TFGPT2LMHeadModel, TFAutoModelForCausalLM, AutoTokenizer
import tensorflow as tf

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
app = Flask(__name__)

# Load the pre-trained GPT-2 and DialoGPT models
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = TFGPT2LMHeadModel.from_pretrained("gpt2")

dialoGPT_tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
dialoGPT_model = TFAutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")

Downloading (…)"tf_model.h5";: 100%|██████████| 498M/498M [03:00<00:00, 2.76MB/s] 
All model checkpoint layers were used when initializing TFGPT2LMHeadModel.

All the layers of TFGPT2LMHeadModel were initialized from the model checkpoint at gpt2.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFGPT2LMHeadModel for predictions without further training.
Downloading (…)neration_config.json: 100%|██████████| 124/124 [00:00<00:00, 10.4kB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 26.0/26.0 [00:00<00:00, 2.89kB/s]
Downloading (…)lve/main/config.json: 100%|██████████| 642/642 [00:00<00:00, 45.9kB/s]
Downloading (…)olve/main/vocab.json: 100%|██████████| 1.04M/1.04M [00:01<00:00, 752kB/s]
Downloading (…)olve/main/merges.txt: 100%|██████████| 456k/456k [00:09<00:00, 46.4kB/s]
Downloading (…)"tf_model.h5";: 100%|██████████| 1.42G/1.42G [08:55<00:00, 2.65MB/s]
All model checkpoint layers were used when initializing TFGPT2LMHeadMode

In [3]:
# Define a class to manage the conversation history
class Conversation:
    def __init__(self):
        self.history = ""

    def add_to_history(self, input_text):
        self.history += input_text + "\n"

    def get_history(self):
        return self.history

    def clear_history(self):
        self.history = ""


In [4]:
# Define a function to generate text from the ensemble model
def generate_text(conversation, input_text):
    # Add the user's input to the conversation history
    conversation.add_to_history(input_text)

    # Generate text using the GPT-2 model, using the conversation history as input
    gpt2_input = conversation.get_history()
    gpt2_input_ids = gpt2_tokenizer.encode(gpt2_input, return_tensors='tf')
    gpt2_outputs = gpt2_model.generate(
        gpt2_input_ids,
        max_length=100,
        do_sample=True,
        top_k=50,
        top_p=0.95,
        temperature=1.0,
        num_return_sequences=1,
    )
    gpt2_generated_text = gpt2_tokenizer.decode(gpt2_outputs[0], skip_special_tokens=True)

    # Generate text using the DialoGPT model, using the conversation history as input
    dialoGPT_input = conversation.get_history()
    dialoGPT_input_ids = dialoGPT_tokenizer.encode(dialoGPT_input, return_tensors='tf')
    dialoGPT_outputs = dialoGPT_model.generate(
        dialoGPT_input_ids,
        max_length=100,
        do_sample=True,
        top_k=50,
        top_p=0.95,
        temperature=1.0,
        num_return_sequences=1,
    )
    dialoGPT_generated_text = dialoGPT_tokenizer.decode(dialoGPT_outputs[0], skip_special_tokens=True)

    # Combine the generated text from the two models
    ensemble_text = gpt2_generated_text + dialoGPT_generated_text

    # Add the generated text to the conversation history
    conversation.add_to_history(ensemble_text)

    return ensemble_text


In [13]:
# Define an endpoint to handle incoming text messages
@app.route("/chat", methods=["POST"])
def chat():
    try:
        input_text = request.json["input_text"]
        conversation_id = request.json["conversation_id"]

        if conversation_id not in conversations:
            conversations[conversation_id] = Conversation()

        conversation = conversations[conversation_id]
        generated_text = generate_text(conversation, input_text)

        response = {"generated_text": generated_text}

        return jsonify(response)
    except Exception as e:
        # log the error message
        print(f"Error in chatbot2 endpoint: {e}")
        
        # return an error response to the client
        return jsonify({'error': 'Internal server error'})

if __name__ == "__main__":
    conversations = {}
    app.run()


 * Serving Flask app '__main__'
 * Debug mode: on


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
 * Restarting with stat


SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [11]:
%tb

SystemExit: 1