In [None]:
import requests
from transformers import pipeline, AutoTokenizer
from collections import Counter

STORAGE_SERVICE_URL = "http://localhost:8001"

class NLPProcessor:
    """
    A processor for performing various NLP tasks such as summarization, sentiment analysis,
    and zero-shot classification on articles retrieved from a storage service.
    
    This class handles long texts by chunking them to fit the model's maximum context length,
    processes each chunk individually, and then aggregates the results.
    """

    def __init__(self):
        """
        Initializes the NLPProcessor with pipelines and tokenizers for each task.
        Uses task-specific models for summarization, sentiment analysis, and classification.
        """
        self.summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
        self.sentiment_analyzer = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
        self.classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
        self.summarizer_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
        self.sentiment_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english")
        self.classify_tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-mnli")
    
    def retrieve_article_content(self, article_id: str) -> str:
        """
        Retrieves an article from the storage service via its API and concatenates its title and paragraphs.
        
        Args:
            article_id (str): The unique identifier for the article to retrieve.
        
        Returns:
            str: A string containing the article's title and paragraphs.
        
        Raises:
            Exception: If the article retrieval fails (non-200 response).
        """
        url = f"{STORAGE_SERVICE_URL}/articles/{article_id}"
        response = requests.get(url)
        if response.status_code != 200:
            raise Exception(f"Failed to retrieve article {article_id}: {response.text}")
        article = response.json()
        title = article.get("Title", "")
        paragraphs = article.get("Paragraphs", [])
        content = title + "\n" + "\n".join(paragraphs)
        return content
    
    def summarize(self, article_id: str) -> str:
        """
        Summarizes the content of an article.
        
        The method retrieves the article content, splits it into manageable chunks based on the
        summarization model's context limit, generates partial summaries for each chunk, and then
        combines and optionally refines these partial summaries into a final summary.
        
        Args:
            article_id (str): The unique identifier for the article to summarize.
            
        Returns:
            str: A summary of the article.
        
        Raises:
            Exception: If the article has no content.
        """
        content = self.retrieve_article_content(article_id)

        if not content:
            raise Exception(f"Article {article_id} has no content to summarize.")

        chunks = self.chunk_text(text=content, task="summarization")

        partial_summaries = [
            self.summarizer(chunk, max_length=200, min_length=100, do_sample=False)[0]["summary_text"]
            for chunk in chunks
        ]

        combined_summary = " ".join(partial_summaries)
        
        if len(self.summarizer_tokenizer.tokenize(combined_summary)) > 1024:
            final_summary = self.summarizer(combined_summary, max_length=400, min_length=200, do_sample=False)[0]["summary_text"]
        else:
            final_summary = combined_summary

        return final_summary
    
    def chunk_text(self, text: str, task: str) -> list:
        """
        Splits the input text into chunks that do not exceed the model's maximum token limit for a given task.
        
        The function uses different tokenizers and token limits based on the task:
          - "summarization": Uses summarizer_tokenizer with a max of 1024 tokens.
          - "sentiment": Uses sentiment_tokenizer with a max of 512 tokens.
          - "classification": Uses classify_tokenizer with a max of 1024 tokens.
        
        Args:
            text (str): The text to be chunked.
            task (str): The task for which the text is being processed. Must be one of "summarization",
                        "sentiment", or "classification".
            
        Returns:
            list: A list of text chunks, each within the specified token limit.
        """
        sentences = text.split(". ")
        chunks, current_chunk = [], ""

        if task == "summarization":
            tokenizer = self.summarizer_tokenizer
            max_tokens=1024
        elif task == "sentiment":
            tokenizer = self.sentiment_tokenizer
            max_tokens=512
        elif task == "classification":
            tokenizer = self.classify_tokenizer
            max_tokens=1024
        else:
            raise ValueError("Unsupported task specified. Use 'summarization', 'sentiment', or 'classification'.")

        for sentence in sentences:
            potential_chunk = current_chunk + sentence + ". "
            token_len = len(tokenizer.tokenize(potential_chunk))

            if token_len > max_tokens:
                chunks.append(current_chunk.strip())
                current_chunk = sentence + ". "
            else:
                current_chunk = potential_chunk

        if current_chunk:
            chunks.append(current_chunk.strip())
        return chunks
    
    def analyze_sentiment(self, article_id: str) -> dict:
        """
        Performs sentiment analysis on an article.
        
        The method retrieves the article content, splits it into chunks to respect the sentiment model's
        context length, performs sentiment analysis on each chunk, and aggregates the results using a
        majority vote for the label and an average of the confidence scores.
        
        Args:
            article_id (str): The unique identifier for the article to analyze.
            
        Returns:
            dict: A dictionary with keys "label" and "score", representing the overall sentiment.
        
        Raises:
            Exception: If the article has no content.
        """
        content = self.retrieve_article_content(article_id)
        if not content:
            raise Exception(f"Article {article_id} has no content for sentiment analysis.")
        
        chunks = self.chunk_text(text=content, task="sentiment")
        sentiments = [self.sentiment_analyzer(chunk)[0] for chunk in chunks]
        
        labels = [result["label"] for result in sentiments]
        avg_score = sum(result["score"] for result in sentiments) / len(sentiments)
        
        majority_label = Counter(labels).most_common(1)[0][0]
        return {"label": majority_label, "score": avg_score}
    
    def classify(self, article_id: str, candidate_labels: list = None) -> dict:
        """
        Classifies an article into one of the candidate labels using zero-shot classification.
        
        The method retrieves the article content, splits it into chunks to respect the classifier's
        context length, performs classification on each chunk, aggregates the scores for each candidate
        label across all chunks, and then normalizes the scores to determine the final classification.
        
        Args:
            article_id (str): The unique identifier for the article to classify.
            candidate_labels (list, optional): A list of candidate labels for classification.
                Defaults to ["economics", "sports", "entertainment", "politics", "technology", "culture", ""].
                
        Returns:
            dict: A dictionary containing the final label under the key "label" and the normalized scores
                  for each candidate label under the key "scores".
            
        Raises:
            Exception: If the article has no content.
        """
        if candidate_labels is None:
            candidate_labels = ["economics", "sports", "entertainment", "politics", "technology", "culture", ""]
    
        content = self.retrieve_article_content(article_id)
        if not content:
            raise Exception(f"Article {article_id} has no content to classify.")
        
        chunks = self.chunk_text(content, task="classification")
        
        aggregated_scores = {label: 0 for label in candidate_labels}
        for chunk in chunks:
            result = self.classifier(chunk, candidate_labels)
            for label, score in zip(result["labels"], result["scores"]):
                aggregated_scores[label] += score
        
        total = sum(aggregated_scores.values())
        normalized_scores = {label: score / total for label, score in aggregated_scores.items()}
        final_label = max(normalized_scores, key=normalized_scores.get)
        return {"label": final_label, "scores": normalized_scores}

In [36]:
nlp_processor = NLPProcessor()
test_article_id = "ec7f043c-7701-4843-aad3-023a283df5d9"
try:
    summary = nlp_processor.summarize(test_article_id)
    print("Summary:", summary)
    
    sentiment = nlp_processor.analyze_sentiment(test_article_id)
    print("Sentiment:", sentiment)
    
    classification = nlp_processor.classify(test_article_id)
    print("Classification:", classification)
except Exception as e:
    print("Error:", str(e))

Device set to use cpu
Device set to use cpu
Device set to use cpu
Token indices sequence length is longer than the specified maximum sequence length for this model (574 > 512). Running this sequence through the model will result in indexing errors


Summary: 6,000 employees were let go at the USDA in February as part of a government-wide purge. The USDA cuts are being felt especially in coastal states home to major shipping ports. Experts warn that the losses could cause food to go rotten while waiting in ports and could lead to even higher grocery prices, in addition to increasing the chances of potentially devastating invasive species getting into the country. Two federal judges and an independent agency that assesses government personnel have already ordered that fired USDA employees be reinstated. Fired USDA workers are still waiting to hear whether they will be reinstated. The Trump administration has signaled it will fight court decisions to reinstate employees. Customs and Border Protection deploys the dogs trained by Copeland and other staffers at the National Dog Detection Training Center. The two agencies run the Agricultural Quarantine Inspection program, but it’s funded by the USDA and not taxpayer dollars. The USDA sa

Token indices sequence length is longer than the specified maximum sequence length for this model (1075 > 1024). Running this sequence through the model will result in indexing errors


Sentiment: {'label': 'NEGATIVE', 'score': 0.9987107714017233}
Classification: {'label': 'economics', 'scores': {'economics': 0.29504014075041657, 'sports': 0.110580816005517, 'entertainment': 0.11551535631807783, 'politics': 0.08210918610448754, 'technology': 0.050974502472549824, 'culture': 0.20001639872887128, '': 0.14576359962007998}}
