In [33]:
!pip install polars requests tqdm



In [34]:
import os
import polars as pl
import requests
import json
from tqdm import tqdm
import time
import re
# Removed: # Removed: import google.generativeai as genai
from concurrent.futures import ThreadPoolExecutor

In [35]:
class KeyManager:
    """Manages multiple API keys and rotates them when rate limits are reached"""
    
    def __init__(self, env_prefix='MISTRAL_API_KEY'):
        """
        Initialize the key manager
        
        Parameters:
        -----------
        env_prefix : str
            Prefix for environment variables storing API keys.
            Keys should be named like MISTRAL_API_KEY, MISTRAL_API_KEY_1, MISTRAL_API_KEY_2, etc.
        """
        self.env_prefix = env_prefix
        self.api_keys = self._load_api_keys()
        self.current_index = 0
        self.rate_limited_keys = {}  # Track which keys hit rate limits and when they can be used again
        self.base_url = "https://api.mistral.ai/v1/chat/completions"
        self.headers = {
            "Content-Type": "application/json"
        }
        
        if not self.api_keys:
            print("⚠️ No API keys found. Please set at least one API key.")
            key = input(f"Enter your Mistral AI API key: ")
            if key:
                self.api_keys.append(key)
            else:
                raise ValueError("No API key provided. Cannot continue.")
        
        print(f"Loaded {len(self.api_keys)} Mistral API keys.")
        
    def _load_api_keys(self):
        """Load API keys from environment variables"""
        api_keys = []
        
        # Try the base key first
        base_key = os.getenv(self.env_prefix)
        if base_key:
            api_keys.append(base_key)
        
        # Try numbered keys (MISTRAL_API_KEY_1, MISTRAL_API_KEY_2, etc.)
        for i in range(1, 10):  # Check for up to 10 keys
            key = os.getenv(f"{self.env_prefix}_{i}")
            if key:
                api_keys.append(key)
        
        return api_keys
    
    def get_current_key(self):
        """Get the current active API key"""
        return self.api_keys[self.current_index]
    
    def get_current_headers(self):
        """Get headers with the current API key"""
        headers = self.headers.copy()
        headers["Authorization"] = f"Bearer {self.get_current_key()}"
        return headers
    
    def rotate_key(self, rate_limited=False, retry_after=60):
        """
        Rotate to the next available API key
        
        Parameters:
        -----------
        rate_limited : bool
            Whether the current key hit a rate limit
        retry_after : int
            Seconds until the rate-limited key can be used again
        
        Returns:
        --------
        str : The new API key
        """
        # Mark the current key as rate limited if needed
        if rate_limited:
            self.rate_limited_keys[self.current_index] = time.time() + retry_after
            print(f"API key {self.current_index + 1} rate limited. Will retry after {retry_after} seconds.")
        
        # Try to find a key that's not rate limited
        original_index = self.current_index
        while True:
            self.current_index = (self.current_index + 1) % len(self.api_keys)
            
            # Check if this key is rate limited
            if self.current_index in self.rate_limited_keys:
                # Check if enough time has passed
                if time.time() > self.rate_limited_keys[self.current_index]:
                    # Key is no longer rate limited
                    del self.rate_limited_keys[self.current_index]
                    break
            else:
                # Key is not rate limited
                break
                
            # If we've checked all keys and they're all rate limited, use the least recently rate limited one
            if self.current_index == original_index:
                # Find the key that will be available soonest
                soonest_available = min(self.rate_limited_keys.items(), key=lambda x: x[1])
                self.current_index = soonest_available[0]
                wait_time = max(0, self.rate_limited_keys[self.current_index] - time.time())
                
                if wait_time > 0:
                    print(f"All API keys are rate limited. Waiting {wait_time:.1f} seconds for the next available key.")
                    time.sleep(wait_time)
                    del self.rate_limited_keys[self.current_index]
                break
        
        # Return the new key
        key = self.api_keys[self.current_index]
        print(f"Switched to Mistral API key {self.current_index + 1}")
        return key

In [36]:
# Initialize the Key Manager
key_manager = KeyManager()

# Set the model name
MODEL = "mistral-small-latest"  # Using Mistral's large model instead of Gemini


Loaded 1 Mistral API keys.



> **⚠️ API Key Setup**
>
> This notebook uses the Mistral AI API with support for multiple API keys:
>
> 1. Set your primary API key as `MISTRAL_API_KEY` environment variable
> 2. For additional keys, use `MISTRAL_API_KEY_1`, `MISTRAL_API_KEY_2`, etc.
> 3. The system will automatically rotate between keys if rate limits are encountered
>
> Keys can be created at [Mistral AI Platform](https://console.mistral.ai/)
>
> **Mistral AI Free Tier Limits:**
> - 1 request per second (60 requests per minute)
> - 500,000 tokens per minute
> - 1 billion tokens per month


In [37]:
def setup_prompt(text):
    """Configure the prompt for Mistral"""
    return [
        {"role": "system", "content": """
            You are a financial sentiment analyzer. Classify the given tweet's sentiment into one of these categories:

            STRONGLY_POSITIVE - Very bullish, highly confident optimistic outlook
            POSITIVE - Generally optimistic, bullish view
            NEUTRAL - Factual, balanced, or no clear sentiment
            NEGATIVE - Generally pessimistic, bearish view
            STRONGLY_NEGATIVE - Very bearish, highly confident pessimistic outlook

            Examples:
            "Breaking: Company XYZ doubles profit forecast!" -> STRONGLY_POSITIVE
            "Expecting modest gains next quarter" -> POSITIVE
            "Market closed at 35,000" -> NEUTRAL
            "Concerned about rising rates" -> NEGATIVE
            "Crash incoming, sell everything!" -> STRONGLY_NEGATIVE

            Format: Return only one word from: STRONGLY_POSITIVE, POSITIVE, NEUTRAL, NEGATIVE, STRONGLY_NEGATIVE
        """},
        {"role": "user", "content": f"Analyze the sentiment of this tweet: {text}"}
    ]

In [38]:
def get_sentiment(text, retries=3):
    """Get sentiment from Gemini with retry logic"""
    if not text or len(str(text).strip()) < 3:
        return 'NEUTRAL'
    
    # Create a generative model object
    model = genai.GenerativeModel(MODEL, generation_config={
        "temperature": 0.0,  # Deterministic output
        "max_output_tokens": 10,  # We only need one word
        "top_p": 1.0
    })
    
    for attempt in range(retries):
        try:
            # Generate content with the model
            response = model.generate_content(setup_prompt(text))
            
            # Extract the sentiment from the response
            if hasattr(response, 'text'):
                sentiment = response.text.strip().upper()
                
                # Validate the response
                valid_labels = [
                    'STRONGLY_POSITIVE', 'POSITIVE', 'NEUTRAL', 'NEGATIVE', 'STRONGLY_NEGATIVE'
                ]
                
                if sentiment in valid_labels:
                    return sentiment
                else:
                    print(f"Invalid sentiment received: {sentiment}, defaulting to NEUTRAL")
                    return 'NEUTRAL'
            else:
                print(f"Unexpected response format: {response}")
                if attempt < retries - 1:
                    time.sleep(2)  # Wait before retry
                    continue
                else:
                    return 'NEUTRAL'
                    
        except Exception as e:
            error_str = str(e).lower()
            # Check for rate limiting errors
            if "quota" in error_str or "rate" in error_str or "429" in error_str:
                # Extract retry time if available (default to 60 seconds if not found)
                retry_after = 60
                if "retryafter" in error_str or "retry-after" in error_str or "retry_after" in error_str:
                    try:
                        # Try to extract the retry time
                        import re
                        matches = re.findall(r'retry.*?(\d+)', error_str)
                        if matches:
                            retry_after = int(matches[0])
                    except:
                        pass
                
                # Switch to another API key if there are multiple keys
                if len(key_manager.api_keys) > 1:
                    key_manager.rotate_key(rate_limited=True, retry_after=retry_after)
                    if attempt < retries - 1:
                        continue
                else:
                    # Only one key, just wait
                    wait_time = min(2 ** attempt * 5, retry_after)  # Exponential backoff with max retry_after
                    print(f"Rate limit hit - waiting {wait_time}s before retry ({attempt+1}/{retries})")
                    time.sleep(wait_time)
                    if attempt < retries - 1:
                        continue
                    
            if attempt == retries - 1:
                print(f"Error processing text: {str(text)[:50]}...\nError: {str(e)}")
                return 'NEUTRAL'
            time.sleep(2)  # Wait before retry
    
    return 'NEUTRAL'

In [39]:
# Test the sentiment analysis with key rotation
test_tweet = "Breaking: Tesla stock hits all-time high after unexpected profit surge"
sentiment = get_sentiment(test_tweet)
print(f"Test tweet: '{test_tweet}'")
print(f"Sentiment: {sentiment}")
print(f"Using Mistral API key index: {key_manager.current_index}")


NameError: name 'genai' is not defined

In [None]:
# Load data from Hugging Face
print("Loading stock market tweets dataset from Hugging Face...")

# Check if the huggingface datasets library is installed
try:
    import huggingface_hub
except ImportError:
    print("Installing huggingface_hub...")
    !pip install huggingface_hub
    import huggingface_hub

# Load the dataset using Polars
df = pl.read_csv('hf://datasets/mjw/stock_market_tweets/stock_market_tweets.csv')

print(f"Loaded {df.shape[0]} tweets")
print("\nSample tweets:")
df.head(5)

In [None]:
# Prepare the dataset for sentiment analysis
# Let's make sure we have the 'body' column which contains the tweet text
if 'body' in df.columns:
    tweet_column = 'body'
elif 'full_text' in df.columns:
    tweet_column = 'full_text'
else:
    raise ValueError("Could not find tweet text column in the dataset")

print(f"Using '{tweet_column}' column for tweet text")

# For demonstration, let's use a small subset of the data
# Use all data instead of a sample - WARNING: This will process 1.7M tweets!
sample_size = df.shape[0]  # Adjust based on your needs
sample_df = df.sample(sample_size, seed=42)

print(f"\nAnalyzing sentiment for {sample_size} tweets")


## Rate Limits for Mistral AI API

Mistral AI's free tier has the following limits:
- 1 request per second (60 requests per minute)
- 500,000 tokens per minute
- 1 billion tokens per month

This notebook implements:
1. Key rotation to handle multiple API keys
2. Automatic retry with exponential backoff
3. Batch processing to optimize throughput
4. Error handling to ensure robust processing

If you need to process many tweets, consider:
- Creating multiple API keys
- Adjusting batch size and workers based on your needs
- Processing tweets in smaller batches with appropriate delays


In [None]:
def process_tweets(tweets, batch_size=4, max_workers=2):
    """Process tweets in batches with Google Gemini API
    
    Google Gemini API free tier allows:
    - Up to 1,500 requests per day
    - Higher throughput than OpenRouter's free tier
    """
    all_sentiments = []
    
    # For demonstration, we'll process a reasonable number of tweets
    # Adjust max_tweets if needed - 200 is safe for the free tier
    max_tweets = min(200, len(tweets))
    tweets = tweets[:max_tweets]
    
    print(f"Processing {max_tweets} tweets using Google's Gemini API with {len(key_manager.api_keys)} API keys")
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        for i in tqdm(range(0, len(tweets), batch_size), desc="Processing tweet batches"):
            batch = tweets[i:i+batch_size]
            
            try:
                # Process tweets in parallel
                results = list(executor.map(get_sentiment, batch))
                all_sentiments.extend(results)
                
                # Add a short delay between batches to avoid rate limiting
                time.sleep(2)
                
            except Exception as e:
                print(f"Error processing batch: {e}")
                # Add neutral sentiments for this batch in case of failure
                all_sentiments.extend(['NEUTRAL'] * len(batch))
                time.sleep(5)  # Wait a bit longer after an error
                
    # If we didn't get enough sentiments (due to errors), fill with NEUTRAL
    if len(all_sentiments) < len(tweets):
        all_sentiments.extend(['NEUTRAL'] * (len(tweets) - len(all_sentiments)))
            
    return all_sentiments

In [None]:

# Test the sentiment analysis with key rotation
test_tweet = "Breaking: Tesla stock hits all-time high after unexpected profit surge"
sentiment = get_sentiment(test_tweet)
print(f"Test tweet: '{test_tweet}'")
print(f"Sentiment: {sentiment}")
print(f"Using API key index: {key_manager.current_index}")


In [10]:
def process_tweets(tweets, batch_size=4, max_workers=2):
    """Process tweets in batches with Mistral AI API
    
    Mistral AI free tier allows:
    - 1 request per second (60 requests per minute)
    - 500,000 tokens per minute
    - 1 billion tokens per month
    """
    all_sentiments = []
    
    # Process all tweets
    # Removed limitation: max_tweets = min(200, len(tweets))
    # Removed limitation: tweets = tweets[:max_tweets]
    
    print(f"Processing {len(tweets)} tweets using Mistral AI API with {len(key_manager.api_keys)} API keys")
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        for i in tqdm(range(0, len(tweets), batch_size), desc="Processing tweet batches"):
            batch = tweets[i:i+batch_size]
            
            try:
                # Process tweets in parallel
                results = list(executor.map(get_sentiment, batch))
                all_sentiments.extend(results)
                
                # Add a delay between batches to respect Mistral's 1 req/sec rate limit
                # With batch_size and max_workers, adjust sleep time accordingly
                time.sleep(batch_size * max_workers)  # Conservative approach
                
            except Exception as e:
                print(f"Error processing batch: {e}")
                # Add neutral sentiments for this batch in case of failure
                all_sentiments.extend(['NEUTRAL'] * len(batch))
                time.sleep(5)  # Wait a bit longer after an error
                
    # If we didn't get enough sentiments (due to errors), fill with NEUTRAL
    if len(all_sentiments) < len(tweets):
        all_sentiments.extend(['NEUTRAL'] * (len(tweets) - len(all_sentiments)))
            
    return all_sentiments

In [None]:
# Process a subset of tweets
tweets = sample_df[tweet_column].to_list()
sentiments = process_tweets(tweets, batch_size=4, max_workers=8)

# Add sentiments to the DataFrame
sample_df = sample_df.with_columns(pl.Series(name='sentiment', values=sentiments))

# Display the results
sample_df.select([tweet_column, 'sentiment']).head(10)

In [None]:
# Save the labeled data
output_path = "../data/labeled_stock_tweets.csv"
os.makedirs(os.path.dirname(output_path), exist_ok=True)
sample_df.write_csv(output_path)
print(f"\nSaved labeled data to {output_path}")

# Display some examples of each sentiment
print("\nExamples for each sentiment:")
for sentiment in ['STRONGLY_POSITIVE', 'POSITIVE', 'NEUTRAL', 'NEGATIVE', 'STRONGLY_NEGATIVE']:
    examples = sample_df.filter(pl.col("sentiment") == sentiment).sample(1, seed=42)
    if examples.shape[0] > 0:
        print(f"\n{sentiment}:\n{examples[0, tweet_column]}")