In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')


In [None]:
class FinbertBackbone(nn.Module):
    def __init__(self, modelName: str = "ProsusAI/finbert"):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(modelName)
        self.hiddenSize = self.encoder.config.hidden_size  # 768 

    def forward(self, input_ids, attention_mask):
        out = self.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        cls = out.last_hidden_state[:, 0]  # [CLS] token
        return cls  # [batch, hidden]
    
class BinaryHead(nn.Module):
    def __init__(self, inFeatures: int, pDrop: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(pDrop)
        self.fc = nn.Linear(inFeatures, 1)  # single logit

    def forward(self, x):
        x = self.dropout(x)
        logits = self.fc(x).squeeze(-1)    # [batch]
        return logits

class FinbertBinaryClf(nn.Module):
    def __init__(self, modelName: str = "ProsusAI/finbert", pDrop: float = 0.1):
        super().__init__()
        self.backbone = FinbertBackbone(modelName)
        self.head = BinaryHead(self.backbone.hiddenSize, pDrop)

    def forward(self, input_ids, attention_mask):
        feats = self.backbone(input_ids, attention_mask)
        logits = self.head(feats)
        return logits


In [25]:
def clean_text(text):
    """Clean text for sentiment analysis"""
    text = str(text)
    text = re.sub(r"http\S+", "", text)      
    text = re.sub(r"@\w+", "", text)        
    text = re.sub(r"^user:\s*", "", text, flags=re.IGNORECASE)  
    text = re.sub(r"^user\s*", "", text, flags=re.IGNORECASE)  
    text = re.sub(r"[\"]+", "", text)        
    text = re.sub(r"\s+", " ", text).strip() 
    return text

def lables_zero_one(y: int) -> int:
    """Convert sentiment labels from {-1, 1} to {0, 1}"""
    return 1 if int(y) == 1 else 0


## 4. Load Models and Tokenizers


In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

print("Loading stock classification model...")
stock_model_path = './results/checkpoint-5017/checkpoint-5017'  # Fixed path to nested directory
stock_tokenizer = AutoTokenizer.from_pretrained('vinai/bertweet-base')
stock_model = AutoModelForSequenceClassification.from_pretrained(stock_model_path)
stock_model.to(device)
stock_model.eval()

stock_tickers = ['AAPL', 'AMD', 'AMZN', 'BA', 'COST', 'DIS', 'GOOG', 'KO', 'META', 'MSFT', 'NFLX', 'NIO', 'Other', 'PG', 'PYPL', 'TSLA']
print(f"Stock tickers: {stock_tickers}")

# Load sentiment analysis model
print("Loading sentiment analysis model...")
sentiment_model = FinbertBinaryClf("ProsusAI/finbert", pDrop=0.1)
sentiment_tokenizer = AutoTokenizer.from_pretrained("ProsusAI/finbert")

checkpoint = torch.load('finbert_finetuned.pt', map_location=device)
sentiment_model.load_state_dict(checkpoint['model_state_dict'])
sentiment_model.to(device)
sentiment_model.eval()

print("Both models loaded successfully!")


Using device: cpu
Loading stock classification model...
Stock tickers: ['AAPL', 'AMD', 'AMZN', 'BA', 'COST', 'DIS', 'GOOG', 'KO', 'META', 'MSFT', 'NFLX', 'NIO', 'Other', 'PG', 'PYPL', 'TSLA']
Loading sentiment analysis model...
✓ Both models loaded successfully!


In [None]:
@torch.no_grad()
def predict_stocks(tweet_text, model, tokenizer, threshold=0.5):
    """Predict which stocks are mentioned in a tweet"""
    
    device = next(model.parameters()).device
    
    inputs = tokenizer(tweet_text, return_tensors='pt', truncation=True, max_length=128)
    
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    outputs = model(**inputs)
    probs = torch.sigmoid(outputs.logits)[0]
    
    probs = probs.cpu()
    predictions = (probs > threshold).int().numpy()
    
    predicted_stocks = [stock_tickers[i] for i, pred in enumerate(predictions) if pred == 1]
    
    return predicted_stocks, probs.numpy()

@torch.no_grad()
def predict_sentiment(text, model, tokenizer, max_len=128, threshold=0.5, confidence_threshold=0.7):
    """Predict sentiment for input text"""
    
    device = next(model.parameters()).device
    model.eval()
    
    text = clean_text(text)
    enc = tokenizer(
        text, 
        truncation=True, 
        padding="max_length", 
        max_length=max_len, 
        return_tensors="pt"
    )
    
    logits = model(enc["input_ids"].to(device), enc["attention_mask"].to(device))
    prob = torch.sigmoid(logits).cpu().numpy()[0]
    
    if prob >= threshold:
        label = 1
        confidence = prob
    else:
        label = 0
        confidence = 1 - prob
    
    if confidence < confidence_threshold:
        label = -1  # -1 represents "unsure"
    
    return label, prob


In [None]:
def analyze_financial_text(text):

    
    predicted_stocks, stock_probs = predict_stocks(
        text, stock_model, stock_tokenizer, threshold=0.7
    )
    
    sentiment_label, sentiment_prob = predict_sentiment(
        text, sentiment_model, sentiment_tokenizer, threshold=0.5, confidence_threshold=0.7
    )
    
    # Convert sentiment to readable format
    if sentiment_label == 1:
        sentiment = 'positive'
        sentiment_confidence = sentiment_prob
    elif sentiment_label == 0:
        sentiment = 'negative'
        sentiment_confidence = 1 - sentiment_prob
    else:  
        sentiment = 'unsure'
        sentiment_confidence = max(sentiment_prob, 1 - sentiment_prob)
    
    stock_probabilities = {ticker: float(prob) for ticker, prob in zip(stock_tickers, stock_probs)}
    
    return {
        'stocks': predicted_stocks,
        'sentiment': sentiment,
        'sentiment_confidence': float(sentiment_confidence),
        'stock_probabilities': stock_probabilities
    }


In [None]:
# Test examples
test_texts = [
    "Tesla and Apple crushing it! $TSLA $AAPL 🚀",
    "Microsoft stock is plummeting, terrible earnings report",
    "Amazon and Google both showing strong growth this quarter",
    "Netflix subscription numbers are declining rapidly",
    "Meta's new VR headset is revolutionary! $META",
    "Disney's streaming service is struggling with competition",
    "NVIDIA and AMD GPUs are in high demand for AI workloads",
    "PayPal's new features are disappointing users"
]

print("=" * 80)
print("FINANCIAL TEXT ANALYSIS RESULTS")
print("=" * 80)

for i, text in enumerate(test_texts, 1):
    print(f"\nExample {i}:")
    print(f"Text: {text}")
    
    result = analyze_financial_text(text)
    
    print(f"Predicted Stocks: {result['stocks']}")
    print(f"Sentiment: {result['sentiment'].upper()} (confidence: {result['sentiment_confidence']:.3f})")
    
    # Show top stock probabilities
    top_stocks = [stock for stock, prob in result['stock_probabilities'].items() if prob > 0.5]
    if top_stocks == []:
        top_stocks = ["Unsure"]
    #sorted(result['stock_probabilities'].items(), 
                      #key=lambda x: x[1], reverse=True)[:3]
    print(f"Top Stock Probabilities: {top_stocks}")
    print("-" * 60)


FINANCIAL TEXT ANALYSIS RESULTS

Example 1:
Text: Tesla and Apple crushing it! $TSLA $AAPL 🚀
Predicted Stocks: ['AAPL']
Sentiment: UNSURE (confidence: 0.661)
Top Stock Probabilities: ['AAPL', 'TSLA']
------------------------------------------------------------

Example 2:
Text: Microsoft stock is plummeting, terrible earnings report
Predicted Stocks: []
Sentiment: NEGATIVE (confidence: 0.920)
Top Stock Probabilities: ['Unsure']
------------------------------------------------------------

Example 3:
Text: Amazon and Google both showing strong growth this quarter
Predicted Stocks: ['AMZN']
Sentiment: POSITIVE (confidence: 0.940)
Top Stock Probabilities: ['AMZN']
------------------------------------------------------------

Example 4:
Text: Netflix subscription numbers are declining rapidly
Predicted Stocks: ['NFLX']
Sentiment: NEGATIVE (confidence: 0.914)
Top Stock Probabilities: ['NFLX']
------------------------------------------------------------

Example 5:
Text: Meta's new VR headse

In [None]:
def interactive_analysis():

    while True:
        text = input("\nEnter text: ").strip()
        
        if text.lower() in ['quit', 'exit', 'q']:
            print("Goodbye!")
            break
            
        if not text:
            print("Please enter some text.")
            continue
        
        try:
            result = analyze_financial_text(text)
            
            print(f"\nANALYSIS RESULTS:")
            print(f"Stocks: {', '.join(result['stocks']) if result['stocks'] else 'None detected'}")
            print(f"Sentiment: {result['sentiment'].upper()} ({result['sentiment_confidence']:.1%} confidence)")
            
            top_stocks = sorted(result['stock_probabilities'].items(), 
                              key=lambda x: x[1], reverse=True)[:3]
            print(f"Top Stock Probabilities:")
            for stock, prob in top_stocks:
                print(f"   {stock}: {prob:.1%}")
                
        except Exception as e:
            print(f"Error analyzing text: {e}")

# interactive_analysis()


In [None]:
def batch_analyze(texts):

    results = []
    
    for text in texts:
        result = analyze_financial_text(text)
        results.append({
            'text': text,
            'stocks': result['stocks'],
            'sentiment': result['sentiment'],
            'sentiment_confidence': result['sentiment_confidence']
        })
    
    return results

print("\n" + "=" * 60)
print("BATCH ANALYSIS EXAMPLE")
print("=" * 60)

batch_texts = [
    "Apple's new iPhone sales are breaking records!",
    "Tesla stock is crashing after the recall news",
    "Microsoft Azure cloud revenue is growing rapidly"
]

batch_results = batch_analyze(batch_texts)

for i, result in enumerate(batch_results, 1):
    print(f"\nText {i}: {result['text']}")
    print(f"  Stocks: {result['stocks']}")
    print(f"  Sentiment: {result['sentiment']} ({result['sentiment_confidence']:.1%})")



BATCH ANALYSIS EXAMPLE

Text 1: Apple's new iPhone sales are breaking records!
  Stocks: ['AAPL']
  Sentiment: positive (95.4%)

Text 2: Tesla stock is crashing after the recall news
  Stocks: ['TSLA']
  Sentiment: negative (88.9%)

Text 3: Microsoft Azure cloud revenue is growing rapidly
  Stocks: []
  Sentiment: positive (93.6%)


In [None]:
import pandas as pd
from datetime import datetime
import numpy as np
from collections import defaultdict, Counter

def analyze_sentiment_stock_correlation(target_date, combined_data_path='data/processed/filtered_tweets_with_stock_data.csv'):

    
    print("=" * 80)
    
    # Load combined dataset
    print("Loading combined dataset")
    df = pd.read_csv(combined_data_path)
    
    df['date_only'] = pd.to_datetime(df['date_only']).dt.date
    target_date_obj = pd.to_datetime(target_date).date()
    
    day_data = df[df['date_only'] == target_date_obj]
    
    print(f"Found {len(day_data)} tweet-stock records for {target_date}")
    
    if len(day_data) == 0:
        return {"error": f"No data found for {target_date}"}
    
    print(" Analyzing tweets with AI models...")
    
    stock_analysis = defaultdict(lambda: {
        'tweets': [],
        'sentiments': [],
        'sentiment_confidences': [],
        'positive_count': 0,
        'negative_count': 0,
        'unsure_count': 0,
        'total_tweets': 0,
        'daily_return': None,
        'price_data': None
    })
    
    for idx, row in day_data.iterrows():
        tweet_text = row['Tweet']
        stock_name = row['Stock Name']
        daily_return = row['daily_return']
        
        # Get price data for this stock
        price_data = {
            'open': row['Open'],
            'close': row['Close'],
            'high': row['High'],
            'low': row['Low'],
            'volume': row['Volume'],
            'daily_return': daily_return
        }
        
        try:
            result = analyze_financial_text(tweet_text)
            
            # Store tweet info
            tweet_info = {
                'text': tweet_text,
                'sentiment': result['sentiment'],
                'confidence': result['sentiment_confidence'],
                'mentioned_stocks': result['stocks']
            }
            
            # Update analysis for this specific stock
            if stock_name in result['stocks'] or stock_name in [s for s in result['stocks']]:
                stock_analysis[stock_name]['tweets'].append(tweet_info)
                stock_analysis[stock_name]['sentiments'].append(result['sentiment'])
                stock_analysis[stock_name]['sentiment_confidences'].append(result['sentiment_confidence'])
                stock_analysis[stock_name]['total_tweets'] += 1
                stock_analysis[stock_name]['daily_return'] = daily_return
                stock_analysis[stock_name]['price_data'] = price_data
                
                # Count sentiment types
                if result['sentiment'] == 'positive':
                    stock_analysis[stock_name]['positive_count'] += 1
                elif result['sentiment'] == 'negative':
                    stock_analysis[stock_name]['negative_count'] += 1
                else:  # unsure
                    stock_analysis[stock_name]['unsure_count'] += 1
            
        except Exception as e:
            print(f"Error analyzing tweet: {e}")
            continue
    
    print("\nSENTIMENT ANALYSIS RESULTS")
    print("=" * 80)
    
    results = {
        'date': target_date,
        'total_records_analyzed': len(day_data),
        'stocks_analyzed': {},
        'price_correlations': {}
    }
    
    for stock, analysis in stock_analysis.items():
        if analysis['total_tweets'] == 0:
            continue
            
        print(f"\n {stock} Analysis:")
        print(f"    Tweets mentioning {stock}: {analysis['total_tweets']}")
        print(f"    Positive: {analysis['positive_count']} ({analysis['positive_count']/analysis['total_tweets']*100:.1f}%)")
        print(f"    Negative: {analysis['negative_count']} ({analysis['negative_count']/analysis['total_tweets']*100:.1f}%)")
        print(f"    Unsure: {analysis['unsure_count']} ({analysis['unsure_count']/analysis['total_tweets']*100:.1f}%)")
        
        # Calculate overall sentiment
        if analysis['positive_count'] > analysis['negative_count']:
            overall_sentiment = 'positive'
            sentiment_strength = analysis['positive_count'] / analysis['total_tweets']
        elif analysis['negative_count'] > analysis['positive_count']:
            overall_sentiment = 'negative'
            sentiment_strength = analysis['negative_count'] / analysis['total_tweets']
        else:
            overall_sentiment = 'neutral'
            sentiment_strength = 0.5
        
        print(f"   Overall Sentiment: {overall_sentiment.upper()} (strength: {sentiment_strength:.2f})")
        
        # Get stock price data
        if analysis['price_data'] is not None:
            price_data = analysis['price_data']
            daily_return = analysis['daily_return']
            daily_return_pct = daily_return * 100
            
            print(f"    Stock Price: ${price_data['open']:.2f} → ${price_data['close']:.2f}")
            print(f"    Daily Return: {daily_return_pct:+.2f}%")
            
            if daily_return > 0:
                price_direction = 'up'
            elif daily_return < 0:
                price_direction = 'down'
            else:
                price_direction = 'flat'
            
            sentiment_price_match = (
                (overall_sentiment == 'positive' and price_direction == 'up') or
                (overall_sentiment == 'negative' and price_direction == 'down') or
                (overall_sentiment == 'neutral' and price_direction == 'flat')
            )
            
            correlation_status = " MATCH" if sentiment_price_match else "NO MATCH"
            print(f"   Correlation: {correlation_status}")
            
            # Store results
            results['stocks_analyzed'][stock] = {
                'tweet_count': analysis['total_tweets'],
                'sentiment_breakdown': {
                    'positive': analysis['positive_count'],
                    'negative': analysis['negative_count'],
                    'unsure': analysis['unsure_count']
                },
                'overall_sentiment': overall_sentiment,
                'sentiment_strength': sentiment_strength,
                'price_data': {
                    'open': float(price_data['open']),
                    'close': float(price_data['close']),
                    'high': float(price_data['high']),
                    'low': float(price_data['low']),
                    'volume': int(price_data['volume']),
                    'daily_return': float(daily_return),
                    'daily_return_pct': float(daily_return_pct),
                    'direction': price_direction
                },
                'correlation_match': sentiment_price_match
            }
            
            results['price_correlations'][stock] = {
                'sentiment': overall_sentiment,
                'price_direction': price_direction,
                'match': sentiment_price_match
            }
        else:
            print(f"   No price data found for {stock}")
            results['stocks_analyzed'][stock] = {
                'tweet_count': analysis['total_tweets'],
                'sentiment_breakdown': {
                    'positive': analysis['positive_count'],
                    'negative': analysis['negative_count'],
                    'unsure': analysis['unsure_count']
                },
                'overall_sentiment': overall_sentiment,
                'sentiment_strength': sentiment_strength,
                'price_data': None,
                'correlation_match': None
            }
    
    print(f"\nCORRELATION SUMMARY")
    print("=" * 80)
    
    total_stocks = len([s for s in results['stocks_analyzed'].keys() if results['stocks_analyzed'][s]['price_data'] is not None])
    matches = sum(1 for s in results['price_correlations'].values() if s['match'])
    
    if total_stocks > 0:
        correlation_rate = matches / total_stocks * 100
        print(f" Overall Correlation Rate: {correlation_rate:.1f}% ({matches}/{total_stocks} stocks)")
        
        if correlation_rate >= 70:
            print(" Strong correlation between sentiment and stock prices!")
        elif correlation_rate >= 50:
            print("Moderate correlation between sentiment and stock prices")
        else:
            print("Weak correlation between sentiment and stock prices")
    else:
        print("No stock price data available for correlation analysis")
    
    results['correlation_summary'] = {
        'total_stocks_with_prices': total_stocks,
        'matches': matches,
        'correlation_rate': correlation_rate if total_stocks > 0 else 0
    }
    
    return results


In [None]:
def analyze_multiple_dates(date_list, combined_data_path='data/processed/filtered_tweets_with_stock_data.csv'):

    
    print(f" Analyzing {len(date_list)} dates for sentiment-stock correlation")
    print("=" * 80)
    
    all_results = {}
    total_correlations = []
    
    for date in date_list:
        print(f"\n Analyzing {date}.")
        result = analyze_sentiment_stock_correlation(date, combined_data_path)
        all_results[date] = result
        
        if 'correlation_summary' in result:
            total_correlations.append(result['correlation_summary']['correlation_rate'])
    
    # Calculate overall statistics
    if total_correlations:
        avg_correlation = np.mean(total_correlations)
        print(f"\nOVERALL ANALYSIS SUMMARY")
        print("=" * 80)
        print(f"Dates analyzed: {len(date_list)}")
        print(f"Average correlation rate: {avg_correlation:.1f}%")
        print(f"Correlation range: {min(total_correlations):.1f}% - {max(total_correlations):.1f}%")
        
        if avg_correlation >= 70:
            print("Strong overall correlation between sentiment and stock prices!")
        elif avg_correlation >= 50:
            print("Moderate overall correlation between sentiment and stock prices")
        else:
            print("Weak overall correlation between sentiment and stock prices")
    
    return all_results

def get_available_dates(combined_data_path='data/processed/filtered_tweets_with_stock_data.csv'):

    
    df = pd.read_csv(combined_data_path)
    df['date_only'] = pd.to_datetime(df['date_only']).dt.date
    
    available_dates = sorted(df['date_only'].unique())
    
    # Get some statistics
    total_records = len(df)
    unique_stocks = df['Stock Name'].nunique()
    
    return {
        'available_dates': [d.strftime('%Y-%m-%d') for d in available_dates],
        'total_dates': len(available_dates),
        'total_records': total_records,
        'unique_stocks': unique_stocks,
        'date_range': {
            'start': available_dates[0].strftime('%Y-%m-%d'),
            'end': available_dates[-1].strftime('%Y-%m-%d')
        }
    }


In [None]:
# Check available dates in the combined dataset
print(" CHECKING AVAILABLE DATES")
print("=" * 60)

available_dates = get_available_dates()
print(f" Total dates: {available_dates['total_dates']}")
print(f"Total records: {available_dates['total_records']:,}")
print(f"Unique stocks: {available_dates['unique_stocks']}")
print(f"Date range: {available_dates['date_range']['start']} to {available_dates['date_range']['end']}")

# Show some example dates
if available_dates['available_dates']:
    print(f"\nExample available dates:")
    for date in available_dates['available_dates'][:5]:
        print(f"   {date}")
    if len(available_dates['available_dates']) > 5:
        print(f"   ... and {len(available_dates['available_dates']) - 5} more dates")
else:
    print("No dates found in the combined dataset")


In [None]:
# Example: Analyze a specific date
if available_dates['available_dates']:
    # Use the first available date as an example
    example_date = available_dates['available_dates'][0]
    print(f"\nEXAMPLE ANALYSIS FOR {example_date}")
    print("=" * 80)
    
    # Run the analysis
    result = analyze_sentiment_stock_correlation(example_date)
    
    if 'error' not in result:
        print(f"\nSUMMARY FOR {example_date}:")
        print(f" Total records analyzed: {result['total_records_analyzed']}")
        print(f"Stocks analyzed: {len(result['stocks_analyzed'])}")
        if 'correlation_summary' in result:
            print(f"Correlation rate: {result['correlation_summary']['correlation_rate']:.1f}%")
    else:
        print(f"Error: {result['error']}")
else:
    print("No dates available for analysis")


In [None]:
if available_dates['available_dates'] and len(available_dates['available_dates']) >= 3:
    print(f"\nMULTIPLE DATE ANALYSIS EXAMPLE")
    print("=" * 80)
    
    dates_to_analyze = available_dates['available_dates'][:3]
    print(f"Analyzing dates: {dates_to_analyze}")
    
    multi_results = analyze_multiple_dates(dates_to_analyze)
    
    print(f"\nMULTI-DATE SUMMARY:")
    for date, result in multi_results.items():
        if 'correlation_summary' in result:
            print(f"   {date}: {result['correlation_summary']['correlation_rate']:.1f}% correlation")
        else:
            print(f"   {date}: No data available")
else:
    print("Not enough dates available for multi-date analysis")


In [None]:
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from datetime import datetime, timedelta

def plot_stock_correlation(stock_name, start_date=None, num_days=10, combined_data_path='data/processed/filtered_tweets_with_stock_data.csv'):

    
    df = pd.read_csv(combined_data_path)
    df['date_only'] = pd.to_datetime(df['date_only']).dt.date
    
    all_dates = sorted(df['date_only'].unique())
    
    # Determine which dates to analyze
    if start_date is None:
        # Use most recent dates
        if len(all_dates) < num_days:
            print(f" Only {len(all_dates)} days available, using all available dates")
            dates_to_analyze = all_dates
        else:
            dates_to_analyze = all_dates[-num_days:]  # Get the most recent dates
    else:
        # Use specified start date
        start_date_obj = pd.to_datetime(start_date).date()
        
        # Find the start date in available dates
        try:
            start_idx = all_dates.index(start_date_obj)
        except ValueError:
            print(f" Start date {start_date} not found in dataset")
            print(f"Available date range: {all_dates[0]} to {all_dates[-1]}")
            return None
        
        # Get the specified number of days starting from start_date
        end_idx = min(start_idx + num_days, len(all_dates))
        dates_to_analyze = all_dates[start_idx:end_idx]
        
        if len(dates_to_analyze) < num_days:
            print(f" Only {len(dates_to_analyze)} days available from {start_date}")
    
    print(f"Analyzing period: {dates_to_analyze[0]} to {dates_to_analyze[-1]}")
    
    print(f"Analyzing {stock_name} for {len(dates_to_analyze)} days")
    print(f"Using AI classifier to identify tweets mentioning {stock_name}")
    
    # Analyze each date
    correlation_data = []
    dates = []
    daily_returns_data = []
    daily_sentiment_data = []
    
    for date in dates_to_analyze:
        day_data = df[df['date_only'] == date]
        
        if len(day_data) == 0:
            continue
        
        sentiments = []
        daily_returns = []
        tweets_analyzed = 0
        tweets_with_target_stock = 0
        
        for idx, row in day_data.iterrows():
            tweet_text = row['Tweet']
            daily_return = row['daily_return']
            
            try:
                predicted_stocks, stock_probs = predict_stocks(tweet_text, stock_model, stock_tokenizer, threshold=0.5)
                
                if stock_name in predicted_stocks:
                    tweets_with_target_stock += 1
                    
                    sentiment_label, sentiment_prob = predict_sentiment(
                        tweet_text, sentiment_model, sentiment_tokenizer, 
                        threshold=0.5, confidence_threshold=0.7
                    )
                    
                    if sentiment_label == 1:
                        sentiment_value = 1
                    elif sentiment_label == 0:
                        sentiment_value = -1
                    else:  # unsure (sentiment_label == -1)
                        sentiment_value = 0
                    
                    sentiments.append(sentiment_value)
                    daily_returns.append(daily_return)
                
                tweets_analyzed += 1
                
            except Exception as e:
                continue
        
        print(f"{date}: Analyzed {tweets_analyzed} tweets, found {tweets_with_target_stock} mentioning {stock_name}")
        
        if sentiments and daily_returns:
            # Calculate correlation for this day
            if len(sentiments) > 1:
                correlation = np.corrcoef(sentiments, daily_returns)[0, 1]
            else:
                correlation = 0
            
            # Calculate average daily sentiment
            avg_sentiment = np.mean(sentiments)
            
            correlation_data.append(correlation)
            dates.append(date)
            daily_sentiment_data.append(avg_sentiment)
            
            # Get daily return for this stock on this date
            stock_data_for_date = df[(df['date_only'] == date) & (df['Stock Name'] == stock_name)]
            if len(stock_data_for_date) > 0:
                daily_returns_data.append(stock_data_for_date['daily_return'].iloc[0] * 100)  # Convert to percentage
            else:
                daily_returns_data.append(0)
    
    if not correlation_data:
        print(f"No correlation data available for {stock_name}")
        print(f" This could mean:")
        print(f" - No tweets were detected mentioning {stock_name}")
        print(f" - No sentiment data was available")
        return None
    
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(12, 15))
    
    date_objects = [datetime.combine(d, datetime.min.time()) for d in dates]
    
    ax1.plot(date_objects, correlation_data, marker='o', linewidth=2, markersize=6, color='blue')
    ax1.axhline(y=0, color='red', linestyle='--', alpha=0.7, label='No Correlation')
    ax1.set_title(f'{stock_name} Sentiment-Price Correlation Over Time (AI-Detected Mentions)', fontsize=14, fontweight='bold')
    ax1.set_ylabel('Correlation Coefficient', fontsize=12)
    ax1.grid(True, alpha=0.3)
    ax1.legend()
    
    # Format x-axis
    ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
    ax1.xaxis.set_major_locator(mdates.DayLocator(interval=1))
    plt.setp(ax1.xaxis.get_majorticklabels(), rotation=45)
    
    ax2.plot(date_objects, daily_sentiment_data, marker='o', linewidth=2, markersize=6, color='purple', label='Daily Sentiment')
    ax2.axhline(y=0, color='red', linestyle='--', alpha=0.7, label='Neutral Sentiment')
    ax2.axhline(y=0.5, color='green', linestyle=':', alpha=0.7, label='Positive Threshold')
    ax2.axhline(y=-0.5, color='orange', linestyle=':', alpha=0.7, label='Negative Threshold')
    ax2.set_title(f'{stock_name} Daily Sentiment Over Time', fontsize=14, fontweight='bold')
    ax2.set_ylabel('Average Sentiment', fontsize=12)
    ax2.grid(True, alpha=0.3)
    ax2.legend()
    ax2.set_ylim(-1.1, 1.1)
    
    # Format x-axis
    ax2.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
    ax2.xaxis.set_major_locator(mdates.DayLocator(interval=1))
    plt.setp(ax2.xaxis.get_majorticklabels(), rotation=45)
    
    ax3.bar(date_objects, daily_returns_data, alpha=0.7, color='green', label='Daily Return %')
    ax3.axhline(y=0, color='red', linestyle='-', alpha=0.5)
    ax3.set_title(f'{stock_name} Daily Returns', fontsize=14, fontweight='bold')
    ax3.set_ylabel('Daily Return (%)', fontsize=12)
    ax3.set_xlabel('Date', fontsize=12)
    ax3.grid(True, alpha=0.3)
    ax3.legend()
    
    ax3.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
    ax3.xaxis.set_major_locator(mdates.DayLocator(interval=1))
    plt.setp(ax3.xaxis.get_majorticklabels(), rotation=45)
    
    plt.tight_layout()
    
    import os
    os.makedirs('correlation_plots', exist_ok=True)
    
    start_str = dates[0].strftime('%Y%m%d')
    end_str = dates[-1].strftime('%Y%m%d')
    filename = f'correlation_plots/{stock_name}_correlation_{start_str}_to_{end_str}.png'
    
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    print(f"Plot saved as: {filename}")
    
    plt.show()
    
    # Calculate and display statistics
    avg_correlation = np.mean(correlation_data)
    max_correlation = np.max(correlation_data)
    min_correlation = np.min(correlation_data)
    
    print(f"\n{stock_name} Correlation Statistics:")
    print(f" Average Correlation: {avg_correlation:.3f}")
    print(f" Max Correlation: {max_correlation:.3f}")
    print(f" Min Correlation: {min_correlation:.3f}")
    print(f" Days Analyzed: {len(dates)}")
    
    # Interpretation
    if avg_correlation > 0.3:
        print(f"Strong positive correlation between sentiment and price")
    elif avg_correlation > 0.1:
        print(f" Moderate positive correlation between sentiment and price")
    elif avg_correlation > -0.1:
        print(f"  Weak correlation between sentiment and price")
    elif avg_correlation > -0.3:
        print(f"Moderate negative correlation between sentiment and price")
    else:
        print(f"Strong negative correlation between sentiment and price")
    
    return {
        'stock': stock_name,
        'dates': dates,
        'correlations': correlation_data,
        'daily_returns': daily_returns_data,
        'daily_sentiment': daily_sentiment_data,
        'avg_correlation': avg_correlation,
        'max_correlation': max_correlation,
        'min_correlation': min_correlation
    }


In [None]:


tesla_results = plot_stock_correlation('TSLA',start_date='2022-09-19', num_days=5)


In [None]:
print("\nAPPLE CORRELATION ANALYSIS - SPECIFIC DATE RANGE")
print("=" * 60)

apple_results = plot_stock_correlation('AAPL', start_date='2022-09-29', num_days=7)


In [None]:
# Example: Compare multiple stocks over same period
print("\nMULTI-STOCK COMPARISON - SAME PERIOD")
print("=" * 60)

stocks_to_analyze = ['TSLA', 'AAPL', 'MSFT', 'GOOG', 'AMZN']
comparison_results = {}
start_date = '2022-09-29'  # Same start date for all stocks

for stock in stocks_to_analyze:
    print(f"\nAnalyzing {stock} from {start_date}...")
    results = plot_stock_correlation(stock, start_date=start_date, num_days=7)
    if results:
        comparison_results[stock] = results['avg_correlation']

if comparison_results:
    print(f"\nCORRELATION COMPARISON SUMMARY:")
    print("=" * 60)
    
    sorted_stocks = sorted(comparison_results.items(), key=lambda x: x[1], reverse=True)
    
    for stock, correlation in sorted_stocks:
        print(f"   {stock}: {correlation:.3f}")
    
    best_stock = sorted_stocks[0]
    worst_stock = sorted_stocks[-1]
    
    print(f"\n Best correlation: {best_stock[0]} ({best_stock[1]:.3f})")
    print(f"Worst correlation: {worst_stock[0]} ({worst_stock[1]:.3f})")
