In [1]:
import os
import requests
import together
from langchain.llms.base import LLM

class TogetherLLM(LLM):
    """Together large language models."""

    model: str = "togethercomputer/llama-2-13b-chat"
    """model endpoint to use"""

    together_api_key: str = os.environ['TOGETHERAI_API_KEY']
    """Together API key"""

    temperature: float = 0.0
    """What sampling temperature to use."""

    max_tokens: int = 512
    """The maximum number of tokens to generate in the completion."""

    @property
    def _llm_type(self) -> str:
        """Return type of LLM."""
        return "together"

    def _call(
        self,
        prompt: str,
        **kwargs,
    ) -> str:
        """Call to Together endpoint."""
        endpoint = 'https://api.together.xyz/inference'
        
        print("model =", self.model)
        print("temperature =", self.temperature)
        print("max_tokens=", self.max_tokens)
                
        for attempt in range(10):
            try:
                res = requests.post(endpoint, json={
                    "prompt": prompt,
                    "model": self.model,
                    "temperature": self.temperature,
                    "max_tokens": self.max_tokens
                }, headers={
                    "Authorization": f"Bearer {self.together_api_key}",
                    "User-Agent": "<YOUR_APP_NAME>"
                })
                output = res.json()['output']['choices'][0]['text']
                return output
            except Exception as e:
                print(e)
                continue
            else:
                break

        raise Exception(f"Request did not succeed with prompt = {prompt}")    


In [2]:
import json
import pandas as pd
from langchain import PromptTemplate
import re

from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix,
    classification_report
)

class Utility:
    B_CHAT, E_CHAT = "<s>", "</s>"
    B_INST, E_INST = "[INST]", "[/INST]"
    B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
    
    @staticmethod
    def get_prompt(system_message, user_message, input_variables, history):
        system_prompt = f"{Utility.B_SYS}{system_message}{Utility.E_SYS}"
        
        prompt_template_items = []
                        
        for index, (query, response) in enumerate(history):
            if index == 0:
                prompt_template_items.append(Utility.B_CHAT)
                prompt_template_items.append(Utility.B_INST)
                prompt_template_items.append(" ")
                prompt_template_items.append(system_prompt)
                prompt_template_items.append(query)
                prompt_template_items.append(" ")
                prompt_template_items.append(Utility.E_INST)
                prompt_template_items.append(" ")
                prompt_template_items.append(response)
                prompt_template_items.append(Utility.E_CHAT)
            else:
                prompt_template_items.append(Utility.B_CHAT)
                prompt_template_items.append(Utility.B_INST)
                prompt_template_items.append(" ")
                prompt_template_items.append(query)
                prompt_template_items.append(" ")
                prompt_template_items.append(Utility.E_INST)
                prompt_template_items.append(" ")
                prompt_template_items.append(response)
                prompt_template_items.append(Utility.E_CHAT)
        
        if not history:
            prompt_template_items.append(Utility.B_CHAT)
            prompt_template_items.append(Utility.B_INST)
            prompt_template_items.append(" ")
            prompt_template_items.append(system_prompt)
            prompt_template_items.append(user_message)
            prompt_template_items.append(" ")
            prompt_template_items.append(Utility.E_INST)
        else:
            prompt_template_items.append(Utility.B_CHAT)
            prompt_template_items.append(Utility.B_INST)
            prompt_template_items.append(" ")
            prompt_template_items.append(user_message)
            prompt_template_items.append(" ")
            prompt_template_items.append(Utility.E_INST)
        
        prompt_template = "".join(prompt_template_items)
        prompt = PromptTemplate(template=prompt_template, input_variables=input_variables)
        return prompt
    
    @staticmethod
    def extract_type_from_response(response, search_term):
        does_not_include_patterns = [
            f"""that tweet as it is not {search_term.lower()}""",
            f"""(the|this|the given) tweet.*is not {search_term.lower()}""",
            f"""i.*classify (it|the tweet) as not {search_term.lower()}""",
            f"""(?:the|this|the given)?\s*(?:tweet|it) does not fall (into|under) the (?:category of )?{search_term.lower()}""",
            f"""it is therefore classified as not {search_term.lower()}""",
            f"""i would classify the (given )?tweet as not {search_term.lower()}"""
        ]

        includes_patterns = [
            f"""(the|this|the given) tweet.*is {search_term.lower()}""",
            f"""i.*classify (it|the tweet) as {search_term.lower()}""",
            f"""(?:the|this|the given)?\s*(?:tweet|it) (falls into|falls under|under) the (?:category of )?{search_term.lower()}""",
            f"""this is a {search_term.lower()} claim as it is based on""",
            f"""the tweet is classified as {search_term.lower()}""",
            f"""this claim can be verified.*{search_term.lower()}""",
            f"""therefore, it can be classified as {search_term.lower()}""",
            f"""this is a direct statement about.*{search_term.lower()}""",
            f"""so it falls under the category of {search_term.lower()}""",
            f"""the tweet is reporting on a {search_term.lower()} fact""",
            f"""this tweet contains a direct statement about.*{search_term.lower()}"""
        ]

        for pattern in does_not_include_patterns:
            if re.search(pattern, response.lower()):
                print("pattern does not include = ", pattern)
                return '@'

        for pattern in includes_patterns:
            if re.search(pattern, response.lower()):
                print("pattern include = ", pattern)
                return '#'

        for c in response[::-1]:
            if c == '@' or c == '#':
                print("Inside last search!")
                return c

        raise Exception(f"No @ or # found in response = {response}")
    
    @staticmethod
    def calculate_metrics(ground_truth, predicted):
#         accuracy = accuracy_score(ground_truth, predicted)
#         precision = precision_score(ground_truth, predicted)
#         recall = recall_score(ground_truth, predicted)
#         f1 = f1_score(ground_truth, predicted)
        
        clsf_report = classification_report(y_true = ground_truth, y_pred = predicted, output_dict=True)
        cf_matrix = confusion_matrix(ground_truth, predicted)
        
        precision = clsf_report['weighted avg']['precision']
        recall = clsf_report['weighted avg']['recall']
        f1 = clsf_report['weighted avg']['f1-score']
        accuracy = accuracy_score(ground_truth, predicted)
        
        return {
            "Accuracy": accuracy * 100,
            "Precision": precision * 100,
            "Recall": recall * 100,
            "F1": f1 * 100,
            "Confusion Matrix": cf_matrix
        }
    
    @staticmethod
    def get_tweet_data(file_name):
        df = pd.read_csv(file_name, index_col=0)
        return df
    
    @staticmethod
    def write_prediction_output(tweet_objects, file_name_to_write):
        if os.path.exists(file_name_to_write):
            os.remove(file_name_to_write)
        
        tweet_objects.to_csv(file_name_to_write)


In [3]:
class Default(dict):
    def __missing__(self, key):
        return f"{{{key}}}"

In [29]:
from langchain import LLMChain

class Category:
    INPUT_VARIABLES=["delimiter", "tweet"]
    
    DELIMITER = "```"
    
    CATEGORY_DESCRIPTIONS = {1: "Scientifically Verifiable"}
    
    def __init__(self, category_type, llm):
        self.category_type = category_type
        
        self.llm = llm
        self.history = []
        
    def does_tweet_fall_into_category(self, tweet, search_term):
        print("Inside does_tweet_fall_into_category function")
        
        user_message = """
         Classify the following tweet:
         Tweet: {delimiter} {tweet} {delimiter}
         """
        
        prompt = Utility.get_prompt(self.system_message, user_message, Category.INPUT_VARIABLES, [])

        chain = LLMChain(llm = self.llm, prompt = prompt)

        input_values = {"tweet": tweet, "delimiter": Category.DELIMITER}
        
        response = chain.run(input_values)
        
        print("Response =", response)
        
        response_number = Utility.extract_type_from_response(response, search_term)
        
        return response_number
        
    def generate_cat_metrics(self, output_file_name, tweet_content_column="polished_text"):
        ground_truths = []
        predicted_outputs = []

        print("<======= Generating metrics for category type =", self.category_type, "=======>")
        print()

        # check if predicted_cat_type column exists. If not, create it.

        category_type_prediction_column_name = f"predicted_{self.category_type}"
        
        start_index = 201
        end_index = 400

        if category_type_prediction_column_name not in self.tweet_objects:
            self.tweet_objects[category_type_prediction_column_name] = -1
        else:
            for index in range(start_index, end_index + 1):
                if index not in self.tweet_objects[column_name]:
                    continue

                if self.tweet_objects[category_type_prediction_column_name][index] != -1:
                    raise Exception(
                        "Some of the indices for the specified range have already been computed."
                    )
        
#         was_previous_classification_correct = None
        
        for index in range(start_index, end_index + 1):
            if index not in self.tweet_objects[tweet_content_column]:
                continue
                
            print("Processing tweet with index# =", index)
            tweet = self.tweet_objects.iloc[index][tweet_content_column]
            print("Tweet content = ", tweet)
            
            for attempt in range(10):
                try:
                    ground_truth = int(self.tweet_objects.iloc[index][f"cat{self.category_type}"])
                except:
                    continue
                else:
                    break
                                
            for attempt in range(10):
                try:
                    search_term = Category.CATEGORY_DESCRIPTIONS[self.category_type]
                    predicted_output = self.does_tweet_fall_into_category(tweet, search_term)
                    predicted_output = 0 if predicted_output == '@' else 1
                    
                    if predicted_output is None:
                        continue
                    else:
                        break
                except:
                    continue
            
            if predicted_output is None:
                print("None for index# = ", index, "and tweet content =", tweet)
                raise Exception("Did not get predicted output for tweet")
                
            if index > 0 and index % 5 == 0:
                print("Metrics till now =", Utility.calculate_metrics(ground_truths, predicted_outputs))

            ground_truths.append(ground_truth)
            predicted_outputs.append(predicted_output)
            
            response_list = ["Tweet is not scientifically verifiable", "Tweet is scientifically verifiable"]
            
#             if len(self.history) == 20:
#                 self.history.pop(0)
            
#             self.history.append((f"Tweet = {tweet}", response_list[predicted_output]))
            
#             if ground_truth == predicted_output:
#                 was_previous_classification_correct = True
#             else:
#                 was_previous_classification_correct = False
            
            self.tweet_objects.loc[index, category_type_prediction_column_name] = predicted_output

            print("Ground truth =", ground_truth, "Predicted output =", predicted_output)
            print("Finished Processing tweet with index# =", index)
            print()

        print("<======= Finished generating metrics for claim existence =======>")
        
        print("Ground truths = ", ground_truths)
        print("Predictions = ", predicted_outputs)
        
        Utility.write_prediction_output(self.tweet_objects, output_file_name)
        
        return Utility.calculate_metrics(ground_truths, predicted_outputs)


In [30]:
class Category1(Category):
    CATEGORY_TYPE = 1
    CATEGORY_DESCRIPTION = ""
    
    def __init__(self, llm, input_file_name):
        self.system_message, self.tweet_objects = Category1.generate_system_prompt_for_category1(input_file_name)
        
        super().__init__(Category1.CATEGORY_TYPE, llm)
        
    @staticmethod
    def generate_system_prompt_for_category1(input_file_name):
        system_message = """
        Imagine you're a COVID-19 tweets classifier using the Clue And Reasoning Prompting (CARP) approach to discern scientifically verifiable claims from tweets.
        
        Tweets will be encased within {delimiter} characters. 
        
        Apply the following steps:

        1. CLUE IDENTIFICATION: Determine the presence of direct statements, reports, or factual claims related to COVID-19 using keywords and context within the tweet.

        2. REASONING PROCESS: Analyze the clues to ascertain if they align with the scientific facts, data, or reputable health authority guidelines (Limit your reasoning to 130 words).

        3. VERIFICATION DETERMINATION: Decide if the tweet's claim is scientifically verifiable, based on the evidence and reasoning.
        
        Example 1:
        TWEET: "New research indicates that COVID-19 can remain on surfaces for days."
        CLUES: "New research," "COVID-19," "remain on surfaces," "days."
        REASONING: The claim is presented as a finding from new research, which is a direct statement about the virus's transmission and is likely based on scientific studies.
        VERIFICATION: # (Scientifically Verifiable)

        Example 2:
        TWEET: "Most COVID-19 infections are mild and don't require hospitalization."
        CLUES: "Most," "COVID-19 infections," "mild," "don't require," "hospitalization."
        REASONING: The statement makes a direct claim about the nature of COVID-19 infections, suggesting a general trend in symptoms and treatment requirements. This claim can be verified against statistical data from health authorities regarding the proportion of cases requiring hospitalization.
        VERIFICATION: # (Scientifically Verifiable)

        Example 3:
        TWEET: "Our city has reported zero new cases of COVID-19 today."
        CLUES: "Our city," "reported," "zero new cases," "COVID-19," "today."
        REASONING: This is a direct report concerning COVID-19 cases, which can be verified with health department data and is a factual claim about the virus.
        VERIFICATION: # (Scientifically Verifiable)

        Example 4:
        TWEET: "How long does the virus stay airborne?"
        CLUES: "How long," "virus," "stay airborne."
        REASONING: This is a question rather than a claim, and it does not provide a direct statement or fact that can be scientifically verified.
        VERIFICATION: @ (Not Scientifically Verifiable)
        
        Example 5:
        TWEET: "The government's response to COVID-19 will surely boost the economy."
        CLUES: "government's response," "COVID-19," "boost," "economy."
        REASONING: The input is making a speculative assertion about the impact of the government's COVID-19 response on the economy. While it relates to COVID-19, it is framed as a prediction rather than a fact and does not directly pertain to the scientific aspects of the virus itself. The claim is about economic impact, which is outside the scope of scientific verification as per the guidelines.
        VERIFICATION: @ (Not Scientifically Verifiable)

        Indicate your conclusion with a single symbol: use @ for non-verifiable claims or # for scientifically verifiable claims.
        """
        
        tweet_objects = Utility.get_tweet_data(input_file_name)
        
        return system_message, tweet_objects

In [31]:
model_parameters = "70b"

llm = TogetherLLM(
    model= f"togethercomputer/llama-2-{model_parameters}-chat",
    temperature=0.2,
    max_tokens=1500
)

input_file_name = "tweets - original.csv"
output_file_name = "updated_tweets_cat_1-{model_parameters}.csv"

In [32]:
cat1 = Category1(llm, input_file_name)