#### IMPORT MODULES

In [None]:
import cohere
import pandas as pd
import requests
import datetime
from tqdm import tqdm
pd.set_option('display.max_colwidth', None)

def get_post_titles(**kwargs):
    """ Gets data from the pushshift api. Read more: https://github.com/pushshift/api """
    base_url = f"https://api.pushshift.io/reddit/search/submission/"
    payload = kwargs
    request = requests.get(base_url, params=payload)
    return [a['title'] for a in request.json()['data']]

In [None]:
# Paste your API key here. Remember to not share publicly
api_key = 'WrY00vdDGucL0sRD1DwX7TM1IPydwFEeVLhnwx37'

# Create and retrieve a Cohere API key from os.cohere.ai
co = cohere.Client(api_key)

In [None]:
movie_examples = [
("Deadpool 2", "Deadpool 2 | Official HD Deadpool's \"Wet on Wet\" Teaser | 2018"),
("none", "Jordan Peele Just Became the First Black Writer-Director With a $100M Movie Debut"),
("Joker", "Joker Officially Rated “R”"),
("Free Guy", "Ryan Reynolds’ 'Free Guy' Receives July 3, 2020 Release Date - About a bank teller stuck in his routine that discovers he’s an NPC character in brutal open world game."),
("none", "James Cameron congratulates Kevin Feige and Marvel!"),
("Guardians of the Galaxy", "The Cast of Guardians of the Galaxy release statement on James Gunn"),
]

In [None]:
#@title Create the prompt (Run this cell to execute required code) {display-mode: "form"}

class cohereExtractor():
    def __init__(self, examples, example_labels, labels, task_desciption, example_prompt):
        self.examples = examples
        self.example_labels = example_labels
        self.labels = labels
        self.task_desciption = task_desciption
        self.example_prompt = example_prompt

    def make_prompt(self, example):
        examples = self.examples + [example]
        labels = self.example_labels + [""]
        return (self.task_desciption +
                "\n---\n".join( [examples[i] + "\n" +
                                self.example_prompt + 
                                 labels[i] for i in range(len(examples))]))

    def extract(self, example):
      extraction = co.generate(
          model='large',
          prompt=self.make_prompt(example),
          max_tokens=10,
          temperature=0.1,
          stop_sequences=["\n"])
      return(extraction.generations[0].text[:-1])


cohereMovieExtractor = cohereExtractor([e[1] for e in movie_examples], 
                                       [e[0] for e in movie_examples], [],
                                       "", 
                                       "extract the movie title from the post:")


In [None]:
# This is what the prompt looks like:
print(cohereMovieExtractor.make_prompt('<input text here>'))

In [None]:
num_posts = 10

movies_list = get_post_titles(size=num_posts, 
      after=str(int(datetime.datetime(2021,1,1,0,0).timestamp())), 
      before=str(int(datetime.datetime(2022,1,1,0,0).timestamp())), 
      subreddit="movies", 
      sort_type="score", 
      sort="desc")

# Show the list
movies_list

In [None]:
results = []
for text in tqdm(movies_list):
    try:
        extracted_text = cohereMovieExtractor.extract(text)
        results.append(extracted_text)
    except Exception as e:
        print('ERROR: ', e)

In [None]:
pd.DataFrame(data={'text': movies_list, 'extracted_text': results})

In [None]:
test_df = pd.read_csv('data/Expamle.csv',index_col=0)
test_df

In [None]:
from concurrent.futures import ThreadPoolExecutor

extracted = []
# Run the model to extract the entities
with ThreadPoolExecutor(max_workers=8) as executor:
    for i in executor.map(cohereMovieExtractor.extract, test_df['text']):
        extracted.append(str(i).strip())
# Save results
test_df['extracted_text'] = extracted

In [None]:
# Compare the label to the extracted text
test_df['correct'] = (test_df['label'].str.lower() == test_df['extracted_text'].str.lower()).astype(int)

# Print the accuracy
print(f'Classification accuracy {test_df["correct"].mean() *100}%')

In [None]:
test_df[test_df['correct']==0]

In [None]:
from sklearn.metrics import classification_report
import warnings
warnings.filterwarnings('ignore')

print(classification_report(test_df['label'].str.lower(), test_df['extracted_text'].str.lower()))

In [None]:
=============///===================