In [1]:
import os
import requests
import together
from langchain.llms.base import LLM

class TogetherLLM(LLM):
    """Together large language models."""

    model: str = "togethercomputer/llama-2-13b-chat"
    """model endpoint to use"""

    together_api_key: str = os.environ['TOGETHERAI_API_KEY']
    """Together API key"""

    temperature: float = 0.0
    """What sampling temperature to use."""

    max_tokens: int = 512
    """The maximum number of tokens to generate in the completion."""

    @property
    def _llm_type(self) -> str:
        """Return type of LLM."""
        return "together"

    def _call(
        self,
        prompt: str,
        **kwargs,
    ) -> str:
        """Call to Together endpoint."""
        endpoint = 'https://api.together.xyz/inference'
        
        print("model =", self.model)
        print("temperature =", self.temperature)
        print("max_tokens=", self.max_tokens)
                
        for attempt in range(10):
            try:
                res = requests.post(endpoint, json={
                    "prompt": prompt,
                    "model": self.model,
                    "temperature": self.temperature,
                    "max_tokens": self.max_tokens
                }, headers={
                    "Authorization": f"Bearer {self.together_api_key}",
                    "User-Agent": "<YOUR_APP_NAME>"
                })
                output = res.json()['output']['choices'][0]['text']
                return output
            except Exception as e:
                print(e)
                continue
            else:
                break

        raise Exception(f"Request did not succeed with prompt = {prompt}")    


In [2]:
import json
import pandas as pd
from langchain import PromptTemplate

from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix,
    classification_report
)

class Utility:
    B_CHAT, E_CHAT = "<s>", "</s>"
    B_INST, E_INST = "[INST]", "[/INST]"
    B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
    
    @staticmethod
    def get_prompt(system_message, user_message, input_variables, history):
        system_prompt = f"{Utility.B_SYS}{system_message}{Utility.E_SYS}"
        
        prompt_template_items = []
                        
        for index, (query, response) in enumerate(history):
            if index == 0:
                prompt_template_items.append(Utility.B_CHAT)
                prompt_template_items.append(Utility.B_INST)
                prompt_template_items.append(" ")
                prompt_template_items.append(system_prompt)
                prompt_template_items.append(query)
                prompt_template_items.append(" ")
                prompt_template_items.append(Utility.E_INST)
                prompt_template_items.append(" ")
                prompt_template_items.append(response)
                prompt_template_items.append(Utility.E_CHAT)
            else:
                prompt_template_items.append(Utility.B_CHAT)
                prompt_template_items.append(Utility.B_INST)
                prompt_template_items.append(" ")
                prompt_template_items.append(query)
                prompt_template_items.append(" ")
                prompt_template_items.append(Utility.E_INST)
                prompt_template_items.append(" ")
                prompt_template_items.append(response)
                prompt_template_items.append(Utility.E_CHAT)
        
        if not history:
            prompt_template_items.append(Utility.B_CHAT)
            prompt_template_items.append(Utility.B_INST)
            prompt_template_items.append(" ")
            prompt_template_items.append(system_prompt)
            prompt_template_items.append(user_message)
            prompt_template_items.append(" ")
            prompt_template_items.append(Utility.E_INST)
        else:
            prompt_template_items.append(Utility.B_CHAT)
            prompt_template_items.append(Utility.B_INST)
            prompt_template_items.append(" ")
            prompt_template_items.append(user_message)
            prompt_template_items.append(" ")
            prompt_template_items.append(Utility.E_INST)
        
        prompt_template = "".join(prompt_template_items)
        prompt = PromptTemplate(template=prompt_template, input_variables=input_variables)
        return prompt
    
    @staticmethod
    def extract_number(s):        
#         start_index = s.find('{')
#         end_index = s.find('}') + 1
        
#         substring = s[start_index : end_index]
#         data = json.loads(substring)
        
#         key = list(data.keys())[0]
#         value = data[key]
        
#         print("value = ", value)
                
#         if value == 0 or value == 1:
#             return value
        
        for c in s[::-1]:
            if c == '0' or c == '1':
                return int(c)

        raise Exception(f"No 0 or 1 found in response = {s}")
    
    @staticmethod
    def calculate_metrics(ground_truth, predicted):
#         accuracy = accuracy_score(ground_truth, predicted)
#         precision = precision_score(ground_truth, predicted)
#         recall = recall_score(ground_truth, predicted)
#         f1 = f1_score(ground_truth, predicted)
        
        clsf_report = classification_report(y_true = ground_truth, y_pred = predicted, output_dict=True)
        cf_matrix = confusion_matrix(ground_truth, predicted)
        
        precision = clsf_report['weighted avg']['precision']
        recall = clsf_report['weighted avg']['recall']
        f1 = clsf_report['weighted avg']['f1-score']
        accuracy = accuracy_score(ground_truth, predicted)
        
        return {
            "Accuracy": accuracy * 100,
            "Precision": precision * 100,
            "Recall": recall * 100,
            "F1": f1 * 100,
            "Confusion Matrix": cf_matrix
        }
    
    @staticmethod
    def get_tweet_data(file_name):
        df = pd.read_csv(file_name, index_col=0)
        return df
    
    @staticmethod
    def write_prediction_output(tweet_objects, file_name_to_write):
        if os.path.exists(file_name_to_write):
            os.remove(file_name_to_write)
        
        tweet_objects.to_csv(file_name_to_write)


In [3]:
class Default(dict):
    def __missing__(self, key):
        return f"{{{key}}}"

In [4]:
from langchain import LLMChain

class Category:
    INPUT_VARIABLES=["delimiter", "tweet"]
    
    DELIMITER = "```"
    
    def __init__(self, category_type, llm):
        self.category_type = category_type
        
        self.llm = llm
        self.history = []
        
    def does_tweet_fall_into_category(self, tweet):
        print("Inside does_tweet_fall_into_category function")
                
        
        user_message = """
         Classify the following tweet:
         Tweet: {delimiter} {tweet} {delimiter}
         """
        
        prompt = Utility.get_prompt(self.system_message, user_message, Category.INPUT_VARIABLES, [])

        chain = LLMChain(llm = self.llm, prompt = prompt)

        input_values = {"tweet": tweet, "delimiter": Category.DELIMITER}
        
        response = chain.run(input_values)
        
        print("Response =", response)
        
        response_number = Utility.extract_number(response)
        
        return response_number
        
    def generate_cat_metrics(self, output_file_name, tweet_content_column="polished_text"):
        ground_truths = []
        predicted_outputs = []

        print("<======= Generating metrics for category type =", self.category_type, "=======>")
        print()

        # check if predicted_cat_type column exists. If not, create it.

        category_type_prediction_column_name = f"predicted_{self.category_type}"
        
        start_index = 0
        end_index = 400

        if category_type_prediction_column_name not in self.tweet_objects:
            self.tweet_objects[category_type_prediction_column_name] = -1
        else:
            for index in range(start_index, end_index + 1):
                if index not in self.tweet_objects[column_name]:
                    continue

                if self.tweet_objects[category_type_prediction_column_name][index] != -1:
                    raise Exception(
                        "Some of the indices for the specified range have already been computed."
                    )
        
#         was_previous_classification_correct = None
        
        for index in range(start_index, end_index + 1):
            if index not in self.tweet_objects[tweet_content_column]:
                continue
                
            print("Processing tweet with index# =", index)
            tweet = self.tweet_objects.iloc[index][tweet_content_column]
            print("Tweet content = ", tweet)
            
            for attempt in range(10):
                try:
                    ground_truth = int(self.tweet_objects.iloc[index][f"cat{self.category_type}"])
                except:
                    continue
                else:
                    break
                                
            for attempt in range(10):
                try:
                    predicted_output = self.does_tweet_fall_into_category(tweet)
                    
                    if predicted_output is None:
                        continue
                    else:
                        break
                except:
                    continue
            
            if predicted_output is None:
                print("None for index# = ", index, "and tweet content =", tweet)
                raise Exception("Did not get predicted output for tweet")
                
            if index > 0 and index % 5 == 0:
                print("Metrics till now =", Utility.calculate_metrics(ground_truths, predicted_outputs))

            ground_truths.append(ground_truth)
            predicted_outputs.append(predicted_output)
            
            response_list = ["Tweet is not scientifically verifiable", "Tweet is scientifically verifiable"]
            
#             if len(self.history) == 20:
#                 self.history.pop(0)
            
#             self.history.append((f"Tweet = {tweet}", response_list[predicted_output]))
            
#             if ground_truth == predicted_output:
#                 was_previous_classification_correct = True
#             else:
#                 was_previous_classification_correct = False
            
            self.tweet_objects.loc[index, category_type_prediction_column_name] = predicted_output

            print("Ground truth =", ground_truth, "Predicted output =", predicted_output)
            print("Finished Processing tweet with index# =", index)
            print()

        print("<======= Finished generating metrics for claim existence =======>")
        
        print("Ground truths = ", ground_truths)
        print("Predictions = ", predicted_outputs)
        
        Utility.write_prediction_output(self.tweet_objects, output_file_name)
        
        return Utility.calculate_metrics(ground_truths, predicted_outputs)


In [5]:
class Category1(Category):
    CATEGORY_TYPE = 1
    CATEGORY_DESCRIPTION = ""
    
    def __init__(self, llm, input_file_name):
        self.system_message, self.tweet_objects = Category1.generate_system_prompt_for_category1(input_file_name)
        
        super().__init__(Category1.CATEGORY_TYPE, llm)
        
    @staticmethod
    def generate_system_prompt_for_category1(input_file_name):
        category1_indices = [22, 56, 71]
        non_category_1_indices = [21, 61, 90]

        tweet_examples_of_category1 = """
        Some examples of tweets that ARE scientifically verifiable (expected response 1):
            a) " ::people_holding_hands:: We can now meet our family and friends outdoors in a group of 6, or 2 households ::leftright_arrow:: Its important that when we do, we follow social distancing guidance ::backhand_index_pointing_right:: This will help to stop the spread of COVID19 as we take the next step out of lockdown LetsDoItForLancashire "
            b) ": BREAKING: Dozens of cops in Massachusetts have resigned in protest of the vaccine mandates. TO WISH THEM GOOD RIDDA"
            c) ": BREAKING Syria president and first lady test positive for COVID19: presidency AFP"
        """
        
        tweet_examples_of_non_category1 = """
        Some examples of tweets ARE NOT scientifically verifiable (expected response 0):
            a) " : The ones calling for lockdown, without risk or injury to themselves, should pay up."
            b) ": Can you catch coronavirus from handling cash? A new study says the risk is low"
            c) ": I wouldnt trust anything this man touches. NoVaccineForMe"
        """
        
        system_message = """
        Imagine you're a COVID-19 tweets classifier. You need to determine whether tweets fall into scientifically verifiable claim category.
        
        The tweets will be delimited with {delimiter} characters.

        A claim or a question is scientifically verified if it's scientifically shown to be true or scientifically shown to be false.
            
        {tweet_examples_of_category1}
        
        {tweet_examples_of_non_category1}

        If the tweet is scientifically verifiable, return 1. Otherwise, return 0.
        """.format_map(Default(tweet_examples_of_non_category1=tweet_examples_of_non_category1, \
                               tweet_examples_of_category1=tweet_examples_of_category1))
        
        tweet_objects = Utility.get_tweet_data(input_file_name)
        indices_to_ignore = category1_indices + non_category_1_indices
        filtered_tweet_objects = tweet_objects.drop(indices_to_ignore)
        
        return system_message, filtered_tweet_objects

In [6]:
model_parameters = "70b"

llm = TogetherLLM(
    model= f"togethercomputer/llama-2-{model_parameters}-chat",
    temperature=0.2,
    max_tokens=1500
)

input_file_name = "tweets - original.csv"
output_file_name = "updated_tweets_cat_1-{model_parameters}.csv"

In [7]:
cat1 = Category1(llm, input_file_name)

In [None]:
cat1.generate_cat_metrics(output_file_name)