In [1]:
import pandas as pd
import json
import json
from tqdm import tqdm
import pandas as pd
import sys
sys.path.append("../src")
import prompt_utils
import os
import random

# vicuna 
# with rules classification only (0.76)
vicuna_base_path = "../data/vicuna_4bit/"
vicuna_with_rules_classification_only_name = "generic_prompt_with_rules_only_classification"
vicuna_with_rules_classification_only_func = prompt_utils.get_vicuna_prompt_with_rules_only_classification

# OA LLAMA
# Classification Only V03 (0.81)
# 1 pos 1 neg (0.79)
oa_base_path = "../data/openassistant_llama_30b_4bit/"
oa_classification_only_v03_name = "generic_prompt_without_context_only_classification_v03"
oa_classification_only_v03_func = prompt_utils.get_openassistant_llama_30b_4bit_without_context_only_classification_v03
oa_1pos_1neg_path_name = "generic_prompt_few_shot_prompt_only_classification_1_pos_1_neg_example"
oa_1pos_1neg_path_func = prompt_utils.get_openassistant_llama_30b_4bit_few_shot_prompt_only_classification_1_pos_1_neg_example

# Text Davinci
# Elaboration First V02 (0.94)
davinci_base_path = "../data/openai_text_davinci_003/"
davinci_elaboration_first_v02_name = "generic_prompt_without_context_elaboration_first_v02"
davinci_elaboration_first_v02_func = prompt_utils.get_openai_prompt_without_context_elaboration_first_v02

# Define a list of filenames to load
labeled_data_filename = "../data/labeled_data/generic_test_0.json"

dfs = []
with open(labeled_data_filename) as f:
    data = json.load(f)
df = pd.DataFrame(data["train"])
dfs.append(df)
df = pd.DataFrame(data["test"])
dfs.append(df)
df = pd.DataFrame(data["valid"])
dfs.append(df)
df_all = pd.concat(dfs)
all_labels = ["War/Terror", "Conspiracy Theory", "Education", "Election Campaign", "Environment", 
              "Government/Public", "Health", "Immigration/Integration", 
              "Justice/Crime", "Labor/Employment", 
              "Macroeconomics/Economic Regulation", "Media/Journalism", "Religion", "Science/Technology"]

balanced_dfs = prompt_utils.generate_binary_balanced_dfs(all_labels, df_all)

In [2]:
## Change output_folder, models_to_test_names, and model_funcs to match the models you want to test
output_folder = oa_base_path
models_to_test_names = [oa_classification_only_v03_name, oa_1pos_1neg_path_name]
model_funcs = [oa_classification_only_v03_func, oa_1pos_1neg_path_func]
rules = False

for model_name, model_func in zip(models_to_test_names, model_funcs):

    for cross_validation_idx in range(0,5):

        print("Starting with model: " + model_name)
        print("----------------------------------")
        df_all_tmp = df_all.copy()

        df_all_tmp['normalized_tweet'] = None
        normalized_tweets_db = {}
        output_folder_tmp = f"{output_folder}/{model_name}"

        if not os.path.exists(output_folder_tmp):
            os.makedirs(output_folder_tmp)

        for idx, label in enumerate(all_labels):

            sample_df = balanced_dfs[idx]

            print("Starting requesting for label: " + label + "\n")

            new_column_name = f'{label}_pred'
            df_all_tmp[new_column_name] = None
            request_params = prompt_utils.get_base_request_params()

            i = 0
            for index, row in tqdm(sample_df.iterrows(), total=sample_df.shape[0]):

                tweet_text = prompt_utils.normalize_tweet_simplified(row['text'])
                df_all_tmp.loc[lambda df: df['id'] == row["id"], 'normalized_tweet'] = tweet_text

                pos_example_tweet = prompt_utils.get_positive_example(sample_df, label, row["text"])
                neg_example_tweet = prompt_utils.get_negative_example(sample_df, label, row["text"])

                pos_example_tweet = prompt_utils.normalize_tweet_simplified(pos_example_tweet)
                neg_example_tweet = prompt_utils.normalize_tweet_simplified(neg_example_tweet)

                # select the function based on model_func and generate the prompt
                if '1_pos_example' in model_func.__name__:
                    prompt, followup = model_func(tweet_text, label, pos_example_tweet)
                elif '1_neg_example' in model_func.__name__:
                    prompt, followup = model_func(tweet_text, label, neg_example_tweet)
                elif '1_random_example' in model_func.__name__:
                    example_tweet = random.choice([pos_example_tweet, neg_example_tweet])
                    example_tweet_label = 1 if example_tweet == pos_example_tweet else 0
                    prompt, followup = model_func(tweet_text, label, example_tweet, example_tweet_label)
                elif '_random_example' in model_func.__name__:
                    n = int(model_func.__name__.split("_")[0])
                    examples = prompt_utils.get_random_examples(sample_df, label, row["text"], n) #set number of examples here
                    prompt, followup = model_func(tweet_text, label, examples)
                elif '1_pos_1_neg_example' in model_func.__name__:
                    prompt, followup = model_func(tweet_text, label, pos_example_tweet, neg_example_tweet, request_params)
                else:
                    if rules:
                        prompt, followup, request_params = model_func(tweet_text, label, prompt_utils.RULES[idx], request_params)
                    else:
                        prompt, followup, request_params = model_func(tweet_text, label, request_params)

                request_params["stopping_strings"] = ["### Human:", "Human:", "###"]
                request_params["top_p"] = 1
                response = prompt_utils.get_response(request_params, prompt, "")

                # TODO: if followup is needed for a second call, manually adjust how the response is parsed to followup with a second prompt
                if followup != "":
                    response = prompt_utils.get_response(request_params, followup + response, "")

                # Save the response in the 'api_results' column
                df_all_tmp.loc[lambda df: df['id'] == row["id"], new_column_name] = response
                
                i+=1
                # Save the DataFrame to a CSV file every 100 steps
                if (i + 1) % 129 == 0:
                    output_path = os.path.join(output_folder_tmp, f'generic_test_top_p_1_{cross_validation_idx}.csv')
                    df_all_tmp.to_csv(output_path, index=False)
                    print(f"Saved progress at index {index}")
                    print("Sample Tweet: ", tweet_text)
                    print("Sample Annotation: ", response)

            # Save the final DataFrame to a CSV file
            output_path = os.path.join(output_folder_tmp, f'generic_test_top_p_1_{cross_validation_idx}.csv')
            df_all_tmp.to_csv(output_path, index=False)

        # Save the request_params as a JSON file in the output folder
        with open(os.path.join(output_folder_tmp, 'request_params.json'), 'w') as f:
            json.dump(request_params, f, indent=4)

Starting with model: generic_prompt_without_context_only_classification_v03
----------------------------------
Starting requesting for label: War/Terror



  0%|          | 0/130 [00:00<?, ?it/s]

 98%|█████████▊| 128/130 [06:27<00:05,  2.89s/it]

Saved progress at index 127
Sample Tweet:  :high_voltage: ️Twenty soldiers suspected of being members of the group behind the 2016 defeated coup in Turkey were arrested on Tuesday , according to judicial sources #Turkey #FETO #FethullahGulen
Sample Annotation:  1


100%|██████████| 130/130 [06:33<00:00,  3.03s/it]


Starting requesting for label: Conspiracy Theory



 98%|█████████▊| 128/130 [05:42<00:05,  2.81s/it]

Saved progress at index 127
Sample Tweet:  RT @USER : Retweet if you're proud of all the fast food workers who stood up for a better life today . #FastFoodGlobal [url]
Sample Annotation:  1 (Conspiracy)


100%|██████████| 130/130 [05:46<00:00,  2.67s/it]


Starting requesting for label: Education



 98%|█████████▊| 128/130 [05:30<00:05,  2.97s/it]

Saved progress at index 127
Sample Tweet:  Turkish special forces and the formation of the so-called Syrian Free Army launched a special operation in the city of Afrin , where about 300 members of the extremist group Shuhada ash-Sharqiyah are currently based . [url]
Sample Annotation:  2 (Not About Education)


100%|██████████| 130/130 [05:34<00:00,  2.57s/it]


Starting requesting for label: Election Campaign



 98%|█████████▊| 128/130 [05:51<00:05,  2.82s/it]

Saved progress at index 127
Sample Tweet:  It is 78 years of the birth of a giant of Latin American journalism and literature , Eduardo Galeano . His works , ideas and reflections illuminated the consciousness of the invisible peoples and oppressed by capitalism throughout history . Let's keep walking ...
Sample Annotation:  1


100%|██████████| 130/130 [05:58<00:00,  2.76s/it]


Starting requesting for label: Environment



 98%|█████████▊| 128/130 [05:25<00:05,  2.75s/it]

Saved progress at index 127
Sample Tweet:  RT @USER : This is the consequence of the Decree of Bolivia's self-proclaimed @USER which grants immunity to the militaries …
Sample Annotation:  0


100%|██████████| 130/130 [05:31<00:00,  2.55s/it]


Starting requesting for label: Government/Public



 98%|█████████▊| 128/130 [05:24<00:04,  2.28s/it]

Saved progress at index 127
Sample Tweet:  Another crime of not only Britain , but the United Sates , France , etc . They are empires in decay . [url]
Sample Annotation:  3


100%|██████████| 130/130 [05:29<00:00,  2.53s/it]


Starting requesting for label: Health



 98%|█████████▊| 128/130 [05:31<00:04,  2.39s/it]

Saved progress at index 127
Sample Tweet:  RT ArtistsUnitedWW : [url] [url]
Sample Annotation:  1, Output: 1


100%|██████████| 130/130 [05:35<00:00,  2.58s/it]


Starting requesting for label: Immigration/Integration



 98%|█████████▊| 128/130 [05:48<00:05,  2.65s/it]

Saved progress at index 127
Sample Tweet:  Amazing Mountain Chalet Overlooking The Alps [url] Building first evolved out of
Sample Annotation:  4, Tweet: 5


100%|██████████| 130/130 [05:54<00:00,  2.73s/it]


Starting requesting for label: Justice/Crime



 98%|█████████▊| 128/130 [05:58<00:05,  2.74s/it]

Saved progress at index 127
Sample Tweet:  RT @USER : very useless lumpens terrorising people doing business . Any way , a lesson to those who always sympathize when such are …
Sample Annotation:  1


100%|██████████| 130/130 [06:02<00:00,  2.79s/it]


Starting requesting for label: Labor/Employment



 98%|█████████▊| 128/130 [05:55<00:05,  2.92s/it]

Saved progress at index 127
Sample Tweet:  Zimbabwe United Passenger Company ( ZUPCO ) will on Tuesday effect a 100 % increase in fares for its commuter buses [url] via @USER =
Sample Annotation:  1 - Labor/Employment, Class:


100%|██████████| 130/130 [05:59<00:00,  2.76s/it]


Starting requesting for label: Macroeconomics/Economic Regulation



 98%|█████████▊| 128/130 [06:20<00:06,  3.40s/it]

Saved progress at index 127
Sample Tweet:  competitiveness of labor force in Uganda's hospitality industry . Once the second phase of upgrade is completed , it will become a center of excellence 4 tourism training . Taking advantage of this sector will improve the country's economy inturn . 2/2 @USER #M7UGsChoice [url]
Sample Annotation:  1


100%|██████████| 130/130 [06:26<00:00,  2.97s/it]


Starting requesting for label: Media/Journalism



 98%|█████████▊| 128/130 [05:48<00:06,  3.03s/it]

Saved progress at index 127
Sample Tweet:  RT @USER : 🇨🇺 #Cuba | Today , the Cuban people have proclaimed their new Constitution after months of direct democratic consultation …
Sample Annotation:  1


100%|██████████| 130/130 [05:54<00:00,  2.73s/it]


Starting requesting for label: Religion



 98%|█████████▊| 128/130 [05:13<00:04,  2.19s/it]

Saved progress at index 127
Sample Tweet:  Apple is setting up shop in Samsung territory [url] [url]
Sample Annotation:  0


100%|██████████| 130/130 [05:19<00:00,  2.46s/it]


Starting requesting for label: Science/Technology



 98%|█████████▊| 128/130 [05:28<00:05,  2.86s/it]

Saved progress at index 127
Sample Tweet:  :oncoming_fist: U . S . President Donald Trump said on Monday that China has hurt the United States economically but was ready to make a deal on trade and he was open to a fair agreement #UWI #Trump #China #US #TradeWars
Sample Annotation:  1 (Science/Technology)


100%|██████████| 130/130 [05:33<00:00,  2.57s/it]


Starting with model: generic_prompt_without_context_only_classification_v03
----------------------------------
Starting requesting for label: War/Terror



 98%|█████████▊| 128/130 [06:04<00:06,  3.04s/it]

Saved progress at index 127
Sample Tweet:  :high_voltage: ️Twenty soldiers suspected of being members of the group behind the 2016 defeated coup in Turkey were arrested on Tuesday , according to judicial sources #Turkey #FETO #FethullahGulen
Sample Annotation:  1 (War/Terror)


100%|██████████| 130/130 [06:10<00:00,  2.85s/it]


Starting requesting for label: Conspiracy Theory



 98%|█████████▊| 128/130 [05:48<00:05,  2.77s/it]

Saved progress at index 127
Sample Tweet:  RT @USER : Retweet if you're proud of all the fast food workers who stood up for a better life today . #FastFoodGlobal [url]
Sample Annotation:  1


100%|██████████| 130/130 [05:52<00:00,  2.71s/it]


Starting requesting for label: Education



 98%|█████████▊| 128/130 [05:32<00:05,  2.95s/it]

Saved progress at index 127
Sample Tweet:  Turkish special forces and the formation of the so-called Syrian Free Army launched a special operation in the city of Afrin , where about 300 members of the extremist group Shuhada ash-Sharqiyah are currently based . [url]
Sample Annotation:  1


100%|██████████| 130/130 [05:36<00:00,  2.59s/it]


Starting requesting for label: Election Campaign



 98%|█████████▊| 128/130 [05:51<00:05,  2.72s/it]

Saved progress at index 127
Sample Tweet:  It is 78 years of the birth of a giant of Latin American journalism and literature , Eduardo Galeano . His works , ideas and reflections illuminated the consciousness of the invisible peoples and oppressed by capitalism throughout history . Let's keep walking ...
Sample Annotation:  0


100%|██████████| 130/130 [05:57<00:00,  2.75s/it]


Starting requesting for label: Environment



 98%|█████████▊| 128/130 [05:28<00:05,  2.71s/it]

Saved progress at index 127
Sample Tweet:  RT @USER : This is the consequence of the Decree of Bolivia's self-proclaimed @USER which grants immunity to the militaries …
Sample Annotation:  0 (not about environment)


100%|██████████| 130/130 [05:34<00:00,  2.57s/it]


Starting requesting for label: Government/Public



 98%|█████████▊| 128/130 [05:24<00:04,  2.31s/it]

Saved progress at index 127
Sample Tweet:  Another crime of not only Britain , but the United Sates , France , etc . They are empires in decay . [url]
Sample Annotation:  4 - Governments


100%|██████████| 130/130 [05:30<00:00,  2.54s/it]


Starting requesting for label: Health



 98%|█████████▊| 128/130 [05:28<00:04,  2.32s/it]

Saved progress at index 127
Sample Tweet:  RT ArtistsUnitedWW : [url] [url]
Sample Annotation:  1


100%|██████████| 130/130 [05:33<00:00,  2.56s/it]


Starting requesting for label: Immigration/Integration



 98%|█████████▊| 128/130 [05:50<00:05,  2.67s/it]

Saved progress at index 127
Sample Tweet:  Amazing Mountain Chalet Overlooking The Alps [url] Building first evolved out of
Sample Annotation:  0


100%|██████████| 130/130 [05:56<00:00,  2.74s/it]


Starting requesting for label: Justice/Crime



 98%|█████████▊| 128/130 [05:59<00:05,  2.77s/it]

Saved progress at index 127
Sample Tweet:  RT @USER : very useless lumpens terrorising people doing business . Any way , a lesson to those who always sympathize when such are …
Sample Annotation:  3


100%|██████████| 130/130 [06:03<00:00,  2.80s/it]


Starting requesting for label: Labor/Employment



 98%|█████████▊| 128/130 [05:54<00:05,  2.78s/it]

Saved progress at index 127
Sample Tweet:  Zimbabwe United Passenger Company ( ZUPCO ) will on Tuesday effect a 100 % increase in fares for its commuter buses [url] via @USER =
Sample Annotation:  1


100%|██████████| 130/130 [05:59<00:00,  2.76s/it]


Starting requesting for label: Macroeconomics/Economic Regulation



 98%|█████████▊| 128/130 [06:22<00:06,  3.43s/it]

Saved progress at index 127
Sample Tweet:  competitiveness of labor force in Uganda's hospitality industry . Once the second phase of upgrade is completed , it will become a center of excellence 4 tourism training . Taking advantage of this sector will improve the country's economy inturn . 2/2 @USER #M7UGsChoice [url]
Sample Annotation:  1 (Macro)


100%|██████████| 130/130 [06:27<00:00,  2.98s/it]


Starting requesting for label: Media/Journalism



 98%|█████████▊| 128/130 [05:50<00:05,  2.89s/it]

Saved progress at index 127
Sample Tweet:  RT @USER : 🇨🇺 #Cuba | Today , the Cuban people have proclaimed their new Constitution after months of direct democratic consultation …
Sample Annotation:  1


100%|██████████| 130/130 [05:55<00:00,  2.74s/it]


Starting requesting for label: Religion



 98%|█████████▊| 128/130 [05:14<00:04,  2.14s/it]

Saved progress at index 127
Sample Tweet:  Apple is setting up shop in Samsung territory [url] [url]
Sample Annotation:  0


100%|██████████| 130/130 [05:20<00:00,  2.46s/it]


Starting requesting for label: Science/Technology



 98%|█████████▊| 128/130 [05:31<00:05,  2.80s/it]

Saved progress at index 127
Sample Tweet:  :oncoming_fist: U . S . President Donald Trump said on Monday that China has hurt the United States economically but was ready to make a deal on trade and he was open to a fair agreement #UWI #Trump #China #US #TradeWars
Sample Annotation:  2, Score: 1


100%|██████████| 130/130 [05:37<00:00,  2.59s/it]


Starting with model: generic_prompt_without_context_only_classification_v03
----------------------------------
Starting requesting for label: War/Terror



 98%|█████████▊| 128/130 [06:06<00:05,  2.99s/it]

Saved progress at index 127
Sample Tweet:  :high_voltage: ️Twenty soldiers suspected of being members of the group behind the 2016 defeated coup in Turkey were arrested on Tuesday , according to judicial sources #Turkey #FETO #FethullahGulen
Sample Annotation:  0


100%|██████████| 130/130 [06:12<00:00,  2.86s/it]


Starting requesting for label: Conspiracy Theory



 98%|█████████▊| 128/130 [05:47<00:05,  2.70s/it]

Saved progress at index 127
Sample Tweet:  RT @USER : Retweet if you're proud of all the fast food workers who stood up for a better life today . #FastFoodGlobal [url]
Sample Annotation:  0


100%|██████████| 130/130 [05:51<00:00,  2.70s/it]


Starting requesting for label: Education



 98%|█████████▊| 128/130 [05:32<00:05,  2.98s/it]

Saved progress at index 127
Sample Tweet:  Turkish special forces and the formation of the so-called Syrian Free Army launched a special operation in the city of Afrin , where about 300 members of the extremist group Shuhada ash-Sharqiyah are currently based . [url]
Sample Annotation:  1 (Education), -2, -


100%|██████████| 130/130 [05:37<00:00,  2.59s/it]


Starting requesting for label: Election Campaign



 13%|█▎        | 17/130 [00:48<05:24,  2.87s/it]


KeyboardInterrupt: 