In [None]:
import pandas as pd
import re
from collections import Counter
from sqlalchemy import create_engine
from sqlalchemy.exc import SQLAlchemyError
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.signal import savgol_filter

# Define your connection string
engine = create_engine('oracle+oracledb://root:password@localhost:1521/?service_name=FREEPDB1')  

# List of known chains to check for more robust extraction
KNOWN_CHAINS = [
    "KFC", "McDonald's", "Burger King", "Pizza Hut", "Starbucks", "Domino's", "Subway", "Taco Bell", "Wendy's", "Chick-fil-A"
]

# Function to extract chain name (consider known chains and regex for location terms)
def extract_chain(name):
    # Remove non-alphabetic characters (like punctuation) and split into words
    name = re.sub(r'[^a-zA-Z\s]', '', name)
    words = name.split()

    # If no valid words, return empty
    if len(words) == 0:
        return ""

    # Check if the name matches a known chain
    for chain in KNOWN_CHAINS:
        if chain.lower() in name.lower():
            return chain

    # List of terms that should not be considered part of the chain name
    location_keywords = ['istanbul', 'baku', 'new', 'york', 'london', 'city', 'town', 'square', 'mall', 'branch', 'the', 'kitchen']
    chain_keywords = ['bar', 'cafe', 'restaurant', 'coffee']

    # Remove unwanted words (location-based or chain-irrelevant terms)
    valid_words = [word for word in words if word.lower() not in location_keywords and word.lower() not in chain_keywords]

    # If valid words remain, use the first valid word as the chain
    if valid_words:
        return valid_words[0]

    # Fallback to first word (if necessary)
    return words[0]

try:
    with engine.connect() as connection:
        # Fetch all restaurant names
        query = "SELECT id, name_val FROM MERCHANT"
        df = pd.read_sql(query, con=connection)

    # ✅ Debug: Print first few names
    print("Raw Data from Database:")
    print(df.head())

    if df.empty or 'name_val' not in df.columns:
        print("⚠️ No restaurant names found or 'name_val' column is missing!")
    else:
        # Clean restaurant names
        df['name_val'] = df['name_val'].str.strip()
        df = df[df['name_val'].notna() & (df['name_val'] != '')]

        # Remove invalid placeholders
        invalid_keywords = ['test', 'placeholder', 'empty', 'admin', 'sample', 'restaurant', 'location']
        df = df[~df['name_val'].str.lower().str.contains('|'.join(invalid_keywords))]

        # Extract restaurant chains using the updated function
        df['Chain'] = df['name_val'].apply(extract_chain)

        # Get top 10 restaurant chains
        chain_counts = Counter(df['Chain'])
        top_chains = chain_counts.most_common(10)
        df_chains = pd.DataFrame(top_chains, columns=['Chain', 'Count'])

        print("\nTop 10 most popular restaurant chains")
        print(df_chains)

        # Fetch ratings and categories
        rate_query = """
        SELECT 
            MR.MERCHANT_FK, 
            MR.rate, 
            MR.created_at, 
            SFA.EN_VAL AS category
        FROM MERCHANT_RATE MR
        JOIN MERCHANT_RATES_FAST_ANSWERS MRFA 
            ON MRFA.MERCHANT_RATE_FK = MR.id
        JOIN SURVEY_FAST_ANSWER SFA
            ON MRFA.FAST_ANSWER_FK = SFA.id
        """
        with engine.connect() as connection:
            df_rate = pd.read_sql(rate_query, con=connection)

        # Map restaurant IDs to chains
        chain_map = df.set_index('id')['Chain'].to_dict()
        df_rate['Chain'] = df_rate['merchant_fk'].map(chain_map)

        # Filter by top chains
        df_rate = df_rate[df_rate['Chain'].isin(df_chains['Chain'].tolist())]

        # Convert to datetime and extract monthly periods
        df_rate['created_at'] = pd.to_datetime(df_rate['created_at'])
        df_rate['YearMonth'] = df_rate['created_at'].dt.to_period('M')

        # Sort by time
        df_rate['YearMonth'] = df_rate['YearMonth'].dt.to_timestamp()
        df_rate = df_rate.sort_values(by='YearMonth')

        # Compute average ratings per month for each restaurant chain and category
        avg_monthly_ratings = df_rate.groupby(['Chain', 'category', 'YearMonth'])['rate'].mean().reset_index()

        # Create a range of all months
        all_months = pd.date_range(start=avg_monthly_ratings['YearMonth'].min(), 
                                   end=avg_monthly_ratings['YearMonth'].max(), 
                                   freq='MS')

        # Generate a DataFrame with all combinations of months, chains, and categories
        all_combinations = pd.MultiIndex.from_product(
            [df_chains['Chain'], df_rate['category'].unique(), all_months], 
            names=['Chain', 'category', 'YearMonth']
        )
        all_combinations_df = pd.DataFrame(index=all_combinations).reset_index()

        # Merge to ensure all months exist
        avg_monthly_ratings = pd.merge(all_combinations_df, avg_monthly_ratings, on=['Chain', 'category', 'YearMonth'], how='left')

        # Fill missing values
        avg_monthly_ratings['rate'] = avg_monthly_ratings['rate'].fillna(method='ffill')

        # Smooth the ratings
        avg_monthly_ratings['smoothed_rate'] = avg_monthly_ratings.groupby(['Chain', 'category'])['rate'].transform(
            lambda x: savgol_filter(x, window_length=5, polyorder=2) if len(x) > 5 else x
        )

        # Set Seaborn color palette
        sns.set_palette("tab20")

        # Create separate plots for each category
        categories = df_rate['category'].unique()
        num_categories = len(categories)

        fig, axes = plt.subplots(num_categories, 1, figsize=(12, 5 * num_categories), sharex=True)

        if num_categories == 1:
            axes = [axes]  # Ensure axes is iterable for a single category

        for ax, category in zip(axes, categories):
            cat_data = avg_monthly_ratings[avg_monthly_ratings['category'] == category]

            for chain in df_chains['Chain']:
                chain_data = cat_data[cat_data['Chain'] == chain]

                if not chain_data.empty:
                    ax.plot(chain_data['YearMonth'], 
                            chain_data['smoothed_rate'], 
                            label=chain, 
                            marker='o', linestyle='-', alpha=0.8)

            ax.set_title(f'Monthly Average Ratings for {category}', fontsize=14)
            ax.set_ylabel('Average Rating', fontsize=12)
            ax.set_ylim(1, 5)  # Assuming a rating scale of 1-5
            ax.legend(title="Restaurant Chains", bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
            ax.grid(True, linestyle='--', alpha=0.5)

        plt.xlabel('Month', fontsize=12)
        plt.xticks(rotation=45)
        plt.tight_layout()
        plt.show()

except SQLAlchemyError as e:
    print(f"❌ Database Error: {e}")

except Exception as e:
    print(f"❌ An unexpected error occurred: {e}")
