In [2]:
import threading
import time
from flask import Flask, render_template_string, request
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Initialize Flask app
app = Flask(__name__)

# Load the fine-tuned model and tokenizer
model = AutoModelForCausalLM.from_pretrained("fine_tuned_DialoGPT_model").to('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model.eval()  # Set the model to evaluation mode

# Store chat history
chat_history = []

def ask_question(question, model, tokenizer, chat_history_ids=None):
    new_input_ids = tokenizer.encode(question + tokenizer.eos_token, return_tensors='pt').to(model.device)
    if chat_history_ids is None:
        chat_history_ids = new_input_ids
    else:
        chat_history_ids = torch.cat([chat_history_ids, new_input_ids], dim=-1)
    reply_ids = model.generate(chat_history_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id, num_return_sequences=1)
    response = tokenizer.decode(reply_ids[:, chat_history_ids.shape[-1]:][0], skip_special_tokens=True)
    return response, chat_history_ids

chat_history_ids = None

@app.route("/", methods=["GET", "POST"])
def index():
    global chat_history_ids, chat_history
    if request.method == "POST":
        user_input = request.form["user_input"]
        response, chat_history_ids = ask_question(user_input, model, tokenizer, chat_history_ids)
        chat_history.append({"user": user_input, "bot": response})
    
    # HTML template as a string
    html_template = '''
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>Chatbot</title>
        <style>
            body {
                font-family: Arial, sans-serif;
                background-color: #f4f4f4;
                margin: 0;
                padding: 0;
                display: flex;
                justify-content: center;
                align-items: center;
                height: 100vh;
            }
            .chat-container {
                background-color: #fff;
                width: 500px;
                padding: 20px;
                box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
                border-radius: 10px;
            }
            .chat-box {
                background-color: #f9f9f9;
                padding: 10px;
                border-radius: 5px;
                max-height: 300px;
                overflow-y: auto;
                margin-bottom: 15px;
            }
            .chat-message {
                margin-bottom: 10px;
                padding: 8px;
                border-radius: 5px;
            }
            .user-message {
                background-color: #d1e7dd;
                text-align: right;
            }
            .bot-message {
                background-color: #e2e3e5;
            }
            .chat-input {
                width: 100%;
                padding: 10px;
                border-radius: 5px;
                border: 1px solid #ccc;
            }
            .submit-btn {
                width: 100%;
                padding: 10px;
                border-radius: 5px;
                border: none;
                background-color: #007bff;
                color: white;
                font-weight: bold;
                cursor: pointer;
                margin-top: 10px;
            }
            .submit-btn:hover {
                background-color: #0056b3;
            }
        </style>
    </head>
    <body>
        <div class="chat-container">
            <h2>Chat with the Bot</h2>
            <div class="chat-box">
                {% for chat in chat_history %}
                    <div class="chat-message user-message">
                        <strong>You:</strong> {{ chat.user }}
                    </div>
                    <div class="chat-message bot-message">
                        <strong>Bot:</strong> {{ chat.bot }}
                    </div>
                {% endfor %}
            </div>
            <form method="POST">
                <input type="text" class="chat-input" name="user_input" placeholder="Ask something..." required>
                <input type="submit" class="submit-btn" value="Send">
            </form>
        </div>
    </body>
    </html>
    '''
    
    return render_template_string(html_template, chat_history=chat_history)

# Function to run Flask app in a separate thread
def run_flask():
    app.run(port=5000)

# Start Flask app in a new thread
thread = threading.Thread(target=run_flask)
thread.start()

time.sleep(2)  # Give Flask time to start


 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
127.0.0.1 - - [27/Aug/2024 15:33:38] "GET / HTTP/1.1" 200 -
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
127.0.0.1 - - [27/Aug/2024 15:33:48] "POST / HTTP/1.1" 200 -
127.0.0.1 - - [27/Aug/2024 15:34:10] "POST / HTTP/1.1" 200 -
