In [24]:
import os
import pandas as pd
import re
import backoff
import time
import json
from openai.error import RateLimitError
from importlib.metadata import version
from typing import List, Callable, Dict
from langchain.llms import OpenAI
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.chains.summarize import load_summarize_chain
from langchain.chains.question_answering import load_qa_chain
from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
from langchain.serpapi import SerpAPIWrapper
from transformers import GPT2TokenizerFast
from langchain.text_splitter import CharacterTextSplitter
from langchain.docstore.document import Document
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score

In [2]:
os.environ["OPENAI_API_KEY"] = OAI_KEY
os.environ["SERPAPI_API_KEY"] = SERP_KEY
version('langchain')

'0.0.54'

In [3]:
d = pd.read_csv(FP + "legislation_relevance_dataset_for_llm_evaluation_unbalanced.csv")

In [4]:
print(d.shape[0])
d = d.drop_duplicates(subset=['company_name', 'summary_text'], keep='first')
print(d.shape[0])
d = d.reset_index(drop=True)

500
485


In [5]:
d.relevance.value_counts()

0    344
1    141
Name: relevance, dtype: int64

In [34]:
class Lobbyist:
    def __init__(self, data: pd.DataFrame):
        self.d = data
        self.model_name = "text-davinci-003"
        self.model_name_older = "text-davinci-002"
        self.temp = 0
        
        self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
        self.text_splitter = CharacterTextSplitter.from_tiktoken_encoder(chunk_size=4000, 
                                                                        chunk_overlap=0, 
                                                                        separator = " ")
        self.llm = OpenAI(temperature=self.temp, model_name=self.model_name, max_tokens=-1)
        self.llm_older = OpenAI(temperature=self.temp, model_name=self.model_name_older, max_tokens=-1)
        
        ###### SUMMARY CHAIN ######
        self.sum_template = """Slightly summarize the following Congressional bill summary:
                                {context}
                                SUMMARY:"""
        self.sum_prompt = PromptTemplate(template=self.sum_template, input_variables=["context"])
        self.sum_combine_template = """Given the following summaries of parts of this bill, create a final summary of the bill. 
                                        =========
                                        SUMMARIES FROM PARTS OF THE OVERALL RESPONSE: {summaries}
                                        =========
                                        FINAL SUMMARY:"""
        self.sum_combine_prompt = PromptTemplate(template=self.sum_combine_template, 
                                                 input_variables=["summaries"])
        self.sum_chain = load_qa_chain(llm=self.llm, 
                                        chain_type="map_reduce", 
                                        question_prompt = self.sum_prompt,
                                        combine_prompt = self.sum_combine_prompt,
                                        collapse_prompt = self.sum_combine_prompt)
        
        ###### ZERO SHOT QA CHAIN ######
        self.zero_shot_template = """You are a lobbyist analyzing Congressional bills for their potential impacts on companies. 
                            Given the title and summary of the bill, plus information on the company from its 10K SEC filing, it is your job to determine if a bill is at least somewhat relevant to a company (in terms of whether it could impact the company if it was later enacted). 
                            Official title of bill: {official_title}
                            Official summary of bill: {summary_text}
                            Official subjects of bill: {subjects}
                            Company name: {company_name}
                            Company business description: {business_description}
                            Is this bill potentially relevant to this company? 
                            Answer in this format:
                            ANSWER: 'YES' or 'NO' (use all caps). EXPLANATION: the step-by-step reasoning you undertook to formulate a response. CONFIDENCE: integer between 0 and 100 for your estimate of confidence in your answer (1 is low confidence and 99 is high)
                            """
        self.zero_shot_temp  = PromptTemplate(template=self.zero_shot_template, input_variables = ["official_title", "summary_text", "subjects",
                                                                     "company_name", "business_description"])
        self.zero_shot_chain = LLMChain(llm=self.llm, prompt=self.zero_shot_temp)
        
        self.zero_shot_chain_older = LLMChain(llm=self.llm_older, prompt=self.zero_shot_temp)
        
        ###### AGENT QA CHAIN ######
        self.self_ask_agent_search = SerpAPIWrapper()
        
        self.self_ask_agent_tools = [
            Tool(
                name="Search",
                func=self.self_ask_agent_search.run,
                description="useful for learning more about a company or bill"
            )
        ]
        
        self.self_ask_agent_prefix = """You are a lobbyist analyzing Congressional bills for their impacts on companies. 
            Given the title and summary of the bill, plus information on the company from its 10K SEC filing, it is your job to determine if a bill is at least somewhat relevant to a company in terms of whether it could impact the company if it was enacted. 
            You have access to the following tools:"""
        self.self_ask_agent_suffix = """
            Official title of bill: {official_title}
            Official summary of bill: {summary_text}
            Official subjects of bill: {subjects}
            Company name: {company_name}
            Company business description: {business_description}
            Question: {input} 
            {agent_scratchpad}"""
        self.self_ask_prompt = ZeroShotAgent.create_prompt(
                    tools=self.self_ask_agent_tools,
                    prefix=self.self_ask_agent_prefix,
                    suffix=self.self_ask_agent_suffix,
                    input_variables=["official_title", "summary_text", "subjects",
                                     "company_name", "business_description" ,
                                     "input", "agent_scratchpad"]
                )
        self.self_ask_llm_chain = LLMChain(llm=self.llm, prompt=self.self_ask_prompt)
        self.self_ask_agent = ZeroShotAgent(llm_chain=self.self_ask_llm_chain, 
                                            tools=self.self_ask_agent_tools)
        self.self_ask_agent_executor = AgentExecutor.from_agent_and_tools(agent=self.self_ask_agent, 
                                                                          tools=self.self_ask_agent_tools, 
                                                                          max_iterations=3, 
                                                                          early_stopping_method="generate",
                                                                          verbose=True)
        
        ###### ZERO SHOT LETTER GENERATION ######
        self.zero_shot_letter_template = """You are a lobbyist analyzing Congressional bills for their impacts on companies. 
                            You have identified the following bill to be relevant to the company you work for.
                            Given the title and summary of the bill, plus information on your company from its 10K SEC filing, it is your job to now write a persuasive letter to sponsor of the bill to convince them to add provisions to the bill that would make it better for the bottom-line of your company.
                            Sign the letter as the general counsel of your company (put the actual name of your company).
                            Official title of bill: {official_title}
                            Official summary of bill: {summary_text}
                            Official subjects of bill: {subjects}
                            Company name: {company_name}
                            Company business description: {business_description}
                            YOUR LETTER: 
                            """
        self.zero_shot_letter_temp  = PromptTemplate(template=self.zero_shot_letter_template, 
                                              input_variables = ["official_title", "summary_text", "subjects",
                                                                "company_name", "business_description"])
        self.zero_shot_letter = LLMChain(llm=self.llm, prompt=self.zero_shot_letter_temp)
    
    def compute_length(self, text: str) -> int:
        tokenized_text = self.tokenizer.tokenize(text)
        nt = len(tokenized_text)
        return nt

    def compute_length_shorten(self, text: str) -> str:
        nt = self.compute_length(text)
        while nt > 4000:
            #print('LENGTH BEFORE SUMMARIZATION:')
            #print(nt)
            texts = self.text_splitter.split_text(text)
            docs = [Document(page_content=t) for t in texts]
            text = self.sum_chain.run(docs)
            nt = self.compute_length(text)
            #print('LENGTH AFTER:')
            #print(nt)
            #print(' ')
        return text

    def add_summary_col(self,):
        results = list()
        for i in range(self.d.shape[0]):
            s = self.compute_length_shorten(self.d.summary_text[i])
            results.append({"summary_text_summarized": s})
        self.d = pd.concat([self.d, pd.DataFrame(results)], axis=1)
    
    def convert_output(self, text: str) -> Dict:
        _out = dict()
        if "YES" in text:
            _out['prediction'] = 1
        elif "NO" in text:
            _out['prediction'] = 0
        elif "Yes" in text:
            _out['prediction'] = 1
        elif "No" in text:
            _out['prediction'] = 0
        elif "yes" in text:
            _out['prediction'] = 1
        elif "no" in text:
            _out['prediction'] = 0
        else:
            _out['prediction'] = -999
        
        match = re.search('CONFIDENCE: (\d+)', text)
        if match:
            _out['prob'] = int(match.group(1))
        else:
            _out['prob'] = -999
        
        return _out

    def dummy_predictor(self, i: int) -> str:
        return "NO and 100"
    
    @backoff.on_exception(backoff.expo, (RateLimitError))
    def zero_shot_predictor(self, i: int) -> str:
        return self.zero_shot_chain.predict(official_title = self.d.official_title[i], 
                                            summary_text = self.d.summary_text_summarized[i], 
                                            subjects = self.d.subjects[i], 
                                            company_name = self.d.company_name[i], 
                                            business_description = self.d.business_description[i])
    
    @backoff.on_exception(backoff.expo, (RateLimitError))
    def zero_shot_predictor_older(self, i: int) -> str:
        return self.zero_shot_chain_older.predict(official_title = self.d.official_title[i], 
                                            summary_text = self.d.summary_text_summarized[i], 
                                            subjects = self.d.subjects[i], 
                                            company_name = self.d.company_name[i], 
                                            business_description = self.d.business_description[i])
    
    @backoff.on_exception(backoff.expo, (RateLimitError))
    def agent_predictor(self, i: int) -> str:
        return self.self_ask_agent_executor.run(official_title= self.d.official_title[i], 
                                                summary_text=self.d.summary_text_summarized[i], 
                                                subjects=self.d.subjects[i], 
                                                company_name=self.d.company_name[i],
                                                business_description=self.d.business_description[i],
                                               input="Is this bill potentially relevant to this company? ('YES' or 'NO'; use all caps)")
    
    @backoff.on_exception(backoff.expo, (RateLimitError))
    def zero_shot_letter_generator(self, i: int) -> str:
        return self.zero_shot_letter.predict(official_title = self.d.official_title[i], 
                                            summary_text = self.d.summary_text_summarized[i], 
                                            subjects = self.d.subjects[i], 
                                            company_name = self.d.company_name[i], 
                                            business_description = self.d.business_description[i])
    
    def evaluate(self, func: Callable, 
                 n = int, m = int) -> List:
        results = list()
        for i in range(n, m):
            try:
                out = func(i)
                pred = self.convert_output(out)
                pp = pred['prediction']
                gt = self.d.relevance[i]

                if pp == -999:
                    correct = False
                    print("-999")
                elif pp == 2:
                    prediction = 1
                    correct = gt == pp
                else:
                    correct = gt == pp

                results.append({"text_response": out,
                                "numeric_prediction": pp,
                                "probability": pred['prob'],
                                "ground_truth": int(gt),
                                "correct": bool(correct)})
            except Exception as e:
                print(f"At data {i}, error: {e}")
                time.sleep(10)
                try:
                    out = func(i)
                    pred = self.convert_output(out)
                    pp = pred['prediction']
                    gt = self.d.relevance[i]

                    if pp == -999:
                        correct = False
                        print("-999")
                    elif pp == 2:
                        prediction = 1
                        correct = gt == pp
                    else:
                        correct = gt == pp

                    results.append({"text_response": out,
                                    "numeric_prediction": pp,
                                    "probability": pred['prob'],
                                    "ground_truth": int(gt),
                                    "correct": bool(correct)})
                except Exception as e2:
                    print(f"At data {i}, error 2: {e2}")
                    results.append({"text_response": "NA",
                                    "numeric_prediction": "NA",
                                    "probability": "NA",
                                    "ground_truth": "NA",
                                    "correct": "NA"})
            finally:
                with open(FP + "gpt_data.json", "wt") as output_file:
                    json.dump(results, output_file)
        
        return results

In [35]:
l = Lobbyist(d)
l.add_summary_col()

In [8]:
dum_result = l.evaluate(l.dummy_predictor, 0, l.d.shape[0])

In [9]:
dum = pd.DataFrame(dum_result)
print(dum.correct.mean())

0.709278350515464


In [10]:
result_chain_older = l.evaluate(l.zero_shot_predictor_older, 0, l.d.shape[0])

In [11]:
result_chain_older = pd.DataFrame(result_chain_older)
result_chain_older.numeric_prediction.value_counts()

1    349
0    136
Name: numeric_prediction, dtype: int64

In [12]:
result_chain_older[['probability', 'numeric_prediction']].value_counts()

probability  numeric_prediction
 80          1                     308
 1           0                      58
-999         0                      37
 99          0                      25
-999         1                      23
 95          1                      11
 100         0                       5
 60          1                       4
 80          0                       4
 95          0                       3
 60          0                       2
 70          1                       2
 50          0                       1
 90          0                       1
             1                       1
dtype: int64

In [48]:
print(result_chain_older.text_response[6])



ANSWER: YES
EXPLANATION: The bill is relevant to Alkermes Plc because it could impact the company's products ARISTADA and VIVITROL.
CONFIDENCE: 80


In [14]:
print(result_chain_older.correct.mean())
print(result_chain_older[result_chain_older.ground_truth == 1].correct.mean())
print(result_chain_older[result_chain_older.ground_truth == 0].correct.mean())
result_chain_older_metadata = pd.concat([l.d, result_chain_older], axis=1)
print(precision_score(result_chain_older_metadata.ground_truth, result_chain_older_metadata.numeric_prediction))
print(recall_score(result_chain_older_metadata.ground_truth, result_chain_older_metadata.numeric_prediction))

0.5175257731958763
0.9078014184397163
0.35755813953488375
0.3667621776504298
0.9078014184397163


In [61]:
print(result_chain_older.shape[0])
print(result_chain_older[result_chain_older.probability > 90].correct.mean())
print(result_chain_older[result_chain_older.probability > 90].shape[0])
print(result_chain_older[result_chain_older.probability <= 90].correct.mean())

485
0.75
44
0.4943310657596372


In [36]:
result_chain = l.evaluate(l.zero_shot_predictor, 0, l.d.shape[0])

Error: The server is overloaded or not ready yet.


In [37]:
result_chain = pd.DataFrame(result_chain)
result_chain.numeric_prediction.value_counts()

0    364
1    121
Name: numeric_prediction, dtype: int64

In [38]:
result_chain[['probability', 'numeric_prediction']].value_counts()

probability  numeric_prediction
 99          0                     345
 95          1                      66
 90          1                      52
 95          0                      14
 100         0                       5
 99          1                       2
-999         1                       1
dtype: int64

In [47]:
print(result_chain.text_response[6])


ANSWER: YES. EXPLANATION: Alkermes Plc develops and commercializes products designed to address the unmet needs of patients suffering from addiction and schizophrenia, which are both addressed in the bill. Additionally, the bill requires the Centers for Medicare & Medicaid Services (CMS) to negotiate with pharmaceutical companies regarding prices for drugs covered under the Medicare prescription drug benefit, which could potentially impact Alkermes Plc. CONFIDENCE: 95


In [40]:
print(result_chain.correct.mean())
print(result_chain[result_chain.ground_truth == 1].correct.mean())
print(result_chain[result_chain.ground_truth == 0].correct.mean())
result_chain_metadata = pd.concat([l.d, result_chain], axis=1)
print(precision_score(result_chain_metadata.ground_truth, result_chain_metadata.numeric_prediction))
print(recall_score(result_chain_metadata.ground_truth, result_chain_metadata.numeric_prediction))

0.7525773195876289
0.5035460992907801
0.8546511627906976
0.5867768595041323
0.5035460992907801


In [60]:
print(result_chain.shape[0])
print(result_chain[result_chain.probability > 90].correct.mean())
print(result_chain[result_chain.probability > 90].shape[0])
print(result_chain[result_chain.probability <= 90].correct.mean())

485
0.7870370370370371
432
0.4716981132075472


In [41]:
letters = list()
for i in result_chain[result_chain.numeric_prediction == 1].index:
    letters.append(l.zero_shot_letter_generator(i))

KeyboardInterrupt: 

In [43]:
print(letters[0])


Dear [Sponsor of the Bill],

I am writing on behalf of Alkermes Plc, a fully integrated, global biopharmaceutical company that applies its scientific expertise and proprietary technologies to research, develop and commercialize pharmaceutical products that are designed to address unmet medical needs of patients in major therapeutic areas.

We are writing to express our support for the Medicare Negotiation and Competitive Licensing Act of 2019. We believe that this bill is an important step in ensuring that Medicare beneficiaries have access to the medications they need at a price they can afford.

We are particularly supportive of the provisions in the bill that would require the Centers for Medicare & Medicaid Services (CMS) to negotiate with pharmaceutical companies regarding prices for drugs covered under the Medicare prescription drug benefit. We believe that this will help to ensure that the prices of these drugs are fair and reasonable.

We are also supportive of the provisions 

In [50]:
result_agent = l.evaluate(l.agent_predictor, 0, l.d.shape[0])



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Thought: I need to understand the bill and the company
Action: Search
Action Input: Hualapai Tribe Water Rights Settlement Act of 2019, BJ's Restaurants, Inc.[0m
Observation: [36;1m[1;3mHualapai Tribe Water Rights Settlement Act of 2019. This bill modifies and ratifies the Hualapai Tribe water rights settlement agreement negotiated between ...[0m
Thought:[32;1m[1;3m I need to understand the potential impacts of the bill on the company
Action: Search
Action Input: Hualapai Tribe Water Rights Settlement Act of 2019, BJ's Restaurants, Inc.[0m
Observation: [36;1m[1;3mHualapai Tribe Water Rights Settlement Act of 2019. This bill modifies and ratifies the Hualapai Tribe water rights settlement agreement negotiated between ...[0m
Thought:[32;1m[1;3m I need to understand the potential impacts of the bill on the company's operations
Action: Search
Action Input: Hualapai Tribe Water Rights Settlement Act of 2019, BJ's Rest

In [51]:
result_agent = pd.DataFrame(result_agent)
result_agent.numeric_prediction.value_counts()

1     298
0     185
NA      2
Name: numeric_prediction, dtype: int64

In [53]:
print(result_agent[result_agent.numeric_prediction != "NA"].correct.mean())

0.5300207039337475
