### Download Modules

In [None]:
!pip install transformers
!pip install annoy
!pip install -U sentence-transformers

### Import Modules

In [7]:
from transformers import GPT2Config
from transformers import GPT2LMHeadModel
from transformers import GPT2Tokenizer
from sklearn.metrics.pairwise import cosine_similarity
import torch
import random
from statistics import mean
from sentence_transformers import SentenceTransformer
from annoy import AnnoyIndex

### Define useful functions

In [None]:
def ANN(test_sentence,saved_ann):
  sentence_embeddings = model_sent.encode(test_sentence)
  f = sentence_embeddings.shape[0]
  u = AnnoyIndex(f, 'angular')
  u.load(f'{saved_ann}') # super fast, will just mmap the file
  closest, dist = u.get_nns_by_vector(sentence_embeddings, 1000, include_distances=True) # will find the 1000 nearest neighbors
  return closest[:10], dist[:10]

### Set up models

In [10]:
model_sent = SentenceTransformer('all-MiniLM-L6-v2')

model_name = "microsoft/DialoGPT-large"
#model_name = "microsoft/DialoGPT-medium"
# model_name = "microsoft/DialoGPT-small"

tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/642 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

### Set up chatbot

In [32]:
def gen_chatbox(topic):

  relevant_sentences = []

  with open(f'{topic}_relevent.txt', encoding='utf-8') as h:
      lines = h.readlines()

  for line in lines:
    relevant_sentences.append(line.strip("\n"))



  # Memory and bigram + personality
  text = random.choice(relevant_sentences)
  list_of_tokens = []
  memory = -4
  print('Robot 1', text)
  cosine = []
  count = 0
  for step in range(20):

      # encode the input and add end of string token
      input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors="pt")
      # print('input_ids',input_ids)
      list_of_tokens.append(input_ids)
      # concatenate new user input with chat history (if there is)
      bot_input_ids = torch.cat(list_of_tokens[memory:], dim=-1) if step > 0 else input_ids
      # print('bot_input_ids',bot_input_ids)
      # generate a bot response
      chat_history_ids = model.generate(
          bot_input_ids,
          max_length=1000,
          do_sample=True,
          top_p=0.9,
          top_k=50,
          no_repeat_ngram_size=2,
          temperature=0.75,
          pad_token_id=tokenizer.eos_token_id
      )
      # print the output

      output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
      text = output

      # Steer convo
      sent, dist = ANN(output, f'{topic}.ann')


      if mean(dist) > 1:
          count += 1
          if count == 4:
              number = random.randint(0, 9)
              steer_text = relevant_sentences[sent[number]]
              if len(steer_text) == 0:
                  continue
              else:
                  text = steer_text
              print("STEERED TOWARDS TOPIC")
              count = 0

      # if step > 2:
      #     cosine.append(calculate_cosine(list_of_tokens[-2][0], list_of_tokens[-1][0]))

      if step % 2 == 0:
          print('Robot 2', text)
      else:
          print('Robot 1', text)

In [43]:
gen_chatbox('marvel')

Robot 1 What with this movie and the excellent video game, your friendly neighbourhood Spider-Man is having a good run lately. 
Robot 2 You have no idea how many times I've watched this video and I have yet to see the part where Spider is actually a woman.
Robot 1 In the comics he is a girl.
Robot 2 Yeah, but he was a man before.
Robot 1 He was referred to as Spiderwoman in the original comic.
Robot 2 I'm pretty sure he's referred as a Spider Man in every other Spiderman book.
Robot 1 The one that was based on was Spider Woman
Robot 2 You're absolutely right. I didn't know that.
Robot 1 Yeah, it's a pretty common mistake.
Robot 2 I've actually read Spider Women and Spider Verse so I know how to read that one.
STEERED TOWARDS TOPIC
Robot 1 Mr fantastic - well, here I cant argue
Robot 2 What are you even doing, champ?
Robot 1 He's trying to be funny.
Robot 2 It's a terrible attempt.
STEERED TOWARDS TOPIC
Robot 1 Another fucking Spider-Man movie?? When will this superhero shit end!?
Robot