In [1]:
# !conda install python-telegram-bot -y
# !pip install dialogflow

In [2]:
# -- coding: utf-8 -

In [3]:
import os
import dialogflow
import logging
import tensorflow as tf
import string
import numpy as np

from telegram import Update
from telegram.ext import Updater, CommandHandler, MessageHandler, Filters, CallbackContext

In [4]:
API_TOKEN = open('data/api_token.txt', 'r').read()
JSON_PROJECT = "data/nlp-bot-df-jmyl-63214f5d01e9.json"

DIALOGFLOW_PROJECT_ID = "nlp-bot-df-jmyl"
DIALOGFLOW_LANGUAGE_CODE = "ru"
SESSION_ID = "NLP_GB_project_bot"

WEIGHTS_GEN_MODEL_PATH = os.path.normpath("./data/checkpoints")
PATH_TO_TEXT = "data/comedy.txt"

EMBEDDING_DIM = 256
RNN_UNITS = 512

In [5]:
updater = Updater(token=API_TOKEN, use_context=True)
dispatcher = updater.dispatcher
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = JSON_PROJECT

In [6]:
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
    model = tf.keras.Sequential([
        tf.keras.layers.Embedding(vocab_size, embedding_dim,
                                  batch_input_shape=[batch_size, None]),
                                 
        tf.keras.layers.GRU(rnn_units,
                            return_sequences=True,
                            stateful=True,
                            recurrent_initializer='glorot_uniform'),

        tf.keras.layers.GRU(rnn_units,
                            return_sequences=True,
                            stateful=True,
                            recurrent_initializer='glorot_uniform'),

         tf.keras.layers.GRU(rnn_units,
                            return_sequences=True,
                            stateful=True,
                            recurrent_initializer='glorot_uniform'),
                                   
        tf.keras.layers.Dense(vocab_size)
    ])
    return model

In [7]:
text = open(PATH_TO_TEXT, 'rb').read().decode(encoding='1251')
vocab = sorted(set(text))

char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)

vocab_size = len(vocab)

In [8]:
model = build_model(vocab_size, EMBEDDING_DIM, RNN_UNITS, batch_size=1)
model.load_weights(tf.train.latest_checkpoint(WEIGHTS_GEN_MODEL_PATH))
model.build(tf.TensorShape([1, None]))

In [9]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        (1, None, 256)            33280     
_________________________________________________________________
gru (GRU)                    (1, None, 512)            1182720   
_________________________________________________________________
gru_1 (GRU)                  (1, None, 512)            1575936   
_________________________________________________________________
gru_2 (GRU)                  (1, None, 512)            1575936   
_________________________________________________________________
dense (Dense)                (1, None, 130)            66690     
Total params: 4,434,562
Trainable params: 4,434,562
Non-trainable params: 0
_________________________________________________________________


In [10]:
def generate_text(model, start_string, char2idx, temperature=1):

    num_generate = 600

    input_eval = [char2idx.get(start_string, 0)]
    input_eval = tf.expand_dims(input_eval, 0)

    text_generated = []

    model.reset_states()
    for i in range(num_generate):
        predictions = model(input_eval)
        predictions = tf.squeeze(predictions, 0)

        predictions = predictions / temperature
        predicted_id = tf.random.categorical(predictions, num_samples=1)[-1, 0].numpy()

        input_eval = tf.expand_dims([predicted_id], 0)

        text_generated.append(idx2char[predicted_id])

    return (start_string + ' ' + ' '.join(text_generated))

In [11]:
def startCommand(update: Update, context: CallbackContext):
    text = "Добрый день!\nЯ могу поболтать, а могу сочинить стихотворение в стиле Данте (команда /verse <Первое слово>)"
    update.message.reply_text(text)
    

def stopCommand(update: Update, context: CallbackContext):
    update.message.reply_text('Пока!')


def verseCommand(update: Update, context: CallbackContext):
    n_lines = 4
    text = " ".join([word for word in update.message.text.split() if word != "/verse"])
    verse = generate_text(model, start_string=text, char2idx=char2idx, temperature=1).split("\n")
    update.message.reply_text("\n".join(verse[:n_lines + 1]))
    
    
def textMessage(update: Update, context: CallbackContext):
    user_text = ''.join([ch for ch in list(update.message.text) if ch not in string.punctuation])

    session_client = dialogflow.SessionsClient()
    session = session_client.session_path(DIALOGFLOW_PROJECT_ID, SESSION_ID)
    text_input = dialogflow.types.TextInput(text=user_text, language_code=DIALOGFLOW_LANGUAGE_CODE)
    query_input = dialogflow.types.QueryInput(text=text_input)

    try:
        response = session_client.detect_intent(session=session, query_input=query_input)
    except InvalidArgument:
         raise

    text = response.query_result.fulfillment_text
    if text:
        update.message.reply_text(response.query_result.fulfillment_text)
    else:
        update.message.reply_text('Что?')

In [12]:
dispatcher.add_handler(CommandHandler("start", startCommand))
dispatcher.add_handler(CommandHandler("stop", stopCommand))
dispatcher.add_handler(CommandHandler("verse", verseCommand))
dispatcher.add_handler(MessageHandler(Filters.text, textMessage))

In [13]:
updater.start_polling(clean=True)
updater.idle()