In [1]:
import os
import glob
import re
import json
import groq
import torch
import signal
import traceback
import local_settings as S
from transformers import pipeline
from tqdm.notebook import tqdm



In [2]:
# Constants
LLM_MODEL = "llama-3.1-70b-versatile"
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
summarizer = pipeline("summarization", device=DEVICE)
def get_summary(text):
    summary = summarizer(text, truncation=True, max_length=400, min_length=50, do_sample=True, temperature=0.3)
    return summary[0]['summary_text']

No model was supplied, defaulted to sshleifer/distilbart-cnn-12-6 and revision a4f8f3e (https://huggingface.co/sshleifer/distilbart-cnn-12-6).
Using a pipeline without specifying a model name and revision in production is not recommended.


In [4]:
class RanOutOfGroqAPIKeys(Exception):
    pass

class GroqAPI:
    keychain = S.GROQ_API_KEYS
    current_key = 0
    client = groq.Groq(api_key=keychain[current_key])


    @classmethod
    def rotate_key(cls):
        for key in range(cls.current_key+1, len(cls.keychain)):
            cls.current_key = key
            cls.client = groq.Groq(api_key=cls.keychain[cls.current_key])
            print(f"Rotated to KEY[{cls.current_key}]")
            return
        raise RanOutOfGroqAPIKeys("All keys are invalid.")


    @classmethod
    def query(cls, **kwargs):
        def handler(signum, frame):   raise TimeoutError("Query took too long!")
        signal.signal(signal.SIGALRM, handler)
        while True:
            signal.alarm(30)  # Start the timer
            try:
                completion = cls.client.chat.completions.create(**kwargs)
                signal.alarm(0)  # Reset the timer
                return completion
            except TimeoutError as e:
                print("Query took too long!")
                cls.rotate_key()
            except groq.RateLimitError as e:
                print("A [429] status code was received; we should back off a bit.")
                cls.rotate_key()
            except groq.APIStatusError as e:
                print(f"A [{e.status_code}] status code was received:")
                print(e.response)
                print(e.message)
                print(e.body)
                cls.rotate_key()
            except groq.APIConnectionError as e:
                print("The server could not be reached.")
                print(e.__cause__)

In [5]:
def get_favourability_ratings(text):
    completion = GroqAPI.query(
        model=LLM_MODEL,
        messages=[
            {
                "role": "system",
                "content": (
                    "You are an expert political analyst. Read the text provided by the user. "
                    "Describe how favourable is it to democrats, and to republicans, in under 50 words. "
                    "Also provide two scores on a scale of -5 to +5, quantifying this favourability to the two parties "
                    "in a JSON format with two keys \"democrats\" & \"republicans\"."
                )
            },
            {
                "role": "user",
                "content": text
            }
        ],
        temperature=0.25,
        max_tokens=512,
        top_p=1,
        stream=False,
        stop=None,
    )
    explanation_with_ratings = completion.choices[0].message.content
    return explanation_with_ratings, completion.usage

In [6]:
def find_json_objects(input_string):
    json_pattern = r'\{.*?\}'
    potential_jsons = re.findall(json_pattern, input_string, re.DOTALL)
    valid_jsons = []
    for json_str in potential_jsons:
        try:
            json_obj = json.loads(json_str)
            valid_jsons.append(json_obj)
        except json.JSONDecodeError:
            continue
    return valid_jsons

In [7]:
FLAG = False
def generate_ratings(article):
    global FLAG
    if FLAG:    return article
    if not any([article['date_google'],article['date_metadata'],article['date_published']]):   return article
    try:
        summary = get_summary(article['text'])
        explanation_with_ratings, usage = get_favourability_ratings(summary)
        ratings = find_json_objects(explanation_with_ratings)[0]
        article['summary'] = summary
        article['explanation'] = explanation_with_ratings
        article['groq_usage'] = str(usage)
        article['rating_democrats'] = ratings['democrats']
        article['rating_republicans'] = ratings['republicans']
        return article
    except RanOutOfGroqAPIKeys as e:
        print('Ran out of Groq API keys. Aborting.')
        FLAG = True
        return article
    except Exception as e:
        print('Failed to generate ratings:', e)
        traceback.print_exc()
        return article

In [8]:
for filename in sorted(glob.glob('../news_data/data/newsdata_*.json')):
    TS = re.search(r'_(\d+\.\d+)\.', filename).group(1)
    if os.path.exists(f'./data/ratings_{TS}.json'):    continue

    print(f"Processing: {filename}")
    data = json.load(open(filename))
    FLAG = False
    raw = list(tqdm(map(generate_ratings, data), total=len(data)))
    rated = [item for item in raw if 'rating_republicans' in item.keys()]
    redo = [item for item in raw if 'rating_republicans' not in item.keys()]
    with open(f'./data/ratings_{TS}.json', 'w') as f:
        json.dump(rated, f, indent=4, sort_keys=True)
    with open(f'./redo/newsdata_{TS}.json', 'w') as f:
        json.dump(redo, f, indent=4, sort_keys=True)

Processing: ../news_data/data/newsdata_1732602008.020645.json


  0%|          | 0/369 [00:00<?, ?it/s]

Your max_length is set to 400, but your input_length is only 94. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=47)
Your max_length is set to 400, but your input_length is only 84. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=42)
Your max_length is set to 400, but your input_length is only 210. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=105)
Your max_length is set to 400, but your input_length is only 151. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=75)
Y

Processing: ../news_data/data/newsdata_1732602079.182383.json


  0%|          | 0/408 [00:00<?, ?it/s]

Your max_length is set to 400, but your input_length is only 44. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=22)
Your max_length is set to 400, but your input_length is only 6. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=3)
Your max_length is set to 400, but your input_length is only 96. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=48)
Your max_length is set to 400, but your input_length is only 148. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=74)
Your 