In [None]:
import json
import re
import pandas as pd
pd.set_option('display.max_rows', 150)
pd.set_option('display.max_colwidth', None)

from transformers import AutoTokenizer
from tqdm import tqdm
from nltk import sent_tokenize

In [None]:
llama_tokenizer = AutoTokenizer.from_pretrained(
        "meta-llama/Llama-2-7b-hf",
        model_max_length=2048,
        padding_side="right",
        use_fast=False,
)

def llama_seq_len(text):
    return len(llama_tokenizer(text)['input_ids']) + 6

In [None]:
# segment_topic_pairs = []
# for call in new_calls:
#     for segment in call['segments']:
#         segment_topic_pairs.append((segment['text'], segment['topics']))

# segment_topic_df = pd.DataFrame(segment_topic_pairs, columns=['text', 'topics'])

In [None]:
with open('/home/ubuntu/speech_to_text/data/enriched_domains_with_empty_values_filled.json', "r") as f:
    enriched_domains = json.load(f)

domain_dict = {record['domain']: record for record in enriched_domains}

In [None]:
with open("/home/ubuntu/speech_to_text/data/matched_categories_simplified.json", "r") as f:
    category_dict = json.load(f)

len(category_dict)

In [None]:
category_line_map = {
    "Lifecycle Marketing": "They want to use Hightouch to make their lifecycle marketing more effective.",
    "Performance Marketing": "They want to use Hightouch to make their performance marketing more effective.",
    "B2B SaaS (PLG)": "They want to use Hightouch to give more insights and lead prioritization to their sales team.",
    "CDP Compete": "They are evaluating Hightouch versus traditional CDPs like Segment, MParticle, and Simon Data.",
    "Census Compete": "They are evaluating Hightouch versus other Reverse ETL providers like Census.",
    "LiveRamp Compete": "They are evaluating Hightouch versus other data enrichment and ad performance boosting solutions like LiveRamp."
}    

In [None]:
def clean_text(text):
    text_hightouch = re.sub(r"hi(gh)?(\s|-)?(touch|tech|treasure|tax|test|tips|tension|towers|tex)", "Hightouch", text, flags=re.IGNORECASE)
    return text_hightouch.strip()


def get_info_record(company_info):
    info_record = {}
    for elem in company_info:
        if elem['name'] == "Name" and elem['name'] in info_record:
            curr_name = info_record[elem['name']]
            new_name = elem['value']
            if new_name and curr_name and len(new_name) < len(curr_name):
                info_record[elem['name']] = new_name
        else:
            info_record[elem['name']] = elem['value']

    if 'Website' in info_record and info_record['Website'] in domain_dict:
        source_record = domain_dict[info_record['Website']]
        if 'description' in source_record:
            info_record['description'] = source_record['description']
        elif 'organization' in source_record and 'short_description' in source_record['organization']:
            info_record['description'] = source_record['organization']['short_description']

        if 'Industry' not in info_record:
            if 'category' in source_record:
                info_record['Industry'] = source_record['category']['industry']
            elif 'organization' in source_record:
                info_record['Industry'] = source_record['organization']['industry']
    
    if 'Industry' not in info_record:
        info_record['Industry'] = None
    
    if 'description' not in info_record:
        info_record['description'] = None

    if info_record['description'] and info_record['description'].strip().endswith('...'):
        description_sents = sent_tokenize(info_record['description'])
        new_desc = " ".join(description_sents[:-1])
        info_record['description'] = new_desc

    if info_record['description']:
        info_record['description'] = info_record['description'].split('\n\n')[0].strip()

    filtered_record = {
        "name": info_record['Name'],
        "industry": info_record['Industry'],
        "description": info_record['description'],
    }

    return filtered_record


def clean_description(description, max_desc_len=80, margin=10):
    if not description:
        return description
    description = description.strip()
    if description.endswith('...'):
        description_sents = sent_tokenize(description)
        description = " ".join(description_sents[:-1])

    description = description.split('\n\n')[0].strip()
    if llama_seq_len(description) > max_desc_len:
        description_sents = sent_tokenize(description)
        total_len = 0
        new_sents = []
        for sent in description_sents:
            sent_len = llama_seq_len(sent)
            if total_len + sent_len <= max_desc_len + margin:
                new_sents.append(sent)
                total_len += sent_len
            else:
                break
        description = " ".join(new_sents)

    if len(description) == 0:
        description = None
    return description



def get_system_message(company_info, categories):
    category_lines = "\n".join(['- ' + category_line_map[category] for category in categories if category in category_line_map])
    industry_starts_with_vowel = company_info['industry'] and company_info['industry'][0].lower() in ['a', 'e', 'i', 'o', 'u']
    system_message = f"A sales call between a sales representative at Hightouch, a data integration company, and a prospect at {company_info['name']}, a{'n' if industry_starts_with_vowel else ''} {company_info['industry'].lower() + ' ' if company_info['industry'] else ''}company. Hightouch helps companies get data into business tools. The prospect's goal is to understand how Hightouch compares to their current solution for syncing data into business tools and how Hightouch can deliver value for the prospect's business. The prospect answers any questions that the sales rep asks them. The prospect also asks questions to the sales rep in order to learn more about Hightouch and how it can help the prospect's business."
    cleaned_description = clean_description(company_info['description'])
    if cleaned_description:
        system_message += f"\n\nHere is a description of {company_info['name']}:\n{cleaned_description}"
    if len(category_lines) > 0:
        system_message += f"\n\nHere is some additional information about how {company_info['name']} is considering to use Hightouch:\n{category_lines}"
    return system_message.strip()


def get_example_from_call(call):
    call_id = call['metadata']['id']
    conversations = []
    bad_topics = set(['Small Talk', 'Wrap-Up', 'Next Steps - Scheduliing'])
    for segment in call['segments']:
        segment_topics = set(segment['topics'])
        if len(segment_topics) == 0 or (not segment_topics.issubset(bad_topics)):
            speaker = "sales rep" if segment['speakerAffiliation'] == "Internal" else "prospect"
            text = clean_text(segment['text'])
            if len(conversations) > 0 and conversations[-1]['from'] == speaker:
                conversations[-1]['value'] += " " + text
            else:
                conversations.append({
                    "from": speaker,
                    "value": text
                })
    info_record = get_info_record(call['company_info'])
    categories = category_dict[call_id] if call_id in category_dict else []
    system_message = get_system_message(info_record, categories)
    output = {
        "id": call_id,
        "title": call['metadata']['title'],
        "system_message": system_message,
        "conversations": conversations
    }
    return output #, (llama_seq_len(info_record['description']) if info_record['description'] else 0)

In [None]:
with open("/home/ubuntu/speech_to_text/data/intro_calls_retranscribed_v2.json", "r") as f:
    new_calls = json.load(f)

len(new_calls)

In [None]:
dataset = [get_example_from_call(call) for call in new_calls]

In [None]:
print(dataset[3]['system_message'])

In [None]:
system_messages = [(elem['system_message'], llama_seq_len(elem['system_message'])) for elem in dataset]


In [None]:
pd.DataFrame(system_messages)[1].plot.hist()

In [None]:
pd.DataFrame(system_messages).sort_values(by=1, ascending=False)[:20]

In [None]:
len(dataset)

In [None]:
with open("/home/ubuntu/FastChat/data/prospect_lm/prospect_lm_v1_intro_calls_full.json", "w+") as f:
    json.dump(dataset, f)