In [None]:

## Load NN

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_tokenizer_and_model(model_name_or_path):
  return GPT2Tokenizer.from_pretrained(model_name_or_path), \
    GPT2LMHeadModel.from_pretrained(model_name_or_path).to(DEVICE)

# Load model from local dir

tokenizer, model = load_tokenizer_and_model("models/aistalker/final/")

model.eval() #freeze gradient calc

print('Model was loaded')


In [None]:

## Define sentences function

def split_into_sentences(text):

        alphabets = "([А-Яа-я])"
        prefixes = "(Гн|Др|Ув|Тов)[.]"
        suffixes = "(Inc|Ltd|Jr|Sr|Co|тыс|млн|руб)"
        starters = "(Гн|Др|Ув|Тов|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
        acronyms = "([А-Я][.][А-Я][.](?:[А-Я][.])?)"
        websites = "[.](com|net|org|io|gov|ru|ch|su|ai|ua|kz|by)"
        digits = "([0-9])"
        
        text = ' <br>'.join(text.splitlines())
        
        text = text.replace("\...", "\.")
        
        text = re.sub(digits + "[.]" + digits, "\\1<prd>\\2", text)
        text = re.sub("([А-Яа-я])" + "[.]" + digits, "\\1<prd>\\2", text)
        text = re.sub("([A-Za-z])" + "[.]" + digits, "\\1<prd>\\2", text)
        text = re.sub("([А-Яа-я])" + "[.]" + "([А-Яа-я])", "\\1<prd>\\2", text)
        text = re.sub("([A-Za-z])" + "[.]" + "([A-Za-z])", "\\1<prd>\\2", text)
        text = re.sub("([а-я])" + "[.]" + "[ ]" + "([а-я])", "\\1<prd>\\2", text)
        text = " " + text + "  "
        text = text.replace("\n", " ")
        text = text.replace("<p>", "")
        text = re.sub(prefixes, "\\1<prd>", text)
        text = re.sub(websites, "<prd>\\1", text)
        if "Ph.D" in text: text = text.replace("Ph.D.", "Ph<prd>D<prd>")
        text = re.sub("\s" + alphabets + "[.] ", " \\1<prd> ", text)
        text = re.sub(acronyms + " " + starters, "\\1<stop> \\2", text)
        text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]", "\\1<prd>\\2<prd>\\3<prd>", text)
        text = re.sub(alphabets + "[.]" + alphabets + "[.]", "\\1<prd>\\2<prd>", text)
        text = re.sub(" " + suffixes + "[.] " + starters, " \\1<stop> \\2", text)
        text = re.sub(" " + suffixes + "[.]", " \\1<prd>", text)
        text = re.sub(" " + alphabets + "[.]", " \\1<prd>", text)
        if "”" in text: text = text.replace(".”", "”.")
        if "\"" in text: text = text.replace(".\"", "\".")
        if "!" in text: text = text.replace("!\"", "\"!")
        if "?" in text: text = text.replace("?\"", "\"?")
        text = text.replace("<prd>", ".")
        text = text.replace(".", ".<stop>")
        text = text.replace("?", "?<stop>")
        text = text.replace("!", "!<stop>")

        sentences = text.split("<stop>")
        sentences = sentences[:-1]
        sentences = [s.strip() for s in sentences]

        return sentences

print('Sentences function was defined')


In [None]:

## Define Generate functions

import re

def generate_story_start(
    model, 
    tok, 
    text,
    max_length = 350,
    top_k = 5,
    top_p = 0.95,
    temperature = 1.2,
    do_sample = True,
    num_beams = 3,
    no_repeat_ngram_size = 3,
    repetition_penalty = 2.,
    num_sentences = 4
    ):

    input_ids = tok.encode(text, return_tensors="pt").to(DEVICE)

    out = model.generate(
      input_ids,
      max_length=max_length,
      repetition_penalty=repetition_penalty,
      do_sample=do_sample,
      top_k=top_k, top_p=top_p, temperature=temperature,
      num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size
      )

    generated_content = list(map(tok.decode, out))[0]
    
    #input_ids_len = len(tokenizer.encode(text, return_tensors="pt").to(DEVICE).tolist()[0])
    
    last_text = '''— Твой путь начинается '''
    
    story_sentences = split_into_sentences(generated_content.split(last_text)[-1])#[0:num_sentences]
    story_sentences = story_sentences[:-1]
    
    #print(generated_content)
    #print(story_sentences)
    
    story_sentences = [sent for sent in story_sentences if "div" not in sent]
    story_sentences = [sent for sent in story_sentences if "def" not in sent]
    story_sentences = [sent for sent in story_sentences if "php" not in sent]
    story_sentences = [sent for sent in story_sentences if "http" not in sent]
    story_sentences = [sent for sent in story_sentences if "css" not in sent]
    story_sentences = [sent for sent in story_sentences if "Глава" not in sent]

    story_start = 'Твой путь начинается ' + ' '.join(story_sentences)
    
    story_start = re.sub('\s+', ' ', story_start)
    story_start = re.sub("[\<\[].*?[\>\]]", "", story_start)
    story_start = re.sub("[\{\[].*?[\}\]]", "", story_start)
    story_start = re.sub("\*", "", story_start)

    return story_start

def generate_story_actions(
    model, 
    tok, 
    text,
    max_length = 500,
    top_k = 5,
    top_p = 0.95,
    do_sample = True,
    temperature = 1.2,
    num_beams = 3,
    no_repeat_ngram_size = 3,
    repetition_penalty = 2.,
    last_text = None,
    num_sentences = 3
    ):

    input_ids = tok.encode(text, return_tensors="pt").to(DEVICE)

    out = model.generate(
      input_ids,
      max_length=max_length,
      repetition_penalty=repetition_penalty,
      do_sample=do_sample,
      top_k=top_k, top_p=top_p, temperature=temperature,
      num_beams=num_beams, no_repeat_ngram_size=no_repeat_ngram_size
      )

    generated_content = list(map(tok.decode, out))[0]
    
    #input_ids_len = len(tokenizer.encode(text, return_tensors="pt").to(DEVICE).tolist())
    #print(input_ids_len)
    
    story_sentences = split_into_sentences(generated_content.split(last_text)[-1])#[0:num_sentences]
    story_sentences = story_sentences[:-1]
    
    #print(generated_content)
    #print(story_sentences)
    
    story_sentences = [sent for sent in story_sentences if "div" not in sent]
    story_sentences = [sent for sent in story_sentences if "def" not in sent]
    story_sentences = [sent for sent in story_sentences if "php" not in sent]
    story_sentences = [sent for sent in story_sentences if "http" not in sent]
    story_sentences = [sent for sent in story_sentences if "css" not in sent]
    story_sentences = [sent for sent in story_sentences if "Глава" not in sent]

    story_action = ' '.join(story_sentences)
    
    story_action = re.sub('\s+', ' ', story_action)
    story_action = re.sub("[\<\[].*?[\>\]]", "", story_action)
    story_action = re.sub("[\{\[].*?[\}\]]", "", story_action)
    story_action = re.sub("\*", "", story_action)
    
    story_action = re.sub('\s+', ' ', story_action)
    story_action = re.sub("[\<\[].*?[\>\]]", "", story_action)
    story_action = re.sub("[\{\[].*?[\}\]]", "", story_action)
    story_action = re.sub("\*", "", story_action)

    return story_action

print('Generative functions were defined')


In [None]:
# Update history function

def update_message_history(df, message_details, df_name):

    # Update message history

    df.loc[len(df)] = message_details

    df.to_csv('./ai_stalker_bot_messages/' + df_name, index = False)

    return df

print('History function was defined')


In [None]:

## Prompts for few shot inferences

story_start_promt = '''— Твой путь начинается в баре "100 рентген", куда ты пришел пополнить арсенал оружия, сдать пару артефактов, выпить пива и послушать сталкеров.

— Твой путь начинается у НИИ "Агропром", где, судя по слухам, есть подземные сооружения и даже, возможно, остатки военной лаборатории.

— Твой путь начинается в Болотах. Вокруг, куда ни глянь, рыжеватая растительность и бочаги с водой. На горизонте виднеется крыша чьего-то домика.

— Твой путь начинается около Дикой территории, куда ты собрался пробраться завтра с рассветом, а сегодня нашел овраг под деревом в чахлом лесу и затаился на ночь.

— Твой путь начинается промозглой ночью, когда сквозь облака едва проглядывала бледно-жёлтая луна, ты поднял лежащий на коленях АКМ и пошел на свет аномалии.

— Твой путь начинается '''


In [None]:

## Run bot

import telebot
import pandas as pd
import time

# Создаем экземпляр бота

bot = telebot.TeleBot("...:...")

# Функция, обрабатывающая команду /start

@bot.message_handler(commands=["start"])
def start(message, res=False):
    
    # Text to the user
    
    manual = '''*******
    
    Советы для новых игроков:
    
    Писать лучше полными предложениями: Я пошел..., Я сказал: ..., ..., - обратился я к нему, Я направил ...

    Если предложение на оканчивается знаком препинания то модель допишет его за вас, как ей вздумается. Например, "Я вижу на горизонте лес, на его опушке", а модель довершит описание сцены.
    
    Если в своем представлении вы указали конкретные локации, персонажей, лут и т.д. будьте готовы, что система будет их часто упоминать.
    Хорошая идея - описать локации, монстров, предметы и т.п., которые хотите видеть в игре, и они точно будут.
    
    Бывает, что клиент подвисает и тогда нет фразы "Пишу историю...", через 10 секунд надо повторить ввод.
    
    Удачной охоты!
    
    *******
    '''
    
    bot.send_message(message.chat.id, manual)
    
    time.sleep(3)
    
    bot.send_message(message.chat.id, 'Опусти оружие, путник, и ответь, кто ты и зачем здесь.')
    
    # Get user id and create the unique message history for them
    
    message_user_id = str(message.from_user.id)
    
    # Create a local file with message logs for unique user, for each new game (after pressing /start)

    ai_stalker_bot_logs = pd.DataFrame(
        columns = ['chat_id', 'user_id', 'message_id', 'is_bot', 'message_datetime', 'message_text']
        )
    
    ai_stalker_bot_logs.to_csv('./ai_stalker_bot_messages/ai_stalker_bot_logs' + '_' + str(message_user_id) + '.csv'
                               , index = False
                              )

# Получение сообщений от юзера

@bot.message_handler(content_types=["text"])
def handle_text(message):
    
    # Make sure to cut input
    
    message_text = message.text
    
    try:
        
        input_ids = tokenizer.encode(message_text, return_tensors="pt").to(DEVICE)
        
        good_message = True
        
    except:
        
        good_message = False
    
    if not good_message:
        
        bot.send_message(message.chat.id, 'Что ты прошептал? Уши забило фиолетовым мхом, плохо слышу.')
        
    else:

        # Get user id and use the unique message history for them
        
        message_user_id = str(message.from_user.id)

        df_name = 'ai_stalker_bot_logs_' + message_user_id + '.csv'

        ai_stalker_bot_logs = pd.read_csv('./ai_stalker_bot_messages/' + df_name)

        # Check if the story has just started (zero row dataframe)

        len_ai_stalker_bot_logs = len(ai_stalker_bot_logs)

        if len_ai_stalker_bot_logs == 0:

            ## dataframe of length 0

            message_chat_id = str(message.chat.id)
            message_user_id = message_user_id
            message_message_id = str(message.message_id)
            message_is_bot = message.from_user.is_bot
            message_message_datetime = str(message.date)
            message_message_text = message_text[:500]

            # Update message history

            ai_stalker_bot_logs = update_message_history(
                ai_stalker_bot_logs
                , [
                    message_chat_id,
                    message_user_id,
                    message_message_id,
                    message_is_bot,
                    message_message_datetime,
                    message_message_text
                ]
                , df_name
            )

            # Text

            bot.send_message(message.chat.id, 'Жую жгучий пух и забиваю трубку из кости псевдо-филина, да ты не шелести особо.')

            # Create story start with NN (use story promts)

            story_start = generate_story_start(model, tokenizer, story_start_promt)

            bot.send_message(message.chat.id, story_start)

            time.sleep(7)

            bot.send_message(message.chat.id, 'Что думаешь делать дальше?')

            # Update message history

            ai_stalker_bot_logs = update_message_history(
                        ai_stalker_bot_logs
                        , [
                        str(message.chat.id),
                        str(777),
                        str(0),
                        True,
                        str(int(time.time())),
                        story_start[:400]
                        ]
                , df_name
                )

        else: 

            ## dataframe of length 2 and more

            # record player's reply

            message_chat_id = str(message.chat.id)
            message_user_id = message_user_id
            message_message_id = str(message.message_id)
            message_is_bot = message.from_user.is_bot
            message_message_datetime = str(message.date)
            message_message_text = message_text[:250]

            # Update message history

            ai_stalker_bot_logs = update_message_history(
                    ai_stalker_bot_logs
                    , [
                    message_chat_id,
                    message_user_id,
                    message_message_id,
                    message_is_bot,
                    message_message_datetime,
                    message_message_text
                ]
                , df_name
            )

            ## Generate infeence on the reply

            player_action_promt = ''

            len_ai_stalker_bot_logs = len(ai_stalker_bot_logs)

            if len_ai_stalker_bot_logs <= 9:

                for index, row in ai_stalker_bot_logs.iterrows():
                    if index <= 1:
                        player_action_promt += row['message_text'] + '\n\n'
                    else:
                        player_action_promt += row['message_text'] + ' '

            else:

                for index, row in ai_stalker_bot_logs.iterrows():
                    if index <= 1:
                        player_action_promt += row['message_text'] + '\n\n'
                    elif index >= len(ai_stalker_bot_logs) - 7:
                        player_action_promt += row['message_text'] + ' '

            #print(player_action_promt)

            # Text

            bot.send_message(message.chat.id, 'Пишу историю...')
            
            last_text = player_action_promt[-20:]

            story_action = generate_story_actions(model, tokenizer, text = player_action_promt, last_text = last_text)

            bot.send_message(message.chat.id, story_action)

            # Update message history

            ai_stalker_bot_logs = update_message_history(
                        ai_stalker_bot_logs
                        , [
                        str(message.chat.id),
                        str(777),
                        str(0),
                        str(True),
                        str(int(time.time())),
                        story_action[-250:]
                        ]
                , df_name
                )


### Starting bot

if __name__ == '__main__':
    
    while True:
        
        try:
            
            bot.polling(none_stop = True, interval = 1)

        except Exception as e:
            
            print(e)

            time.sleep(10)
