In [52]:
import numpy as np
import torch
import json
import networkx as nx
from transformers import AutoTokenizer, AutoConfig,\
      T5ForConditionalGeneration, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForCausalLM
from scipy.special import softmax
import time

In [2]:
# download the models
cot_tokenizer = AutoTokenizer.from_pretrained("prakharz/DIAL-BART0")
cot_model = AutoModelForSeq2SeqLM.from_pretrained("prakharz/DIAL-BART0")
cot_model.load_state_dict(torch.load('./CoT/topic_extraction/model/topic_er2.pt'))

recommender_tokenizer = AutoTokenizer.from_pretrained("t5-large")
recommender_model = AutoModelForSeq2SeqLM.from_pretrained("t5-large")
recommender_model.load_state_dict(torch.load('./CoT/recommender/model/rec_er.pt'))

# guideliner_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
# guideliner = T5ForConditionalGeneration.from_pretrained("TrevorAshby/guideliner")

blen_tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill")
blen_model = AutoModelForSeq2SeqLM.from_pretrained("TrevorAshby/blenderbot-400M-distill")

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-large automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [17]:
def generate_cot(text_in):
    tok_text = cot_tokenizer(text_in, return_tensors='pt')
    gen_text = cot_model.generate(**tok_text, max_length=60)
    dec_text = cot_tokenizer.decode(gen_text[0], skip_special_tokens=True)
    return dec_text

def generate_recommendation(text_in):
    tok_text = recommender_tokenizer(text_in, return_tensors='pt')
    gen_text = recommender_model.generate(**tok_text, max_new_tokens=32)
    dec_text = recommender_tokenizer.decode(gen_text[0], skip_special_tokens=True)
    return dec_text

# def generate_guideline(text_in, suggested_topics):
#     guide_in_str =f'A: {text_in}| {suggested_topics}'
#     in_ids = guideliner_tokenizer(guide_in_str, max_length=512, padding='max_length', return_tensors='pt').input_ids
#     guideline_example = guideliner.generate(in_ids, max_new_tokens=50)
#     guideline = guideliner_tokenizer.decode(guideline_example[0], skip_special_tokens=True)
#     return guideline

def generate_response(text_in, guideline):
    blend_in_str = text_in + ' [GUIDELINE] ' + guideline
    blend_in_ids = blen_tokenizer([blend_in_str], max_length=512, return_tensors='pt', truncation=True)
    blend_example = blen_model.generate(**blend_in_ids, max_length=60)
    response = blen_tokenizer.batch_decode(blend_example, skip_special_tokens=True)[0]
    return response

def CoT_to_Preference(cot):
    # (sports,yes)|(football team,yes)
    # "{\"sports\":\"positive\", \"football\":\"positive\"}"
    topics = cot.split('|')
    top_dict = {}
    for top in topics:
        top = top.replace('(', '')
        top = top.replace(')', '')
        the_top, pref = top.split(',')
        # print(pref)
        if pref == 'yes':
            pref = 'positive'
        elif pref == 'no':
            pref = 'negative'
        else:
            pref = 'unknown'
        top_dict[the_top] = pref
    return top_dict

In [53]:
db = nx.Graph()
# db.add_node('test')
while 1:
    # human response
    #user_in = "I watched a really good movie today."
    #user_in = "The weather was great for swimming today."
    user_in = input('User:')
    if user_in == 'exit()':
        break
    print(f'User: {user_in}')
    # topic extraction
    topic_xtract = generate_cot(user_in)
    print(f'==== cot: {topic_xtract.strip()}')

    topic_pref_profile = CoT_to_Preference(topic_xtract.strip())

    prev_tpxt = []
    # memory update
    for tpxt in topic_pref_profile:
        # add node if not in graph, else update it
        if tpxt not in db.nodes:
            db.add_node(tpxt, pref=topic_pref_profile[tpxt])
        else:
            db.nodes[tpxt]['pref'] = topic_pref_profile[tpxt]
            
        # add all links between nodes in chain if not already existing only if more than 1 node
        if len(topic_pref_profile) > 1 and len(prev_tpxt) >= 1:
            for pt in prev_tpxt:
                if (pt.split(',')[0], tpxt.split(',')[0]) not in db.edges:
                    db.add_edge(pt.split(',')[0], tpxt.split(',')[0])
        # prev_tpxt = tpxt
        prev_tpxt.append(tpxt)   
    print(f'==== nodes: {db.nodes}')

    # memory retrieve
    focus_topic = list(topic_pref_profile.keys())[-1]
    print(f'==== focus topic: {focus_topic}')
    print(f'==== db edges to focus topic: {db.edges([focus_topic])}')
    print(f'==== db all edges: {db.edges()}')
    print(f'==== db nodes: {db.nodes[focus_topic]}')
    print(f'==== pref: {topic_pref_profile}')

    xtract_prof = {}
    xtract_prof[focus_topic] = db.nodes[focus_topic]['pref']
    for x_nodes in db.edges([focus_topic]):
        xn = x_nodes[1]
        xtract_prof[xn] = db.nodes[xn]['pref']

    #xtract_prof = 
    print(f'==== xtract_prof: {xtract_prof}')


    # topic recommender
    num_sugg = 3
    prompt = f"Instruction: Generate only {num_sugg} similar topics that could be suggested for new conversation that takes influence from but are not present in the following user profile: {xtract_prof} In the generated answer, generate each of the suggested topics separated by a comma like so: TOPIC1,TOPIC2,TOPIC3,TOPIC4,etc.\nSuggested Topics:"
    topic_recs = generate_recommendation(prompt).split(',')
    print(f'==== topics recs: {topic_recs}')

    # template guideline generation
    if xtract_prof[focus_topic] == 'positive':
        tpref = 'The user likes'
    elif xtract_prof[focus_topic] == 'negative':
        tpref = 'The user dislikes'
    else:
        tpref = 'It is unclear if the user likes or dislikes'

    guideline = f'{tpref} {focus_topic}. Direct the conversation to one of the following 3 topics: {topic_recs}.'

    print(f'==== guideline: {guideline}')

    # response generate

    blend_in_ids = blen_tokenizer([f'{user_in} [GUIDELINE] {guideline}'], max_length=128, return_tensors='pt', truncation=True)
    blend_example = blen_model.generate(**blend_in_ids, max_length=60)
    blend_response = blen_tokenizer.batch_decode(blend_example, skip_special_tokens=True)[0]
    print(f'Bot: {blend_response}')
    time.sleep(2.5)

User: I watched football yesterday.
==== cot: sports,yes)|(football,yes)
==== nodes: ['sports', 'football']
==== focus topic: football
==== db edges to focus topic: [('football', 'sports')]
==== db all edges: [('sports', 'football')]
==== db nodes: {'pref': 'positive'}
==== pref: {'sports': 'positive', 'football': 'positive'}
==== xtract_prof: {'football': 'positive', 'sports': 'positive'}
==== topics recs: ['soccer', 'sportsmanship', 'sportsmanship']
==== guideline: The user likes football. Direct the conversation to one of the following 3 topics: ['soccer', 'sportsmanship', 'sportsmanship'].
Bot:  Football is the most popular sport in the world. Do you have a favorite team?
User: My favorite team is the Kansas City Chiefs.
==== cot: sports,yes)|(favorite team,yes)
==== nodes: ['sports', 'football', 'favorite team']
==== focus topic: favorite team
==== db edges to focus topic: [('favorite team', 'sports')]
==== db all edges: [('sports', 'football'), ('sports', 'favorite team')]
==== d