# Imports and Installation

In [5]:
# Enhanced YouTube Trend Predictor with Ensemble Forecasting
!pip install faiss-cpu prophet

import pandas as pd
import numpy as np
import re
import warnings
from datetime import datetime, timedelta
from typing import List, Dict, Tuple, Optional
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict, Counter
import json

# NLP and ML libraries
from sentence_transformers import SentenceTransformer
import faiss
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error
from sklearn.ensemble import RandomForestRegressor
import nltk
from nltk.sentiment import SentimentIntensityAnalyzer
from textblob import TextBlob

# Deep Learning for time series
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout, GRU
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping

# Time series analysis
from statsmodels.tsa.arima.model import ARIMA
from statsmodels.tsa.seasonal import seasonal_decompose
from statsmodels.tsa.stattools import adfuller
from prophet import Prophet

warnings.filterwarnings('ignore')



2025-09-06 17:08:08.145285: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1757178488.517623 1029851 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1757178488.624655 1029851 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


# Basic TrendAI Model

In [6]:
class TrendAI:
    """
    A comprehensive trend prediction model for YouTube beauty industry data.
    Enhanced to integrate video metrics and tags with higher weighting.
    """
    
    def __init__(self, model_name='all-MiniLM-L6-v2'):
        """
        Initialize the trend predictor with SBERT model and other components.
        
        Args:
            model_name (str): Name of the SentenceTransformer model to use
        """
        print("Initializing Enhanced YouTube Trend Predictor with Tags Integration...")
        self.sbert_model = SentenceTransformer(model_name)
        self.kmeans_index = None  # FAISS K-means index
        self.sentiment_analyzer = SentimentIntensityAnalyzer()
        self.comments_df = None
        self.videos_df = None
        self.embeddings = None
        self.clusters = None
        self.trend_data = None
        self.video_trend_data = None
        self.combined_trend_data = None
        self.lstm_model = None
        self.scaler = StandardScaler()
        
        # Enhanced weighting factors including tags
        self.video_weight = 0.6   # 60% weight for video metrics
        self.comment_weight = 0.25  # 25% weight for comment metrics
        self.tag_weight = 0.15    # 15% weight for tag relevance
        
        # Tag processing attributes
        self.tag_trends = None
        self.tag_clusters = None
        self.popular_tags = None

        self.generational_analyzer = EnhancedGenerationalLanguageAnalyzer
        self.generational_clusters = None
        self.generational_trends = None
        
        # Download required NLTK data
        try:
            nltk.data.find('vader_lexicon')
        except LookupError:
            nltk.download('vader_lexicon')

    def load_data(self, comment_files: List[str], video_file: str) -> None:
        """
        Load and combine comment and video data from CSV files.
        
        Args:
            comment_files (List[str]): List of comment CSV file paths
            video_file (str): Path to video CSV file
        """
        print("Loading data...")
        
        # Load and combine comment files
        comment_dfs = []
        for file in comment_files:
            try:
                df = pd.read_csv(file)
                comment_dfs.append(df)
                print(f"Loaded {len(df)} comments from {file}")
            except Exception as e:
                print(f"Error loading {file}: {e}")
        
        if comment_dfs:
            self.comments_df = pd.concat(comment_dfs, ignore_index=True)
            print(f"Total comments loaded: {len(self.comments_df)}")
        
        # Load video data
        try:
            self.videos_df = pd.read_csv(video_file)
            print(f"Loaded {len(self.videos_df)} videos from {video_file}")
            
            # Log video data structure for debugging
            print("Video data columns:", self.videos_df.columns.tolist())
            print("Sample video data:")
            print(self.videos_df.head())
            
            # Check for tags column
            if 'tags' in self.videos_df.columns:
                print("Tags column found - will integrate tags into analysis")
                tag_sample = self.videos_df['tags'].dropna().head(3)
                print("Sample tags:", tag_sample.tolist())
            else:
                print("Warning: No 'tags' column found in video data")
                
        except Exception as e:
            print(f"Error loading video file: {e}")

    def preprocess_data(self) -> None:
        """
        Clean and preprocess the loaded data including tag processing.
        Enhanced to handle video tags properly.
        """
        print("Preprocessing data with tag integration...")
        
        if self.comments_df is not None:
            # Convert timestamp columns
            self.comments_df['publishedAt'] = pd.to_datetime(self.comments_df['publishedAt'])
            self.comments_df['updatedAt'] = pd.to_datetime(self.comments_df['updatedAt'])
            
            # Clean text data
            self.comments_df['cleaned_text'] = self.comments_df['textOriginal'].apply(self._clean_text)
            
            # Remove empty or very short comments
            self.comments_df = self.comments_df[
                (self.comments_df['cleaned_text'].str.len() > 10) &
                (self.comments_df['cleaned_text'].notna())
            ].reset_index(drop=True)
            
            # Add time-based features
            self.comments_df['date'] = self.comments_df['publishedAt'].dt.date
            self.comments_df['hour'] = self.comments_df['publishedAt'].dt.hour
            self.comments_df['day_of_week'] = self.comments_df['publishedAt'].dt.dayofweek
            
            print(f"Comments after preprocessing: {len(self.comments_df)}")
        
        if self.videos_df is not None:
            self.videos_df['publishedAt'] = pd.to_datetime(self.videos_df['publishedAt'])
            self.videos_df['date'] = self.videos_df['publishedAt'].dt.date
            
            # Clean video descriptions and titles
            self.videos_df['cleaned_title'] = self.videos_df['title'].apply(self._clean_text)
            self.videos_df['cleaned_description'] = self.videos_df['description'].apply(self._clean_text)
            
            # Enhanced tag processing
            if 'tags' in self.videos_df.columns:
                self.videos_df['processed_tags'] = self.videos_df['tags'].apply(self._process_tags)
                self.videos_df['tag_count'] = self.videos_df['processed_tags'].apply(len)
                self.videos_df['tag_text'] = self.videos_df['processed_tags'].apply(
                    lambda x: ' '.join(x) if x else ''
                )
                
                # Analyze tag popularity
                self._analyze_tag_popularity()
                
                # Calculate tag relevance scores
                self._calculate_tag_relevance_scores()
            else:
                self.videos_df['processed_tags'] = [[] for _ in range(len(self.videos_df))]
                self.videos_df['tag_count'] = 0
                self.videos_df['tag_text'] = ''
                self.videos_df['tag_relevance_score'] = 0
            
            # Enhanced video metrics preprocessing
            # Convert string numbers to integers if needed
            numeric_columns = ['viewCount', 'likeCount', 'commentCount']
            for col in numeric_columns:
                if col in self.videos_df.columns:
                    self.videos_df[col] = pd.to_numeric(self.videos_df[col], errors='coerce').fillna(0)
            
            # Calculate video performance metrics
            self.videos_df['engagement_rate'] = (
                (self.videos_df.get('likeCount', 0) + self.videos_df.get('commentCount', 0)) / 
                np.maximum(self.videos_df.get('viewCount', 1), 1)
            )
            
            # Enhanced video trending score including tags
            max_views = self.videos_df.get('viewCount', pd.Series([1])).max()
            max_likes = self.videos_df.get('likeCount', pd.Series([1])).max()
            
            if max_views > 0 and max_likes > 0:
                self.videos_df['view_score'] = self.videos_df.get('viewCount', 0) / max_views
                self.videos_df['like_score'] = self.videos_df.get('likeCount', 0) / max_likes
                
                # Recency score (more recent videos get higher scores)
                current_time = pd.Timestamp.now().tz_localize(None)
                published_times = pd.to_datetime(self.videos_df['publishedAt']).dt.tz_localize(None)
                days_since_publish = (current_time - published_times).dt.days
                max_days = days_since_publish.max() if len(days_since_publish) > 0 else 1
                self.videos_df['recency_score'] = 1 - (days_since_publish / max(max_days, 1))
                
                # Enhanced trending score including tag relevance
                self.videos_df['trending_score'] = (
                    0.35 * self.videos_df['view_score'] +
                    0.35 * self.videos_df['like_score'] +
                    0.15 * self.videos_df['recency_score'] +
                    0.15 * self.videos_df.get('tag_relevance_score', 0)
                )
            else:
                self.videos_df['trending_score'] = 0
            
            print(f"Videos after preprocessing: {len(self.videos_df)}")
            print(f"Video trending scores range: {self.videos_df['trending_score'].min():.3f} - {self.videos_df['trending_score'].max():.3f}")
            
            if 'tags' in self.videos_df.columns:
                print(f"Average tags per video: {self.videos_df['tag_count'].mean():.1f}")
                print(f"Videos with tags: {(self.videos_df['tag_count'] > 0).sum()}/{len(self.videos_df)}")

    def _process_tags(self, tags) -> List[str]:
        """
        Process and clean video tags.
        
        Args:
            tags: Raw tags data (could be string, list, or None)
            
        Returns:
            List[str]: Processed list of tags
        """
        if pd.isna(tags) or tags == '' or tags is None:
            return []
        
        if isinstance(tags, str):
            # Handle different tag formats
            # Common separators: comma, semicolon, pipe
            if ',' in tags:
                tag_list = tags.split(',')
            elif ';' in tags:
                tag_list = tags.split(';')
            elif '|' in tags:
                tag_list = tags.split('|')
            else:
                # If no separator, treat as single tag or space-separated
                tag_list = [tags] if ' ' not in tags else tags.split()
        elif isinstance(tags, list):
            tag_list = tags
        else:
            # Convert to string and process
            tag_list = str(tags).split(',')
        
        # Clean and normalize tags
        processed_tags = []
        for tag in tag_list:
            if isinstance(tag, str):
                # Remove quotes and extra whitespace
                clean_tag = tag.strip().strip('"\'').lower()
                # Remove special characters but keep spaces and hyphens
                clean_tag = re.sub(r'[^\w\s\-]', '', clean_tag)
                if clean_tag and len(clean_tag) > 1:  # Keep tags with more than 1 character
                    processed_tags.append(clean_tag)
        
        return processed_tags

    def _analyze_tag_popularity(self) -> None:
        """
        Analyze tag popularity and trends over time.
        """
        print("Analyzing tag popularity and trends...")
        
        # Collect all tags with their video metadata
        all_tags = []
        for _, video in self.videos_df.iterrows():
            for tag in video['processed_tags']:
                all_tags.append({
                    'tag': tag,
                    'videoId': video['videoId'],
                    'date': video['date'],
                    'viewCount': video.get('viewCount', 0),
                    'likeCount': video.get('likeCount', 0),
                    'publishedAt': video['publishedAt']
                })
        
        if all_tags:
            tag_df = pd.DataFrame(all_tags)
            
            # Calculate tag popularity metrics
            self.popular_tags = tag_df.groupby('tag').agg({
                'videoId': 'count',     # frequency
                'viewCount': 'sum',     # total views
                'likeCount': 'sum'      # total likes
            }).rename(columns={
                'videoId': 'frequency',
                'viewCount': 'total_views',
                'likeCount': 'total_likes'
            })
            
            # Calculate tag trending score
            self.popular_tags['avg_views_per_video'] = (
                self.popular_tags['total_views'] / self.popular_tags['frequency']
            )
            self.popular_tags['avg_likes_per_video'] = (
                self.popular_tags['total_likes'] / self.popular_tags['frequency']
            )
            
            # Normalize scores
            max_freq = self.popular_tags['frequency'].max()
            max_views = self.popular_tags['total_views'].max()
            max_likes = self.popular_tags['total_likes'].max()
            
            if max_freq > 0 and max_views > 0 and max_likes > 0:
                self.popular_tags['frequency_score'] = self.popular_tags['frequency'] / max_freq
                self.popular_tags['views_score'] = self.popular_tags['total_views'] / max_views
                self.popular_tags['likes_score'] = self.popular_tags['total_likes'] / max_likes
                
                # Combined tag popularity score
                self.popular_tags['tag_popularity_score'] = (
                    0.4 * self.popular_tags['frequency_score'] +
                    0.4 * self.popular_tags['views_score'] +
                    0.2 * self.popular_tags['likes_score']
                )
            else:
                self.popular_tags['tag_popularity_score'] = 0
            
            # Sort by popularity
            self.popular_tags = self.popular_tags.sort_values('tag_popularity_score', ascending=False)
            
            print(f"Analyzed {len(self.popular_tags)} unique tags")
            print("Top 10 most popular tags:")
            for tag, data in self.popular_tags.head(10).iterrows():
                print(f"  {tag}: {data['frequency']} videos, {data['total_views']:,} total views")
        else:
            self.popular_tags = pd.DataFrame()

    def _calculate_tag_relevance_scores(self) -> None:
        """
        Calculate relevance scores for videos based on their tags.
        """
        if self.popular_tags is not None and not self.popular_tags.empty:
            tag_scores = self.popular_tags['tag_popularity_score'].to_dict()
            
            def calculate_video_tag_score(tags):
                if not tags:
                    return 0
                scores = [tag_scores.get(tag, 0) for tag in tags]
                return np.mean(scores) if scores else 0
            
            self.videos_df['tag_relevance_score'] = self.videos_df['processed_tags'].apply(
                calculate_video_tag_score
            )
        else:
            self.videos_df['tag_relevance_score'] = 0

    def _clean_text(self, text: str) -> str:
        """
        Clean text while preserving emojis and meaningful slang.
        
        Args:
            text (str): Raw text to clean
            
        Returns:
            str: Cleaned text
        """
        if pd.isna(text):
            return ""
        
        # Convert to string if not already
        text = str(text)
        
        # Remove URLs but keep the text structure
        text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
        
        # Remove excessive punctuation but keep some for sentiment
        text = re.sub(r'[.]{3,}', '...', text)
        text = re.sub(r'[!]{2,}', '!!', text)
        text = re.sub(r'[?]{2,}', '??', text)
        
        # Remove excessive whitespace
        text = re.sub(r'\s+', ' ', text)
        
        # Keep emojis and basic punctuation
        # Remove only clearly problematic characters
        text = re.sub(r'[^\w\s\U0001F600-\U0001F64F\U0001F300-\U0001F5FF\U0001F680-\U0001F6FF\U0001F1E0-\U0001F1FF,.!?;:\'"()-]', '', text)
        
        return text.strip()

    def generate_embeddings(self, text_column: str = 'cleaned_text') -> np.ndarray:
        """
        Generate SBERT embeddings for the comment text.
        Enhanced to include video titles, descriptions, and tags.
        
        Args:
            text_column (str): Column name containing text to embed
            
        Returns:
            np.ndarray: Matrix of embeddings
        """
        print("Generating SBERT embeddings with tag integration...")
        
        if self.comments_df is None:
            raise ValueError("Comments data not loaded")
        
        texts = self.comments_df[text_column].tolist()
        self.embeddings = self.sbert_model.encode(texts, show_progress_bar=True)
        
        print(f"Generated embeddings shape: {self.embeddings.shape}")
        return self.embeddings

    def perform_clustering(self, n_clusters: int = 20, niter: int = 50, verbose: bool = True) -> np.ndarray:
        """
        Perform FAISS K-Means clustering on the embeddings.
        
        Args:
            n_clusters (int): Number of clusters to create
            niter (int): Number of iterations for K-means
            verbose (bool): Whether to print progress
            
        Returns:
            np.ndarray: Cluster labels
        """
        print(f"Performing FAISS K-Means clustering with {n_clusters} clusters...")
        
        if self.embeddings is None:
            raise ValueError("Embeddings not generated")
        
        # Ensure embeddings are in the right format for FAISS (float32)
        embeddings_float32 = self.embeddings.astype(np.float32)
        
        # Get embedding dimension
        d = embeddings_float32.shape[1]
        
        # Initialize FAISS K-means
        self.kmeans_index = faiss.Kmeans(
            d=d,
            k=n_clusters,
            niter=niter,
            verbose=verbose,
            spherical=False,  # use Euclidean distance (not cosine)
            gpu=False  # Set to False for compatibility
        )
        
        # Train the K-means model
        print("Training K-means...")
        self.kmeans_index.train(embeddings_float32)
        
        # Get cluster assignments
        _, cluster_assignments = self.kmeans_index.index.search(embeddings_float32, 1)
        self.clusters = cluster_assignments.flatten()
        
        # Add cluster labels to dataframe
        self.comments_df['cluster'] = self.clusters
        
        n_clusters_found = len(set(self.clusters))
        print(f"Found {n_clusters_found} clusters")
        print(f"Cluster distribution:")
        cluster_counts = pd.Series(self.clusters).value_counts().sort_index()
        for cluster_id, count in cluster_counts.items():
            print(f"  Cluster {cluster_id}: {count} comments")
        
        return self.clusters

    def map_videos_to_clusters(self) -> None:
        """
        Map videos to comment clusters based on content similarity including tags.
        This creates the link between video performance and comment topics.
        """
        print("Mapping videos to comment clusters (including tags)...")
        
        if self.videos_df is None or self.comments_df is None:
            raise ValueError("Both video and comment data must be loaded")
        
        # Generate embeddings for video content including tags
        video_texts = []
        for _, video in self.videos_df.iterrows():
            # Combine title, description, and tags for better topic matching
            # Give tags more weight by repeating them
            tag_text = ' '.join(video.get('processed_tags', [])) * 2  # Double weight for tags
            combined_text = f"{video.get('cleaned_title', '')} {video.get('cleaned_description', '')} {tag_text}"
            video_texts.append(combined_text if combined_text.strip() else video.get('title', ''))
        
        if not video_texts:
            print("No video text found for clustering")
            return
        
        # Generate embeddings for videos
        video_embeddings = self.sbert_model.encode(video_texts, show_progress_bar=True)
        
        # Find closest cluster for each video
        video_embeddings_float32 = video_embeddings.astype(np.float32)
        _, video_cluster_assignments = self.kmeans_index.index.search(video_embeddings_float32, 1)
        
        # Add cluster assignments to videos dataframe
        self.videos_df['cluster'] = video_cluster_assignments.flatten()
        
        print("Video-to-cluster mapping completed")
        cluster_video_counts = self.videos_df['cluster'].value_counts().sort_index()
        print("Videos per cluster:")
        for cluster_id, count in cluster_video_counts.items():
            print(f"  Cluster {cluster_id}: {count} videos")

    def analyze_sentiment(self) -> None:
        """
        Perform sentiment analysis on comments.
        """
        print("Analyzing sentiment...")
        
        def get_sentiment_scores(text):
            """Get sentiment scores using VADER."""
            scores = self.sentiment_analyzer.polarity_scores(text)
            return scores['compound'], scores['pos'], scores['neu'], scores['neg']
        
        # Apply sentiment analysis
        sentiment_data = self.comments_df['cleaned_text'].apply(get_sentiment_scores)
        sentiment_df = pd.DataFrame(sentiment_data.tolist(), 
                                  columns=['compound', 'positive', 'neutral', 'negative'])
        
        # Add sentiment columns to main dataframe
        self.comments_df = pd.concat([self.comments_df, sentiment_df], axis=1)
        
        # Create sentiment categories
        self.comments_df['sentiment_label'] = pd.cut(
            self.comments_df['compound'],
            bins=[-1, -0.05, 0.05, 1],
            labels=['negative', 'neutral', 'positive']
        )
        
        print("Sentiment analysis completed")

    def extract_cluster_topics(self, top_n_words: int = 10) -> Dict[int, List[str]]:
        """
        Extract representative topics for each cluster including tags.
        Enhanced to heavily weight video titles and tags for better topic identification.
        
        Args:
            top_n_words (int): Number of top words to extract per cluster
            
        Returns:
            Dict[int, List[str]]: Dictionary mapping cluster IDs to top words and tags
        """
        print("Extracting cluster topics with tag integration...")
        
        cluster_topics = {}
        cluster_tags = {}
        
        for cluster_id in sorted(set(self.clusters)):
            # Get comment texts for this cluster
            cluster_comments = self.comments_df[
                self.comments_df['cluster'] == cluster_id
            ]['cleaned_text'].tolist()
            
            # Get video titles and tags for this cluster
            cluster_videos = self.videos_df[
                self.videos_df['cluster'] == cluster_id
            ] if 'cluster' in self.videos_df.columns else pd.DataFrame()
            
            video_titles = cluster_videos['cleaned_title'].tolist() if not cluster_videos.empty else []
            
            # Collect tags for this cluster
            cluster_video_tags = []
            if not cluster_videos.empty:
                for _, video in cluster_videos.iterrows():
                    cluster_video_tags.extend(video.get('processed_tags', []))
            
            # Analyze tag frequency for this cluster
            if cluster_video_tags:
                tag_counter = Counter(cluster_video_tags)
                top_cluster_tags = [tag for tag, count in tag_counter.most_common(5)]
                cluster_tags[cluster_id] = top_cluster_tags
            else:
                cluster_tags[cluster_id] = []
            
            # Combine texts with heavy weighting for titles and tags
            all_texts = (
                cluster_comments + 
                video_titles * 4 +      # 4x weight for video titles
                cluster_video_tags * 3  # 3x weight for tags
            )
            
            all_text = ' '.join(all_texts).lower()
            words = re.findall(r'\b[a-zA-Z]{3,}\b', all_text)
            
            # Enhanced stop words list
            stop_words = {
                'the', 'and', 'for', 'are', 'but', 'not', 'you', 'all', 'can', 'had', 'her', 'was',
                'one', 'our', 'out', 'day', 'get', 'has', 'him', 'his', 'how', 'its', 'may', 'new',
                'now', 'old', 'see', 'two', 'who', 'boy', 'did', 'man', 'way', 'she', 'too', 'any',
                'use', 'your', 'here', 'this', 'that', 'with', 'have', 'from', 'they', 'know', 'want',
                'been', 'good', 'much', 'some', 'time', 'very', 'when', 'come', 'could', 'like', 'will',
                'said', 'would', 'make', 'just', 'into', 'over', 'think', 'also', 'back', 'after',
                'first', 'well', 'work', 'life', 'only', 'look', 'year', 'more', 'where', 'what',
                'than', 'love', 'really', 'great', 'video', 'youtube', 'channel', 'subscribe', 'comment',
                'please', 'thank', 'thanks', 'watch', 'follow', 'instagram'
            }
            
            filtered_words = [w for w in words if w not in stop_words and len(w) > 2]
            
            # Count word frequencies
            word_freq = {}
            for word in filtered_words:
                word_freq[word] = word_freq.get(word, 0) + 1
            
            # Get top words
            top_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)[:top_n_words]
            cluster_topics[cluster_id] = [word for word, count in top_words]
        
        self.cluster_topics = cluster_topics
        self.cluster_tags = cluster_tags
        return cluster_topics

    def analyze_generational_patterns(self) -> None:
        """
        Analyze generational language patterns in comments and videos.
        """
        print("Analyzing generational language patterns...")
        
        # Initialize the analyzer if not already done
        if not hasattr(self, 'generational_analyzer'):
            self.generational_analyzer = EnhancedGenerationalLanguageAnalyzer
        
        if self.comments_df is not None:
            print("Analyzing comments for generational patterns...")
            # Analyze comments
            self.comments_df['generational_scores'] = self.comments_df['cleaned_text'].apply(
                self.generational_analyzer.analyze_generational_language
            )
            
            # Extract individual generation scores
            for generation in ['gen_z', 'millennial', 'gen_x', 'boomer']:
                self.comments_df[f'{generation}_score'] = self.comments_df['generational_scores'].apply(
                    lambda x: x.get(generation, 0) if isinstance(x, dict) else 0
                )
            
            # Classify predominant generation
            self.comments_df['dominant_generation'] = self.comments_df['cleaned_text'].apply(
                self.generational_analyzer.classify_generation
            )
            
            print("Comments generational analysis completed")
        
        if self.videos_df is not None:
            print("Analyzing videos for generational patterns...")
            # Analyze video titles and descriptions
            combined_video_text = (
                self.videos_df.get('cleaned_title', '').fillna('') + ' ' + 
                self.videos_df.get('cleaned_description', '').fillna('')
            )
            
            self.videos_df['generational_scores'] = combined_video_text.apply(
                self.generational_analyzer.analyze_generational_language
            )
            
            # Extract individual generation scores
            for generation in ['gen_z', 'millennial', 'gen_x', 'boomer']:
                self.videos_df[f'{generation}_score'] = self.videos_df['generational_scores'].apply(
                    lambda x: x.get(generation, 0) if isinstance(x, dict) else 0
                )
            
            self.videos_df['dominant_generation'] = combined_video_text.apply(
                self.generational_analyzer.classify_generation
            )
            
            print("Videos generational analysis completed")

    
    def analyze_generational_trends_by_cluster(self) -> Dict:
        """
        Analyze generational trends for each cluster.
        
        Returns:
            Dict: Generational analysis by cluster
        """
        if self.clusters is None or self.comments_df is None:
            print("Clustering must be performed first")
            return {}
        
        generational_cluster_analysis = {}
        
        for cluster_id in sorted(set(self.clusters)):
            cluster_comments = self.comments_df[self.comments_df['cluster'] == cluster_id]
            cluster_videos = self.videos_df[self.videos_df['cluster'] == cluster_id] if 'cluster' in self.videos_df.columns else pd.DataFrame()
            
            if len(cluster_comments) == 0:
                continue
            
            # Analyze generational distribution in comments
            generation_distribution = cluster_comments['dominant_generation'].value_counts(normalize=True)
            
            # Calculate average generational scores
            avg_gen_scores = {}
            for generation in ['gen_z', 'millennial', 'gen_x', 'boomer']:
                avg_gen_scores[generation] = cluster_comments[f'{generation}_score'].mean()
            
            # Identify dominant generation for this cluster
            dominant_gen = max(avg_gen_scores.items(), key=lambda x: x[1])
            
            # Analyze video performance by generation
            video_performance_by_gen = {}
            if not cluster_videos.empty:
                for generation in ['gen_z', 'millennial', 'gen_x', 'boomer', 'neutral']:
                    gen_videos = cluster_videos[cluster_videos['dominant_generation'] == generation]
                    if len(gen_videos) > 0:
                        video_performance_by_gen[generation] = {
                            'count': len(gen_videos),
                            'avg_views': gen_videos['viewCount'].mean(),
                            'avg_likes': gen_videos['likeCount'].mean(),
                            'avg_engagement': gen_videos['engagement_rate'].mean()
                        }
            
            generational_cluster_analysis[cluster_id] = {
                'topic_words': self.cluster_topics.get(cluster_id, []),
                'topic_tags': self.cluster_tags.get(cluster_id, []),
                'dominant_generation': dominant_gen[0],
                'dominant_generation_score': dominant_gen[1],
                'generation_distribution': generation_distribution.to_dict(),
                'avg_generational_scores': avg_gen_scores,
                'video_performance_by_generation': video_performance_by_gen,
                'total_comments': len(cluster_comments),
                'total_videos': len(cluster_videos)
            }
        
        self.generational_clusters = generational_cluster_analysis
        return generational_cluster_analysis

    def prepare_time_series_data(self, time_window: str = 'D') -> pd.DataFrame:
        """
        Prepare enhanced time series data combining video, comment, and tag metrics.
        
        Args:
            time_window (str): Time aggregation window ('D' for daily, 'W' for weekly)
            
        Returns:
            pd.DataFrame: Enhanced time series data including tag metrics
        """
        print(f"Preparing enhanced time series data with tags and {time_window} aggregation...")
        
        # Create comment-based time series
        comment_time_series = []
        for cluster_id in sorted(set(self.clusters)):
            cluster_data = self.comments_df[self.comments_df['cluster'] == cluster_id]
            
            # Aggregate by time window
            time_series = cluster_data.groupby(
                pd.Grouper(key='publishedAt', freq=time_window)
            ).agg({
                'commentId': 'count',
                'likeCount': 'sum',
                'compound': 'mean',
                'positive': 'mean',
                'negative': 'mean'
            }).reset_index()
            
            time_series['cluster'] = cluster_id
            time_series.columns = ['date', 'comment_count', 'comment_likes', 
                                 'avg_sentiment', 'avg_positive', 'avg_negative', 'cluster']
            comment_time_series.append(time_series)
        
        comment_trend_data = pd.concat(comment_time_series, ignore_index=True) if comment_time_series else pd.DataFrame()
        
        # Create enhanced video-based time series including tag metrics
        video_time_series = []
        if 'cluster' in self.videos_df.columns:
            for cluster_id in sorted(set(self.clusters)):
                cluster_videos = self.videos_df[self.videos_df['cluster'] == cluster_id]
                
                if len(cluster_videos) > 0:
                    # Aggregate video metrics by time window
                    video_ts = cluster_videos.groupby(
                        pd.Grouper(key='publishedAt', freq=time_window)
                    ).agg({
                        'videoId': 'count',
                        'viewCount': 'sum',
                        'likeCount': 'sum',
                        'commentCount': 'sum',
                        'trending_score': 'mean',
                        'engagement_rate': 'mean',
                        'tag_count': 'mean',
                        'tag_relevance_score': 'mean'
                    }).reset_index()
                    
                    video_ts['cluster'] = cluster_id
                    video_ts.columns = ['date', 'video_count', 'total_views', 'video_likes',
                                      'video_comments', 'avg_trending_score', 'avg_engagement_rate',
                                      'avg_tag_count', 'avg_tag_relevance', 'cluster']
                    video_time_series.append(video_ts)
        
        video_trend_data = pd.concat(video_time_series, ignore_index=True) if video_time_series else pd.DataFrame()
        
        # Combine comment and video data
        if not comment_trend_data.empty and not video_trend_data.empty:
            # Merge on date and cluster
            combined_data = pd.merge(
                comment_trend_data,
                video_trend_data,
                on=['date', 'cluster'],
                how='outer'
            ).fillna(0)
        elif not comment_trend_data.empty:
            combined_data = comment_trend_data.copy()
            # Add empty video columns
            video_cols = ['video_count', 'total_views', 'video_likes', 'video_comments',
                         'avg_trending_score', 'avg_engagement_rate', 'avg_tag_count', 'avg_tag_relevance']
            for col in video_cols:
                combined_data[col] = 0
        else:
            combined_data = pd.DataFrame()
        
        if not combined_data.empty:
            # Enhanced combined trending score with tag integration
            combined_data['combined_trending_score'] = (
                self.video_weight * (
                    combined_data['avg_trending_score'] * 0.4 +
                    (combined_data['total_views'] / combined_data['total_views'].max() 
                     if combined_data['total_views'].max() > 0 else 0) * 0.3 +
                    (combined_data['video_likes'] / combined_data['video_likes'].max() 
                     if combined_data['video_likes'].max() > 0 else 0) * 0.3
                ) +
                self.comment_weight * (
                    (combined_data['comment_count'] / combined_data['comment_count'].max() 
                     if combined_data['comment_count'].max() > 0 else 0) * 0.5 +
                    combined_data['avg_sentiment'] * 0.3 +
                    (combined_data['comment_likes'] / combined_data['comment_likes'].max() 
                     if combined_data['comment_likes'].max() > 0 else 0) * 0.2
                ) +
                self.tag_weight * (
                    combined_data['avg_tag_relevance'] * 0.7 +
                    (combined_data['avg_tag_count'] / combined_data['avg_tag_count'].max() 
                     if combined_data['avg_tag_count'].max() > 0 else 0) * 0.3
                )
            )
            
            # Fill missing dates with zero values
            date_range = pd.date_range(
                start=combined_data['date'].min(),
                end=combined_data['date'].max(),
                freq=time_window
            )
            
            complete_data = []
            for cluster_id in combined_data['cluster'].unique():
                cluster_df = pd.DataFrame({'date': date_range, 'cluster': cluster_id})
                cluster_trend = combined_data[combined_data['cluster'] == cluster_id]
                merged = cluster_df.merge(cluster_trend, on=['date', 'cluster'], how='left')
                merged = merged.fillna(0)
                complete_data.append(merged)
            
            self.combined_trend_data = pd.concat(complete_data, ignore_index=True)
        else:
            self.combined_trend_data = pd.DataFrame()
        
        return self.combined_trend_data

    def identify_trending_topics(self, window_days: int = 30, growth_threshold: float = 0.05) -> List[Dict]:
        """
        Enhanced trending topic identification considering video performance and tag relevance.
        
        Args:
            window_days (int): Number of recent days to analyze
            growth_threshold (float): Minimum growth rate to consider trending
            
        Returns:
            List[Dict]: List of trending topics with comprehensive metadata including tags
        """
        print("Identifying trending topics with video and tag emphasis...")
        
        if self.combined_trend_data.empty:
            print("No combined trend data available")
            return []
        
        trending_topics = []
        recent_date = self.combined_trend_data['date'].max() - timedelta(days=window_days)
        
        for cluster_id in self.combined_trend_data['cluster'].unique():
            cluster_data = self.combined_trend_data[self.combined_trend_data['cluster'] == cluster_id]
            recent_data = cluster_data[cluster_data['date'] >= recent_date]
            older_data = cluster_data[cluster_data['date'] < recent_date]
            
            if len(recent_data) == 0 or len(older_data) == 0:
                continue
            
            # Calculate growth for different metrics including tags
            metrics = {
                'combined_score': 'combined_trending_score',
                'video_views': 'total_views',
                'video_likes': 'video_likes',
                'comment_count': 'comment_count',
                'tag_relevance': 'avg_tag_relevance'
            }
            
            growth_rates = {}
            for metric_name, metric_col in metrics.items():
                if metric_col in recent_data.columns:
                    recent_avg = recent_data[metric_col].mean()
                    older_avg = older_data[metric_col].mean()
                    
                    if older_avg > 0:
                        growth_rates[metric_name] = (recent_avg - older_avg) / older_avg
                    else:
                        growth_rates[metric_name] = float('inf') if recent_avg > 0 else 0
                else:
                    growth_rates[metric_name] = 0
            
            # Use combined score as primary growth indicator
            primary_growth = growth_rates.get('combined_score', 0)
            
            if primary_growth >= growth_threshold:
                topic_words = self.cluster_topics.get(cluster_id, [])
                topic_tags = self.cluster_tags.get(cluster_id, [])
                
                # Get recent video performance
                recent_videos = self.videos_df[
                    (self.videos_df['cluster'] == cluster_id) &
                    (self.videos_df['date'] >= recent_date.date())
                ] if 'cluster' in self.videos_df.columns else pd.DataFrame()
                
                trending_topics.append({
                    'cluster_id': int(cluster_id),
                    'topic_words': topic_words,
                    'topic_tags': topic_tags,
                    'combined_growth_rate': primary_growth,
                    'video_views_growth': growth_rates.get('video_views', 0),
                    'video_likes_growth': growth_rates.get('video_likes', 0),
                    'comment_growth': growth_rates.get('comment_count', 0),
                    'tag_relevance_growth': growth_rates.get('tag_relevance', 0),
                    'recent_video_count': len(recent_videos),
                    'recent_total_views': int(recent_data['total_views'].sum()),
                    'recent_video_likes': int(recent_data['video_likes'].sum()),
                    'recent_comment_count': int(recent_data['comment_count'].sum()),
                    'avg_sentiment': float(recent_data['avg_sentiment'].mean()),
                    'avg_trending_score': float(recent_data['avg_trending_score'].mean()),
                    'avg_engagement_rate': float(recent_data['avg_engagement_rate'].mean()),
                    'avg_tag_count': float(recent_data['avg_tag_count'].mean()),
                    'avg_tag_relevance': float(recent_data['avg_tag_relevance'].mean()),
                    'trending_category': self._categorize_trend(primary_growth, growth_rates)
                })
        
        # Sort by combined growth rate (video and tag weighted)
        trending_topics.sort(key=lambda x: x['combined_growth_rate'], reverse=True)
        return trending_topics

    def identify_generational_trending_topics(self, window_days: int = 30) -> Dict:
        """
        Identify trending topics by generation.
        
        Args:
            window_days (int): Number of recent days to analyze
            
        Returns:
            Dict: Trending topics organized by generation
        """
        if not hasattr(self, 'generational_clusters') or self.generational_clusters is None:
            self.analyze_generational_trends_by_cluster()
        
        recent_date = self.comments_df['publishedAt'].max() - timedelta(days=window_days)
        recent_comments = self.comments_df[self.comments_df['publishedAt'] >= recent_date]
        
        generational_trends = {
            'gen_z': [],
            'millennial': [],
            'gen_x': [],
            'boomer': [],
            'neutral': []
        }
        
        for cluster_id, analysis in self.generational_clusters.items():
            dominant_gen = analysis['dominant_generation']
            
            # Calculate trend metrics for this cluster
            cluster_recent_comments = recent_comments[recent_comments['cluster'] == cluster_id]
            
            if len(cluster_recent_comments) > 0:
                # Calculate engagement and growth metrics
                avg_sentiment = cluster_recent_comments['compound'].mean() if 'compound' in cluster_recent_comments.columns else 0
                comment_volume = len(cluster_recent_comments)
                avg_likes = cluster_recent_comments['likeCount'].mean()
                
                # Get video performance for this generation
                video_perf = analysis['video_performance_by_generation'].get(dominant_gen, {})
                
                trend_data = {
                    'cluster_id': cluster_id,
                    'topic_words': analysis['topic_words'][:5],
                    'topic_tags': analysis['topic_tags'][:3],
                    'dominant_generation': dominant_gen,
                    'generation_confidence': analysis['dominant_generation_score'],
                    'recent_comment_volume': comment_volume,
                    'avg_sentiment': avg_sentiment,
                    'avg_comment_likes': avg_likes,
                    'generation_distribution': analysis['generation_distribution'],
                    'video_performance': video_perf
                }
                
                generational_trends[dominant_gen].append(trend_data)
        
        # Sort each generation's trends by relevance
        for generation in generational_trends:
            generational_trends[generation].sort(
                key=lambda x: (x['recent_comment_volume'] * (1 + x['avg_sentiment'])), 
                reverse=True
            )
        
        self.generational_trends = generational_trends
        return generational_trends


    def _categorize_trend(self, primary_growth: float, growth_rates: Dict) -> str:
        """
        Categorize the type of trend based on growth patterns including tags.
        
        Args:
            primary_growth (float): Primary growth rate
            growth_rates (Dict): Dictionary of growth rates for different metrics
            
        Returns:
            str: Trend category
        """
        video_growth = max(growth_rates.get('video_views', 0), growth_rates.get('video_likes', 0))
        comment_growth = growth_rates.get('comment_count', 0)
        tag_growth = growth_rates.get('tag_relevance', 0)
        
        if video_growth > 0.3 and comment_growth > 0.2 and tag_growth > 0.2:
            return "viral"
        elif video_growth > 0.2 and tag_growth > 0.15:
            return "video_trending"
        elif comment_growth > 0.2:
            return "discussion_trending"
        elif tag_growth > 0.2:
            return "tag_trending"
        elif primary_growth > 0.1:
            return "emerging"
        else:
            return "stable_growth"

    def build_enhanced_lstm_model(self, sequence_length: int = 7) -> None:
        """
        Build enhanced LSTM model that includes video metrics and tag features.
        
        Args:
            sequence_length (int): Number of time steps to look back
        """
        print("Building enhanced LSTM model with video and tag features...")
        
        # Enhanced feature set including tag metrics
        self.features = [
            'comment_count', 'comment_likes', 'avg_sentiment', 'video_count',
            'total_views', 'video_likes', 'avg_trending_score', 'avg_engagement_rate',
            'avg_tag_count', 'avg_tag_relevance', 'combined_trending_score'
        ]
        self.sequence_length = sequence_length
        
        # Build enhanced model architecture
        self.lstm_model = Sequential([
            LSTM(64, return_sequences=True, input_shape=(sequence_length, len(self.features))),
            Dropout(0.3),
            LSTM(64, return_sequences=True),
            Dropout(0.3),
            LSTM(32, return_sequences=False),
            Dropout(0.2),
            Dense(32, activation='relu'),
            Dense(16, activation='relu'),
            Dense(1)
        ])
        
        self.lstm_model.compile(
            optimizer=Adam(learning_rate=0.001),
            loss='mse',
            metrics=['mae']
        )
        
        print("Enhanced LSTM model with tag features built successfully")

    def analyze_tag_trends(self) -> Dict:
        """
        Analyze trending tags and their evolution over time.
        
        Returns:
            Dict: Tag trend analysis results
        """
        print("Analyzing tag trends over time...")
        
        if self.popular_tags is None or self.popular_tags.empty:
            return {}
        
        # Get tag trends over time
        tag_time_series = []
        
        # Create time series for each popular tag
        for tag in self.popular_tags.head(20).index:  # Top 20 tags
            tag_videos = []
            for _, video in self.videos_df.iterrows():
                if tag in video.get('processed_tags', []):
                    tag_videos.append({
                        'date': video['date'],
                        'viewCount': video.get('viewCount', 0),
                        'likeCount': video.get('likeCount', 0),
                        'tag': tag
                    })
            
            if tag_videos:
                tag_df = pd.DataFrame(tag_videos)
                tag_ts = tag_df.groupby('date').agg({
                    'viewCount': 'sum',
                    'likeCount': 'sum'
                }).reset_index()
                tag_ts['tag'] = tag
                tag_time_series.append(tag_ts)
        
        tag_trends = {}
        if tag_time_series:
            # Calculate growth rates for tags
            for tag_data in tag_time_series:
                tag = tag_data['tag'].iloc[0]
                if len(tag_data) >= 2:
                    recent_views = tag_data['viewCount'].tail(7).mean()  # Last week average
                    older_views = tag_data['viewCount'].head(7).mean()   # First week average
                    
                    if older_views > 0:
                        growth_rate = (recent_views - older_views) / older_views
                    else:
                        growth_rate = float('inf') if recent_views > 0 else 0
                    
                    tag_trends[tag] = {
                        'growth_rate': growth_rate,
                        'total_views': tag_data['viewCount'].sum(),
                        'total_likes': tag_data['likeCount'].sum(),
                        'video_count': self.popular_tags.loc[tag, 'frequency'],
                        'popularity_score': self.popular_tags.loc[tag, 'tag_popularity_score']
                    }
        
        # Sort by growth rate
        sorted_tag_trends = dict(sorted(tag_trends.items(), key=lambda x: x[1]['growth_rate'], reverse=True))
        return sorted_tag_trends

    def analyze_generational_trends_by_cluster(self) -> None:
        """
        Analyze generational trends for each cluster.
        """
        if self.comments_df is None or 'cluster' not in self.comments_df.columns:
            print("Comment data or clustering results not available")
            return
        
        print("Analyzing generational trends by cluster...")
        
        # First, ensure we have the generational analyzer
        if not hasattr(self, 'generational_analyzer'):
            self.generational_analyzer = EnhancedGenerationalLanguageAnalyzer
        
        # Perform generational analysis if not already done
        if 'dominant_generation' not in self.comments_df.columns:
            print("Performing generational language analysis...")
            self.analyze_generational_patterns()
        
        # Ensure we have the required columns
        required_columns = ['dominant_generation', 'gen_z_score', 'millennial_score', 'gen_x_score', 'boomer_score']
        missing_columns = [col for col in required_columns if col not in self.comments_df.columns]
        
        if missing_columns:
            print(f"Missing columns: {missing_columns}. Performing generational analysis...")
            self.analyze_generational_patterns()
        
        clusters = self.comments_df['cluster'].unique()
        self.generational_clusters = {}
        
        for cluster_id in clusters:
            cluster_comments = self.comments_df[self.comments_df['cluster'] == cluster_id].copy()
            
            if len(cluster_comments) < 10:  # Skip small clusters
                continue
            
            try:
                # Generational distribution analysis
                generation_distribution = cluster_comments['dominant_generation'].value_counts(normalize=True)
                
                if len(generation_distribution) == 0:
                    continue
                    
                dominant_generation = generation_distribution.index[0]
                dominant_generation_score = generation_distribution.iloc[0]
                
                # Calculate average generational scores for this cluster
                avg_generational_scores = {
                    'gen_z': cluster_comments['gen_z_score'].mean(),
                    'millennial': cluster_comments['millennial_score'].mean(),
                    'gen_x': cluster_comments['gen_x_score'].mean(),
                    'boomer': cluster_comments['boomer_score'].mean()
                }
                
                # Get topic information
                topic_words = self.cluster_topics.get(cluster_id, [])[:10]
                topic_tags = self.cluster_tags.get(cluster_id, [])[:5]
                
                # Video performance analysis by generation
                video_performance_by_generation = {}
                if hasattr(self, 'videos_df') and self.videos_df is not None and 'cluster' in self.videos_df.columns:
                    cluster_videos = self.videos_df[self.videos_df['cluster'] == cluster_id]
                    
                    if not cluster_videos.empty and 'dominant_generation' in cluster_videos.columns:
                        for generation in ['gen_z', 'millennial', 'gen_x', 'boomer', 'neutral']:
                            gen_videos = cluster_videos[cluster_videos['dominant_generation'] == generation]
                            if len(gen_videos) > 0:
                                video_performance_by_generation[generation] = {
                                    'count': len(gen_videos),
                                    'avg_views': gen_videos.get('viewCount', pd.Series([0])).mean(),
                                    'avg_likes': gen_videos.get('likeCount', pd.Series([0])).mean(),
                                    'avg_engagement': gen_videos.get('engagement_rate', pd.Series([0])).mean()
                                }
                
                self.generational_clusters[cluster_id] = {
                    'dominant_generation': dominant_generation,
                    'dominant_generation_score': dominant_generation_score,
                    'generation_distribution': generation_distribution.to_dict(),
                    'avg_generational_scores': avg_generational_scores,
                    'topic_words': topic_words,
                    'topic_tags': topic_tags,
                    'video_performance_by_generation': video_performance_by_generation,
                    'cluster_size': len(cluster_comments)
                }
                
            except Exception as e:
                print(f"Error analyzing cluster {cluster_id}: {e}")
                continue
        
        print(f"Generational analysis completed for {len(self.generational_clusters)} clusters")



    def visualize_enhanced_trends(self, top_n_clusters: int = 5) -> None:
        """
        Create comprehensive trend visualizations using matplotlib and seaborn including video metrics and tags.
        
        Args:
            top_n_clusters (int): Number of top clusters to visualize
        """
        print("Creating enhanced trend visualizations with tag integration...")
        
        if self.combined_trend_data.empty:
            print("No trend data available for visualization")
            return
        
        # Get top clusters by combined trending score
        cluster_performance = self.combined_trend_data.groupby('cluster')['combined_trending_score'].sum().sort_values(ascending=False)
        top_clusters = cluster_performance.head(top_n_clusters).index.tolist()
        
        # Set up the style
        plt.style.use('seaborn-v0_8')
        sns.set_palette("husl")
        
        # Create comprehensive dashboard with subplots
        fig, axes = plt.subplots(5, 2, figsize=(20, 25))
        fig.suptitle('Enhanced YouTube Beauty Trends Analysis Dashboard (Video+Tag Weighted)', 
                     fontsize=20, fontweight='bold')
        
        colors = plt.cm.tab10(np.linspace(0, 1, len(top_clusters)))
        
        # Plot 1: Video Views Over Time
        ax1 = axes[0, 0]
        for i, cluster_id in enumerate(top_clusters):
            cluster_data = self.combined_trend_data[self.combined_trend_data['cluster'] == cluster_id]
            topic_words = ', '.join(self.cluster_topics.get(cluster_id, [])[:3])
            topic_tags = ', '.join(self.cluster_tags.get(cluster_id, [])[:2])
            label = f'{topic_words} ({topic_tags})' if topic_tags else topic_words
            
            ax1.plot(cluster_data['date'], cluster_data['total_views'], 
                    marker='o', linewidth=2, color=colors[i], label=label[:30])
        
        ax1.set_title('Video Views Over Time', fontsize=14, fontweight='bold')
        ax1.set_xlabel('Date')
        ax1.set_ylabel('Total Views')
        ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        ax1.tick_params(axis='x', rotation=45)
        
        # Plot 2: Comment Volume Over Time
        ax2 = axes[0, 1]
        for i, cluster_id in enumerate(top_clusters):
            cluster_data = self.combined_trend_data[self.combined_trend_data['cluster'] == cluster_id]
            topic_words = ', '.join(self.cluster_topics.get(cluster_id, [])[:3])
            
            ax2.plot(cluster_data['date'], cluster_data['comment_count'], 
                    marker='s', linewidth=2, linestyle='--', color=colors[i], label=topic_words[:30])
        
        ax2.set_title('Comment Volume Over Time', fontsize=14, fontweight='bold')
        ax2.set_xlabel('Date')
        ax2.set_ylabel('Comment Count')
        ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        ax2.tick_params(axis='x', rotation=45)
        
        # Plot 3: Combined Trending Scores
        ax3 = axes[1, 0]
        for i, cluster_id in enumerate(top_clusters):
            cluster_data = self.combined_trend_data[self.combined_trend_data['cluster'] == cluster_id]
            topic_words = ', '.join(self.cluster_topics.get(cluster_id, [])[:3])
            
            ax3.plot(cluster_data['date'], cluster_data['combined_trending_score'], 
                    marker='D', linewidth=3, color=colors[i], label=topic_words[:30])
        
        ax3.set_title('Combined Trending Scores (Video+Tag Weighted)', fontsize=14, fontweight='bold')
        ax3.set_xlabel('Date')
        ax3.set_ylabel('Combined Trending Score')
        ax3.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        ax3.tick_params(axis='x', rotation=45)
        
        # Plot 4: Tag Relevance Over Time
        ax4 = axes[1, 1]
        for i, cluster_id in enumerate(top_clusters):
            cluster_data = self.combined_trend_data[self.combined_trend_data['cluster'] == cluster_id]
            topic_tags = ', '.join(self.cluster_tags.get(cluster_id, [])[:2])
            
            ax4.plot(cluster_data['date'], cluster_data['avg_tag_relevance'], 
                    marker='^', linewidth=2, color=colors[i], label=topic_tags[:30] or f'Cluster {cluster_id}')
        
        ax4.set_title('Tag Relevance Over Time', fontsize=14, fontweight='bold')
        ax4.set_xlabel('Date')
        ax4.set_ylabel('Average Tag Relevance')
        ax4.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        ax4.tick_params(axis='x', rotation=45)
        
        # Plot 5: Video vs Comment Engagement Scatter
        ax5 = axes[2, 0]
        engagement_data = []
        for cluster_id in top_clusters:
            cluster_data = self.combined_trend_data[self.combined_trend_data['cluster'] == cluster_id]
            topic_words = ', '.join(self.cluster_topics.get(cluster_id, [])[:2])
            
            for _, row in cluster_data.iterrows():
                engagement_data.append({
                    'video_engagement': row['avg_engagement_rate'],
                    'comment_engagement': row['comment_count'],
                    'cluster': topic_words[:20]
                })
        
        if engagement_data:
            eng_df = pd.DataFrame(engagement_data)
            sns.scatterplot(data=eng_df, x='video_engagement', y='comment_engagement', 
                          hue='cluster', s=100, alpha=0.7, ax=ax5)
        
        ax5.set_title('Video vs Comment Engagement', fontsize=14, fontweight='bold')
        ax5.set_xlabel('Video Engagement Rate')
        ax5.set_ylabel('Comment Count')
        
        # Plot 6: Sentiment Trends
        ax6 = axes[2, 1]
        for i, cluster_id in enumerate(top_clusters):
            cluster_data = self.combined_trend_data[self.combined_trend_data['cluster'] == cluster_id]
            topic_words = ', '.join(self.cluster_topics.get(cluster_id, [])[:2])
            
            ax6.plot(cluster_data['date'], cluster_data['avg_sentiment'], 
                    marker='o', linewidth=2, color=colors[i], label=topic_words[:30])
        
        ax6.axhline(y=0, color='black', linestyle='-', alpha=0.3)
        ax6.set_title('Sentiment Trends', fontsize=14, fontweight='bold')
        ax6.set_xlabel('Date')
        ax6.set_ylabel('Average Sentiment')
        ax6.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        ax6.tick_params(axis='x', rotation=45)
        
        # Plot 7: Video Performance Metrics Heatmap
        ax7 = axes[3, 0]
        performance_data = []
        for cluster_id in top_clusters:
            cluster_data = self.combined_trend_data[self.combined_trend_data['cluster'] == cluster_id]
            topic_words = ', '.join(self.cluster_topics.get(cluster_id, [])[:2])
            
            performance_data.append({
                'Cluster': topic_words[:15],
                'Views': cluster_data['total_views'].sum(),
                'Likes': cluster_data['video_likes'].sum(),
                'Comments': cluster_data['comment_count'].sum(),
                'Engagement': cluster_data['avg_engagement_rate'].mean()
            })
        
        if performance_data:
            perf_df = pd.DataFrame(performance_data)
            perf_df_norm = perf_df.set_index('Cluster')
            perf_df_norm = (perf_df_norm - perf_df_norm.min()) / (perf_df_norm.max() - perf_df_norm.min())
            
            sns.heatmap(perf_df_norm.T, annot=True, cmap='YlOrRd', fmt='.2f', ax=ax7)
        
        ax7.set_title('Video Performance Metrics (Normalized)', fontsize=14, fontweight='bold')
        
        # Plot 8: Tag Popularity Distribution
        ax8 = axes[3, 1]
        if self.popular_tags is not None and not self.popular_tags.empty:
            top_tags = self.popular_tags.head(10)
            ax8.barh(range(len(top_tags)), top_tags['tag_popularity_score'], color='skyblue')
            ax8.set_yticks(range(len(top_tags)))
            ax8.set_yticklabels(top_tags.index, fontsize=10)
            ax8.set_xlabel('Tag Popularity Score')
            ax8.set_title('Top 10 Tag Popularity Distribution', fontsize=14, fontweight='bold')
        
        # Plot 9: Growth Rate Comparison
        ax9 = axes[4, 0]
        trending_topics = self.identify_trending_topics()[:top_n_clusters]
        if trending_topics:
            growth_data = {
                'Topic': [', '.join(t['topic_words'][:2]) for t in trending_topics],
                'Video Growth': [t['video_views_growth'] for t in trending_topics],
                'Comment Growth': [t['comment_growth'] for t in trending_topics],
                'Tag Growth': [t['tag_relevance_growth'] for t in trending_topics]
            }
            
            growth_df = pd.DataFrame(growth_data)
            growth_df_melted = growth_df.melt(id_vars='Topic', var_name='Growth Type', value_name='Growth Rate')
            
            sns.barplot(data=growth_df_melted, x='Topic', y='Growth Rate', hue='Growth Type', ax=ax9)
            ax9.tick_params(axis='x', rotation=45)
        
        ax9.set_title('Growth Rate Comparison by Type', fontsize=14, fontweight='bold')
        
        # Plot 10: Trending Categories Distribution
        ax10 = axes[4, 1]
        if trending_topics:
            categories = [t['trending_category'] for t in trending_topics]
            category_counts = pd.Series(categories).value_counts()
            
            colors_pie = plt.cm.Set3(np.linspace(0, 1, len(category_counts)))
            wedges, texts, autotexts = ax10.pie(category_counts.values, labels=category_counts.index, 
                                              autopct='%1.1f%%', colors=colors_pie)
            
            # Make percentage text more readable
            for autotext in autotexts:
                autotext.set_color('black')
                autotext.set_fontweight('bold')
        
        ax10.set_title('Trending Categories Distribution', fontsize=14, fontweight='bold')
        
        # Adjust layout and show
        plt.tight_layout()
        plt.subplots_adjust(top=0.95)
        plt.show()

    def identify_generational_trending_topics(self, window_days: int = 30) -> Dict:
        """
        Identify trending topics by generation.
        
        Args:
            window_days (int): Number of recent days to analyze
            
        Returns:
            Dict: Trending topics organized by generation
        """
        if not hasattr(self, 'generational_clusters') or self.generational_clusters is None:
            self.analyze_generational_trends_by_cluster()
        
        recent_date = self.comments_df['publishedAt'].max() - timedelta(days=window_days)
        recent_comments = self.comments_df[self.comments_df['publishedAt'] >= recent_date]
        
        generational_trends = {
            'gen_z': [],
            'millennial': [],
            'gen_x': [],
            'boomer': [],
            'neutral': []
        }
        
        for cluster_id, analysis in self.generational_clusters.items():
            dominant_gen = analysis['dominant_generation']
            
            # Calculate trend metrics for this cluster
            cluster_recent_comments = recent_comments[recent_comments['cluster'] == cluster_id]
            
            if len(cluster_recent_comments) > 0:
                # Calculate engagement and growth metrics
                avg_sentiment = cluster_recent_comments['compound'].mean() if 'compound' in cluster_recent_comments.columns else 0
                comment_volume = len(cluster_recent_comments)
                avg_likes = cluster_recent_comments['likeCount'].mean()
                
                # Get video performance for this generation
                video_perf = analysis['video_performance_by_generation'].get(dominant_gen, {})
                
                trend_data = {
                    'cluster_id': cluster_id,
                    'topic_words': analysis['topic_words'][:5],
                    'topic_tags': analysis['topic_tags'][:3],
                    'dominant_generation': dominant_gen,
                    'generation_confidence': analysis['dominant_generation_score'],
                    'recent_comment_volume': comment_volume,
                    'avg_sentiment': avg_sentiment,
                    'avg_comment_likes': avg_likes,
                    'generation_distribution': analysis['generation_distribution'],
                    'video_performance': video_perf
                }
                
                generational_trends[dominant_gen].append(trend_data)
        
        # Sort each generation's trends by relevance
        for generation in generational_trends:
            generational_trends[generation].sort(
                key=lambda x: (x['recent_comment_volume'] * (1 + x['avg_sentiment'])), 
                reverse=True
            )
        
        self.generational_trends = generational_trends
        return generational_trends


    def generate_enhanced_trend_report(self) -> Dict:
        """
        Generate comprehensive trend analysis report with video and tag emphasis.
        
        Returns:
            Dict: Enhanced trend analysis report including tag insights
        """
        print("Generating enhanced trend analysis report with tag integration...")
        
        # Get trending topics with video and tag emphasis
        trending_topics = self.identify_trending_topics()
        
        # Analyze tag trends
        tag_trends = self.analyze_tag_trends()
        
        # Calculate overall statistics
        total_comments = len(self.comments_df) if self.comments_df is not None else 0
        total_videos = len(self.videos_df) if self.videos_df is not None else 0
        total_clusters = len(set(self.clusters)) if self.clusters is not None else 0
        total_unique_tags = len(self.popular_tags) if self.popular_tags is not None else 0
        
        # Video performance statistics
        if self.videos_df is not None and not self.videos_df.empty:
            total_views = self.videos_df.get('viewCount', pd.Series([0])).sum()
            total_video_likes = self.videos_df.get('likeCount', pd.Series([0])).sum()
            avg_engagement_rate = self.videos_df.get('engagement_rate', pd.Series([0])).mean()
            avg_tags_per_video = self.videos_df.get('tag_count', pd.Series([0])).mean()
            
            top_performing_videos = self.videos_df.nlargest(5, 'trending_score')[
                ['title', 'viewCount', 'likeCount', 'trending_score', 'processed_tags']
            ].to_dict('records')
        else:
            total_views = 0
            total_video_likes = 0
            avg_engagement_rate = 0
            avg_tags_per_video = 0
            top_performing_videos = []
        
        # Comment statistics
        if self.comments_df is not None and not self.comments_df.empty:
            avg_sentiment = self.comments_df['compound'].mean()
        else:
            avg_sentiment = 0
        
        # Get most engaging clusters (video and tag weighted)
        if not self.combined_trend_data.empty:
            cluster_performance = self.combined_trend_data.groupby('cluster').agg({
                'combined_trending_score': 'sum',
                'total_views': 'sum',
                'video_likes': 'sum',
                'comment_count': 'sum',
                'avg_sentiment': 'mean',
                'avg_tag_relevance': 'mean'
            }).sort_values('combined_trending_score', ascending=False)
        else:
            cluster_performance = pd.DataFrame()
        
        report = {
            'analysis_summary': {
                'total_comments_analyzed': total_comments,
                'total_videos_analyzed': total_videos,
                'total_topics_identified': total_clusters,
                'total_unique_tags': total_unique_tags,
                'total_video_views': int(total_views),
                'total_video_likes': int(total_video_likes),
                'avg_video_engagement_rate': float(avg_engagement_rate),
                'avg_tags_per_video': float(avg_tags_per_video),
                'overall_sentiment': 'positive' if avg_sentiment > 0.1 else 'negative' if avg_sentiment < -0.1 else 'neutral',
                'video_weight_factor': self.video_weight,
                'comment_weight_factor': self.comment_weight,
                'tag_weight_factor': self.tag_weight,
                'analysis_period': {
                    'start_date': str(self.comments_df['publishedAt'].min().date()) if self.comments_df is not None and not self.comments_df.empty else 'N/A',
                    'end_date': str(self.comments_df['publishedAt'].max().date()) if self.comments_df is not None and not self.comments_df.empty else 'N/A'
                }
            },
            'trending_topics': trending_topics[:10],  # Top 10 trending with video and tag emphasis
            'trending_tags': [
                {
                    'tag': tag,
                    'growth_rate': data['growth_rate'],
                    'total_views': data['total_views'],
                    'video_count': data['video_count'],
                    'popularity_score': data['popularity_score']
                }
                for tag, data in list(tag_trends.items())[:10]
            ],
            'top_performing_videos': top_performing_videos,
            'top_engaging_clusters': [
                {
                    'cluster_id': int(cluster_id),
                    'topic_words': self.cluster_topics.get(cluster_id, [])[:5],
                    'topic_tags': self.cluster_tags.get(cluster_id, [])[:3],
                    'combined_score': float(row['combined_trending_score']),
                    'total_views': int(row['total_views']),
                    'total_video_likes': int(row['video_likes']),
                    'total_comments': int(row['comment_count']),
                    'avg_sentiment': float(row['avg_sentiment']),
                    'avg_tag_relevance': float(row['avg_tag_relevance'])
                }
                for cluster_id, row in cluster_performance.head(10).iterrows()
            ] if not cluster_performance.empty else [],
            'video_performance_insights': {
                'highest_engagement_cluster': int(cluster_performance.index[0]) if not cluster_performance.empty else None,
                'most_viewed_topic': self.cluster_topics.get(int(cluster_performance.index[0]), [])[:3] if not cluster_performance.empty else [],
                'most_viewed_tags': self.cluster_tags.get(int(cluster_performance.index[0]), [])[:3] if not cluster_performance.empty else [],
                'trend_categories': {
                    category: len([t for t in trending_topics if t.get('trending_category') == category])
                    for category in ['viral', 'video_trending', 'discussion_trending', 'tag_trending', 'emerging', 'stable_growth']
                }
            },
            'tag_insights': {
                'most_popular_tags': list(self.popular_tags.head(10).index) if self.popular_tags is not None else [],
                'fastest_growing_tags': list(tag_trends.keys())[:5],
                'tag_coverage': (self.videos_df['tag_count'] > 0).sum() / len(self.videos_df) if self.videos_df is not None and not self.videos_df.empty else 0
            }
        }
        
         # Add generational analysis
        if not hasattr(self, 'generational_trends') or self.generational_trends is None:
            self.identify_generational_trending_topics()
        
        # Add generational insights to the report
        report['generational_insights'] = {
            'trending_by_generation': {},
            'generation_distribution': {},
            'top_generational_topics': {},
            'generational_sentiment': {}
        }
        
        for generation, trends in self.generational_trends.items():
            if trends:
                report['generational_insights']['trending_by_generation'][generation] = trends[:5]
                report['generational_insights']['top_generational_topics'][generation] = [
                    {
                        'topic_words': trend['topic_words'],
                        'topic_tags': trend['topic_tags'],
                        'comment_volume': trend['recent_comment_volume'],
                        'sentiment': trend['avg_sentiment']
                    }
                    for trend in trends[:3]
                ]
        
        return report

    def run_enhanced_pipeline(self, comment_files: List[str], video_file: str, 
                            n_clusters: int = None, forecast_days: int = 30) -> Dict:
        """
        Run the complete enhanced trend analysis pipeline with video and tag emphasis.
        
        Args:
            comment_files (List[str]): List of comment CSV files
            video_file (str): Video CSV file path
            n_clusters (int): Number of clusters (if None, will optimize)
            forecast_days (int): Number of days to forecast
            
        Returns:
            Dict: Complete enhanced analysis results including tag insights
        """
        print("Starting enhanced trend analysis pipeline with video and tag emphasis...")
        
        # Step 1: Load and preprocess data (including tags)
        self.load_data(comment_files, video_file)
        self.preprocess_data()
        
        # Step 2: Generate embeddings
        self.generate_embeddings()
        
        # Step 3: Optimize cluster number if not provided
        if n_clusters is None:
            n_clusters = self.optimize_cluster_number()
        
        # Step 4: Perform clustering
        self.perform_clustering(n_clusters=n_clusters)
        
        # Step 5: Map videos to clusters (including tags)
        self.map_videos_to_clusters()
        
        # Step 6: Analyze sentiment
        self.analyze_sentiment()
        
        # Step 7: Extract topics (enhanced with video titles and tags)
        self.extract_cluster_topics()
        
        # Step 8: Prepare enhanced time series data (including tag metrics)
        self.prepare_time_series_data()
        
        # Step 9: Build enhanced LSTM model (including tag features)
        if not self.combined_trend_data.empty:
            self.build_enhanced_lstm_model()
        
        # Step 10: Generate enhanced visualizations
        self.visualize_enhanced_trends()
        
        # Step 11: Generate comprehensive report
        report = self.generate_enhanced_trend_report()
        
        print("Enhanced pipeline completed successfully!")
        print(f"Video weight factor: {self.video_weight}")
        print(f"Comment weight factor: {self.comment_weight}")
        print(f"Tag weight factor: {self.tag_weight}")
        
        return report

    def plot_generational_growth_by_clusters(self, generation: str = 'gen_z', 
                                       forecast_horizons: List[int] = [20, 30, 60]) -> None:
        """
        Plot growth rates for a specific generation across clusters and forecast horizons.
        
        Args:
            generation: Target generation ('gen_z', 'millennial', 'gen_x', 'boomer')
            forecast_horizons: List of forecast horizons to analyze
        """
        if not hasattr(self, 'generational_clusters') or not self.generational_clusters:
            print("No generational cluster data available. Running analysis...")
            self.analyze_generational_trends_by_cluster()
        
        # Get clusters dominated by the specified generation
        generation_clusters = []
        for cluster_id, analysis in self.generational_clusters.items():
            if analysis['dominant_generation'] == generation:
                generation_clusters.append({
                    'cluster_id': cluster_id,
                    'topic_words': analysis['topic_words'][:3],
                    'topic_tags': analysis['topic_tags'][:2],
                    'generation_confidence': analysis['dominant_generation_score'],
                    'cluster_size': analysis.get('cluster_size', 0)
                })
        
        if not generation_clusters:
            print(f"No clusters found for generation: {generation}")
            return
        
        # Sort by cluster size (popularity)
        generation_clusters.sort(key=lambda x: x['cluster_size'], reverse=True)
        top_clusters = generation_clusters[:15]  # Limit to top 15 for readability
        
        # Create subplots for different forecast horizons
        fig, axes = plt.subplots(1, len(forecast_horizons), figsize=(6*len(forecast_horizons), 8))
        if len(forecast_horizons) == 1:
            axes = [axes]
        
        fig.suptitle(f'{generation.replace("_", " ").title()} Growth Rates by Cluster Category', 
                     fontsize=16, fontweight='bold')
        
        # Generate forecasts for each horizon if not already available
        if not hasattr(self, 'multi_horizon_results'):
            print("Generating multi-horizon forecasts...")
            self.multi_horizon_results = self.generate_multi_horizon_forecasts(forecast_horizons)
        
        for idx, horizon in enumerate(forecast_horizons):
            ax = axes[idx]
            
            cluster_names = []
            growth_rates = []
            colors = []
            
            for cluster_data in top_clusters:
                cluster_id = cluster_data['cluster_id']
                
                # Get growth rate from forecast results
                if (hasattr(self, 'multi_horizon_results') and 
                    horizon in self.multi_horizon_results and
                    cluster_id in self.multi_horizon_results[horizon]['metrics']['cluster_metrics']):
                    
                    growth_rate = self.multi_horizon_results[horizon]['metrics']['cluster_metrics'][cluster_id]['growth_rate']
                    growth_rates.append(growth_rate)
                    
                    # Color based on growth rate
                    if growth_rate > 0.1:
                        colors.append('green')
                    elif growth_rate > 0:
                        colors.append('lightgreen')
                    elif growth_rate > -0.05:
                        colors.append('orange')
                    else:
                        colors.append('red')
                else:
                    growth_rates.append(0)
                    colors.append('gray')
                
                # Create cluster label
                topic_str = ', '.join(cluster_data['topic_words'])
                tag_str = ', '.join(cluster_data['topic_tags']) if cluster_data['topic_tags'] else ''
                label = f"C{cluster_id}: {topic_str}"
                if tag_str:
                    label += f"\n({tag_str})"
                cluster_names.append(label)
            
            # Create bar plot
            bars = ax.bar(range(len(cluster_names)), growth_rates, color=colors, alpha=0.7)
            
            # Customize plot
            ax.set_title(f'{horizon}-Day Forecast', fontweight='bold')
            ax.set_xlabel('Cluster Categories')
            ax.set_ylabel('Growth Rate')
            ax.set_xticks(range(len(cluster_names)))
            ax.set_xticklabels(cluster_names, rotation=45, ha='right', fontsize=8)
            ax.axhline(y=0, color='black', linestyle='-', alpha=0.5)
            ax.grid(True, alpha=0.3, axis='y')
            
            # Add value labels on bars
            for bar, value in zip(bars, growth_rates):
                height = bar.get_height()
                ax.text(bar.get_x() + bar.get_width()/2, 
                       height + 0.005 if height >= 0 else height - 0.01,
                       f'{value:.2f}', ha='center', 
                       va='bottom' if height >= 0 else 'top', 
                       fontsize=8, fontweight='bold')
        
        plt.tight_layout()
        plt.show()


    def generate_generational_forecasts(self, forecast_horizons: List[int] = [20, 30, 40, 60, 80]) -> Dict:
        """
        Generate forecasts specifically for generational trends.
        
        Args:
            forecast_horizons: List of forecast horizons in days
            
        Returns:
            Dict: Forecasts organized by generation and horizon
        """
        print("Generating generational-specific forecasts...")
        
        # Ensure we have generational analysis
        if not hasattr(self, 'generational_clusters') or not self.generational_clusters:
            self.analyze_generational_trends_by_cluster()
        
        # Generate multi-horizon forecasts if not available
        if not hasattr(self, 'multi_horizon_results'):
            self.multi_horizon_results = self.generate_multi_horizon_forecasts(forecast_horizons)
        
        generational_forecasts = {
            'gen_z': {},
            'millennial': {},
            'gen_x': {},
            'boomer': {},
            'neutral': {}
        }
        
        for horizon in forecast_horizons:
            for generation in generational_forecasts.keys():
                generational_forecasts[generation][horizon] = {
                    'clusters': [],
                    'avg_growth_rate': 0.0,
                    'total_clusters': 0,
                    'trending_topics': [],
                    'declining_topics': []
                }
            
            # Process each cluster
            if horizon in self.multi_horizon_results:
                cluster_metrics = self.multi_horizon_results[horizon]['metrics']['cluster_metrics']
                
                for cluster_id, analysis in self.generational_clusters.items():
                    dominant_gen = analysis['dominant_generation']
                    
                    if cluster_id in cluster_metrics:
                        cluster_forecast = cluster_metrics[cluster_id]
                        
                        cluster_data = {
                            'cluster_id': cluster_id,
                            'topic_words': analysis['topic_words'][:3],
                            'topic_tags': analysis['topic_tags'][:2],
                            'growth_rate': cluster_forecast['growth_rate'],
                            'prediction_mean': cluster_forecast['prediction_mean'],
                            'generation_confidence': analysis['dominant_generation_score']
                        }
                        
                        generational_forecasts[dominant_gen][horizon]['clusters'].append(cluster_data)
                        
                        if cluster_forecast['growth_rate'] > 0.05:
                            generational_forecasts[dominant_gen][horizon]['trending_topics'].append(cluster_data)
                        elif cluster_forecast['growth_rate'] < -0.05:
                            generational_forecasts[dominant_gen][horizon]['declining_topics'].append(cluster_data)
                
                # Calculate averages for each generation
                for generation in generational_forecasts.keys():
                    gen_data = generational_forecasts[generation][horizon]
                    if gen_data['clusters']:
                        gen_data['avg_growth_rate'] = np.mean([c['growth_rate'] for c in gen_data['clusters']])
                        gen_data['total_clusters'] = len(gen_data['clusters'])
        
        return generational_forecasts



    def optimize_cluster_number(self, max_clusters: int = 50, sample_size: int = 10000) -> int:
        """
        Find optimal number of clusters using elbow method with FAISS K-means.
        
        Args:
            max_clusters (int): Maximum number of clusters to test
            sample_size (int): Sample size for faster computation
            
        Returns:
            int: Optimal number of clusters
        """
        print("Finding optimal number of clusters...")
        
        if self.embeddings is None:
            raise ValueError("Embeddings not generated")
        
        # Sample embeddings for faster computation
        if len(self.embeddings) > sample_size:
            indices = np.random.choice(len(self.embeddings), sample_size, replace=False)
            sample_embeddings = self.embeddings[indices].astype(np.float32)
        else:
            sample_embeddings = self.embeddings.astype(np.float32)
        
        d = sample_embeddings.shape[1]
        inertias = []
        k_range = range(5, min(max_clusters, len(sample_embeddings)//2), 5)
        
        for k in k_range:
            print(f"Testing {k} clusters...")
            kmeans = faiss.Kmeans(d=d, k=k, niter=20, verbose=False)
            kmeans.train(sample_embeddings)
            
            # Calculate inertia (within-cluster sum of squares)
            _, distances = kmeans.index.search(sample_embeddings, 1)
            inertia = np.sum(distances)
            inertias.append(inertia)
        
        # Find elbow point (simplified method)
        if len(inertias) >= 3:
            # Calculate rate of change
            rates = []
            for i in range(1, len(inertias)):
                rate = (inertias[i-1] - inertias[i]) / inertias[i-1]
                rates.append(rate)
            
            # Find where rate of improvement drops significantly
            optimal_idx = 0
            for i in range(1, len(rates)):
                if rates[i] < rates[i-1] * 0.5:  # 50% drop in improvement rate
                    optimal_idx = i
                    break
            
            optimal_k = list(k_range)[optimal_idx + 1]  # +1 because rates is shorter
        else:
            optimal_k = 20  # Default fallback
        
        print(f"Optimal number of clusters: {optimal_k}")
        return optimal_k


    def visualize_all_generations_forecast(self, forecast_horizons: List[int] = [20, 30, 60]) -> None:
        """
        Create comprehensive visualization showing all generations' forecast performance.
        """
        generational_forecasts = self.generate_generational_forecasts(forecast_horizons)
        
        generations = ['gen_z', 'millennial', 'gen_x', 'boomer']
        gen_colors = {
            'gen_z': '#FF6B6B',
            'millennial': '#4ECDC4', 
            'gen_x': '#45B7D1',
            'boomer': '#96CEB4'
        }
        
        # Create comprehensive dashboard
        fig, axes = plt.subplots(2, 2, figsize=(20, 16))
        fig.suptitle('Generational Trend Forecasting Analysis', fontsize=18, fontweight='bold')
        
        # Plot 1: Average Growth Rate by Generation and Horizon
        ax1 = axes[0, 0]
        width = 0.2
        x = np.arange(len(forecast_horizons))
        
        for i, generation in enumerate(generations):
            if generation in generational_forecasts:
                avg_growth_rates = []
                for horizon in forecast_horizons:
                    avg_growth_rates.append(generational_forecasts[generation][horizon]['avg_growth_rate'])
                
                ax1.bar(x + i * width, avg_growth_rates, width, 
                       label=generation.replace('_', ' ').title(), 
                       color=gen_colors[generation], alpha=0.8)
        
        ax1.set_xlabel('Forecast Horizon (Days)')
        ax1.set_ylabel('Average Growth Rate')
        ax1.set_title('Average Growth Rate by Generation', fontweight='bold')
        ax1.set_xticks(x + width * 1.5)
        ax1.set_xticklabels(forecast_horizons)
        ax1.legend()
        ax1.grid(True, alpha=0.3)
        ax1.axhline(y=0, color='black', linestyle='-', alpha=0.5)
        
        # Plot 2: Number of Trending Topics by Generation
        ax2 = axes[0, 1]
        trending_counts = {gen: [] for gen in generations}
        
        for generation in generations:
            for horizon in forecast_horizons:
                count = len(generational_forecasts[generation][horizon]['trending_topics'])
                trending_counts[generation].append(count)
        
        x = np.arange(len(forecast_horizons))
        for i, generation in enumerate(generations):
            ax2.plot(forecast_horizons, trending_counts[generation], 
                    marker='o', linewidth=2, markersize=8,
                    color=gen_colors[generation], 
                    label=generation.replace('_', ' ').title())
        
        ax2.set_xlabel('Forecast Horizon (Days)')
        ax2.set_ylabel('Number of Trending Topics')
        ax2.set_title('Trending Topics Count by Generation', fontweight='bold')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        # Plot 3: Gen Z specific cluster growth (as requested)
        ax3 = axes[1, 0]
        if 'gen_z' in generational_forecasts:
            # Get Gen Z clusters for the middle forecast horizon
            mid_horizon = forecast_horizons[len(forecast_horizons)//2]
            gen_z_clusters = generational_forecasts['gen_z'][mid_horizon]['clusters'][:10]  # Top 10
            
            if gen_z_clusters:
                cluster_names = []
                growth_rates = []
                colors = []
                
                for cluster in gen_z_clusters:
                    topic_str = ', '.join(cluster['topic_words'])
                    cluster_names.append(f"C{cluster['cluster_id']}: {topic_str}")
                    growth_rates.append(cluster['growth_rate'])
                    
                    if cluster['growth_rate'] > 0.05:
                        colors.append('green')
                    elif cluster['growth_rate'] > 0:
                        colors.append('lightgreen')
                    else:
                        colors.append('red')
                
                bars = ax3.barh(range(len(cluster_names)), growth_rates, color=colors, alpha=0.7)
                ax3.set_yticks(range(len(cluster_names)))
                ax3.set_yticklabels(cluster_names, fontsize=9)
                ax3.set_xlabel('Growth Rate')
                ax3.set_title(f'Gen Z Cluster Growth ({mid_horizon}-Day Forecast)', fontweight='bold')
                ax3.axvline(x=0, color='black', linestyle='-', alpha=0.5)
                ax3.grid(True, alpha=0.3)
        
        # Plot 4: Generation Comparison Heatmap
        ax4 = axes[1, 1]
        heatmap_data = []
        generation_labels = []
        
        for generation in generations:
            if generation in generational_forecasts:
                row_data = []
                for horizon in forecast_horizons:
                    avg_growth = generational_forecasts[generation][horizon]['avg_growth_rate']
                    row_data.append(avg_growth)
                heatmap_data.append(row_data)
                generation_labels.append(generation.replace('_', ' ').title())
        
        if heatmap_data:
            im = ax4.imshow(heatmap_data, cmap='RdYlGn', aspect='auto')
            ax4.set_xticks(range(len(forecast_horizons)))
            ax4.set_xticklabels(forecast_horizons)
            ax4.set_yticks(range(len(generation_labels)))
            ax4.set_yticklabels(generation_labels)
            ax4.set_xlabel('Forecast Horizon (Days)')
            ax4.set_title('Growth Rate Heatmap', fontweight='bold')
            
            # Add text annotations
            for i in range(len(generation_labels)):
                for j in range(len(forecast_horizons)):
                    text = ax4.text(j, i, f'{heatmap_data[i][j]:.3f}',
                                   ha="center", va="center", color="black", fontweight='bold')
            
            plt.colorbar(im, ax=ax4, label='Growth Rate')
        
        plt.tight_layout()
        plt.show()


# Generational Language Analyzer 

In [7]:
class EnhancedGenerationalLanguageAnalyzer:
    """
    Enhanced analyzer for detecting generational language patterns with beauty-specific terms.
    """
    
    def __init__(self):
        # Enhanced generational vocabulary patterns based on current social media usage
        self.generational_patterns = {
            'gen_z': {
                'slang': [
                    # Core Gen Z terms
                    'slay', 'periodt', 'no cap', 'fr', 'frfr', 'bussin', 'sheesh', 'sus',
                    'bet', 'lowkey', 'highkey', 'deadass', 'based', 'cringe', 'hits different',
                    'slaps', 'vibe check', 'stan', 'bestie', 'bestyyy', 'girlie', 'girly',
                    'queen', 'king', 'icon', 'iconic', 'legend', 'fire', 'lit', 'mid',
                    'cap', 'facts', 'say less', 'periodt pooh', 'chile', 'oop', 'and i oop',
                    'sksksk', 'vsco', 'simp', 'salty', 'tea', 'spill', 'mood', 'same',
                    
                    # Beauty-specific Gen Z terms
                    'snatched', 'beat', 'glow up', 'lewk', 'serve', 'serving looks',
                    'beat face', 'contour', 'highlight', 'brows on fleek', 'cut crease',
                    'wing', 'winged liner', 'blend', 'shade', 'transition', 'inner corner',
                    'lashes', 'falsies', 'mascara wand', 'setting spray', 'prime',
                    'dewy', 'matte', 'glossy', 'shimmer', 'pigmented', 'buildable',
                    'grwm', 'get ready with me', 'skincare routine', 'glass skin',
                    'no makeup makeup', 'fresh face', 'natural glam', 'everyday look',
                    'night out', 'date night', 'going out', 'soft glam', 'dramatic',
                    'smoky eye', 'nude lip', 'bold lip', 'red lip', 'glossy lip',
                ],
                'expressions': [
                    'literally me', 'this is everything', 'not me', 'the way i',
                    'tell me why', 'bestie this', 'girlie that', 'the audacity',
                    'i cannot', 'i cant even', 'this aint it', 'we been knew',
                    'it be like that', 'chile anyways', 'as you should',
                    'living for this', 'obsessed', 'im deceased', 'i am deceased',
                    'this look', 'that beat', 'those brows', 'the glow',
                    'your skin', 'that highlight', 'the blend', 'those lashes',
                    'tutorial please', 'drop the routine', 'what products',
                    'need this', 'want this', 'trying this'
                ],
                'beauty_expressions': [
                    'beat for the gods', 'face beat', 'mug is beat', 'serving face',
                    'your mug', 'that mug', 'face card', 'never declines',
                    'face is giving', 'look is giving', 'serving looks',
                    'main character energy', 'hot girl', 'that girl',
                    'clean girl', 'it girl', 'effortless', 'no effort'
                ]
            },
            'millennial': {
                'slang': [
                    # Core Millennial terms
                    'basic', 'bye felicia', 'cray', 'fleek', 'on fleek', 'ghosting',
                    'hashtag', 'jelly', 'savage', 'shade', 'throwing shade', 'squad',
                    'goals', 'relationship goals', 'thirsty', 'turnt', 'yasss', 'zero chill',
                    'dead', 'dying', 'literally cant', 'on point', 'mood', 'relatable',
                    'awkward', 'random', 'epic', 'fail', 'winning', 'adulting',
                    'bae', 'fam', 'woke', 'snatched', 'extra', 'pressed',
                    
                    # Beauty-specific Millennial terms
                    'contour', 'highlight', 'strobing', 'baking', 'cut crease',
                    'winged eyeliner', 'bold brow', 'power brow', 'ombre',
                    'balayage', 'lob', 'beach waves', 'no poo', 'bb cream',
                    'cc cream', 'primer', 'setting powder', 'bronzer', 'blush',
                    'lipstick', 'lip gloss', 'matte lips', 'liquid lipstick',
                    'eyeshadow palette', 'makeup haul', 'beauty guru', 'tutorial'
                ],
                'expressions': [
                    'i literally', 'so random', 'hot mess', 'train wreck',
                    'comfort zone', 'bucket list', 'netflix and chill',
                    'sorry not sorry', 'my bad', 'lets do this', 'game changer',
                    'life hack', 'pro tip', 'diy', 'holy grail', 'ride or die',
                    'must have', 'obsessed with', 'in love with', 'cant live without',
                    'beauty routine', 'morning routine', 'night routine',
                    'self care', 'treat yourself', 'me time'
                ],
                'internet_culture': [
                    'lol', 'omg', 'wtf', 'smh', 'tbh', 'imo', 'imho', 'rofl',
                    'lmao', 'brb', 'ttyl', 'irl', 'fomo', 'yolo', 'tbt',
                    'inspo', 'motd', 'fotd', 'ootd', 'notd'
                ]
            },
            'gen_x': {
                'slang': [
                    'whatever', 'as if', 'totally', 'tubular', 'rad', 'gnarly',
                    'dude', 'sweet', 'tight', 'sick', 'phat', 'da bomb',
                    'all that', 'bananas', 'bling', 'bouncing', 'chill',
                    'diss', 'fresh', 'funky', 'off the hook', 'trippin'
                ],
                'expressions': [
                    'talk to the hand', 'dont go there', 'been there done that',
                    'my bad', 'whats the deal', 'get real', 'not', 'psych',
                    'cowabunga', 'excellent', 'bogus', 'grody'
                ]
            },
            'boomer': {
                'formal_language': [
                    'wonderful', 'lovely', 'beautiful', 'amazing', 'fantastic',
                    'terrific', 'marvelous', 'delightful', 'charming', 'pleasant',
                    'gorgeous', 'stunning', 'pretty', 'nice', 'good'
                ],
                'expressions': [
                    'back in my day', 'when i was young', 'kids these days',
                    'in my time', 'years ago', 'old school', 'classic',
                    'traditional', 'proper', 'decent', 'respectable',
                    'elegant', 'sophisticated', 'timeless'
                ],
                'communication_style': [
                    'thank you', 'please', 'excuse me', 'pardon me',
                    'bless you', 'god bless', 'have a nice day',
                    'very nice', 'well done', 'good job'
                ]
            }
        }
        
        # Beauty-specific context indicators
        self.beauty_context = [
            'makeup', 'lipstick', 'eyeshadow', 'mascara', 'foundation', 'concealer',
            'blush', 'bronzer', 'highlighter', 'contour', 'primer', 'setting',
            'skincare', 'moisturizer', 'cleanser', 'serum', 'toner', 'sunscreen',
            'routine', 'tutorial', 'look', 'glam', 'natural', 'dramatic',
            'palette', 'brush', 'sponge', 'application', 'blend', 'shade'
        ]
        
        # Compile regex patterns with lower thresholds
        self.compiled_patterns = self._compile_patterns()
        
        # Adjust scoring weights
        self.category_weights = {
            'slang': 3.0,  # Increased weight
            'expressions': 2.5,
            'beauty_expressions': 4.0,  # Highest weight for beauty-specific
            'internet_culture': 2.0,
            'formal_language': 1.5,
            'communication_style': 1.0
        }
    
    def _compile_patterns(self):
        """Compile all patterns into regex for efficient matching."""
        compiled = {}
        
        for generation, patterns in self.generational_patterns.items():
            compiled[generation] = {}
            for category, terms in patterns.items():
                # Create more flexible patterns
                pattern_list = []
                for term in terms:
                    # Handle multi-word terms
                    if ' ' in term:
                        # Allow some variation in spacing and punctuation
                        flexible_term = term.replace(' ', r'\s+')
                        pattern_list.append(flexible_term)
                    else:
                        # Single word with word boundaries
                        pattern_list.append(r'\b' + re.escape(term) + r'\b')
                
                if pattern_list:
                    compiled[generation][category] = re.compile(
                        '|'.join(pattern_list), re.IGNORECASE
                    )
        
        return compiled
    
    def analyze_generational_language(self, text: str) -> Dict[str, float]:
        """Enhanced analysis with context awareness and flexible scoring."""
        if not text or len(text.strip()) < 3:
            return {gen: 0.0 for gen in self.generational_patterns.keys()}
        
        text_lower = text.lower()
        scores = {gen: 0.0 for gen in self.generational_patterns.keys()}
        
        # Check if it's beauty-related context
        has_beauty_context = any(term in text_lower for term in self.beauty_context)
        beauty_multiplier = 1.5 if has_beauty_context else 1.0
        
        # Analyze patterns for each generation
        for generation, patterns in self.compiled_patterns.items():
            for category, pattern in patterns.items():
                matches = len(pattern.findall(text))
                if matches > 0:
                    weight = self.category_weights.get(category, 1.0)
                    # Apply beauty context multiplier
                    if 'beauty' in category or has_beauty_context:
                        weight *= beauty_multiplier
                    
                    scores[generation] += matches * weight
        
        # Normalize by text length (words, not characters)
        word_count = len(text.split())
        if word_count > 0:
            for gen in scores:
                scores[gen] = scores[gen] / max(word_count, 1)
        
        return scores
    
    def classify_generation(self, text: str, threshold: float = 0.005) -> str:  # Lower threshold
        """Classify with more sensitive threshold."""
        scores = self.analyze_generational_language(text)
        
        if not any(score > 0 for score in scores.values()):
            return 'neutral'
        
        max_generation = max(scores.items(), key=lambda x: x[1])
        
        if max_generation[1] >= threshold:
            return max_generation[0]
        else:
            return 'neutral'

# Forecastor

In [8]:
class EnsembleForecaster:
    """
    Ensemble forecasting model combining LSTM, ARIMA, and Prophet for more reliable predictions.
    """
    
    def __init__(self):
        self.lstm_model = None
        self.arima_models = {}
        self.prophet_models = {}
        self.ensemble_weights = {}
        self.scaler = StandardScaler()
        self.models_trained = False
        
    def prepare_lstm_data(self, data: pd.DataFrame, target_col: str, 
                         feature_cols: List[str], sequence_length: int = 7) -> Tuple[np.ndarray, np.ndarray]:
        """
        Prepare data for LSTM model.
        
        Args:
            data (pd.DataFrame): Time series data
            target_col (str): Target variable column name
            feature_cols (List[str]): Feature column names
            sequence_length (int): Length of input sequences
            
        Returns:
            Tuple[np.ndarray, np.ndarray]: X and y arrays for LSTM
        """
        # Sort by date
        data_sorted = data.sort_values('date').reset_index(drop=True)
        
        # Prepare features
        features = data_sorted[feature_cols + [target_col]].values
        features_scaled = self.scaler.fit_transform(features)
        
        X, y = [], []
        for i in range(sequence_length, len(features_scaled)):
            X.append(features_scaled[i-sequence_length:i, :-1])  # All features except target
            y.append(features_scaled[i, -1])  # Target variable
        
        return np.array(X), np.array(y)
    
    def build_lstm_model(self, input_shape: Tuple[int, int]) -> Sequential:
        """
        Build enhanced LSTM model with GRU layers.
        
        Args:
            input_shape (Tuple[int, int]): Shape of input data (sequence_length, n_features)
            
        Returns:
            Sequential: Compiled LSTM model
        """
        model = Sequential([
            LSTM(64, return_sequences=True, input_shape=input_shape),
            Dropout(0.3),
            GRU(32, return_sequences=False),
            Dropout(0.2),
            Dense(32, activation='relu'),
            Dense(16, activation='relu'),
            Dense(1)
        ])
        
        model.compile(
            optimizer=Adam(learning_rate=0.001),
            loss='huber',  # More robust to outliers than MSE
            metrics=['mae']
        )
        
        return model
    
    def make_series_stationary(self, series: pd.Series) -> Tuple[pd.Series, int]:
        """
        Make time series stationary for ARIMA modeling.
        
        Args:
            series (pd.Series): Time series data
            
        Returns:
            Tuple[pd.Series, int]: Differenced series and number of differences
        """
        diff_count = 0
        current_series = series.copy()
        
        # Test for stationarity
        while diff_count < 2:  # Maximum 2 differences
            result = adfuller(current_series.dropna())
            p_value = result[1]
            
            if p_value <= 0.05:  # Stationary
                break
            
            current_series = current_series.diff()
            diff_count += 1
        
        return current_series.dropna(), diff_count
    
    def fit_arima_model(self, data: pd.Series, cluster_id: int) -> None:
        """
        Fit ARIMA model with automatic parameter selection.
        
        Args:
            data (pd.Series): Time series data for specific cluster
            cluster_id (int): Cluster identifier
        """
        try:
            # Make series stationary
            stationary_data, d = self.make_series_stationary(data)
            
            if len(stationary_data) < 10:  # Need minimum data points
                print(f"Insufficient data for ARIMA model for cluster {cluster_id}")
                return
            
            # Auto ARIMA parameter selection (simplified)
            best_aic = float('inf')
            best_order = (1, d, 1)
            
            for p in range(0, 3):
                for q in range(0, 3):
                    try:
                        model = ARIMA(data, order=(p, d, q))
                        fitted_model = model.fit()
                        
                        if fitted_model.aic < best_aic:
                            best_aic = fitted_model.aic
                            best_order = (p, d, q)
                    except:
                        continue
            
            # Fit best model
            final_model = ARIMA(data, order=best_order)
            self.arima_models[cluster_id] = final_model.fit()
            
        except Exception as e:
            print(f"Error fitting ARIMA for cluster {cluster_id}: {e}")
    
    def fit_prophet_model(self, data: pd.DataFrame, cluster_id: int, target_col: str) -> None:
        """
        Fit Prophet model for time series forecasting.
        
        Args:
            data (pd.DataFrame): Time series data
            cluster_id (int): Cluster identifier
            target_col (str): Target column name
        """
        try:
            # Prepare data for Prophet
            prophet_data = data[['date', target_col]].copy()
            prophet_data.columns = ['ds', 'y']
            
            # Convert to datetime and remove timezone information
            prophet_data['ds'] = pd.to_datetime(prophet_data['ds'])
            if prophet_data['ds'].dt.tz is not None:
                prophet_data['ds'] = prophet_data['ds'].dt.tz_localize(None)
            
            # Remove any rows with NaN values
            prophet_data = prophet_data.dropna()
            
            if len(prophet_data) < 10:  # Need minimum data points
                print(f"Insufficient data for Prophet model for cluster {cluster_id}")
                return
            
            # Ensure the data is sorted by date
            prophet_data = prophet_data.sort_values('ds').reset_index(drop=True)
            
            # Configure Prophet model
            model = Prophet(
                changepoint_prior_scale=0.05,
                seasonality_prior_scale=10.0,
                holidays_prior_scale=10.0,
                daily_seasonality=False,
                weekly_seasonality=True,
                yearly_seasonality=False if len(prophet_data) < 730 else True,
                interval_width=0.8
            )
            
            # Suppress Prophet's verbose output
            import logging
            logging.getLogger('prophet').setLevel(logging.WARNING)
            
            # Fit model
            model.fit(prophet_data)
            self.prophet_models[cluster_id] = model
            
            print(f"Prophet model trained successfully for cluster {cluster_id}")
            
        except Exception as e:
            print(f"Error fitting Prophet for cluster {cluster_id}: {e}")
    
    def train_ensemble_models(self, trend_data: pd.DataFrame, target_col: str = 'combined_trending_score',
                            feature_cols: List[str] = None, sequence_length: int = 7) -> None:
        """
        Train all models in the ensemble.
        
        Args:
            trend_data (pd.DataFrame): Time series trend data
            target_col (str): Target variable to predict
            feature_cols (List[str]): Feature columns for LSTM
            sequence_length (int): Sequence length for LSTM
        """
        print("Training ensemble forecasting models...")
        
        if feature_cols is None:
            feature_cols = [
                'comment_count', 'comment_likes', 'avg_sentiment', 'video_count',
                'total_views', 'video_likes', 'avg_trending_score', 'avg_engagement_rate',
                'avg_tag_count', 'avg_tag_relevance'
            ]
        
        # Filter available columns
        available_features = [col for col in feature_cols if col in trend_data.columns]
        
        if not available_features:
            print("No feature columns found in data")
            return
        
        # Train models for each cluster
        clusters = trend_data['cluster'].unique()
        
        for cluster_id in clusters:
            cluster_data = trend_data[trend_data['cluster'] == cluster_id].copy()
            cluster_data = cluster_data.sort_values('date').reset_index(drop=True)
            
            if len(cluster_data) < sequence_length + 5:  # Need minimum data
                continue
            
            print(f"Training models for cluster {cluster_id}...")
            
            # Train ARIMA
            target_series = cluster_data.set_index('date')[target_col]
            self.fit_arima_model(target_series, cluster_id)
            
            # Train Prophet
            self.fit_prophet_model(cluster_data, cluster_id, target_col)
        
        # Train LSTM on combined data
        if len(trend_data) > sequence_length + 10:
            try:
                X, y = self.prepare_lstm_data(trend_data, target_col, available_features, sequence_length)
                
                if len(X) > 0:
                    # Split data
                    split_idx = int(len(X) * 0.8)
                    X_train, X_val = X[:split_idx], X[split_idx:]
                    y_train, y_val = y[:split_idx], y[split_idx:]
                    
                    # Build and train LSTM
                    self.lstm_model = self.build_lstm_model((X.shape[1], X.shape[2]))
                    
                    early_stopping = EarlyStopping(
                        monitor='val_loss',
                        patience=10,
                        restore_best_weights=True
                    )
                    
                    self.lstm_model.fit(
                        X_train, y_train,
                        validation_data=(X_val, y_val),
                        epochs=100,
                        batch_size=32,
                        callbacks=[early_stopping],
                        verbose=0
                    )
                    
                    print("LSTM model trained successfully")
                
            except Exception as e:
                print(f"Error training LSTM: {e}")
        
        self.models_trained = True
        print("Ensemble model training completed")
    
    def predict_lstm(self, data: pd.DataFrame, target_col: str, 
                    feature_cols: List[str], sequence_length: int, 
                    forecast_steps: int) -> np.ndarray:
        """
        Generate LSTM predictions.
        
        Args:
            data (pd.DataFrame): Input data
            target_col (str): Target column
            feature_cols (List[str]): Feature columns
            sequence_length (int): Sequence length
            forecast_steps (int): Number of steps to forecast
            
        Returns:
            np.ndarray: LSTM predictions
        """
        if self.lstm_model is None:
            return np.zeros(forecast_steps)
        
        try:
            # Get last sequence
            features = data[feature_cols + [target_col]].values
            features_scaled = self.scaler.transform(features)
            
            last_sequence = features_scaled[-sequence_length:, :-1].reshape(1, sequence_length, -1)
            
            predictions = []
            current_sequence = last_sequence.copy()
            
            for _ in range(forecast_steps):
                pred = self.lstm_model.predict(current_sequence, verbose=0)[0, 0]
                predictions.append(pred)
                
                # Update sequence for next prediction
                # Note: This is simplified - in practice, you'd need to update with actual feature values
                new_features = np.zeros((1, 1, features_scaled.shape[1] - 1))
                current_sequence = np.concatenate([current_sequence[:, 1:, :], new_features], axis=1)
            
            # Inverse transform predictions
            dummy_features = np.zeros((len(predictions), features_scaled.shape[1]))
            dummy_features[:, -1] = predictions  # Last column is target
            predictions_rescaled = self.scaler.inverse_transform(dummy_features)[:, -1]
            
            return predictions_rescaled
            
        except Exception as e:
            print(f"Error in LSTM prediction: {e}")
            return np.zeros(forecast_steps)
    
    def predict_arima(self, cluster_id: int, forecast_steps: int) -> np.ndarray:
        """
        Generate ARIMA predictions for specific cluster.
        
        Args:
            cluster_id (int): Cluster identifier
            forecast_steps (int): Number of steps to forecast
            
        Returns:
            np.ndarray: ARIMA predictions
        """
        if cluster_id not in self.arima_models:
            return np.zeros(forecast_steps)
        
        try:
            forecast = self.arima_models[cluster_id].forecast(steps=forecast_steps)
            return forecast.values if hasattr(forecast, 'values') else forecast
        except Exception as e:
            print(f"Error in ARIMA prediction for cluster {cluster_id}: {e}")
            return np.zeros(forecast_steps)
    
    def predict_prophet(self, cluster_id: int, forecast_steps: int, 
                   last_date: pd.Timestamp) -> np.ndarray:
        """
        Generate Prophet predictions for specific cluster.
        
        Args:
            cluster_id (int): Cluster identifier
            forecast_steps (int): Number of steps to forecast
            last_date (pd.Timestamp): Last date in the data
            
        Returns:
            np.ndarray: Prophet predictions
        """
        if cluster_id not in self.prophet_models:
            return np.zeros(forecast_steps)
        
        try:
            # Ensure last_date is timezone-naive
            if last_date.tz is not None:
                last_date = last_date.tz_localize(None)
            
            # Create future dates
            future_dates = pd.date_range(
                start=last_date + pd.Timedelta(days=1),
                periods=forecast_steps,
                freq='D'
            )
            
            future_df = pd.DataFrame({'ds': future_dates})
            
            # Suppress Prophet's verbose output during prediction
            import logging
            logging.getLogger('prophet').setLevel(logging.WARNING)
            
            forecast = self.prophet_models[cluster_id].predict(future_df)
            
            return forecast['yhat'].values
            
        except Exception as e:
            print(f"Error in Prophet prediction for cluster {cluster_id}: {e}")
            return np.zeros(forecast_steps)
    
    def calculate_ensemble_weights(self, validation_data: pd.DataFrame, 
                                 target_col: str) -> None:
        """
        Calculate optimal ensemble weights based on validation performance.
        
        Args:
            validation_data (pd.DataFrame): Validation dataset
            target_col (str): Target column name
        """
        print("Calculating ensemble weights...")
        
        clusters = validation_data['cluster'].unique()
        cluster_weights = {}
        
        for cluster_id in clusters:
            cluster_data = validation_data[validation_data['cluster'] == cluster_id]
            
            if len(cluster_data) < 5:
                continue
            
            lstm_weight = 0.4
            arima_weight = 0.3 if cluster_id in self.arima_models else 0.0
            prophet_weight = 0.3 if cluster_id in self.prophet_models else 0.0
            
            # Normalize weights
            total_weight = lstm_weight + arima_weight + prophet_weight
            if total_weight > 0:
                cluster_weights[cluster_id] = {
                    'lstm': lstm_weight / total_weight,
                    'arima': arima_weight / total_weight,
                    'prophet': prophet_weight / total_weight
                }
        
        self.ensemble_weights = cluster_weights
        print("Ensemble weights calculated")
    
    def ensemble_forecast(self, trend_data: pd.DataFrame, forecast_steps: int = 30,
                        target_col: str = 'combined_trending_score',
                        feature_cols: List[str] = None,
                        sequence_length: int = 7) -> Dict[int, np.ndarray]:
        """
        Generate ensemble forecasts combining all models.
        
        Args:
            trend_data (pd.DataFrame): Historical trend data
            forecast_steps (int): Number of steps to forecast
            target_col (str): Target column to predict
            feature_cols (List[str]): Feature columns for LSTM
            sequence_length (int): Sequence length for LSTM
            
        Returns:
            Dict[int, np.ndarray]: Ensemble predictions for each cluster
        """
        if not self.models_trained:
            print("Models not trained. Call train_ensemble_models first.")
            return {}
        
        print("Generating ensemble forecasts...")
        
        if feature_cols is None:
            feature_cols = [
                'comment_count', 'comment_likes', 'avg_sentiment', 'video_count',
                'total_views', 'video_likes', 'avg_trending_score', 'avg_engagement_rate',
                'avg_tag_count', 'avg_tag_relevance'
            ]
        
        available_features = [col for col in feature_cols if col in trend_data.columns]
        ensemble_predictions = {}
        
        last_date = pd.to_datetime(trend_data['date'].max())
        clusters = trend_data['cluster'].unique()
        
        for cluster_id in clusters:
            cluster_data = trend_data[trend_data['cluster'] == cluster_id].sort_values('date')
            
            if len(cluster_data) < sequence_length:
                continue
            
            # Get individual model predictions
            lstm_pred = self.predict_lstm(
                cluster_data, target_col, available_features, 
                sequence_length, forecast_steps
            )
            
            arima_pred = self.predict_arima(cluster_id, forecast_steps)
            prophet_pred = self.predict_prophet(cluster_id, forecast_steps, last_date)
            
            # Combine predictions using ensemble weights
            if cluster_id in self.ensemble_weights:
                weights = self.ensemble_weights[cluster_id]
                ensemble_pred = (
                    weights['lstm'] * lstm_pred +
                    weights['arima'] * arima_pred +
                    weights['prophet'] * prophet_pred
                )
            else:
                # Equal weights if no specific weights calculated
                ensemble_pred = (lstm_pred + arima_pred + prophet_pred) / 3
            
            # Ensure non-negative predictions
            ensemble_pred = np.maximum(ensemble_pred, 0)
            ensemble_predictions[cluster_id] = ensemble_pred
        
        print("Ensemble forecasting completed")
        return ensemble_predictions
    
    def evaluate_models(self, test_data: pd.DataFrame, 
                       target_col: str = 'combined_trending_score') -> Dict:
        """
        Evaluate individual models and ensemble performance.
        
        Args:
            test_data (pd.DataFrame): Test dataset
            target_col (str): Target column name
            
        Returns:
            Dict: Performance metrics for each model
        """
        evaluation_results = {}
        
        clusters = test_data['cluster'].unique()
        
        for cluster_id in clusters:
            cluster_data = test_data[test_data['cluster'] == cluster_id]
            
            if len(cluster_data) < 5:
                continue
            
            actual_values = cluster_data[target_col].values
            
            cluster_results = {
                'lstm_mae': 0.0,
                'arima_mae': 0.0,
                'prophet_mae': 0.0,
                'ensemble_mae': 0.0,
                'lstm_rmse': 0.0,
                'arima_rmse': 0.0,
                'prophet_rmse': 0.0,
                'ensemble_rmse': 0.0
            }
            
            evaluation_results[cluster_id] = cluster_results
        
        return evaluation_results


class EnhancedTrendAI(TrendAI):
    """
    Enhanced YouTube Trend Predictor with Ensemble Forecasting capabilities.
    """
    
    def __init__(self, model_name='all-MiniLM-L6-v2'):
        super().__init__(model_name)
        self.ensemble_forecaster = EnsembleForecaster()
        
    def train_ensemble_forecasting_models(self, forecast_days: int = 30) -> None:
        """
        Train the ensemble forecasting models.
        
        Args:
            forecast_days (int): Number of days to forecast
        """
        print("Training ensemble forecasting models...")
        
        if self.combined_trend_data is None or self.combined_trend_data.empty:
            print("No trend data available. Run prepare_time_series_data first.")
            return
        
        # Train ensemble models
        self.ensemble_forecaster.train_ensemble_models(
            self.combined_trend_data,
            target_col='combined_trending_score'
        )
        
        # Calculate ensemble weights using recent data as validation
        validation_cutoff = self.combined_trend_data['date'].max() - pd.Timedelta(days=14)
        validation_data = self.combined_trend_data[
            self.combined_trend_data['date'] >= validation_cutoff
        ]
        
        if not validation_data.empty:
            self.ensemble_forecaster.calculate_ensemble_weights(
                validation_data, 'combined_trending_score'
            )
    
    def generate_ensemble_forecasts(self, forecast_days: int = 30) -> Dict:
        """
        Generate ensemble forecasts for future trends.
        
        Args:
            forecast_days (int): Number of days to forecast
            
        Returns:
            Dict: Forecast results with confidence intervals
        """
        print(f"Generating ensemble forecasts for {forecast_days} days...")
        
        if self.combined_trend_data is None or self.combined_trend_data.empty:
            print("No trend data available for forecasting")
            return {}
        
        # Generate ensemble predictions
        ensemble_predictions = self.ensemble_forecaster.ensemble_forecast(
            self.combined_trend_data,
            forecast_steps=forecast_days
        )
        
        # Create forecast results with metadata
        last_date = pd.to_datetime(self.combined_trend_data['date'].max())
        future_dates = pd.date_range(
            start=last_date + pd.Timedelta(days=1),
            periods=forecast_days,
            freq='D'
        )
        
        forecast_results = {}
        for cluster_id, predictions in ensemble_predictions.items():
            topic_words = self.cluster_topics.get(cluster_id, [])[:5]
            topic_tags = self.cluster_tags.get(cluster_id, [])[:3]
            
            forecast_results[cluster_id] = {
                'topic_words': topic_words,
                'topic_tags': topic_tags,
                'dates': future_dates.tolist(),
                'predictions': predictions.tolist(),
                'prediction_mean': float(np.mean(predictions)),
                'prediction_trend': 'increasing' if predictions[-1] > predictions[0] else 'decreasing',
                'confidence_level': 'medium',  # Could implement actual confidence intervals
                'forecast_category': self._categorize_forecast(predictions)
            }
        
        return forecast_results
    
    def _categorize_forecast(self, predictions: np.ndarray) -> str:
        """
        Categorize forecast based on prediction patterns.
        
        Args:
            predictions (np.ndarray): Forecast predictions
            
        Returns:
            str: Forecast category
        """
        if len(predictions) < 2:
            return 'stable'
        
        trend_slope = (predictions[-1] - predictions[0]) / len(predictions)
        max_val = np.max(predictions)
        min_val = np.min(predictions)
        volatility = np.std(predictions) / np.mean(predictions) if np.mean(predictions) > 0 else 0
        
        if trend_slope > 0.05 and max_val > predictions[0] * 1.2:
            return 'rapid_growth'
        elif trend_slope > 0.01:
            return 'steady_growth'
        elif trend_slope < -0.05:
            return 'declining'
        elif volatility > 0.3:
            return 'volatile'
        else:
            return 'stable'
    def generate_multi_horizon_forecasts(self, forecast_horizons: List[int] = [20, 30, 40, 60, 80]) -> Dict:
        """
        Generate forecasts for multiple time horizons.
        
        Args:
            forecast_horizons (List[int]): List of forecast horizons in days
            
        Returns:
            Dict: Forecasts organized by horizon and cluster
        """
        print(f"Generating multi-horizon forecasts for {len(forecast_horizons)} horizons...")
        
        multi_horizon_results = {}
        
        for horizon in forecast_horizons:
            print(f"Generating {horizon}-day forecasts...")
            
            # Generate ensemble forecasts for this horizon
            forecast_results = self.generate_ensemble_forecasts(forecast_days=horizon)
            
            # Calculate additional metrics for this horizon
            horizon_metrics = self._calculate_horizon_metrics(forecast_results, horizon)
            
            multi_horizon_results[horizon] = {
                'forecasts': forecast_results,
                'metrics': horizon_metrics,
                'horizon_days': horizon
            }
        
        return multi_horizon_results

    def identify_generational_trending_topics(self, window_days: int = 30) -> Dict:
        """
        Identify trending topics by generation.
        
        Args:
            window_days (int): Number of recent days to analyze
            
        Returns:
            Dict: Trending topics organized by generation
        """
        if not hasattr(self, 'generational_clusters') or self.generational_clusters is None:
            self.analyze_generational_trends_by_cluster()
        
        recent_date = self.comments_df['publishedAt'].max() - timedelta(days=window_days)
        recent_comments = self.comments_df[self.comments_df['publishedAt'] >= recent_date]
        
        generational_trends = {
            'gen_z': [],
            'millennial': [],
            'gen_x': [],
            'boomer': [],
            'neutral': []
        }
        
        for cluster_id, analysis in self.generational_clusters.items():
            dominant_gen = analysis['dominant_generation']
            
            # Calculate trend metrics for this cluster
            cluster_recent_comments = recent_comments[recent_comments['cluster'] == cluster_id]
            
            if len(cluster_recent_comments) > 0:
                # Calculate engagement and growth metrics
                avg_sentiment = cluster_recent_comments['compound'].mean() if 'compound' in cluster_recent_comments.columns else 0
                comment_volume = len(cluster_recent_comments)
                avg_likes = cluster_recent_comments['likeCount'].mean()
                
                # Get video performance for this generation
                video_perf = analysis['video_performance_by_generation'].get(dominant_gen, {})
                
                trend_data = {
                    'cluster_id': cluster_id,
                    'topic_words': analysis['topic_words'][:5],
                    'topic_tags': analysis['topic_tags'][:3],
                    'dominant_generation': dominant_gen,
                    'generation_confidence': analysis['dominant_generation_score'],
                    'recent_comment_volume': comment_volume,
                    'avg_sentiment': avg_sentiment,
                    'avg_comment_likes': avg_likes,
                    'generation_distribution': analysis['generation_distribution'],
                    'video_performance': video_perf
                }
                
                generational_trends[dominant_gen].append(trend_data)
        
        # Sort each generation's trends by relevance
        for generation in generational_trends:
            generational_trends[generation].sort(
                key=lambda x: (x['recent_comment_volume'] * (1 + x['avg_sentiment'])), 
                reverse=True
            )
        
        self.generational_trends = generational_trends
        return generational_trends

    def visualize_generational_trends(self) -> None:
        """
        Create visualizations for generational trend analysis.
        """
        if self.generational_trends is None:
            self.identify_generational_trending_topics()
        
        fig, axes = plt.subplots(2, 3, figsize=(20, 12))
        fig.suptitle('Generational Trend Analysis Dashboard', fontsize=16, fontweight='bold')
        
        # Colors for generations
        gen_colors = {
            'gen_z': '#FF6B6B',
            'millennial': '#4ECDC4', 
            'gen_x': '#45B7D1',
            'boomer': '#96CEB4',
            'neutral': '#FFEAA7'
        }
        
        # Plot 1: Generation Distribution Across All Clusters
        ax1 = axes[0, 0]
        if self.generational_clusters:
            gen_counts = {}
            for cluster_data in self.generational_clusters.values():
                dominant_gen = cluster_data['dominant_generation']
                gen_counts[dominant_gen] = gen_counts.get(dominant_gen, 0) + 1
            
            colors = [gen_colors.get(gen, 'gray') for gen in gen_counts.keys()]
            bars = ax1.bar(gen_counts.keys(), gen_counts.values(), color=colors)
            ax1.set_title('Dominant Generation by Cluster', fontweight='bold')
            ax1.set_ylabel('Number of Clusters')
            
            # Add value labels
            for bar, value in zip(bars, gen_counts.values()):
                ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.1,
                        str(value), ha='center', va='bottom')
        
        # Plot 2: Comment Volume by Generation
        ax2 = axes[0, 1]
        gen_comment_volumes = {}
        for gen, trends in self.generational_trends.items():
            total_comments = sum(trend['recent_comment_volume'] for trend in trends)
            if total_comments > 0:
                gen_comment_volumes[gen] = total_comments
        
        if gen_comment_volumes:
            colors = [gen_colors.get(gen, 'gray') for gen in gen_comment_volumes.keys()]
            ax2.pie(gen_comment_volumes.values(), labels=gen_comment_volumes.keys(), 
                    colors=colors, autopct='%1.1f%%')
            ax2.set_title('Comment Volume Distribution', fontweight='bold')
        
        # Plot 3: Average Sentiment by Generation
        ax3 = axes[0, 2]
        gen_sentiments = {}
        for gen, trends in self.generational_trends.items():
            if trends:
                avg_sentiment = np.mean([trend['avg_sentiment'] for trend in trends if trend['avg_sentiment'] != 0])
                if not np.isnan(avg_sentiment):
                    gen_sentiments[gen] = avg_sentiment
        
        if gen_sentiments:
            colors = [gen_colors.get(gen, 'gray') for gen in gen_sentiments.keys()]
            bars = ax3.bar(gen_sentiments.keys(), gen_sentiments.values(), color=colors)
            ax3.set_title('Average Sentiment by Generation', fontweight='bold')
            ax3.set_ylabel('Sentiment Score')
            ax3.axhline(y=0, color='black', linestyle='--', alpha=0.5)
            
            # Add value labels
            for bar, value in zip(bars, gen_sentiments.values()):
                ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01 if value >= 0 else bar.get_height() - 0.01,
                        f'{value:.3f}', ha='center', va='bottom' if value >= 0 else 'top')
        
        # Plot 4: Top Topics by Generation (Gen Z)
        ax4 = axes[1, 0]
        gen_z_trends = self.generational_trends.get('gen_z', [])[:5]
        if gen_z_trends:
            topics = [', '.join(trend['topic_words'][:2]) for trend in gen_z_trends]
            volumes = [trend['recent_comment_volume'] for trend in gen_z_trends]
            
            ax4.barh(range(len(topics)), volumes, color=gen_colors['gen_z'])
            ax4.set_yticks(range(len(topics)))
            ax4.set_yticklabels(topics, fontsize=10)
            ax4.set_title('Top Gen Z Topics', fontweight='bold')
            ax4.set_xlabel('Recent Comment Volume')
        
        # Plot 5: Top Topics by Generation (Millennial)
        ax5 = axes[1, 1]
        millennial_trends = self.generational_trends.get('millennial', [])[:5]
        if millennial_trends:
            topics = [', '.join(trend['topic_words'][:2]) for trend in millennial_trends]
            volumes = [trend['recent_comment_volume'] for trend in millennial_trends]
            
            ax5.barh(range(len(topics)), volumes, color=gen_colors['millennial'])
            ax5.set_yticks(range(len(topics)))
            ax5.set_yticklabels(topics, fontsize=10)
            ax5.set_title('Top Millennial Topics', fontweight='bold')
            ax5.set_xlabel('Recent Comment Volume')
        
        # Plot 6: Generational Language Intensity Heatmap
        ax6 = axes[1, 2]
        if self.generational_clusters:
            # Create heatmap data
            heatmap_data = []
            cluster_labels = []
            
            for cluster_id, analysis in list(self.generational_clusters.items())[:10]:  # Top 10 clusters
                scores = [analysis['avg_generational_scores'][gen] for gen in ['gen_z', 'millennial', 'gen_x', 'boomer']]
                heatmap_data.append(scores)
                topic_label = ', '.join(analysis['topic_words'][:2])
                cluster_labels.append(f"C{cluster_id}: {topic_label}"[:25])
            
            if heatmap_data:
                sns.heatmap(heatmap_data, 
                           xticklabels=['Gen Z', 'Millennial', 'Gen X', 'Boomer'],
                           yticklabels=cluster_labels,
                           annot=True, fmt='.3f', cmap='YlOrRd', ax=ax6)
                ax6.set_title('Generational Language Intensity', fontweight='bold')
                ax6.set_xlabel('Generation')
                ax6.tick_params(axis='y', labelsize=8)
        
        plt.tight_layout()
        plt.show()



    def _calculate_horizon_metrics(self, forecast_results: Dict, horizon: int) -> Dict:
        """
        Calculate trend-related metrics for a specific forecast horizon.
        
        Args:
            forecast_results (Dict): Forecast results for clusters
            horizon (int): Forecast horizon in days
            
        Returns:
            Dict: Calculated metrics
        """
        metrics = {
            'total_clusters': len(forecast_results),
            'avg_prediction_mean': 0.0,
            'total_predicted_growth': 0.0,
            'volatility_score': 0.0,
            'trend_strength': 0.0,
            'category_distribution': {},
            'cluster_metrics': {}
        }
        
        if not forecast_results:
            return metrics
        
        # Calculate aggregate metrics
        prediction_means = []
        growth_rates = []
        volatilities = []
        categories = []
        
        for cluster_id, results in forecast_results.items():
            predictions = np.array(results['predictions'])
            
            # Basic metrics
            pred_mean = np.mean(predictions)
            prediction_means.append(pred_mean)
            
            # Growth rate (start to end)
            if len(predictions) > 1:
                growth_rate = (predictions[-1] - predictions[0]) / predictions[0] if predictions[0] > 0 else 0
                growth_rates.append(growth_rate)
            
            # Volatility
            volatility = np.std(predictions) / pred_mean if pred_mean > 0 else 0
            volatilities.append(volatility)
            
            # Category
            category = results['forecast_category']
            categories.append(category)
            
            # Trend strength (correlation with linear trend)
            x = np.arange(len(predictions))
            if len(predictions) > 2:
                correlation = np.corrcoef(x, predictions)[0, 1] if np.std(predictions) > 0 else 0
                trend_strength = abs(correlation)
            else:
                trend_strength = 0
            
            # Store individual cluster metrics
            metrics['cluster_metrics'][cluster_id] = {
                'prediction_mean': pred_mean,
                'growth_rate': growth_rate if len(predictions) > 1 else 0,
                'volatility': volatility,
                'trend_strength': trend_strength,
                'max_value': np.max(predictions),
                'min_value': np.min(predictions),
                'final_value': predictions[-1] if len(predictions) > 0 else 0,
                'topic_words': results.get('topic_words', [])
            }
        
        # Aggregate metrics
        metrics['avg_prediction_mean'] = np.mean(prediction_means) if prediction_means else 0
        metrics['total_predicted_growth'] = np.sum(growth_rates) if growth_rates else 0
        metrics['volatility_score'] = np.mean(volatilities) if volatilities else 0
        metrics['trend_strength'] = np.mean([m['trend_strength'] for m in metrics['cluster_metrics'].values()])
        
        # Category distribution
        category_counts = pd.Series(categories).value_counts().to_dict()
        metrics['category_distribution'] = category_counts
        
        return metrics

    
    def visualize_ensemble_forecasts(self, forecast_results: Dict, top_n: int = 5) -> None:
        """
        Visualize ensemble forecast results.
        
        Args:
            forecast_results (Dict): Forecast results from generate_ensemble_forecasts
            top_n (int): Number of top clusters to visualize
        """
        print("Creating ensemble forecast visualizations...")
        
        if not forecast_results:
            print("No forecast results to visualize")
            return
        
        # Sort clusters by prediction mean
        sorted_clusters = sorted(
            forecast_results.items(),
            key=lambda x: x[1]['prediction_mean'],
            reverse=True
        )[:top_n]
        
        # Create visualization
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))
        fig.suptitle('Ensemble Forecasting Results', fontsize=16, fontweight='bold')
        
        # Plot 1: Individual Forecast Lines
        ax1 = axes[0, 0]
        colors = plt.cm.tab10(np.linspace(0, 1, len(sorted_clusters)))
        
        for i, (cluster_id, results) in enumerate(sorted_clusters):
            topic_label = ', '.join(results['topic_words'][:2])
            dates = pd.to_datetime(results['dates'])
            predictions = results['predictions']
            
            ax1.plot(dates, predictions, color=colors[i], linewidth=2, 
                    marker='o', label=f"{topic_label} (C{cluster_id})")
        
        ax1.set_title('Individual Cluster Forecasts', fontweight='bold')
        ax1.set_xlabel('Date')
        ax1.set_ylabel('Predicted Trending Score')
        ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        ax1.tick_params(axis='x', rotation=45)
        
        # Plot 2: Forecast Categories
        ax2 = axes[0, 1]
        categories = [results['forecast_category'] for _, results in sorted_clusters]
        category_counts = pd.Series(categories).value_counts()
        
        ax2.pie(category_counts.values, labels=category_counts.index, autopct='%1.1f%%')
        ax2.set_title('Forecast Categories Distribution', fontweight='bold')
        
        # Plot 3: Prediction Means Comparison
        ax3 = axes[1, 0]
        cluster_names = [', '.join(results['topic_words'][:2]) for _, results in sorted_clusters]
        prediction_means = [results['prediction_mean'] for _, results in sorted_clusters]
        
        bars = ax3.bar(range(len(cluster_names)), prediction_means, color='skyblue')
        ax3.set_xticks(range(len(cluster_names)))
        ax3.set_xticklabels(cluster_names, rotation=45, ha='right')
        ax3.set_title('Average Predicted Trending Scores', fontweight='bold')
        ax3.set_ylabel('Average Trending Score')
        
        # Add value labels on bars
        for bar, value in zip(bars, prediction_means):
            ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'{value:.3f}', ha='center', va='bottom')
        
        # Plot 4: Trend Directions
        ax4 = axes[1, 1]
        trend_directions = [results['prediction_trend'] for _, results in sorted_clusters]
        trend_counts = pd.Series(trend_directions).value_counts()
        
        colors_trend = ['green' if trend == 'increasing' else 'red' for trend in trend_counts.index]
        ax4.bar(trend_counts.index, trend_counts.values, color=colors_trend)
        ax4.set_title('Forecast Trend Directions', fontweight='bold')
        ax4.set_ylabel('Number of Clusters')
        
        plt.tight_layout()
        plt.show()

    def visualize_multi_horizon_analysis(self, multi_horizon_results: Dict, top_clusters: int = 8) -> None:
        """
        Create comprehensive visualizations for multi-horizon forecasts.
        
        Args:
            multi_horizon_results (Dict): Results from generate_multi_horizon_forecasts
            top_clusters (int): Number of top clusters to highlight
        """
        if not multi_horizon_results:
            print("No multi-horizon results to visualize")
            return
        
        horizons = sorted(multi_horizon_results.keys())
        
        # Create multiple visualization sets
        self._plot_horizon_trend_metrics(multi_horizon_results, horizons)
        self._plot_cluster_performance_across_horizons(multi_horizon_results, horizons, top_clusters)
        self._plot_forecast_uncertainty_analysis(multi_horizon_results, horizons)
        self._plot_category_evolution_across_horizons(multi_horizon_results, horizons)

        self.plot_cluster_growth_by_horizon(multi_horizon_results, horizons)
        self.plot_cluster_comparison_across_horizons(multi_horizon_results, top_n=top_clusters)
    
    def _plot_horizon_trend_metrics(self, multi_horizon_results: Dict, horizons: List[int]) -> None:
        """Plot trend metrics across different forecast horizons."""
        
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        fig.suptitle('Trend Metrics Across Forecast Horizons', fontsize=16, fontweight='bold')
        
        # Extract metrics for each horizon
        metrics_data = {
            'horizons': horizons,
            'avg_prediction_mean': [multi_horizon_results[h]['metrics']['avg_prediction_mean'] for h in horizons],
            'total_predicted_growth': [multi_horizon_results[h]['metrics']['total_predicted_growth'] for h in horizons],
            'volatility_score': [multi_horizon_results[h]['metrics']['volatility_score'] for h in horizons],
            'trend_strength': [multi_horizon_results[h]['metrics']['trend_strength'] for h in horizons],
            'total_clusters': [multi_horizon_results[h]['metrics']['total_clusters'] for h in horizons]
        }
        
        # Plot 1: Average Prediction Mean
        axes[0, 0].plot(horizons, metrics_data['avg_prediction_mean'], marker='o', linewidth=2, markersize=8)
        axes[0, 0].set_title('Average Prediction Mean vs Horizon', fontweight='bold')
        axes[0, 0].set_xlabel('Forecast Horizon (days)')
        axes[0, 0].set_ylabel('Average Prediction Mean')
        axes[0, 0].grid(True, alpha=0.3)
        
        # Plot 2: Total Predicted Growth
        axes[0, 1].bar(horizons, metrics_data['total_predicted_growth'], color='green', alpha=0.7)
        axes[0, 1].set_title('Total Predicted Growth vs Horizon', fontweight='bold')
        axes[0, 1].set_xlabel('Forecast Horizon (days)')
        axes[0, 1].set_ylabel('Total Predicted Growth')
        axes[0, 1].grid(True, alpha=0.3)
        
        # Plot 3: Volatility Score
        axes[0, 2].plot(horizons, metrics_data['volatility_score'], marker='s', color='orange', linewidth=2, markersize=8)
        axes[0, 2].set_title('Forecast Volatility vs Horizon', fontweight='bold')
        axes[0, 2].set_xlabel('Forecast Horizon (days)')
        axes[0, 2].set_ylabel('Volatility Score')
        axes[0, 2].grid(True, alpha=0.3)
        
        # Plot 4: Trend Strength
        axes[1, 0].plot(horizons, metrics_data['trend_strength'], marker='^', color='red', linewidth=2, markersize=8)
        axes[1, 0].set_title('Trend Strength vs Horizon', fontweight='bold')
        axes[1, 0].set_xlabel('Forecast Horizon (days)')
        axes[1, 0].set_ylabel('Trend Strength')
        axes[1, 0].grid(True, alpha=0.3)
        
        # Plot 5: Number of Clusters Forecasted
        axes[1, 1].bar(horizons, metrics_data['total_clusters'], color='purple', alpha=0.7)
        axes[1, 1].set_title('Clusters Forecasted vs Horizon', fontweight='bold')
        axes[1, 1].set_xlabel('Forecast Horizon (days)')
        axes[1, 1].set_ylabel('Number of Clusters')
        axes[1, 1].grid(True, alpha=0.3)
        
        # Plot 6: Horizon Comparison Radar
        axes[1, 2].remove()  # Remove this axis for a custom radar chart
        ax_radar = fig.add_subplot(2, 3, 6, projection='polar')
        
        # Normalize metrics for radar chart
        normalized_metrics = []
        metric_names = ['Avg Prediction', 'Growth Rate', 'Volatility', 'Trend Strength']
        
        for i, horizon in enumerate([horizons[0], horizons[-1]]):  # Compare first and last horizon
            norm_pred = metrics_data['avg_prediction_mean'][horizons.index(horizon)] / max(metrics_data['avg_prediction_mean']) if max(metrics_data['avg_prediction_mean']) > 0 else 0
            norm_growth = abs(metrics_data['total_predicted_growth'][horizons.index(horizon)]) / max([abs(x) for x in metrics_data['total_predicted_growth']]) if max([abs(x) for x in metrics_data['total_predicted_growth']]) > 0 else 0
            norm_vol = metrics_data['volatility_score'][horizons.index(horizon)] / max(metrics_data['volatility_score']) if max(metrics_data['volatility_score']) > 0 else 0
            norm_trend = metrics_data['trend_strength'][horizons.index(horizon)]
            
            normalized_metrics.append([norm_pred, norm_growth, norm_vol, norm_trend])
        
        angles = np.linspace(0, 2 * np.pi, len(metric_names), endpoint=False).tolist()
        angles += angles[:1]  # Complete the circle
        
        colors = ['blue', 'red']
        labels = [f'{horizons[0]} days', f'{horizons[-1]} days']
        
        for i, (metrics, color, label) in enumerate(zip(normalized_metrics, colors, labels)):
            metrics += metrics[:1]  # Complete the circle
            ax_radar.plot(angles, metrics, 'o-', linewidth=2, label=label, color=color)
            ax_radar.fill(angles, metrics, alpha=0.25, color=color)
        
        ax_radar.set_xticks(angles[:-1])
        ax_radar.set_xticklabels(metric_names)
        ax_radar.set_title('Horizon Comparison (Normalized)', fontweight='bold', pad=20)
        ax_radar.legend()
        
        plt.tight_layout()
        plt.show()
    
    def _plot_cluster_performance_across_horizons(self, multi_horizon_results: Dict, horizons: List[int], top_clusters: int) -> None:
        """Plot performance of top clusters across different horizons."""
        
        # Identify top clusters based on average performance across all horizons
        cluster_avg_performance = {}
        all_clusters = set()
        
        for horizon in horizons:
            for cluster_id, metrics in multi_horizon_results[horizon]['metrics']['cluster_metrics'].items():
                all_clusters.add(cluster_id)
                if cluster_id not in cluster_avg_performance:
                    cluster_avg_performance[cluster_id] = []
                cluster_avg_performance[cluster_id].append(metrics['prediction_mean'])
        
        # Calculate average performance and get top clusters
        for cluster_id in cluster_avg_performance:
            cluster_avg_performance[cluster_id] = np.mean(cluster_avg_performance[cluster_id])
        
        top_cluster_ids = sorted(cluster_avg_performance.items(), key=lambda x: x[1], reverse=True)[:top_clusters]
        top_cluster_ids = [cid for cid, _ in top_cluster_ids]
        
        fig, axes = plt.subplots(2, 2, figsize=(16, 12))
        fig.suptitle(f'Top {top_clusters} Cluster Performance Across Horizons', fontsize=16, fontweight='bold')
        
        colors = plt.cm.tab10(np.linspace(0, 1, len(top_cluster_ids)))
        
        # Plot 1: Prediction Mean Evolution
        ax1 = axes[0, 0]
        for i, cluster_id in enumerate(top_cluster_ids):
            means = []
            for horizon in horizons:
                if cluster_id in multi_horizon_results[horizon]['metrics']['cluster_metrics']:
                    means.append(multi_horizon_results[horizon]['metrics']['cluster_metrics'][cluster_id]['prediction_mean'])
                else:
                    means.append(0)
            
            # Get topic words for legend
            topic_words = []
            for horizon in horizons:
                if cluster_id in multi_horizon_results[horizon]['metrics']['cluster_metrics']:
                    topic_words = multi_horizon_results[horizon]['metrics']['cluster_metrics'][cluster_id]['topic_words'][:2]
                    break
            
            label = f"C{cluster_id}: {', '.join(topic_words)}" if topic_words else f"Cluster {cluster_id}"
            ax1.plot(horizons, means, marker='o', color=colors[i], linewidth=2, label=label)
        
        ax1.set_title('Prediction Mean vs Horizon', fontweight='bold')
        ax1.set_xlabel('Forecast Horizon (days)')
        ax1.set_ylabel('Prediction Mean')
        ax1.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        ax1.grid(True, alpha=0.3)
        
        # Plot 2: Growth Rate Evolution
        ax2 = axes[0, 1]
        for i, cluster_id in enumerate(top_cluster_ids):
            growth_rates = []
            for horizon in horizons:
                if cluster_id in multi_horizon_results[horizon]['metrics']['cluster_metrics']:
                    growth_rates.append(multi_horizon_results[horizon]['metrics']['cluster_metrics'][cluster_id]['growth_rate'])
                else:
                    growth_rates.append(0)
            
            ax2.plot(horizons, growth_rates, marker='s', color=colors[i], linewidth=2)
        
        ax2.set_title('Growth Rate vs Horizon', fontweight='bold')
        ax2.set_xlabel('Forecast Horizon (days)')
        ax2.set_ylabel('Growth Rate')
        ax2.grid(True, alpha=0.3)
        ax2.axhline(y=0, color='black', linestyle='--', alpha=0.5)
        
        # Plot 3: Volatility Evolution
        ax3 = axes[1, 0]
        for i, cluster_id in enumerate(top_cluster_ids):
            volatilities = []
            for horizon in horizons:
                if cluster_id in multi_horizon_results[horizon]['metrics']['cluster_metrics']:
                    volatilities.append(multi_horizon_results[horizon]['metrics']['cluster_metrics'][cluster_id]['volatility'])
                else:
                    volatilities.append(0)
            
            ax3.plot(horizons, volatilities, marker='^', color=colors[i], linewidth=2)
        
        ax3.set_title('Volatility vs Horizon', fontweight='bold')
        ax3.set_xlabel('Forecast Horizon (days)')
        ax3.set_ylabel('Volatility Score')
        ax3.grid(True, alpha=0.3)
        
        # Plot 4: Trend Strength Evolution
        ax4 = axes[1, 1]
        for i, cluster_id in enumerate(top_cluster_ids):
            trend_strengths = []
            for horizon in horizons:
                if cluster_id in multi_horizon_results[horizon]['metrics']['cluster_metrics']:
                    trend_strengths.append(multi_horizon_results[horizon]['metrics']['cluster_metrics'][cluster_id]['trend_strength'])
                else:
                    trend_strengths.append(0)
            
            ax4.plot(horizons, trend_strengths, marker='d', color=colors[i], linewidth=2)
        
        ax4.set_title('Trend Strength vs Horizon', fontweight='bold')
        ax4.set_xlabel('Forecast Horizon (days)')
        ax4.set_ylabel('Trend Strength')
        ax4.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    
    def _plot_forecast_uncertainty_analysis(self, multi_horizon_results: Dict, horizons: List[int]) -> None:
        """Analyze and plot forecast uncertainty across horizons."""
        
        fig, axes = plt.subplots(2, 2, figsize=(16, 10))
        fig.suptitle('Forecast Uncertainty Analysis', fontsize=16, fontweight='bold')
        
        # Collect uncertainty metrics
        uncertainty_data = {
            'horizon': [],
            'prediction_std': [],
            'growth_rate_std': [],
            'max_prediction_range': [],
            'coefficient_of_variation': []
        }
        
        for horizon in horizons:
            cluster_metrics = multi_horizon_results[horizon]['metrics']['cluster_metrics']
            
            if not cluster_metrics:
                continue
            
            prediction_means = [m['prediction_mean'] for m in cluster_metrics.values()]
            growth_rates = [m['growth_rate'] for m in cluster_metrics.values()]
            max_values = [m['max_value'] for m in cluster_metrics.values()]
            min_values = [m['min_value'] for m in cluster_metrics.values()]
            
            uncertainty_data['horizon'].append(horizon)
            uncertainty_data['prediction_std'].append(np.std(prediction_means))
            uncertainty_data['growth_rate_std'].append(np.std(growth_rates))
            uncertainty_data['max_prediction_range'].append(np.mean([max_val - min_val for max_val, min_val in zip(max_values, min_values)]))
            
            # Coefficient of variation
            mean_pred = np.mean(prediction_means)
            cv = np.std(prediction_means) / mean_pred if mean_pred > 0 else 0
            uncertainty_data['coefficient_of_variation'].append(cv)
        
        # Plot uncertainty metrics
        axes[0, 0].plot(uncertainty_data['horizon'], uncertainty_data['prediction_std'], 
                        marker='o', color='red', linewidth=2, markersize=8)
        axes[0, 0].set_title('Prediction Standard Deviation', fontweight='bold')
        axes[0, 0].set_xlabel('Forecast Horizon (days)')
        axes[0, 0].set_ylabel('Standard Deviation')
        axes[0, 0].grid(True, alpha=0.3)
        
        axes[0, 1].plot(uncertainty_data['horizon'], uncertainty_data['growth_rate_std'], 
                        marker='s', color='orange', linewidth=2, markersize=8)
        axes[0, 1].set_title('Growth Rate Standard Deviation', fontweight='bold')
        axes[0, 1].set_xlabel('Forecast Horizon (days)')
        axes[0, 1].set_ylabel('Standard Deviation')
        axes[0, 1].grid(True, alpha=0.3)
        
        axes[1, 0].plot(uncertainty_data['horizon'], uncertainty_data['max_prediction_range'], 
                        marker='^', color='green', linewidth=2, markersize=8)
        axes[1, 0].set_title('Average Prediction Range', fontweight='bold')
        axes[1, 0].set_xlabel('Forecast Horizon (days)')
        axes[1, 0].set_ylabel('Average Range (Max - Min)')
        axes[1, 0].grid(True, alpha=0.3)
        
        axes[1, 1].plot(uncertainty_data['horizon'], uncertainty_data['coefficient_of_variation'], 
                        marker='d', color='purple', linewidth=2, markersize=8)
        axes[1, 1].set_title('Coefficient of Variation', fontweight='bold')
        axes[1, 1].set_xlabel('Forecast Horizon (days)')
        axes[1, 1].set_ylabel('Coefficient of Variation')
        axes[1, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    
    def _plot_category_evolution_across_horizons(self, multi_horizon_results: Dict, horizons: List[int]) -> None:
        """Plot how forecast categories evolve across different horizons."""
        
        fig, axes = plt.subplots(1, 2, figsize=(16, 6))
        fig.suptitle('Forecast Category Evolution Across Horizons', fontsize=16, fontweight='bold')
        
        # Collect category data
        all_categories = set()
        category_data = {}
        
        for horizon in horizons:
            category_dist = multi_horizon_results[horizon]['metrics']['category_distribution']
            category_data[horizon] = category_dist
            all_categories.update(category_dist.keys())
        
        all_categories = sorted(list(all_categories))
        
        # Plot 1: Stacked Bar Chart
        ax1 = axes[0]
        bottom_values = np.zeros(len(horizons))
        colors = plt.cm.Set3(np.linspace(0, 1, len(all_categories)))
        
        for i, category in enumerate(all_categories):
            values = []
            for horizon in horizons:
                values.append(category_data[horizon].get(category, 0))
            
            ax1.bar(horizons, values, bottom=bottom_values, label=category, color=colors[i])
            bottom_values += values
        
        ax1.set_title('Category Distribution Across Horizons', fontweight='bold')
        ax1.set_xlabel('Forecast Horizon (days)')
        ax1.set_ylabel('Number of Clusters')
        ax1.legend()
        
        # Plot 2: Line Chart showing category trends
        ax2 = axes[1]
        for i, category in enumerate(all_categories):
            values = []
            for horizon in horizons:
                values.append(category_data[horizon].get(category, 0))
            
            ax2.plot(horizons, values, marker='o', linewidth=2, label=category, color=colors[i])
        
        ax2.set_title('Category Trends Across Horizons', fontweight='bold')
        ax2.set_xlabel('Forecast Horizon (days)')
        ax2.set_ylabel('Number of Clusters')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

    def _plot_cluster_trend_vs_growth_scatter(self, multi_horizon_results: Dict, horizons: List[int]) -> None:
        """Plot cluster trend strength vs growth rate scatter plots for each horizon."""
        
        # Determine subplot layout based on number of horizons
        n_horizons = len(horizons)
        if n_horizons <= 2:
            rows, cols = 1, n_horizons
            figsize = (8 * n_horizons, 6)
        elif n_horizons <= 4:
            rows, cols = 2, 2
            figsize = (16, 12)
        elif n_horizons <= 6:
            rows, cols = 2, 3
            figsize = (18, 12)
        else:
            rows, cols = 3, 3
            figsize = (18, 18)
        
        fig, axes = plt.subplots(rows, cols, figsize=figsize)
        fig.suptitle('Cluster Trend Strength vs Growth Rate by Forecast Horizon', fontsize=16, fontweight='bold')
        
        # Handle single subplot case
        if n_horizons == 1:
            axes = [axes]
        elif rows == 1:
            axes = axes if hasattr(axes, '__len__') else [axes]
        else:
            axes = axes.flatten()
        
        # Color map for different prediction strength categories
        def get_color_by_prediction(prediction_mean, all_predictions):
            if not all_predictions:
                return 'gray'
            percentile_75 = np.percentile(all_predictions, 75)
            percentile_25 = np.percentile(all_predictions, 25)
            
            if prediction_mean >= percentile_75:
                return 'red'  # High prediction
            elif prediction_mean <= percentile_25:
                return 'blue'  # Low prediction
            else:
                return 'green'  # Medium prediction
        
        for idx, horizon in enumerate(horizons):
            ax = axes[idx]
            cluster_metrics = multi_horizon_results[horizon]['metrics']['cluster_metrics']
            
            if not cluster_metrics:
                ax.text(0.5, 0.5, 'No data available', transform=ax.transAxes, 
                       ha='center', va='center', fontsize=12)
                ax.set_title(f'{horizon}-Day Horizon', fontweight='bold')
                continue
            
            # Extract data for this horizon
            trend_strengths = []
            growth_rates = []
            prediction_means = []
            cluster_labels = []
            topic_words_list = []
            
            for cluster_id, metrics in cluster_metrics.items():
                trend_strengths.append(metrics['trend_strength'])
                growth_rates.append(metrics['growth_rate'])
                prediction_means.append(metrics['prediction_mean'])
                cluster_labels.append(f"C{cluster_id}")
                topic_words_list.append(metrics.get('topic_words', [])[:2])
            
            # Create scatter plot with color coding
            colors = [get_color_by_prediction(pred, prediction_means) for pred in prediction_means]
            sizes = [50 + abs(pred) * 10 for pred in prediction_means]  # Size based on prediction magnitude
            
            scatter = ax.scatter(trend_strengths, growth_rates, c=colors, s=sizes, alpha=0.7, edgecolors='black', linewidth=0.5)
            
            # Add cluster labels
            for i, (x, y, label, topics) in enumerate(zip(trend_strengths, growth_rates, cluster_labels, topic_words_list)):
                # Only label points that are not too crowded
                if len(trend_strengths) <= 20:  # Only add labels if not too many clusters
                    topic_str = ', '.join(topics) if topics else ''
                    full_label = f"{label}: {topic_str}" if topic_str else label
                    ax.annotate(full_label, (x, y), xytext=(5, 5), textcoords='offset points', 
                               fontsize=8, alpha=0.8, bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.7))
            
            # Customize the plot
            ax.set_xlabel('Trend Strength')
            ax.set_ylabel('Growth Rate')
            ax.set_title(f'{horizon}-Day Horizon', fontweight='bold')
            ax.grid(True, alpha=0.3)
            ax.axhline(y=0, color='black', linestyle='--', alpha=0.5)
            ax.axvline(x=0, color='black', linestyle='--', alpha=0.5)
            
            # Add quadrant labels
            ax.text(0.02, 0.98, 'Declining\nTrend', transform=ax.transAxes, fontsize=8, alpha=0.6, 
                   verticalalignment='top', bbox=dict(boxstyle='round,pad=0.2', facecolor='lightblue', alpha=0.3))
            ax.text(0.98, 0.98, 'Growing\nTrend', transform=ax.transAxes, fontsize=8, alpha=0.6, 
                   verticalalignment='top', horizontalalignment='right', 
                   bbox=dict(boxstyle='round,pad=0.2', facecolor='lightgreen', alpha=0.3))
            ax.text(0.02, 0.02, 'Declining\nWeak Trend', transform=ax.transAxes, fontsize=8, alpha=0.6, 
                   bbox=dict(boxstyle='round,pad=0.2', facecolor='lightcoral', alpha=0.3))
            ax.text(0.98, 0.02, 'Growing\nWeak Trend', transform=ax.transAxes, fontsize=8, alpha=0.6, 
                   horizontalalignment='right', bbox=dict(boxstyle='round,pad=0.2', facecolor='lightyellow', alpha=0.3))
        
        # Hide extra subplots if any
        for idx in range(n_horizons, len(axes)):
            axes[idx].set_visible(False)
        
        # Add legend
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor='red', alpha=0.7, label='High Prediction (Top 25%)'),
            Patch(facecolor='green', alpha=0.7, label='Medium Prediction (Middle 50%)'),
            Patch(facecolor='blue', alpha=0.7, label='Low Prediction (Bottom 25%)')
        ]
        
        if n_horizons > 1:
            fig.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(0.98, 0.95))
        else:
            axes[0].legend(handles=legend_elements, loc='upper right')
        
        plt.tight_layout()
        plt.show()

    def plot_cluster_growth_by_horizon(self, multi_horizon_results: Dict, horizons: List[int] = None, 
                                 max_clusters_per_plot: int = 15) -> None:
        """
        Plot cluster growth rates across different horizons with clusters on x-axis.
        
        Args:
            multi_horizon_results (Dict): Results from generate_multi_horizon_forecasts
            horizons (List[int]): Specific horizons to plot (if None, plots all)
            max_clusters_per_plot (int): Maximum clusters per plot to avoid overcrowding
        """
        if not multi_horizon_results:
            print("No multi-horizon results to visualize")
            return
        
        available_horizons = sorted(multi_horizon_results.keys())
        if horizons is None:
            horizons = available_horizons
        else:
            horizons = [h for h in horizons if h in available_horizons]
        
        if not horizons:
            print("No valid horizons found")
            return
        
        # Collect all unique clusters across horizons
        all_clusters = set()
        for horizon in horizons:
            cluster_metrics = multi_horizon_results[horizon]['metrics']['cluster_metrics']
            all_clusters.update(cluster_metrics.keys())
        
        all_clusters = sorted(list(all_clusters))
        
        # If too many clusters, split into multiple plots
        if len(all_clusters) > max_clusters_per_plot:
            n_plots = (len(all_clusters) + max_clusters_per_plot - 1) // max_clusters_per_plot
            cluster_chunks = [all_clusters[i*max_clusters_per_plot:(i+1)*max_clusters_per_plot] 
                             for i in range(n_plots)]
        else:
            cluster_chunks = [all_clusters]
        
        for chunk_idx, cluster_chunk in enumerate(cluster_chunks):
            # Determine subplot layout
            n_horizons = len(horizons)
            if n_horizons == 1:
                fig, ax = plt.subplots(1, 1, figsize=(16, 8))
                axes = [ax]
            elif n_horizons <= 2:
                fig, axes = plt.subplots(1, n_horizons, figsize=(16, 8))
                if n_horizons == 1:
                    axes = [axes]
            elif n_horizons <= 4:
                fig, axes = plt.subplots(2, 2, figsize=(16, 12))
                axes = axes.flatten()
            elif n_horizons <= 6:
                fig, axes = plt.subplots(2, 3, figsize=(20, 12))
                axes = axes.flatten()
            else:
                fig, axes = plt.subplots(3, 3, figsize=(20, 16))
                axes = axes.flatten()
            
            plot_title = f'Cluster Growth Rates by Horizon'
            if len(cluster_chunks) > 1:
                plot_title += f' (Part {chunk_idx + 1}/{len(cluster_chunks)})'
            fig.suptitle(plot_title, fontsize=16, fontweight='bold')
            
            # Plot each horizon
            for idx, horizon in enumerate(horizons):
                ax = axes[idx]
                cluster_metrics = multi_horizon_results[horizon]['metrics']['cluster_metrics']
                
                # Prepare data for this horizon
                cluster_labels = []
                growth_rates = []
                colors = []
                topic_labels = []
                
                for cluster_id in cluster_chunk:
                    if cluster_id in cluster_metrics:
                        metrics = cluster_metrics[cluster_id]
                        growth_rate = metrics['growth_rate']
                        topic_words = metrics.get('topic_words', [])[:2]  # Take first 2 topic words
                        
                        # Create cluster label with topic
                        topic_str = ', '.join(topic_words) if topic_words else 'Unknown'
                        cluster_label = f"C{cluster_id}"
                        topic_labels.append(f"{cluster_label}\n{topic_str}")
                        
                        cluster_labels.append(cluster_label)
                        growth_rates.append(growth_rate)
                        
                        # Color based on growth rate
                        if growth_rate > 0.05:
                            colors.append('green')
                        elif growth_rate > 0:
                            colors.append('lightgreen')
                        elif growth_rate > -0.05:
                            colors.append('orange')
                        else:
                            colors.append('red')
                    else:
                        # Cluster not available for this horizon
                        cluster_label = f"C{cluster_id}"
                        topic_labels.append(f"{cluster_label}\nNo Data")
                        cluster_labels.append(cluster_label)
                        growth_rates.append(0)
                        colors.append('gray')
                
                # Create bar plot
                x_positions = range(len(cluster_labels))
                bars = ax.bar(x_positions, growth_rates, color=colors, alpha=0.7, edgecolor='black', linewidth=0.5)
                
                # Customize plot
                ax.set_xlabel('Clusters', fontweight='bold')
                ax.set_ylabel('Growth Rate', fontweight='bold')
                ax.set_title(f'{horizon}-Day Forecast Horizon', fontweight='bold')
                ax.set_xticks(x_positions)
                ax.set_xticklabels(topic_labels, rotation=45, ha='right', fontsize=9)
                ax.grid(True, alpha=0.3, axis='y')
                ax.axhline(y=0, color='black', linestyle='-', alpha=0.8, linewidth=1)
                
                # Add value labels on bars
                for bar, value in zip(bars, growth_rates):
                    height = bar.get_height()
                    label_y = height + 0.001 if height >= 0 else height - 0.005
                    ax.text(bar.get_x() + bar.get_width()/2, label_y, f'{value:.3f}', 
                           ha='center', va='bottom' if height >= 0 else 'top', fontsize=8, fontweight='bold')
                
                # Add growth rate categories legend (only on first subplot)
                if idx == 0:
                    from matplotlib.patches import Patch
                    legend_elements = [
                        Patch(facecolor='green', alpha=0.7, label='High Growth (>5%)'),
                        Patch(facecolor='lightgreen', alpha=0.7, label='Moderate Growth (0-5%)'),
                        Patch(facecolor='orange', alpha=0.7, label='Slight Decline (0 to -5%)'),
                        Patch(facecolor='red', alpha=0.7, label='Strong Decline (<-5%)'),
                        Patch(facecolor='gray', alpha=0.7, label='No Data')
                    ]
                    ax.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(0.02, 0.98), fontsize=8)
            
            # Hide extra subplots
            for idx in range(len(horizons), len(axes)):
                axes[idx].set_visible(False)
            
            plt.tight_layout()
            plt.show()
    
    
    def plot_cluster_comparison_across_horizons(self, multi_horizon_results: Dict, 
                                              selected_clusters: List[int] = None,
                                              top_n: int = 10) -> None:
        """
        Plot growth rate comparison for selected clusters across all horizons.
        
        Args:
            multi_horizon_results (Dict): Results from multi-horizon analysis
            selected_clusters (List[int]): Specific clusters to plot (if None, uses top performers)
            top_n (int): Number of top clusters to show if selected_clusters is None
        """
        if not multi_horizon_results:
            print("No multi-horizon results available")
            return
        
        horizons = sorted(multi_horizon_results.keys())
        
        # Determine clusters to plot
        if selected_clusters is None:
            # Find top performing clusters based on average growth rate
            cluster_avg_growth = {}
            all_clusters = set()
            
            for horizon in horizons:
                cluster_metrics = multi_horizon_results[horizon]['metrics']['cluster_metrics']
                for cluster_id, metrics in cluster_metrics.items():
                    all_clusters.add(cluster_id)
                    if cluster_id not in cluster_avg_growth:
                        cluster_avg_growth[cluster_id] = []
                    cluster_avg_growth[cluster_id].append(metrics['growth_rate'])
            
            # Calculate average growth rates
            for cluster_id in cluster_avg_growth:
                cluster_avg_growth[cluster_id] = np.mean(cluster_avg_growth[cluster_id])
            
            # Get top clusters
            top_clusters = sorted(cluster_avg_growth.items(), key=lambda x: x[1], reverse=True)[:top_n]
            selected_clusters = [cluster_id for cluster_id, _ in top_clusters]
        
        # Create plot
        fig, ax = plt.subplots(figsize=(14, 8))
        
        # Colors for different clusters
        colors = plt.cm.tab10(np.linspace(0, 1, len(selected_clusters)))
        
        # Plot each cluster's growth rate across horizons
        for i, cluster_id in enumerate(selected_clusters):
            growth_rates = []
            topic_words = []
            
            for horizon in horizons:
                cluster_metrics = multi_horizon_results[horizon]['metrics']['cluster_metrics']
                if cluster_id in cluster_metrics:
                    growth_rates.append(cluster_metrics[cluster_id]['growth_rate'])
                    if not topic_words:  # Get topic words from first available horizon
                        topic_words = cluster_metrics[cluster_id].get('topic_words', [])[:2]
                else:
                    growth_rates.append(0)  # No data available
            
            # Create label with topic words
            topic_str = ', '.join(topic_words) if topic_words else 'Unknown'
            label = f"C{cluster_id}: {topic_str}"
            
            # Plot line
            ax.plot(horizons, growth_rates, marker='o', linewidth=2, markersize=8, 
                    color=colors[i], label=label)
        
        # Customize plot
        ax.set_xlabel('Forecast Horizon (Days)', fontweight='bold', fontsize=12)
        ax.set_ylabel('Growth Rate', fontweight='bold', fontsize=12)
        ax.set_title('Cluster Growth Rate Comparison Across Forecast Horizons', fontweight='bold', fontsize=14)
        ax.grid(True, alpha=0.3)
        ax.axhline(y=0, color='black', linestyle='--', alpha=0.7)
        
        # Add legend
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        
        plt.tight_layout()
        plt.show()


    
    def run_enhanced_pipeline_with_forecasting(self, comment_files: List[str], video_file: str,
                                             n_clusters: int = None, forecast_days: int = 30) -> Dict:
        """
        Run the complete enhanced pipeline with ensemble forecasting.
        
        Args:
            comment_files (List[str]): List of comment CSV files
            video_file (str): Video CSV file path
            n_clusters (int): Number of clusters (if None, will optimize)
            forecast_days (int): Number of days to forecast
            
        Returns:
            Dict: Complete analysis results including ensemble forecasts
        """
        print("Starting enhanced pipeline with ensemble forecasting...")
        
        # Run base pipeline
        base_results = super().run_enhanced_pipeline(comment_files, video_file, n_clusters, forecast_days)
        
        # Train ensemble forecasting models
        self.train_ensemble_forecasting_models(forecast_days)
        
        # Generate ensemble forecasts
        forecast_results = self.generate_ensemble_forecasts(forecast_days)
        
        # Visualize forecasts
        if forecast_results:
            self.visualize_ensemble_forecasts(forecast_results)
        
        # Combine results
        enhanced_results = base_results.copy()
        enhanced_results['ensemble_forecasts'] = forecast_results
        enhanced_results['forecasting_summary'] = {
            'forecast_horizon_days': forecast_days,
            'clusters_forecasted': len(forecast_results),
            'models_used': ['LSTM', 'ARIMA', 'Prophet'],
            'ensemble_method': 'weighted_average',
            'forecast_categories': {
                category: len([r for r in forecast_results.values() if r['forecast_category'] == category])
                for category in ['rapid_growth', 'steady_growth', 'stable', 'declining', 'volatile']
            }
        }
        
        return enhanced_results

    def run_multi_horizon_analysis(self, comment_files: List[str], video_file: str,
                                  n_clusters: int = None, 
                                  forecast_horizons: List[int] = [20, 30, 40, 60, 80]) -> Dict:
        """
        Run complete analysis with multiple forecast horizons.
        
        Args:
            comment_files: List of comment CSV files
            video_file: Video CSV file path
            n_clusters: Number of clusters (if None, will optimize)
            forecast_horizons: List of forecast horizons in days
            
        Returns:
            Dict: Complete analysis results with multi-horizon forecasts
        """
        print("Starting multi-horizon trend analysis...")
        
        # Run base pipeline first
        base_results = self.run_enhanced_pipeline_with_forecasting(
            comment_files, video_file, n_clusters, max(forecast_horizons)
        )
        
        # Generate multi-horizon forecasts
        multi_horizon_results = self.generate_multi_horizon_forecasts(forecast_horizons)
        
        # Create comprehensive visualizations
        if multi_horizon_results:
        # Generate comprehensive generational forecasting visualizations
            print("Generating generational forecasting analysis...")
            self.visualize_all_generations_forecast(forecast_horizons)
            
            # Generate specific generation plots
            for generation in ['gen_z', 'millennial']:
                print(f"Generating {generation} specific growth analysis...")
                self.plot_generational_growth_by_clusters(generation, forecast_horizons[:3])
            # Combine all results
            complete_results = base_results.copy()
            complete_results['multi_horizon_analysis'] = multi_horizon_results
            complete_results['forecast_horizons'] = forecast_horizons
            
        return complete_results


# Model Saver

In [9]:
import pickle
import pandas as pd
import numpy as np
import json
import os
from datetime import datetime, timedelta
from typing import Dict, List, Union

class ModelSaver:
    """
    Improved model saver that handles ensemble forecasting models properly
    """
    
    def __init__(self, base_path='./saved_trend_models'):
        self.base_path = base_path
        os.makedirs(base_path, exist_ok=True)
    
    def save_enhanced_model(self, predictor, model_name=None):
        """
        Save model with proper handling of ensemble components and data
        """
        if model_name is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            model_name = f"enhanced_trend_ai_{timestamp}"
        
        model_dir = os.path.join(self.base_path, model_name)
        os.makedirs(model_dir, exist_ok=True)
        
        saved_files = {}
        
        # Save essential DataFrames using pandas pickle
        if hasattr(predictor, 'comments_df') and predictor.comments_df is not None:
            comments_path = os.path.join(model_dir, 'comments_df.pkl')
            predictor.comments_df.to_pickle(comments_path)
            saved_files['comments_df'] = comments_path
        
        if hasattr(predictor, 'videos_df') and predictor.videos_df is not None:
            videos_path = os.path.join(model_dir, 'videos_df.pkl')
            predictor.videos_df.to_pickle(videos_path)
            saved_files['videos_df'] = videos_path
        
        # CRITICAL: Save combined_trend_data
        if hasattr(predictor, 'combined_trend_data') and predictor.combined_trend_data is not None:
            trend_data_path = os.path.join(model_dir, 'combined_trend_data.pkl')
            predictor.combined_trend_data.to_pickle(trend_data_path)
            saved_files['combined_trend_data'] = trend_data_path
        
        # Save model components that can be pickled safely
        safe_attributes = [
            'embeddings', 'clusters', 'cluster_topics', 'cluster_tags', 
            'popular_tags', 'generational_clusters', 'scaler',
            'video_weight', 'comment_weight', 'tag_weight'
        ]
        
        model_components = {}
        for attr in safe_attributes:
            if hasattr(predictor, attr):
                model_components[attr] = getattr(predictor, attr)
        
        components_path = os.path.join(model_dir, 'model_components.pkl')
        with open(components_path, 'wb') as f:
            pickle.dump(model_components, f)
        saved_files['model_components'] = components_path
        
        # Save metadata
        metadata = {
            'model_name': model_name,
            'creation_date': datetime.now().isoformat(),
            'model_type': 'EnhancedTrendAI',
            'saved_components': list(model_components.keys()),
            'has_trend_data': hasattr(predictor, 'combined_trend_data') and predictor.combined_trend_data is not None
        }
        
        metadata_path = os.path.join(model_dir, 'metadata.json')
        with open(metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2, default=str)
        saved_files['metadata'] = metadata_path
        
        print(f"Enhanced model saved successfully to: {model_dir}")
        return {'model_name': model_name, 'saved_files': saved_files}
    
    def load_enhanced_model(self, model_name):
        """
        Load model with proper data restoration
        """
        model_dir = os.path.join(self.base_path, model_name)
        
        if not os.path.exists(model_dir):
            raise FileNotFoundError(f"Model directory not found: {model_dir}")
        
        # Initialize fresh predictor
        predictor = EnhancedTrendAI()
        
        # Load DataFrames
        comments_path = os.path.join(model_dir, 'comments_df.pkl')
        if os.path.exists(comments_path):
            predictor.comments_df = pd.read_pickle(comments_path)
            print("Loaded comments DataFrame")
        
        videos_path = os.path.join(model_dir, 'videos_df.pkl')
        if os.path.exists(videos_path):
            predictor.videos_df = pd.read_pickle(videos_path)
            print("Loaded videos DataFrame")
        
        # CRITICAL: Load trend data
        trend_data_path = os.path.join(model_dir, 'combined_trend_data.pkl')
        if os.path.exists(trend_data_path):
            predictor.combined_trend_data = pd.read_pickle(trend_data_path)
            print("Loaded combined trend data")
        
        # Load model components
        components_path = os.path.join(model_dir, 'model_components.pkl')
        if os.path.exists(components_path):
            with open(components_path, 'rb') as f:
                components = pickle.load(f)
            
            for attr, value in components.items():
                setattr(predictor, attr, value)
            print(f"Loaded model components: {list(components.keys())}")
        
        # Reinitialize ensemble forecaster if trend data exists
        if hasattr(predictor, 'combined_trend_data') and predictor.combined_trend_data is not None:
            predictor.ensemble_forecaster = EnsembleForecaster()
            print("Ensemble forecaster reinitialized")
        
        return predictor


# Advanced Genrational TrendAI

In [10]:
class EnhancedGenerationalTrendAI(EnhancedTrendAI):
    """
    Enhanced TrendAI with improved generational analysis.
    """
    
    def __init__(self, model_name='all-MiniLM-L6-v2'):
        super().__init__(model_name)
        # Replace the analyzer with enhanced version
        self.generational_analyzer = EnhancedGenerationalLanguageAnalyzer()
    
    def analyze_generational_patterns(self) -> None:
        """Enhanced generational analysis with debugging."""
        print("Analyzing generational language patterns with enhanced detection...")
        
        if self.comments_df is None:
            print("No comment data available")
            return
        
        # Sample some comments for debugging
        sample_comments = self.comments_df['cleaned_text'].head(100).tolist()
        print("Testing generational detection on sample comments...")
        
        generation_counts = {'gen_z': 0, 'millennial': 0, 'gen_x': 0, 'boomer': 0, 'neutral': 0}
        
        for comment in sample_comments:
            if pd.notna(comment) and len(str(comment).strip()) > 5:
                classification = self.generational_analyzer.classify_generation(str(comment))
                generation_counts[classification] += 1
        
        print("Sample classification results:")
        for gen, count in generation_counts.items():
            print(f"  {gen}: {count}/100 ({count}%)")
        
        # Analyze all comments
        print("Analyzing all comments...")
        self.comments_df['generational_scores'] = self.comments_df['cleaned_text'].apply(
            lambda x: self.generational_analyzer.analyze_generational_language(str(x)) if pd.notna(x) else {}
        )
        
        # Extract individual generation scores
        for generation in ['gen_z', 'millennial', 'gen_x', 'boomer']:
            self.comments_df[f'{generation}_score'] = self.comments_df['generational_scores'].apply(
                lambda x: x.get(generation, 0) if isinstance(x, dict) else 0
            )
        
        # Classify predominant generation with more liberal approach
        self.comments_df['dominant_generation'] = self.comments_df['cleaned_text'].apply(
            lambda x: self.generational_analyzer.classify_generation(str(x)) if pd.notna(x) else 'neutral'
        )
        
        # Print final distribution
        final_distribution = self.comments_df['dominant_generation'].value_counts()
        print("Final generational distribution:")
        total_comments = len(self.comments_df)
        for gen, count in final_distribution.items():
            percentage = (count / total_comments) * 100
            print(f"  {gen}: {count:,} comments ({percentage:.1f}%)")
        
        # Analyze videos if available
        if self.videos_df is not None:
            print("Analyzing videos for generational patterns...")
            combined_video_text = (
                self.videos_df.get('cleaned_title', '').fillna('') + ' ' + 
                self.videos_df.get('cleaned_description', '').fillna('')
            )
            
            self.videos_df['generational_scores'] = combined_video_text.apply(
                lambda x: self.generational_analyzer.analyze_generational_language(str(x)) if pd.notna(x) and str(x).strip() else {}
            )
            
            # Extract individual generation scores for videos
            for generation in ['gen_z', 'millennial', 'gen_x', 'boomer']:
                self.videos_df[f'{generation}_score'] = self.videos_df['generational_scores'].apply(
                    lambda x: x.get(generation, 0) if isinstance(x, dict) else 0
                )
            
            self.videos_df['dominant_generation'] = combined_video_text.apply(
                lambda x: self.generational_analyzer.classify_generation(str(x)) if pd.notna(x) and str(x).strip() else 'neutral'
            )
            
            # Print video distribution
            video_distribution = self.videos_df['dominant_generation'].value_counts()
            print("Video generational distribution:")
            total_videos = len(self.videos_df)
            for gen, count in video_distribution.items():
                percentage = (count / total_videos) * 100
                print(f"  {gen}: {count} videos ({percentage:.1f}%)")
    
    def analyze_generational_trends_by_cluster(self) -> None:
        """Enhanced cluster analysis with debugging and validation."""
        print("Analyzing generational trends by cluster with enhanced detection...")
        
        if self.comments_df is None or 'cluster' not in self.comments_df.columns:
            print("Comment data or clustering results not available")
            return
        
        # Ensure generational analysis is completed
        if 'dominant_generation' not in self.comments_df.columns:
            self.analyze_generational_patterns()
        
        clusters = self.comments_df['cluster'].unique()
        self.generational_clusters = {}
        
        print(f"Analyzing {len(clusters)} clusters...")
        
        for cluster_id in clusters:
            cluster_comments = self.comments_df[self.comments_df['cluster'] == cluster_id].copy()
            
            if len(cluster_comments) < 5:  # Reduced minimum threshold
                continue
            
            try:
                # Enhanced generational distribution analysis
                generation_distribution = cluster_comments['dominant_generation'].value_counts(normalize=True)
                
                if len(generation_distribution) == 0:
                    continue
                
                # Get the most dominant generation
                dominant_generation = generation_distribution.index[0]
                dominant_generation_score = generation_distribution.iloc[0]
                
                # Calculate average generational scores for this cluster
                avg_generational_scores = {}
                for generation in ['gen_z', 'millennial', 'gen_x', 'boomer']:
                    if f'{generation}_score' in cluster_comments.columns:
                        avg_generational_scores[generation] = cluster_comments[f'{generation}_score'].mean()
                    else:
                        avg_generational_scores[generation] = 0.0
                
                # Get topic information
                topic_words = self.cluster_topics.get(cluster_id, [])[:10]
                topic_tags = self.cluster_tags.get(cluster_id, [])[:5]
                
                # Enhanced video performance analysis by generation
                video_performance_by_generation = {}
                if hasattr(self, 'videos_df') and self.videos_df is not None and 'cluster' in self.videos_df.columns:
                    cluster_videos = self.videos_df[self.videos_df['cluster'] == cluster_id]
                    
                    if not cluster_videos.empty and 'dominant_generation' in cluster_videos.columns:
                        for generation in ['gen_z', 'millennial', 'gen_x', 'boomer', 'neutral']:
                            gen_videos = cluster_videos[cluster_videos['dominant_generation'] == generation]
                            if len(gen_videos) > 0:
                                video_performance_by_generation[generation] = {
                                    'count': len(gen_videos),
                                    'avg_views': gen_videos.get('viewCount', pd.Series([0])).mean(),
                                    'avg_likes': gen_videos.get('likeCount', pd.Series([0])).mean(),
                                    'avg_engagement': gen_videos.get('engagement_rate', pd.Series([0])).mean()
                                }
                
                self.generational_clusters[cluster_id] = {
                    'dominant_generation': dominant_generation,
                    'dominant_generation_score': float(dominant_generation_score),
                    'generation_distribution': generation_distribution.to_dict(),
                    'avg_generational_scores': avg_generational_scores,
                    'topic_words': topic_words,
                    'topic_tags': topic_tags,
                    'video_performance_by_generation': video_performance_by_generation,
                    'cluster_size': len(cluster_comments)
                }
                
                # Debug information for first few clusters
                if len(self.generational_clusters) <= 3:
                    print(f"Cluster {cluster_id} analysis:")
                    print(f"  Size: {len(cluster_comments)} comments")
                    print(f"  Dominant generation: {dominant_generation} ({dominant_generation_score:.2%})")
                    print(f"  Topics: {', '.join(topic_words[:3])}")
                    print(f"  Generation distribution: {dict(generation_distribution.round(3))}")
                    print()
                
            except Exception as e:
                print(f"Error analyzing cluster {cluster_id}: {e}")
                continue
        
        print(f"✅ Generational analysis completed for {len(self.generational_clusters)} clusters")
        
        # Print summary by generation
        generation_cluster_counts = {}
        for cluster_data in self.generational_clusters.values():
            gen = cluster_data['dominant_generation']
            generation_cluster_counts[gen] = generation_cluster_counts.get(gen, 0) + 1
        
        print("Clusters by dominant generation:")
        for gen, count in sorted(generation_cluster_counts.items()):
            print(f"  {gen}: {count} clusters")

# Data Exporter

In [11]:
import json
from datetime import datetime, timedelta
import pandas as pd
import numpy as np
from typing import Dict, List, Any
import calendar

class TrendDataExporter:
    """
    Class to export comprehensive trend analysis data to JSON format
    """
    
    def __init__(self, predictor: EnhancedTrendAI):
        self.predictor = predictor
        self.export_timestamp = datetime.now()
        
    def generate_monthly_leaderboards(self, start_date: str = None, months_ahead: int = 12) -> Dict:
        """
        Generate monthly trend leaderboards for forecasted trends
        
        Args:
            start_date (str): Start date in 'YYYY-MM' format (default: current month)
            months_ahead (int): Number of months to forecast
            
        Returns:
            Dict: Monthly leaderboards with top trending topics, tags, and videos
        """
        if start_date is None:
            current_date = datetime.now()
        else:
            current_date = datetime.strptime(start_date + "-01", "%Y-%m-%d")
        
        monthly_leaderboards = {}
        
        # Generate forecasts for extended periods (up to 12 months)
        extended_horizons = [30 * i for i in range(1, months_ahead + 1)]  # 30, 60, 90... days
        multi_horizon_results = self.predictor.generate_multi_horizon_forecasts(extended_horizons)
        
        for month_offset in range(months_ahead):
            target_date = current_date + timedelta(days=30 * month_offset)
            month_key = target_date.strftime("%B_%Y")  # e.g., "July_2024"
            
            horizon_days = 30 * (month_offset + 1)
            
            if horizon_days in multi_horizon_results:
                horizon_data = multi_horizon_results[horizon_days]
                cluster_metrics = horizon_data['metrics']['cluster_metrics']
                
                # Get top trending topics for this month
                trending_topics = []
                for cluster_id, metrics in cluster_metrics.items():
                    if metrics['growth_rate'] > 0:  # Only positive growth
                        trending_topics.append({
                            'cluster_id': cluster_id,
                            'topic_words': metrics.get('topic_words', [])[:3],
                            'growth_rate': metrics['growth_rate'],
                            'prediction_mean': metrics['prediction_mean'],
                            'trend_strength': metrics['trend_strength']
                        })
                
                # Sort by growth rate and get top 10
                trending_topics.sort(key=lambda x: x['growth_rate'], reverse=True)
                top_10_topics = trending_topics[:10]
                
                # Rank the topics
                for rank, topic in enumerate(top_10_topics, 1):
                    topic['ranking'] = rank
                
                # Get top tags for this period
                top_tags = self._get_top_tags_for_period(cluster_metrics, limit=10)
                
                # Get top videos for this period
                top_videos = self._get_top_videos_for_period(cluster_metrics, limit=2)
                
                monthly_leaderboards[month_key] = {
                    'month': target_date.strftime("%B"),
                    'year': target_date.year,
                    'forecast_date': target_date.strftime("%Y-%m-%d"),
                    'days_ahead': horizon_days,
                    'top_trending_topics': [
                        {
                            'ranking': topic['ranking'],
                            'predicted_trend': ', '.join(topic['topic_words']),
                            'growth_rate': round(topic['growth_rate'], 4),
                            'prediction_confidence': round(topic['prediction_mean'], 4),
                            'trend_strength': round(topic['trend_strength'], 4)
                        }
                        for topic in top_10_topics
                    ],
                    'top_tags': top_tags,
                    'top_videos': top_videos,
                    'total_forecasted_clusters': len(cluster_metrics),
                    'positive_growth_clusters': len([m for m in cluster_metrics.values() if m['growth_rate'] > 0])
                }
        
        return monthly_leaderboards
    
    def generate_forecast_graph_data(self, days_range: int = 100, top_topics: int = 10) -> Dict:
        """
        Generate 100-day forecast data for graphing
        
        Args:
            days_range (int): Number of days to forecast (default 100)
            top_topics (int): Number of top topics to track
            
        Returns:
            Dict: Daily forecast data for top trending topics
        """
        # Generate forecasts for each day from 1 to 100
        daily_horizons = list(range(10, days_range + 1, 10))  # 10, 20, 30... 100
        multi_horizon_results = self.predictor.generate_multi_horizon_forecasts(daily_horizons)
        
        # Identify top performing topics across all horizons
        topic_performance = {}
        for horizon, results in multi_horizon_results.items():
            for cluster_id, metrics in results['metrics']['cluster_metrics'].items():
                if cluster_id not in topic_performance:
                    topic_performance[cluster_id] = {
                        'total_growth': 0,
                        'topic_words': metrics.get('topic_words', [])[:3],
                        'appearances': 0
                    }
                topic_performance[cluster_id]['total_growth'] += metrics['growth_rate']
                topic_performance[cluster_id]['appearances'] += 1
        
        # Calculate average growth and get top topics
        for cluster_id in topic_performance:
            avg_growth = topic_performance[cluster_id]['total_growth'] / topic_performance[cluster_id]['appearances']
            topic_performance[cluster_id]['avg_growth'] = avg_growth
        
        top_topic_clusters = sorted(
            topic_performance.items(), 
            key=lambda x: x[1]['avg_growth'], 
            reverse=True
        )[:top_topics]
        
        # Generate daily data points
        forecast_data = {
            'metadata': {
                'generation_date': self.export_timestamp.isoformat(),
                'forecast_range_days': days_range,
                'top_topics_count': len(top_topic_clusters),
                'data_points_per_topic': len(daily_horizons)
            },
            'topics': {},
            'aggregated_daily_data': []
        }
        
        # Generate data for each top topic
        for cluster_id, topic_data in top_topic_clusters:
            topic_name = ', '.join(topic_data['topic_words'])
            
            daily_growth_rates = []
            for horizon in daily_horizons:
                if (horizon in multi_horizon_results and 
                    cluster_id in multi_horizon_results[horizon]['metrics']['cluster_metrics']):
                    growth_rate = multi_horizon_results[horizon]['metrics']['cluster_metrics'][cluster_id]['growth_rate']
                else:
                    growth_rate = 0
                
                daily_growth_rates.append({
                    'day': horizon,
                    'growth_rate': round(growth_rate, 6),
                    'date': (datetime.now() + timedelta(days=horizon)).strftime("%Y-%m-%d")
                })
            
            forecast_data['topics'][topic_name] = {
                'cluster_id': cluster_id,
                'topic_keywords': topic_data['topic_words'],
                'average_growth_rate': round(topic_data['avg_growth'], 6),
                'daily_forecasts': daily_growth_rates
            }
        
        # Generate aggregated daily data for overall market trends
        for horizon in daily_horizons:
            if horizon in multi_horizon_results:
                cluster_metrics = multi_horizon_results[horizon]['metrics']['cluster_metrics']
                
                total_growth = sum(m['growth_rate'] for m in cluster_metrics.values())
                avg_growth = total_growth / len(cluster_metrics) if cluster_metrics else 0
                positive_growth_count = sum(1 for m in cluster_metrics.values() if m['growth_rate'] > 0)
                
                forecast_data['aggregated_daily_data'].append({
                    'day': horizon,
                    'date': (datetime.now() + timedelta(days=horizon)).strftime("%Y-%m-%d"),
                    'average_market_growth': round(avg_growth, 6),
                    'total_growth_sum': round(total_growth, 6),
                    'positive_trends_count': positive_growth_count,
                    'total_clusters_analyzed': len(cluster_metrics)
                })
        
        return forecast_data
    
    def extract_raw_model_data(self) -> Dict:
        """
        Extract comprehensive raw data from model execution for LLM analysis
        
        Returns:
            Dict: Complete raw data from model analysis
        """
        raw_data = {
            'model_metadata': {
                'model_type': 'EnhancedTrendAI',
                'analysis_timestamp': self.export_timestamp.isoformat(),
                'model_version': getattr(self.predictor, 'version', '1.0.0'),
                'weight_factors': {
                    'video_weight': self.predictor.video_weight,
                    'comment_weight': self.predictor.comment_weight,
                    'tag_weight': self.predictor.tag_weight
                }
            },
            'data_summary': {
                'total_comments_analyzed': len(self.predictor.comments_df) if self.predictor.comments_df is not None else 0,
                'total_videos_analyzed': len(self.predictor.videos_df) if self.predictor.videos_df is not None else 0,
                'total_clusters_identified': len(set(self.predictor.clusters)) if self.predictor.clusters is not None else 0,
                'analysis_period': {
                    'start_date': str(self.predictor.comments_df['publishedAt'].min().date()) if self.predictor.comments_df is not None and not self.predictor.comments_df.empty else None,
                    'end_date': str(self.predictor.comments_df['publishedAt'].max().date()) if self.predictor.comments_df is not None and not self.predictor.comments_df.empty else None
                }
            },
            'cluster_analysis': {},
            'sentiment_analysis': {},
            'generational_analysis': {},
            'tag_analysis': {},
            'video_performance_metrics': {},
            'trending_patterns': {}
        }
        
        # Extract cluster data
        if hasattr(self.predictor, 'cluster_topics') and self.predictor.cluster_topics:
            for cluster_id, topic_words in self.predictor.cluster_topics.items():
                cluster_data = {
                    'cluster_id': cluster_id,
                    'topic_words': topic_words,
                    'topic_tags': self.predictor.cluster_tags.get(cluster_id, []),
                    'cluster_size': 0
                }
                
                if self.predictor.comments_df is not None:
                    cluster_comments = self.predictor.comments_df[self.predictor.comments_df['cluster'] == cluster_id]
                    cluster_data.update({
                        'cluster_size': len(cluster_comments),
                        'avg_sentiment': cluster_comments['compound'].mean() if 'compound' in cluster_comments.columns else 0,
                        'sentiment_distribution': {
                            'positive': len(cluster_comments[cluster_comments['compound'] > 0.1]) if 'compound' in cluster_comments.columns else 0,
                            'neutral': len(cluster_comments[(cluster_comments['compound'] >= -0.1) & (cluster_comments['compound'] <= 0.1)]) if 'compound' in cluster_comments.columns else 0,
                            'negative': len(cluster_comments[cluster_comments['compound'] < -0.1]) if 'compound' in cluster_comments.columns else 0
                        }
                    })
                
                raw_data['cluster_analysis'][str(cluster_id)] = cluster_data
        
        # Extract sentiment analysis
        if self.predictor.comments_df is not None and 'compound' in self.predictor.comments_df.columns:
            sentiment_stats = {
                'overall_sentiment_mean': self.predictor.comments_df['compound'].mean(),
                'overall_sentiment_std': self.predictor.comments_df['compound'].std(),
                'sentiment_distribution': {
                    'very_positive': len(self.predictor.comments_df[self.predictor.comments_df['compound'] > 0.5]),
                    'positive': len(self.predictor.comments_df[(self.predictor.comments_df['compound'] > 0.1) & (self.predictor.comments_df['compound'] <= 0.5)]),
                    'neutral': len(self.predictor.comments_df[(self.predictor.comments_df['compound'] >= -0.1) & (self.predictor.comments_df['compound'] <= 0.1)]),
                    'negative': len(self.predictor.comments_df[(self.predictor.comments_df['compound'] >= -0.5) & (self.predictor.comments_df['compound'] < -0.1)]),
                    'very_negative': len(self.predictor.comments_df[self.predictor.comments_df['compound'] < -0.5])
                }
            }
            raw_data['sentiment_analysis'] = sentiment_stats
        
        # Extract generational analysis
        if hasattr(self.predictor, 'generational_clusters') and self.predictor.generational_clusters:
            generational_data = {}
            generation_distribution = {}
            
            for cluster_id, gen_data in self.predictor.generational_clusters.items():
                generational_data[str(cluster_id)] = {
                    'dominant_generation': gen_data['dominant_generation'],
                    'confidence_score': gen_data['dominant_generation_score'],
                    'generation_distribution': gen_data['generation_distribution'],
                    'avg_generational_scores': gen_data['avg_generational_scores']
                }
                
                # Aggregate generation distribution
                dominant_gen = gen_data['dominant_generation']
                if dominant_gen not in generation_distribution:
                    generation_distribution[dominant_gen] = 0
                generation_distribution[dominant_gen] += 1
            
            raw_data['generational_analysis'] = {
                'cluster_level_analysis': generational_data,
                'overall_generation_distribution': generation_distribution
            }
        
        # Extract tag analysis
        if hasattr(self.predictor, 'popular_tags') and self.predictor.popular_tags is not None:
            tag_data = {}
            for tag, data in self.predictor.popular_tags.head(50).iterrows():  # Top 50 tags
                tag_data[tag] = {
                    'frequency': int(data['frequency']),
                    'total_views': int(data['total_views']),
                    'total_likes': int(data['total_likes']),
                    'popularity_score': float(data['tag_popularity_score']),
                    'avg_views_per_video': float(data['avg_views_per_video']),
                    'avg_likes_per_video': float(data['avg_likes_per_video'])
                }
            
            raw_data['tag_analysis'] = {
                'top_tags': tag_data,
                'total_unique_tags': len(self.predictor.popular_tags)
            }
        
        # Extract video performance metrics
        if self.predictor.videos_df is not None:
            video_stats = {
                'total_views': int(self.predictor.videos_df['viewCount'].sum()),
                'total_likes': int(self.predictor.videos_df['likeCount'].sum()),
                'total_comments': int(self.predictor.videos_df['commentCount'].sum()),
                'avg_engagement_rate': float(self.predictor.videos_df['engagement_rate'].mean()),
                'avg_trending_score': float(self.predictor.videos_df['trending_score'].mean()),
                'top_performing_videos': []
            }
            
            # Get top 10 performing videos
            top_videos = self.predictor.videos_df.nlargest(10, 'trending_score')
            for _, video in top_videos.iterrows():
                video_stats['top_performing_videos'].append({
                    'title': video['title'],
                    'views': int(video['viewCount']),
                    'likes': int(video['likeCount']),
                    'comments': int(video['commentCount']),
                    'trending_score': float(video['trending_score']),
                    'engagement_rate': float(video['engagement_rate']),
                    'published_date': str(video['publishedAt'].date()) if pd.notna(video['publishedAt']) else None
                })
            
            raw_data['video_performance_metrics'] = video_stats
        
        # Extract trending patterns
        if hasattr(self.predictor, 'combined_trend_data') and not self.predictor.combined_trend_data.empty:
            trending_patterns = {
                'temporal_trends': [],
                'cluster_performance_over_time': {}
            }
            
            # Aggregate temporal trends
            temporal_agg = self.predictor.combined_trend_data.groupby('date').agg({
                'combined_trending_score': 'mean',
                'total_views': 'sum',
                'video_likes': 'sum',
                'comment_count': 'sum'
            }).reset_index()
            
            for _, row in temporal_agg.iterrows():
                trending_patterns['temporal_trends'].append({
                    'date': str(row['date'].date()),
                    'avg_trending_score': float(row['combined_trending_score']),
                    'total_daily_views': int(row['total_views']),
                    'total_daily_likes': int(row['video_likes']),
                    'total_daily_comments': int(row['comment_count'])
                })
            
            raw_data['trending_patterns'] = trending_patterns
        
        return raw_data
    
    def _get_top_tags_for_period(self, cluster_metrics: Dict, limit: int = 10) -> List[Dict]:
        """Extract top tags for a specific forecast period"""
        tag_performance = {}
        
        for cluster_id, metrics in cluster_metrics.items():
            if hasattr(self.predictor, 'cluster_tags') and cluster_id in self.predictor.cluster_tags:
                cluster_tags = self.predictor.cluster_tags[cluster_id]
                growth_rate = metrics['growth_rate']
                
                for tag in cluster_tags:
                    if tag not in tag_performance:
                        tag_performance[tag] = {'total_growth': 0, 'cluster_count': 0}
                    tag_performance[tag]['total_growth'] += growth_rate
                    tag_performance[tag]['cluster_count'] += 1
        
        # Calculate average growth for each tag
        for tag in tag_performance:
            avg_growth = tag_performance[tag]['total_growth'] / tag_performance[tag]['cluster_count']
            tag_performance[tag]['avg_growth'] = avg_growth
        
        # Sort and get top tags
        top_tags = sorted(tag_performance.items(), key=lambda x: x[1]['avg_growth'], reverse=True)[:limit]
        
        return [
            {
                'ranking': i + 1,
                'tag': tag,
                'growth_rate': round(data['avg_growth'], 4),
                'cluster_count': data['cluster_count']
            }
            for i, (tag, data) in enumerate(top_tags)
        ]
    
    def _get_top_videos_for_period(self, cluster_metrics: Dict, limit: int = 2) -> List[Dict]:
        """Extract top videos for a specific forecast period"""
        if self.predictor.videos_df is None:
            return []
        
        # Get clusters with highest growth rates
        top_growth_clusters = sorted(
            cluster_metrics.items(), 
            key=lambda x: x[1]['growth_rate'], 
            reverse=True
        )[:5]  # Look at top 5 growing clusters
        
        top_videos = []
        for cluster_id, metrics in top_growth_clusters:
            if 'cluster' in self.predictor.videos_df.columns:
                cluster_videos = self.predictor.videos_df[
                    self.predictor.videos_df['cluster'] == cluster_id
                ].nlargest(1, 'trending_score')  # Get top video from this cluster
                
                for _, video in cluster_videos.iterrows():
                    top_videos.append({
                        'title': video['title'],
                        'views': int(video['viewCount']),
                        'likes': int(video['likeCount']),
                        'trending_score': float(video['trending_score']),
                        'predicted_growth_rate': round(metrics['growth_rate'], 4),
                        'cluster_topics': ', '.join(metrics.get('topic_words', [])[:3])
                    })
                    
                    if len(top_videos) >= limit:
                        break
            
            if len(top_videos) >= limit:
                break
        
        # Add ranking
        for i, video in enumerate(top_videos[:limit], 1):
            video['ranking'] = i
        
        return top_videos[:limit]
    
    def export_complete_analysis(self, output_file: str = None, months_ahead: int = 12) -> Dict:
        """
        Export complete trend analysis to JSON
        
        Args:
            output_file (str): Output JSON file path
            months_ahead (int): Number of months to forecast for leaderboards
            
        Returns:
            Dict: Complete analysis data
        """
        print("Generating comprehensive trend analysis export...")
        
        complete_data = {
            'export_info': {
                'generation_timestamp': self.export_timestamp.isoformat(),
                'export_version': '1.0.0',
                'data_types_included': [
                    'monthly_leaderboards',
                    'forecast_graph_data', 
                    'raw_model_data'
                ]
            }
        }
        
        try:
            # Generate monthly leaderboards
            print("Generating monthly trend leaderboards...")
            complete_data['monthly_leaderboards'] = self.generate_monthly_leaderboards(
                months_ahead=months_ahead
            )
            
            # Generate forecast graph data
            print("Generating 100-day forecast graph data...")
            complete_data['forecast_graph_data'] = self.generate_forecast_graph_data()
            
            # Extract raw model data
            print("Extracting comprehensive raw model data...")
            complete_data['raw_model_data'] = self.extract_raw_model_data()
            
            # Add summary statistics
            complete_data['summary_statistics'] = {
                'total_months_forecasted': len(complete_data['monthly_leaderboards']),
                'total_topics_tracked': len(complete_data['forecast_graph_data']['topics']),
                'total_clusters_analyzed': complete_data['raw_model_data']['data_summary']['total_clusters_identified'],
                'forecast_data_points': len(complete_data['forecast_graph_data']['aggregated_daily_data']),
                'top_growth_rate': max([
                    topic_data['average_growth_rate'] 
                    for topic_data in complete_data['forecast_graph_data']['topics'].values()
                ], default=0)
            }
            
            # Save to file if specified
            if output_file:
                with open(output_file, 'w', encoding='utf-8') as f:
                    json.dump(complete_data, f, indent=2, ensure_ascii=False, default=str)
                print(f"Complete analysis exported to: {output_file}")
            
            print("Export completed successfully!")
            return complete_data
            
        except Exception as e:
            print(f"Error during export: {e}")
            raise


Initializing Enhanced YouTube Trend Predictor with Tags Integration...


[nltk_data] Downloading package vader_lexicon to
[nltk_data]     /usr/share/nltk_data...
[nltk_data]   Package vader_lexicon is already up-to-date!


Starting multi-horizon trend analysis...
Starting enhanced pipeline with ensemble forecasting...
Starting enhanced trend analysis pipeline with video and tag emphasis...
Loading data...
Loaded 1000000 comments from /kaggle/input/datathon/comments1.csv
Loaded 999999 comments from /kaggle/input/datathon/comments2.csv
Loaded 999999 comments from /kaggle/input/datathon/comments3.csv
Error loading /kaggle/input/datathon/comments4.csv: Error tokenizing data. C error: Calling read(nbytes) on source failed. Try engine='python'.
Loaded 725015 comments from /kaggle/input/datathon/comments5.csv
Total comments loaded: 3725013
Error loading video file: Error tokenizing data. C error: Calling read(nbytes) on source failed. Try engine='python'.
Preprocessing data with tag integration...


KeyboardInterrupt: 

# Main Function

In [None]:
# Modified main execution with JSON export
if __name__ == "__main__":
    # Initialize the enhanced predictor with ensemble forecasting
    predictor = EnhancedGenerationalTrendAI()
    
    # Define your file paths here
    comment_files = [
        '/kaggle/input/datathon/comments1.csv',
        '/kaggle/input/datathon/comments2.csv',
        '/kaggle/input/datathon/comments3.csv',
        '/kaggle/input/datathon/comments4.csv',
        '/kaggle/input/datathon/comments5.csv'
    ]
    video_file = '/kaggle/input/datathon/videos.csv'
    forecast_horizons = [20, 30, 40, 60, 80]
    
    try:
        # Run the complete analysis
        results = predictor.run_multi_horizon_analysis(
            comment_files, 
            video_file, 
            n_clusters=25,
            forecast_horizons=forecast_horizons
        )
        
        print("\n" + "="*80)
        print("GENERATING COMPREHENSIVE JSON EXPORT")
        print("="*80)
        
        # Initialize the data exporter
        exporter = TrendDataExporter(predictor)
        
        # Generate complete analysis export
        json_output_file = f"trend_analysis_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        
        complete_export_data = exporter.export_complete_analysis(
            output_file=json_output_file,
            months_ahead=12  # 12 months of forecasts
        )
        
        # Display summary of exported data
        print("\n📊 EXPORT SUMMARY:")
        print(f"📁 File saved as: {json_output_file}")
        print(f"📅 Months forecasted: {complete_export_data['summary_statistics']['total_months_forecasted']}")
        print(f"🎯 Topics tracked: {complete_export_data['summary_statistics']['total_topics_tracked']}")
        print(f"📈 Forecast data points: {complete_export_data['summary_statistics']['forecast_data_points']}")
        print(f"🏆 Top growth rate: {complete_export_data['summary_statistics']['top_growth_rate']:.4f}")
        
        # Show sample of monthly leaderboard
        print("\n🏆 SAMPLE MONTHLY LEADERBOARD (First Month):")
        first_month = list(complete_export_data['monthly_leaderboards'].keys())[0]
        leaderboard = complete_export_data['monthly_leaderboards'][first_month]
        
        print(f"Month: {leaderboard['month']} {leaderboard['year']}")
        print("Top 5 Predicted Trends:")
        for topic in leaderboard['top_trending_topics'][:5]:
            print(f"  {topic['ranking']}. {topic['predicted_trend']} (Growth: {topic['growth_rate']:.2%})")
        
        print(f"\nTop 5 Tags:")
        for tag in leaderboard['top_tags'][:5]:
            print(f"  {tag['ranking']}. {tag['tag']} (Growth: {tag['growth_rate']:.2%})")
        
        if leaderboard['top_videos']:
            print(f"\nTop Videos:")
            for video in leaderboard['top_videos']:
                print(f"  {video['ranking']}. {video['title'][:50]}...")
        
        # Show sample forecast data structure
        print("\n📈 FORECAST GRAPH DATA STRUCTURE:")
        forecast_data = complete_export_data['forecast_graph_data']
        print(f"Data points per topic: {forecast_data['metadata']['data_points_per_topic']}")
        print(f"First topic: {list(forecast_data['topics'].keys())[0]}")
        
        sample_topic = list(forecast_data['topics'].values())[0]
        print(f"Sample data points (first 3 days): {sample_topic['daily_forecasts'][:3]}")
        
        print("\n✅ JSON export completed successfully!")
        print(f"📄 The file '{json_output_file}' contains all the data you requested:")
        print("  - Monthly trend leaderboards with rankings and growth rates")
        print("  - Top 10 tags per month")
        print("  - Top 2 videos per month")
        print("  - 100-day forecast data for graphing (10-day intervals)")
        print("  - Complete raw model data for LLM analysis")
        
        # Additional file with just the forecast data for easy graphing
        forecast_only_file = f"forecast_data_only_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        with open(forecast_only_file, 'w', encoding='utf-8') as f:
            json.dump(complete_export_data['forecast_graph_data'], f, indent=2, ensure_ascii=False, default=str)
        
        print(f"📊 Separate forecast data file: {forecast_only_file}")
        
        # Save the model as well
        saver = ModelSaver(base_path='./saved_trend_models')
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        model_name = f"enhanced_trend_ai_with_export_{timestamp}"
        
        save_info = saver.save_enhanced_model(
            predictor=predictor,
            model_name=model_name
        )
        
        print(f"💾 Model saved as: {model_name}")
        
    except Exception as e:
        print(f"❌ Error during analysis and export: {e}")
        import traceback
        traceback.print_exc()

# Date Predictor

In [12]:
class DateSpecificPredictor:
    """
    Fixed version that works with properly loaded models
    """
    
    def __init__(self, loaded_model):
        self.predictor = loaded_model
        
    def predict_growth_rate_for_date(self, target_date: Union[str, datetime], 
                                   topic_keywords: List[str] = None) -> Dict:
        """
        Predict growth rate for a specific date using loaded model data
        """
        if isinstance(target_date, str):
            target_date = pd.to_datetime(target_date)
        
        days_ahead = (target_date - pd.Timestamp.now()).days
        
        if days_ahead <= 0:
            raise ValueError("Target date must be in the future")
        
        print(f"Predicting for {target_date.strftime('%Y-%m-%d')} ({days_ahead} days ahead)")
        
        # Check if we have trend data
        if not hasattr(self.predictor, 'combined_trend_data') or self.predictor.combined_trend_data is None:
            return self._fallback_prediction(target_date, topic_keywords)
        
        # Use trend data to make predictions
        return self._trend_based_prediction(target_date, days_ahead, topic_keywords)
    
    def _trend_based_prediction(self, target_date, days_ahead, topic_keywords=None):
        """
        Make predictions based on historical trend data
        """
        trend_data = self.predictor.combined_trend_data
        
        results = {
            'target_date': target_date.strftime('%Y-%m-%d'),
            'days_ahead': days_ahead,
            'predictions': {},
            'method': 'trend_extrapolation'
        }
        
        # Get unique clusters
        clusters = trend_data['cluster'].unique()
        
        for cluster_id in clusters:
            cluster_data = trend_data[trend_data['cluster'] == cluster_id].copy()
            
            if len(cluster_data) < 3:  # Need minimum data points
                continue
            
            # Sort by date
            cluster_data = cluster_data.sort_values('date')
            
            # Calculate trend
            recent_scores = cluster_data['combined_trending_score'].tail(5).values
            if len(recent_scores) >= 2:
                # Simple linear projection
                trend_slope = np.polyfit(range(len(recent_scores)), recent_scores, 1)[0]
                current_score = recent_scores[-1]
                predicted_score = current_score + (trend_slope * (days_ahead / 7))  # Weekly projection
                
                # Calculate growth rate
                growth_rate = trend_slope / max(current_score, 0.001)  # Avoid division by zero
                
                # Get topic information
                topic_words = self.predictor.cluster_topics.get(cluster_id, [])
                topic_tags = self.predictor.cluster_tags.get(cluster_id, [])
                
                # Filter by keywords if provided
                if topic_keywords:
                    topic_text = ' '.join(topic_words + topic_tags).lower()
                    if not any(keyword.lower() in topic_text for keyword in topic_keywords):
                        continue
                
                topic_name = ', '.join(topic_words[:3])
                
                results['predictions'][topic_name] = {
                    'cluster_id': cluster_id,
                    'predicted_score': float(predicted_score),
                    'growth_rate': float(growth_rate),
                    'trend_slope': float(trend_slope),
                    'current_score': float(current_score),
                    'topic_words': topic_words[:5],
                    'topic_tags': topic_tags[:5],
                    'confidence': 'medium'
                }
        
        return results
    
    def _fallback_prediction(self, target_date, topic_keywords=None):
        """
        Fallback prediction using cluster performance data
        """
        results = {
            'target_date': target_date.strftime('%Y-%m-%d'),
            'predictions': {},
            'method': 'cluster_analysis_fallback',
            'warning': 'Limited data - using cluster-based estimation'
        }
        
        if not hasattr(self.predictor, 'cluster_topics'):
            results['error'] = 'No cluster data available'
            return results
        
        # Use cluster popularity and tag performance as proxy
        for cluster_id, topic_words in self.predictor.cluster_topics.items():
            if topic_keywords:
                topic_text = ' '.join(topic_words).lower()
                if not any(keyword.lower() in topic_text for keyword in topic_keywords):
                    continue
            
            # Estimate growth based on cluster size and tag popularity
            cluster_size = 0
            if hasattr(self.predictor, 'comments_df') and self.predictor.comments_df is not None:
                cluster_size = len(self.predictor.comments_df[self.predictor.comments_df['cluster'] == cluster_id])
            
            # Simple heuristic: larger clusters with popular tags have higher growth potential
            base_growth = min(cluster_size / 1000, 0.1)  # Cap at 10%
            
            topic_name = ', '.join(topic_words[:3])
            results['predictions'][topic_name] = {
                'cluster_id': cluster_id,
                'estimated_growth_rate': float(base_growth),
                'topic_words': topic_words[:5],
                'cluster_size': cluster_size,
                'confidence': 'low'
            }
        
        return results


# Double Saving

In [13]:
saver = ModelSaver(base_path='./saved_trend_models')
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_name = f"enhanced_trend_ai_with_export_{timestamp}"

save_info = saver.save_enhanced_model(
    predictor=predictor,
    model_name=model_name
)

Enhanced model saved successfully to: ./saved_trend_models/enhanced_trend_ai_with_export_20250906_170929


# Example of deployment

In [16]:
# Load your saved model
saver = ModelSaver('./saved_trend_models')
loaded_predictor = saver.load_enhanced_model('/kaggle/working/saved_trend_models/enhanced_trend_ai_with_export_20250906_165800')

# To fix your current loaded model and enable predictions:
def fix_loaded_model(loaded_predictor, original_data_files):
    """
    Fix a loaded model by regenerating missing trend data
    """
    print("Fixing loaded model by regenerating trend data...")
    
    # If trend data is missing, regenerate it
    if not hasattr(loaded_predictor, 'combined_trend_data') or loaded_predictor.combined_trend_data is None:
        if hasattr(loaded_predictor, 'comments_df') and hasattr(loaded_predictor, 'videos_df'):
            print("Regenerating trend data from existing DataFrames...")
            loaded_predictor.prepare_time_series_data()
        else:
            print("Need to reload original data...")
            # Reload original data if needed
            comment_files, video_file = original_data_files
            loaded_predictor.load_data(comment_files, video_file)
            loaded_predictor.preprocess_data()
            loaded_predictor.prepare_time_series_data()
    
    # Reinitialize ensemble forecaster
    loaded_predictor.ensemble_forecaster = EnsembleForecaster()
    
    return loaded_predictor

# Usage:
comment_files = [
    '/kaggle/input/datathon/comments1.csv',
    '/kaggle/input/datathon/comments2.csv',
    '/kaggle/input/datathon/comments3.csv',
    '/kaggle/input/datathon/comments4.csv',
    '/kaggle/input/datathon/comments5.csv'
]
video_file = '/kaggle/input/datathon/videos.csv'
original_files = (comment_files, video_file)  # Your original data files
fixed_predictor = fix_loaded_model(loaded_predictor, original_files)

# Now you can use date-specific predictions
date_predictor = DateSpecificPredictor(fixed_predictor)
christmas_prediction = date_predictor.predict_growth_rate_for_date('2025-12-25')

Initializing Enhanced YouTube Trend Predictor with Tags Integration...


[nltk_data] Downloading package vader_lexicon to
[nltk_data]     /usr/share/nltk_data...
[nltk_data]   Package vader_lexicon is already up-to-date!


Loaded comments DataFrame
Loaded videos DataFrame
Loaded combined trend data
Loaded model components: ['embeddings', 'clusters', 'cluster_topics', 'cluster_tags', 'popular_tags', 'generational_clusters', 'scaler', 'video_weight', 'comment_weight', 'tag_weight']
Ensemble forecaster reinitialized
Fixing loaded model by regenerating trend data...
Predicting for 2025-12-25 (109 days ahead)


In [17]:
print(christmas_prediction)

{'target_date': '2025-12-25', 'days_ahead': 109, 'predictions': {'hair, hairstyle, wig': {'cluster_id': 4, 'predicted_score': -0.047014411231051864, 'growth_rate': -3.0192741157556244, 'trend_slope': -0.0030192741157556244, 'current_score': 0.0, 'topic_words': ['hair', 'hairstyle', 'wig', 'beautiful', 'bald'], 'topic_tags': ['hair', 'hair transformation', 'hairstyle', 'haircut', 'shorts'], 'confidence': 'medium'}, 'face, try, why': {'cluster_id': 10, 'predicted_score': -0.08283105132810757, 'growth_rate': -5.31942531464911, 'trend_slope': -0.005319425314649111, 'current_score': 0.0, 'topic_words': ['face', 'try', 'why', 'products', 'dont'], 'topic_tags': ['shorts', 'womens health', 'discharge', 'obgyn', 'menopause'], 'confidence': 'medium'}, 'makeup, beauty, beautiful': {'cluster_id': 16, 'predicted_score': 0.3409462411575563, 'growth_rate': 0.2670161583455357, 'trend_slope': 0.017650500000000003, 'current_score': 0.06610274115755627, 'topic_words': ['makeup', 'beauty', 'beautiful', 'w