In [None]:
!pip install gradio pandas numpy scikit-learn matplotlib seaborn



In [8]:
# Customer Satisfaction Prediction Project - COMPLETE FIXED VERSION
# All visualization functions implemented and working

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import gradio as gr
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LinearRegression
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, mean_squared_error
from sklearn.cluster import KMeans
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')

class CustomerSatisfactionPredictor:
    def __init__(self):
        self.model = None
        self.scaler = StandardScaler()
        self.label_encoders = {}
        self.feature_names = []
        self.is_trained = False

    def load_data(self, file_path=None, uploaded_file=None):
        """Load customer support ticket dataset"""
        if uploaded_file is not None:
            return pd.read_csv(uploaded_file.name)
        elif file_path is not None:
            return pd.read_csv(file_path)
        else:
            return self.generate_sample_data()

    def generate_sample_data(self):
        """Generate sample customer support data for demonstration"""
        np.random.seed(42)
        n_samples = 1000

        # Generate sample data with dates
        start_date = datetime(2024, 1, 1)
        end_date = datetime(2024, 12, 31)
        date_range = pd.date_range(start=start_date, end=end_date, freq='D')

        data = {
            'Ticket ID': range(1, n_samples + 1),
            'Customer Name': [f'Customer_{i}' for i in range(1, n_samples + 1)],
            'Customer Email': [f'customer{i}@example.com' for i in range(1, n_samples + 1)],
            'Customer Age': np.random.randint(18, 71, n_samples),
            'Customer Gender': np.random.choice(['Male', 'Female', 'Other'], n_samples),
            'Product Purchased': np.random.choice([
                'Dell XPS', 'LG Smart TV', 'GoPro Hero', 'Microsoft Office',
                'Autodesk AutoCAD', 'iPhone 13', 'Samsung Galaxy', 'MacBook Pro',
                'iPad', 'Sony PlayStation', 'Nintendo Switch', 'Adobe Creative Suite'
            ], n_samples),
            'Ticket Type': np.random.choice([
                'Technical issue', 'Billing inquiry', 'Product inquiry',
                'Refund request', 'Cancellation request', 'Installation help',
                'Account issue', 'Feature request'
            ], n_samples),
            'Ticket Subject': np.random.choice([
                'Product setup', 'Network problem', 'Software bug',
                'Hardware issue', 'Battery life', 'Refund request',
                'Login issues', 'Performance problems', 'Compatibility issue',
                'Update problems', 'Audio/Video issues', 'Shipping inquiry'
            ], n_samples),
            'Ticket Status': np.random.choice([
                'Open', 'Closed', 'Pending Customer Response', 'In Progress',
                'Escalated', 'Resolved'
            ], n_samples),
            'Ticket Priority': np.random.choice(['Low', 'Medium', 'High', 'Critical'], n_samples),
            'Ticket Channel': np.random.choice(['Email', 'Chat', 'Phone', 'Social media', 'Web form'], n_samples),
            'Customer Satisfaction Rating': np.random.randint(1, 6, n_samples),
            'Date of Purchase': np.random.choice(date_range, n_samples),
            'Ticket Creation Date': np.random.choice(date_range, n_samples)
        }

        return pd.DataFrame(data)

    def preprocess_data(self, data, is_prediction=False):
        """Comprehensive data preprocessing with improved error handling"""
        processed_data = data.copy()

        if not is_prediction and 'Customer Satisfaction Rating' in processed_data.columns:
            processed_data = processed_data.dropna(subset=['Customer Satisfaction Rating'])

        categorical_columns = ['Customer Gender', 'Product Purchased', 'Ticket Type',
                             'Ticket Subject', 'Ticket Status', 'Ticket Priority', 'Ticket Channel']

        for column in categorical_columns:
            if column in processed_data.columns:
                if column not in self.label_encoders:
                    self.label_encoders[column] = LabelEncoder()
                    processed_data[column] = processed_data[column].fillna('Unknown')
                    processed_data[column + '_encoded'] = self.label_encoders[column].fit_transform(processed_data[column])
                else:
                    try:
                        processed_data[column] = processed_data[column].fillna('Unknown')
                        unique_values = processed_data[column].unique()
                        unknown_mask = ~np.isin(unique_values, self.label_encoders[column].classes_)

                        if unknown_mask.any():
                            default_class = self.label_encoders[column].classes_[0]
                            processed_data[column] = processed_data[column].replace(
                                unique_values[unknown_mask], default_class
                            )

                        processed_data[column + '_encoded'] = self.label_encoders[column].transform(processed_data[column])
                    except Exception as e:
                        print(f"Error encoding {column}: {e}")
                        processed_data[column + '_encoded'] = 0

        return processed_data

    def create_age_groups(self, data):
        """Create age groups for analysis"""
        data = data.copy()
        if 'Customer Age' in data.columns:
            bins = [0, 25, 35, 45, 55, 100]
            labels = ['18-25', '26-35', '36-45', '46-55', '56+']
            data['Age Group'] = pd.cut(data['Customer Age'], bins=bins, labels=labels, include_lowest=True)
        return data

    def exploratory_data_analysis(self, data):
        """Perform comprehensive EDA"""
        analysis_results = {}

        analysis_results['basic_stats'] = data.describe()
        analysis_results['missing_values'] = data.isnull().sum()

        if 'Customer Satisfaction Rating' in data.columns:
            satisfaction_dist = data['Customer Satisfaction Rating'].value_counts().sort_index()
            analysis_results['satisfaction_distribution'] = satisfaction_dist
        else:
            analysis_results['satisfaction_distribution'] = pd.Series()

        if 'Ticket Subject' in data.columns:
            analysis_results['common_issues'] = data['Ticket Subject'].value_counts().head(10)

        if 'Ticket Type' in data.columns:
            analysis_results['ticket_type_segmentation'] = data['Ticket Type'].value_counts()
        if 'Customer Gender' in data.columns:
            analysis_results['gender_segmentation'] = data['Customer Gender'].value_counts()

        return analysis_results

    def analyze_ticket_types(self, data):
        """Analyze ticket type distribution"""
        try:
            if 'Ticket Type' not in data.columns:
                return pd.DataFrame()

            if 'Customer Satisfaction Rating' not in data.columns:
                ticket_analysis = data.groupby('Ticket Type').agg({
                    'Ticket ID': 'count'
                }).rename(columns={'Ticket ID': 'ticket_count'})

                if 'Customer Age' in data.columns:
                    ticket_analysis['avg_customer_age'] = data.groupby('Ticket Type')['Customer Age'].mean()

                ticket_analysis = ticket_analysis.reset_index()
                return ticket_analysis

            ticket_analysis = data.groupby('Ticket Type').agg({
                'Customer Satisfaction Rating': ['count', 'mean', 'std'],
                'Customer Age': 'mean',
                'Ticket ID': 'count'
            }).round(2)

            new_columns = []
            for col in ticket_analysis.columns:
                if isinstance(col, tuple):
                    if col[1] == 'count' and col[0] == 'Customer Satisfaction Rating':
                        new_columns.append('ticket_count')
                    elif col[1] == 'mean' and col[0] == 'Customer Satisfaction Rating':
                        new_columns.append('avg_satisfaction')
                    elif col[1] == 'std' and col[0] == 'Customer Satisfaction Rating':
                        new_columns.append('satisfaction_std')
                    elif col[1] == 'mean' and col[0] == 'Customer Age':
                        new_columns.append('avg_customer_age')
                    elif col[1] == 'count' and col[0] == 'Ticket ID':
                        new_columns.append('total_tickets')
                    else:
                        new_columns.append('_'.join([str(x) for x in col if str(x) != '']))
                else:
                    new_columns.append(str(col))

            ticket_analysis.columns = new_columns
            ticket_analysis = ticket_analysis.sort_values('ticket_count', ascending=False).reset_index()

            return ticket_analysis

        except Exception as e:
            print(f"Error in analyze_ticket_types: {str(e)}")
            return pd.DataFrame()

    def analyze_top_products(self, data, gender_filter='all'):
        """Analyze top 10 purchased products by average customer with gender filtering"""
        try:
            filtered_data = data.copy()

            if gender_filter != 'all':
                if 'Customer Gender' in filtered_data.columns:
                    filtered_data = filtered_data[filtered_data['Customer Gender'] == gender_filter]
                else:
                    return pd.DataFrame()

            if len(filtered_data) == 0:
                return pd.DataFrame()

            if 'Product Purchased' not in filtered_data.columns:
                return pd.DataFrame()

            product_counts = filtered_data['Product Purchased'].value_counts()

            product_analysis = pd.DataFrame({
                'Product Purchased': product_counts.index,
                'purchase_count': product_counts.values
            })

            if 'Customer Satisfaction Rating' in filtered_data.columns:
                satisfaction_stats = filtered_data.groupby('Product Purchased')['Customer Satisfaction Rating'].agg(['mean', 'std']).round(2)
                product_analysis = product_analysis.merge(
                    satisfaction_stats.reset_index(),
                    on='Product Purchased',
                    how='left'
                )
                product_analysis.columns = ['Product Purchased', 'purchase_count', 'avg_satisfaction', 'satisfaction_std']
            else:
                product_analysis['avg_satisfaction'] = 3.0
                product_analysis['satisfaction_std'] = 1.0

            if 'Customer Age' in filtered_data.columns:
                age_stats = filtered_data.groupby('Product Purchased')['Customer Age'].mean().round(2)
                product_analysis = product_analysis.merge(
                    age_stats.reset_index().rename(columns={'Customer Age': 'avg_customer_age'}),
                    on='Product Purchased',
                    how='left'
                )
            else:
                product_analysis['avg_customer_age'] = 40.0

            ticket_counts = filtered_data.groupby('Product Purchased')['Ticket ID'].count() if 'Ticket ID' in filtered_data.columns else product_analysis['purchase_count']
            product_analysis['total_tickets'] = product_analysis['Product Purchased'].map(ticket_counts).fillna(product_analysis['purchase_count'])

            top_products = product_analysis.nlargest(10, 'purchase_count').reset_index(drop=True)

            return top_products

        except Exception as e:
            print(f"Error in analyze_top_products: {str(e)}")
            return pd.DataFrame()

    def _create_error_plot(self, message):
        """Create a simple error message plot"""
        plt.figure(figsize=(10, 6))
        plt.text(0.5, 0.5, message, ha='center', va='center',
                transform=plt.gca().transAxes, fontsize=14)
        plt.title('Visualization Error')
        plt.axis('off')
        return plt.gcf()

    # ALL VISUALIZATION METHODS - COMPLETE IMPLEMENTATION

    def create_priority_viz(self, data):
        """Create priority visualization"""
        try:
            if 'Ticket Priority' not in data.columns:
                return self._create_error_plot("No Ticket Priority data available")

            plt.figure(figsize=(12, 8))
            priority_order = ['Low', 'Medium', 'High', 'Critical']
            priority_counts = data['Ticket Priority'].value_counts()

            ordered_counts = []
            ordered_labels = []
            colors = ['#28A745', '#FFC107', '#FD7E14', '#DC3545']

            for priority in priority_order:
                if priority in priority_counts.index:
                    ordered_counts.append(priority_counts[priority])
                    ordered_labels.append(priority)

            if not ordered_counts:
                return self._create_error_plot("No priority data to display")

            bars = plt.bar(range(len(ordered_counts)), ordered_counts,
                          color=colors[:len(ordered_counts)], edgecolor='black', linewidth=1.5)

            for i, (bar, count) in enumerate(zip(bars, ordered_counts)):
                plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 10,
                        f'{count}\n({count/len(data)*100:.1f}%)', ha='center', va='bottom',
                        fontweight='bold', fontsize=12)

            if 'Customer Satisfaction Rating' in data.columns:
                avg_satisfaction = data.groupby('Ticket Priority')['Customer Satisfaction Rating'].mean()
                ax2 = plt.gca().twinx()
                satisfaction_values = [avg_satisfaction.get(priority, 0) for priority in ordered_labels]
                ax2.plot(range(len(satisfaction_values)), satisfaction_values,
                        color='purple', marker='o', linewidth=3, markersize=10,
                        label='Avg Satisfaction', markeredgecolor='white', markeredgewidth=2)
                ax2.set_ylabel('Average Satisfaction Rating', fontsize=12, fontweight='bold', color='purple')
                ax2.tick_params(axis='y', labelcolor='purple')
                ax2.set_ylim(1, 5)
                ax2.legend(loc='upper right')

            plt.xlabel('Priority Level', fontsize=14, fontweight='bold')
            plt.ylabel('Number of Tickets', fontsize=14, fontweight='bold')
            plt.title('Ticket Priority Level Distribution', fontsize=16, fontweight='bold', pad=20)
            plt.xticks(range(len(ordered_labels)), ordered_labels, fontsize=12)
            plt.grid(axis='y', alpha=0.3)
            plt.tight_layout()

            return plt.gcf()

        except Exception as e:
            return self._create_error_plot(f"Error creating priority visualization: {str(e)}")

    def create_channel_viz(self, data):
        """Create channel distribution visualization"""
        try:
            if 'Ticket Channel' not in data.columns:
                return self._create_error_plot("No Ticket Channel data available")

            plt.figure(figsize=(12, 8))

            channel_counts = data['Ticket Channel'].value_counts()
            colors = plt.cm.Set3(np.linspace(0, 1, len(channel_counts)))

            # Create pie chart
            wedges, texts, autotexts = plt.pie(channel_counts.values, labels=channel_counts.index,
                                              autopct='%1.1f%%', startangle=90, colors=colors,
                                              explode=[0.05 if i == 0 else 0 for i in range(len(channel_counts))])

            # Enhance text appearance
            for autotext in autotexts:
                autotext.set_color('white')
                autotext.set_fontweight('bold')
                autotext.set_fontsize(12)

            for text in texts:
                text.set_fontsize(12)
                text.set_fontweight('bold')

            plt.title('Ticket Channel Distribution', fontsize=16, fontweight='bold', pad=20)
            plt.axis('equal')

            return plt.gcf()

        except Exception as e:
            return self._create_error_plot(f"Error creating channel visualization: {str(e)}")

    def create_status_viz(self, data):
        """Create status distribution visualization"""
        try:
            if 'Ticket Status' not in data.columns:
                return self._create_error_plot("No Ticket Status data available")

            plt.figure(figsize=(12, 8))

            status_counts = data['Ticket Status'].value_counts()
            colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DDA0DD']

            bars = plt.bar(status_counts.index, status_counts.values,
                          color=colors[:len(status_counts)], edgecolor='black', linewidth=1.5)

            # Add value labels on bars
            for bar in bars:
                height = bar.get_height()
                plt.text(bar.get_x() + bar.get_width()/2., height + 5,
                        f'{int(height)}', ha='center', va='bottom', fontweight='bold', fontsize=12)

            plt.xlabel('Ticket Status', fontsize=14, fontweight='bold')
            plt.ylabel('Number of Tickets', fontsize=14, fontweight='bold')
            plt.title('Ticket Status Distribution', fontsize=16, fontweight='bold', pad=20)
            plt.xticks(rotation=45, ha='right')
            plt.grid(axis='y', alpha=0.3)
            plt.tight_layout()

            return plt.gcf()

        except Exception as e:
            return self._create_error_plot(f"Error creating status visualization: {str(e)}")

    def create_customer_gender_viz(self, data):
        """Create customer gender distribution visualization"""
        try:
            if 'Customer Gender' not in data.columns:
                return self._create_error_plot("No Customer Gender data available")

            plt.figure(figsize=(10, 8))

            gender_counts = data['Customer Gender'].value_counts()
            colors = ['#FF9999', '#66B2FF', '#99FF99']

            plt.pie(gender_counts.values, labels=gender_counts.index, autopct='%1.1f%%',
                   startangle=90, colors=colors[:len(gender_counts)],
                   explode=[0.05 for _ in range(len(gender_counts))],
                   textprops={'fontsize': 12, 'fontweight': 'bold'})

            plt.title('Customer Gender Distribution', fontsize=16, fontweight='bold', pad=20)
            plt.axis('equal')

            return plt.gcf()

        except Exception as e:
            return self._create_error_plot(f"Error creating gender visualization: {str(e)}")

    def create_customer_age_viz(self, data):
        """Create customer age distribution visualization"""
        try:
            if 'Customer Age' not in data.columns:
                return self._create_error_plot("No Customer Age data available")

            plt.figure(figsize=(12, 8))

            plt.hist(data['Customer Age'], bins=20, color='skyblue', alpha=0.7, edgecolor='black')
            plt.axvline(data['Customer Age'].mean(), color='red', linestyle='--', linewidth=2,
                       label=f'Mean: {data["Customer Age"].mean():.1f}')
            plt.axvline(data['Customer Age'].median(), color='green', linestyle='--', linewidth=2,
                       label=f'Median: {data["Customer Age"].median():.1f}')

            plt.xlabel('Customer Age', fontsize=14, fontweight='bold')
            plt.ylabel('Frequency', fontsize=14, fontweight='bold')
            plt.title('Customer Age Distribution', fontsize=16, fontweight='bold', pad=20)
            plt.legend(fontsize=12)
            plt.grid(axis='y', alpha=0.3)
            plt.tight_layout()

            return plt.gcf()

        except Exception as e:
            return self._create_error_plot(f"Error creating age visualization: {str(e)}")

    def create_customer_age_groups_viz(self, data):
        """Create customer age groups visualization"""
        try:
            if 'Customer Age' not in data.columns:
                return self._create_error_plot("No Customer Age data available")

            plt.figure(figsize=(14, 8))

            data_with_groups = self.create_age_groups(data)
            age_group_counts = data_with_groups['Age Group'].value_counts()

            # Create subplot for counts and satisfaction
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))

            # Age group counts
            colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7']
            bars1 = ax1.bar(age_group_counts.index, age_group_counts.values,
                           color=colors[:len(age_group_counts)], edgecolor='black')

            for bar in bars1:
                height = bar.get_height()
                ax1.text(bar.get_x() + bar.get_width()/2., height + 5,
                        f'{int(height)}', ha='center', va='bottom', fontweight='bold')

            ax1.set_xlabel('Age Groups', fontsize=12, fontweight='bold')
            ax1.set_ylabel('Number of Customers', fontsize=12, fontweight='bold')
            ax1.set_title('Customer Distribution by Age Group', fontsize=14, fontweight='bold')
            ax1.grid(axis='y', alpha=0.3)

            # Satisfaction by age group
            if 'Customer Satisfaction Rating' in data.columns:
                satisfaction_by_age = data_with_groups.groupby('Age Group')['Customer Satisfaction Rating'].mean()
                bars2 = ax2.bar(satisfaction_by_age.index, satisfaction_by_age.values,
                               color=colors[:len(satisfaction_by_age)], edgecolor='black')

                for bar in bars2:
                    height = bar.get_height()
                    ax2.text(bar.get_x() + bar.get_width()/2., height + 0.05,
                            f'{height:.2f}', ha='center', va='bottom', fontweight='bold')

                ax2.set_xlabel('Age Groups', fontsize=12, fontweight='bold')
                ax2.set_ylabel('Average Satisfaction Rating', fontsize=12, fontweight='bold')
                ax2.set_title('Average Satisfaction by Age Group', fontsize=14, fontweight='bold')
                ax2.set_ylim(1, 5)
                ax2.grid(axis='y', alpha=0.3)
            else:
                ax2.text(0.5, 0.5, 'No Satisfaction Data Available', ha='center', va='center',
                        transform=ax2.transAxes, fontsize=14)
                ax2.set_title('Satisfaction by Age Group', fontsize=14, fontweight='bold')

            plt.tight_layout()
            return fig

        except Exception as e:
            return self._create_error_plot(f"Error creating age groups visualization: {str(e)}")

    def create_satisfaction_by_gender_viz(self, data):
        """Create satisfaction by gender visualization - FIXED"""
        try:
            if 'Customer Gender' not in data.columns or 'Customer Satisfaction Rating' not in data.columns:
                return self._create_error_plot("Missing gender or satisfaction data")

            # Remove NaN values and ensure integer ratings
            clean_data = data.dropna(subset=['Customer Satisfaction Rating', 'Customer Gender'])
            if clean_data.empty:
                return self._create_error_plot("No valid satisfaction or gender data available")

            plt.figure(figsize=(12, 8))

            satisfaction_by_gender = clean_data.groupby('Customer Gender')['Customer Satisfaction Rating'].mean()
            colors = ['#FF9999', '#66B2FF', '#99FF99']

            bars = plt.bar(satisfaction_by_gender.index, satisfaction_by_gender.values,
                          color=colors[:len(satisfaction_by_gender)], edgecolor='black', linewidth=1.5)

            for bar in bars:
                height = bar.get_height()
                plt.text(bar.get_x() + bar.get_width()/2., height + 0.05,
                        f'{height:.2f}', ha='center', va='bottom', fontweight='bold', fontsize=12)

            plt.xlabel('Customer Gender', fontsize=14, fontweight='bold')
            plt.ylabel('Average Satisfaction Rating', fontsize=14, fontweight='bold')
            plt.title('Customer Satisfaction by Gender', fontsize=16, fontweight='bold', pad=20)
            plt.ylim(1, 5)
            plt.grid(axis='y', alpha=0.3)
            plt.tight_layout()

            return plt.gcf()

        except Exception as e:
            return self._create_error_plot(f"Error creating satisfaction by gender visualization: {str(e)}")

    def create_satisfaction_distribution_viz(self, data):
        """Create overall satisfaction distribution visualization - FIXED"""
        try:
            if 'Customer Satisfaction Rating' not in data.columns:
                return self._create_error_plot("No Customer Satisfaction Rating data available")

            # Remove NaN values and ensure we have valid ratings
            clean_data = data.dropna(subset=['Customer Satisfaction Rating'])
            if clean_data.empty:
                return self._create_error_plot("No valid satisfaction rating data available")

            # Ensure ratings are integers and within valid range
            clean_data = clean_data[clean_data['Customer Satisfaction Rating'].between(1, 5)]
            if clean_data.empty:
                return self._create_error_plot("No satisfaction ratings in valid range (1-5)")

            plt.figure(figsize=(12, 8))

            satisfaction_counts = clean_data['Customer Satisfaction Rating'].value_counts().sort_index()
            colors = ['#FF4444', '#FF8844', '#FFDD44', '#88FF44', '#44FF88']
            labels = ['Very Dissatisfied', 'Dissatisfied', 'Neutral', 'Satisfied', 'Very Satisfied']

            # Create bars using integer positions
            bar_positions = list(range(len(satisfaction_counts)))
            bars = plt.bar(bar_positions, satisfaction_counts.values,
                          color=colors[:len(satisfaction_counts)], edgecolor='black', linewidth=1.5)

            for i, bar in enumerate(bars):
                height = bar.get_height()
                plt.text(bar.get_x() + bar.get_width()/2., height + 5,
                        f'{int(height)}\n({height/len(clean_data)*100:.1f}%)',
                        ha='center', va='bottom', fontweight='bold', fontsize=11)

            plt.xlabel('Satisfaction Rating', fontsize=14, fontweight='bold')
            plt.ylabel('Number of Customers', fontsize=14, fontweight='bold')
            plt.title('Customer Satisfaction Distribution', fontsize=16, fontweight='bold', pad=20)

            # Set x-tick labels properly
            rating_labels = []
            for i, rating in enumerate(satisfaction_counts.index):
                label_idx = int(rating) - 1
                if 0 <= label_idx < len(labels):
                    rating_labels.append(f'{int(rating)}\n{labels[label_idx]}')
                else:
                    rating_labels.append(f'{int(rating)}')

            plt.xticks(bar_positions, rating_labels)
            plt.grid(axis='y', alpha=0.3)
            plt.tight_layout()

            return plt.gcf()

        except Exception as e:
            return self._create_error_plot(f"Error creating satisfaction distribution visualization: {str(e)}")

    def create_top_products_viz(self, data, gender_filter='all'):
        """Create top products visualization"""
        try:
            top_products = self.analyze_top_products(data, gender_filter)

            if top_products.empty:
                return self._create_error_plot(f"No product data available for {gender_filter}")

            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))

            # Purchase counts
            bars1 = ax1.bar(range(len(top_products)), top_products['purchase_count'],
                           color='steelblue', edgecolor='black')

            for i, bar in enumerate(bars1):
                height = bar.get_height()
                ax1.text(bar.get_x() + bar.get_width()/2., height + 1,
                        f'{int(height)}', ha='center', va='bottom', fontweight='bold')

            ax1.set_xlabel('Products', fontsize=12, fontweight='bold')
            ax1.set_ylabel('Purchase Count', fontsize=12, fontweight='bold')
            ax1.set_title(f'Top 10 Products - {gender_filter} Customers', fontsize=14, fontweight='bold')
            ax1.set_xticks(range(len(top_products)))
            ax1.set_xticklabels(top_products['Product Purchased'], rotation=45, ha='right')
            ax1.grid(axis='y', alpha=0.3)

            # Average satisfaction
            if 'avg_satisfaction' in top_products.columns:
                bars2 = ax2.bar(range(len(top_products)), top_products['avg_satisfaction'],
                               color='orange', edgecolor='black')

                for i, bar in enumerate(bars2):
                    height = bar.get_height()
                    ax2.text(bar.get_x() + bar.get_width()/2., height + 0.05,
                            f'{height:.2f}', ha='center', va='bottom', fontweight='bold')

                ax2.set_xlabel('Products', fontsize=12, fontweight='bold')
                ax2.set_ylabel('Average Satisfaction', fontsize=12, fontweight='bold')
                ax2.set_title('Average Satisfaction by Product', fontsize=14, fontweight='bold')
                ax2.set_xticks(range(len(top_products)))
                ax2.set_xticklabels(top_products['Product Purchased'], rotation=45, ha='right')
                ax2.set_ylim(1, 5)
                ax2.grid(axis='y', alpha=0.3)

            plt.tight_layout()
            return fig

        except Exception as e:
            return self._create_error_plot(f"Error creating top products visualization: {str(e)}")

    def create_ticket_type_viz(self, data):
        """Create ticket type distribution visualization"""
        try:
            if 'Ticket Type' not in data.columns:
                return self._create_error_plot("No Ticket Type data available")

            plt.figure(figsize=(14, 8))

            ticket_type_counts = data['Ticket Type'].value_counts()
            colors = plt.cm.Set3(np.linspace(0, 1, len(ticket_type_counts)))

            bars = plt.bar(range(len(ticket_type_counts)), ticket_type_counts.values,
                          color=colors, edgecolor='black', linewidth=1.5)

            for i, bar in enumerate(bars):
                height = bar.get_height()
                plt.text(bar.get_x() + bar.get_width()/2., height + 5,
                        f'{int(height)}', ha='center', va='bottom', fontweight='bold')

            plt.xlabel('Ticket Type', fontsize=14, fontweight='bold')
            plt.ylabel('Number of Tickets', fontsize=14, fontweight='bold')
            plt.title('Ticket Type Distribution', fontsize=16, fontweight='bold', pad=20)
            plt.xticks(range(len(ticket_type_counts)), ticket_type_counts.index, rotation=45, ha='right')
            plt.grid(axis='y', alpha=0.3)
            plt.tight_layout()

            return plt.gcf()

        except Exception as e:
            return self._create_error_plot(f"Error creating ticket type visualization: {str(e)}")

    def create_ticket_trends_viz(self, data):
        """Create ticket trends over time visualization"""
        try:
            if 'Ticket Creation Date' not in data.columns and 'Date of Purchase' not in data.columns:
                return self._create_error_plot("No date data available for trends analysis")

            fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))

            # Use available date column
            date_col = 'Ticket Creation Date' if 'Ticket Creation Date' in data.columns else 'Date of Purchase'
            data_copy = data.copy()
            data_copy[date_col] = pd.to_datetime(data_copy[date_col])

            # Monthly ticket trends
            monthly_tickets = data_copy.groupby(data_copy[date_col].dt.to_period('M')).size()
            ax1.plot(monthly_tickets.index.astype(str), monthly_tickets.values, marker='o', linewidth=2, markersize=8)
            ax1.set_title('Monthly Ticket Volume', fontweight='bold')
            ax1.set_xlabel('Month')
            ax1.set_ylabel('Number of Tickets')
            ax1.tick_params(axis='x', rotation=45)
            ax1.grid(True, alpha=0.3)

            # Priority trends
            if 'Ticket Priority' in data.columns:
                priority_trends = data_copy.groupby([data_copy[date_col].dt.to_period('M'), 'Ticket Priority']).size().unstack(fill_value=0)
                priority_trends.plot(kind='bar', stacked=True, ax=ax2, color=['#28A745', '#FFC107', '#FD7E14', '#DC3545'])
                ax2.set_title('Priority Distribution Over Time', fontweight='bold')
                ax2.set_xlabel('Month')
                ax2.set_ylabel('Number of Tickets')
                ax2.legend(title='Priority', bbox_to_anchor=(1.05, 1), loc='upper left')
                ax2.tick_params(axis='x', rotation=45)
            else:
                ax2.text(0.5, 0.5, 'No Priority Data', ha='center', va='center', transform=ax2.transAxes)
                ax2.set_title('Priority Trends', fontweight='bold')

            # Satisfaction trends
            if 'Customer Satisfaction Rating' in data.columns:
                satisfaction_trends = data_copy.groupby(data_copy[date_col].dt.to_period('M'))['Customer Satisfaction Rating'].mean()
                ax3.plot(satisfaction_trends.index.astype(str), satisfaction_trends.values,
                        marker='s', linewidth=2, markersize=8, color='purple')
                ax3.set_title('Average Satisfaction Over Time', fontweight='bold')
                ax3.set_xlabel('Month')
                ax3.set_ylabel('Average Satisfaction')
                ax3.set_ylim(1, 5)
                ax3.tick_params(axis='x', rotation=45)
                ax3.grid(True, alpha=0.3)
            else:
                ax3.text(0.5, 0.5, 'No Satisfaction Data', ha='center', va='center', transform=ax3.transAxes)
                ax3.set_title('Satisfaction Trends', fontweight='bold')

            # Channel trends
            if 'Ticket Channel' in data.columns:
                channel_trends = data_copy.groupby([data_copy[date_col].dt.to_period('M'), 'Ticket Channel']).size().unstack(fill_value=0)
                channel_trends.plot(kind='area', ax=ax4, alpha=0.7)
                ax4.set_title('Channel Usage Over Time', fontweight='bold')
                ax4.set_xlabel('Month')
                ax4.set_ylabel('Number of Tickets')
                ax4.legend(title='Channel', bbox_to_anchor=(1.05, 1), loc='upper left')
                ax4.tick_params(axis='x', rotation=45)
            else:
                ax4.text(0.5, 0.5, 'No Channel Data', ha='center', va='center', transform=ax4.transAxes)
                ax4.set_title('Channel Trends', fontweight='bold')

            plt.tight_layout()
            return fig

        except Exception as e:
            return self._create_error_plot(f"Error creating ticket trends visualization: {str(e)}")

# Global variables and helper functions
global_data = None
predictor = CustomerSatisfactionPredictor()

def load_and_analyze_data(uploaded_file, use_sample_data):
    """Load and analyze the dataset with improved error handling"""
    global global_data

    try:
        if use_sample_data:
            global_data = predictor.generate_sample_data()
            source = "Sample Dataset"
        elif uploaded_file is not None:
            global_data = pd.read_csv(uploaded_file)
            source = f"Uploaded file: {uploaded_file.split('/')[-1] if isinstance(uploaded_file, str) else 'uploaded_file.csv'}"
        else:
            return "Please either upload a CSV file or check 'Use Sample Data'"

        # Perform EDA
        analysis = predictor.exploratory_data_analysis(global_data)

        result = f"""Data Loaded Successfully from {source}

Dataset Overview:
• Total records: {len(global_data):,}
• Total columns: {len(global_data.columns)}
• Date range: Analysis ready

Column Information:
{list(global_data.columns)}

Basic Statistics:
{global_data.describe().round(2) if not global_data.empty else 'No numeric data available'}

Missing Values Analysis:
{analysis['missing_values'][analysis['missing_values'] > 0] if not analysis['missing_values'].empty else 'No missing values found'}

Customer Satisfaction Distribution:
{analysis['satisfaction_distribution'] if not analysis['satisfaction_distribution'].empty else 'No satisfaction data available'}

Top 5 Common Issues:
{analysis.get('common_issues', 'No ticket subject data available')}

Customer Segmentation by Gender:
{analysis.get('gender_segmentation', 'No gender data available')}

Ticket Type Distribution:
{analysis.get('ticket_type_segmentation', 'No ticket type data available')}

Data is now ready for visualization and model training!
"""

        return result

    except Exception as e:
        return f"Error loading data: {str(e)}"

def show_data_preview():
    """Show preview of loaded data"""
    global global_data
    if global_data is not None:
        return global_data.head(10)
    else:
        return pd.DataFrame({"Message": ["No data loaded. Please upload data first."]})

def train_model_interface():
    """Train the model interface with improved error handling"""
    global global_data
    if global_data is None:
        return "No data loaded. Please upload data first."

    try:
        if 'Customer Satisfaction Rating' not in global_data.columns:
            return "Error: Customer Satisfaction Rating column not found in the dataset."

        processed_data = predictor.preprocess_data(global_data)

        feature_columns = [col for col in processed_data.columns if col.endswith('_encoded')] + ['Customer Age']
        feature_columns = [col for col in feature_columns if col in processed_data.columns]

        if not feature_columns:
            return "Error: No suitable features found for training. Please ensure your data has categorical columns."

        X = processed_data[feature_columns]
        y = processed_data['Customer Satisfaction Rating']

        X = X.fillna(0)
        y = y.fillna(y.median())

        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

        X_train_scaled = predictor.scaler.fit_transform(X_train)
        X_test_scaled = predictor.scaler.transform(X_test)

        predictor.model = RandomForestClassifier(n_estimators=100, random_state=42)
        predictor.model.fit(X_train_scaled, y_train)
        predictor.feature_names = feature_columns
        predictor.is_trained = True

        y_pred = predictor.model.predict(X_test_scaled)

        accuracy = accuracy_score(y_test, y_pred)
        report = classification_report(y_test, y_pred)

        feature_importance = pd.DataFrame({
            'feature': feature_columns,
            'importance': predictor.model.feature_importances_
        }).sort_values('importance', ascending=False)

        result = f"""Model Training Completed Successfully!

Training Results:
• Training samples: {len(X_train):,}
• Test samples: {len(X_test):,}
• Features used: {len(feature_columns)}

Model Performance:
• Accuracy: {accuracy:.3f} ({accuracy*100:.1f}%)

Detailed Classification Report:
{report}

Feature Importance Ranking:
{feature_importance}

Model is now ready for predictions!
"""

        return result

    except Exception as e:
        return f"Error training model: {str(e)}"

def predict_single(age, gender, product, ticket_type, subject, status, priority, channel):
    """Make single prediction with improved error handling"""
    global global_data
    if global_data is None or not predictor.is_trained:
        return "Please load data and train the model first."

    try:
        input_data = pd.DataFrame({
            'Customer Age': [age],
            'Customer Gender': [gender],
            'Product Purchased': [product],
            'Ticket Type': [ticket_type],
            'Ticket Subject': [subject],
            'Ticket Status': [status],
            'Ticket Priority': [priority],
            'Ticket Channel': [channel]
        })

        processed_input = predictor.preprocess_data(input_data, is_prediction=True)

        available_features = [col for col in predictor.feature_names if col in processed_input.columns]

        if not available_features:
            return "Error: Unable to process input features for prediction."

        X_input = processed_input[available_features]

        for feature in predictor.feature_names:
            if feature not in X_input.columns:
                X_input[feature] = 0

        X_input = X_input[predictor.feature_names]
        X_input_scaled = predictor.scaler.transform(X_input)

        prediction = predictor.model.predict(X_input_scaled)[0]
        prediction_proba = predictor.model.predict_proba(X_input_scaled)[0]

        predicted_prob = prediction_proba[int(prediction)-1] * 100

        rating_descriptions = {
            1: "Very Dissatisfied",
            2: "Dissatisfied",
            3: "Neutral",
            4: "Satisfied",
            5: "Very Satisfied"
        }

        result = f"""Prediction Results for Customer Profile:

Input Summary:
• Age: {age} years old {gender}
• Product: {product}
• Issue Type: {ticket_type}
• Subject: {subject}
• Status: {status}
• Priority: {priority}
• Channel: {channel}

Predicted Satisfaction Rating: {prediction}/5
Description: {rating_descriptions.get(prediction, 'Unknown')}
Confidence: {predicted_prob:.1f}%

Probability Distribution:
• Very Dissatisfied (1): {prediction_proba[0]*100:.1f}%
• Dissatisfied (2): {prediction_proba[1]*100:.1f}%
• Neutral (3): {prediction_proba[2]*100:.1f}%
• Satisfied (4): {prediction_proba[3]*100:.1f}%
• Very Satisfied (5): {prediction_proba[4]*100:.1f}%

Recommendation: {"Focus on improving customer experience" if prediction < 4 else "Maintain current service quality"}
"""

        return result

    except Exception as e:
        return f"Error making prediction: {str(e)}"

# Interface functions for visualizations
def create_priority_viz():
    global global_data
    if global_data is None:
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.text(0.5, 0.5, 'Please upload data first', ha='center', va='center',
                transform=ax.transAxes, fontsize=16)
        ax.set_title('No Data Available')
        return fig
    return predictor.create_priority_viz(global_data)

def create_channel_viz():
    global global_data
    if global_data is None:
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.text(0.5, 0.5, 'Please upload data first', ha='center', va='center',
                transform=ax.transAxes, fontsize=16)
        ax.set_title('No Data Available')
        return fig
    return predictor.create_channel_viz(global_data)

def create_status_viz():
    global global_data
    if global_data is None:
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.text(0.5, 0.5, 'Please upload data first', ha='center', va='center',
                transform=ax.transAxes, fontsize=16)
        ax.set_title('No Data Available')
        return fig
    return predictor.create_status_viz(global_data)

def create_customer_gender_viz():
    global global_data
    if global_data is None:
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.text(0.5, 0.5, 'Please upload data first', ha='center', va='center',
                transform=ax.transAxes, fontsize=16)
        ax.set_title('No Data Available')
        return fig
    return predictor.create_customer_gender_viz(global_data)

def create_customer_age_viz():
    global global_data
    if global_data is None:
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.text(0.5, 0.5, 'Please upload data first', ha='center', va='center',
                transform=ax.transAxes, fontsize=16)
        ax.set_title('No Data Available')
        return fig
    return predictor.create_customer_age_viz(global_data)

def create_customer_age_groups_viz():
    global global_data
    if global_data is None:
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.text(0.5, 0.5, 'Please upload data first', ha='center', va='center',
                transform=ax.transAxes, fontsize=16)
        ax.set_title('No Data Available')
        return fig
    return predictor.create_customer_age_groups_viz(global_data)

def create_satisfaction_by_gender_viz():
    global global_data
    if global_data is None:
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.text(0.5, 0.5, 'Please upload data first', ha='center', va='center',
                transform=ax.transAxes, fontsize=16)
        ax.set_title('No Data Available')
        return fig
    return predictor.create_satisfaction_by_gender_viz(global_data)

def create_satisfaction_distribution_viz():
    global global_data
    if global_data is None:
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.text(0.5, 0.5, 'Please upload data first', ha='center', va='center',
                transform=ax.transAxes, fontsize=16)
        ax.set_title('No Data Available')
        return fig
    return predictor.create_satisfaction_distribution_viz(global_data)

def create_top_products_viz(gender_filter='all'):
    global global_data
    if global_data is None:
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.text(0.5, 0.5, 'Please upload data first', ha='center', va='center',
                transform=ax.transAxes, fontsize=16)
        ax.set_title('No Data Available')
        return fig
    return predictor.create_top_products_viz(global_data, gender_filter)

def create_ticket_type_viz():
    global global_data
    if global_data is None:
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.text(0.5, 0.5, 'Please upload data first', ha='center', va='center',
                transform=ax.transAxes, fontsize=16)
        ax.set_title('No Data Available')
        return fig
    return predictor.create_ticket_type_viz(global_data)

def create_ticket_trends_viz():
    global global_data
    if global_data is None:
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.text(0.5, 0.5, 'Please upload data first', ha='center', va='center',
                transform=ax.transAxes, fontsize=16)
        ax.set_title('No Data Available')
        return fig
    return predictor.create_ticket_trends_viz(global_data)

# Analysis interface functions
def analyze_ticket_types_interface():
    global global_data
    if global_data is None:
        return "No data loaded. Please upload data first."

    try:
        analysis = predictor.analyze_ticket_types(global_data)
        if analysis.empty:
            return "No ticket type data available for analysis."

        result = "TICKET TYPE ANALYSIS RESULTS\n" + "="*50 + "\n\n"

        for _, row in analysis.iterrows():
            result += f"Ticket Type: {row['Ticket Type']}\n"
            result += f"• Total Tickets: {row['ticket_count']}\n"
            if 'avg_satisfaction' in row:
                result += f"• Average Satisfaction: {row['avg_satisfaction']:.2f}/5\n"
            if 'avg_customer_age' in row:
                result += f"• Average Customer Age: {row['avg_customer_age']:.1f} years\n"
            result += "\n"

        return result
    except Exception as e:
        return f"Error analyzing ticket types: {str(e)}"

def analyze_ticket_trends_interface():
    global global_data
    if global_data is None:
        return "No data loaded. Please upload data first."

    try:
        if 'Ticket Creation Date' not in global_data.columns and 'Date of Purchase' not in global_data.columns:
            return "No date columns available for trend analysis."

        date_col = 'Ticket Creation Date' if 'Ticket Creation Date' in global_data.columns else 'Date of Purchase'
        data_copy = global_data.copy()
        data_copy[date_col] = pd.to_datetime(data_copy[date_col])

        # Monthly trends
        monthly_tickets = data_copy.groupby(data_copy[date_col].dt.to_period('M')).size()

        result = "TICKET TRENDS ANALYSIS RESULTS\n" + "="*50 + "\n\n"
        result += f"Analysis Period: {monthly_tickets.index.min()} to {monthly_tickets.index.max()}\n\n"
        result += f"Monthly Ticket Volume:\n"
        for month, count in monthly_tickets.items():
            result += f"• {month}: {count} tickets\n"

        result += f"\nPeak Month: {monthly_tickets.idxmax()} ({monthly_tickets.max()} tickets)\n"
        result += f"Lowest Month: {monthly_tickets.idxmin()} ({monthly_tickets.min()} tickets)\n"
        result += f"Average Monthly Tickets: {monthly_tickets.mean():.1f}\n"

        return result
    except Exception as e:
        return f"Error analyzing ticket trends: {str(e)}"

def analyze_top_products_interface(gender_filter='all'):
    global global_data
    if global_data is None:
        return "No data loaded. Please upload data first."

    try:
        analysis = predictor.analyze_top_products(global_data, gender_filter)
        if analysis.empty:
            return f"No product data available for {gender_filter} customers."

        result = f"TOP 10 PRODUCTS ANALYSIS - {gender_filter.upper()} CUSTOMERS\n" + "="*60 + "\n\n"

        for i, row in analysis.iterrows():
            result += f"{i+1}. {row['Product Purchased']}\n"
            result += f"   • Purchase Count: {row['purchase_count']}\n"
            if 'avg_satisfaction' in row:
                result += f"   • Average Satisfaction: {row['avg_satisfaction']:.2f}/5\n"
            if 'avg_customer_age' in row:
                result += f"   • Average Customer Age: {row['avg_customer_age']:.1f} years\n"
            if 'total_tickets' in row:
                result += f"   • Support Tickets: {row['total_tickets']}\n"
            result += "\n"

        return result
    except Exception as e:
        return f"Error analyzing top products: {str(e)}"

def generate_insights():
    global global_data
    if global_data is None:
        return "No data loaded. Please upload data first."

    try:
        result = "ENHANCED DATA INSIGHTS\n" + "="*50 + "\n\n"

        # Basic stats
        result += f"Dataset Overview:\n"
        result += f"• Total Records: {len(global_data):,}\n"
        result += f"• Unique Customers: {global_data['Customer Name'].nunique() if 'Customer Name' in global_data.columns else 'N/A'}\n"
        result += f"• Date Range: {'Available' if any(col in global_data.columns for col in ['Date of Purchase', 'Ticket Creation Date']) else 'Not Available'}\n\n"

        # Top insights
        if 'Ticket Priority' in global_data.columns:
            high_priority = (global_data['Ticket Priority'] == 'High').sum()
            critical_priority = (global_data['Ticket Priority'] == 'Critical').sum()
            result += f"Priority Analysis:\n"
            result += f"• High Priority Tickets: {high_priority} ({high_priority/len(global_data)*100:.1f}%)\n"
            result += f"• Critical Priority Tickets: {critical_priority} ({critical_priority/len(global_data)*100:.1f}%)\n\n"

        if 'Customer Satisfaction Rating' in global_data.columns:
            avg_satisfaction = global_data['Customer Satisfaction Rating'].mean()
            dissatisfied = (global_data['Customer Satisfaction Rating'] <= 2).sum()
            result += f"Satisfaction Insights:\n"
            result += f"• Average Satisfaction: {avg_satisfaction:.2f}/5\n"
            result += f"• Dissatisfied Customers: {dissatisfied} ({dissatisfied/len(global_data)*100:.1f}%)\n\n"

        return result
    except Exception as e:
        return f"Error generating insights: {str(e)}"

def generate_customer_insights():
    global global_data
    if global_data is None:
        return "No data loaded. Please upload data first."

    try:
        result = "CUSTOMER DEMOGRAPHICS INSIGHTS\n" + "="*50 + "\n\n"

        if 'Customer Age' in global_data.columns:
            result += f"Age Analysis:\n"
            result += f"• Average Age: {global_data['Customer Age'].mean():.1f} years\n"
            result += f"• Age Range: {global_data['Customer Age'].min()} - {global_data['Customer Age'].max()} years\n"
            result += f"• Median Age: {global_data['Customer Age'].median():.1f} years\n\n"

        if 'Customer Gender' in global_data.columns:
            gender_dist = global_data['Customer Gender'].value_counts()
            result += f"Gender Distribution:\n"
            for gender, count in gender_dist.items():
                result += f"• {gender}: {count} ({count/len(global_data)*100:.1f}%)\n"
            result += "\n"

        return result
    except Exception as e:
        return f"Error generating customer insights: {str(e)}"

def generate_satisfaction_insights():
    global global_data
    if global_data is None:
        return "No data loaded. Please upload data first."

    try:
        if 'Customer Satisfaction Rating' not in global_data.columns:
            return "No satisfaction rating data available."

        result = "CUSTOMER SATISFACTION INSIGHTS\n" + "="*50 + "\n\n"

        satisfaction_dist = global_data['Customer Satisfaction Rating'].value_counts().sort_index()
        avg_satisfaction = global_data['Customer Satisfaction Rating'].mean()

        result += f"Overall Satisfaction Metrics:\n"
        result += f"• Average Rating: {avg_satisfaction:.2f}/5\n"
        result += f"• Satisfaction Distribution:\n"

        labels = ['Very Dissatisfied', 'Dissatisfied', 'Neutral', 'Satisfied', 'Very Satisfied']
        for rating, count in satisfaction_dist.items():
            # Fix: Convert rating to int and check bounds before accessing labels
            rating_int = int(rating)
            if 1 <= rating_int <= 5:
                label = labels[rating_int - 1]  # Convert to 0-based index
            else:
                label = f"Rating {rating_int}"
            result += f"  - {label} ({rating_int}): {count} ({count/len(global_data)*100:.1f}%)\n"

        # Satisfaction by gender if available
        if 'Customer Gender' in global_data.columns:
            result += f"\nSatisfaction by Gender:\n"
            gender_satisfaction = global_data.groupby('Customer Gender')['Customer Satisfaction Rating'].mean()
            for gender, rating in gender_satisfaction.items():
                result += f"• {gender}: {rating:.2f}/5\n"

        return result
    except Exception as e:
        return f"Error generating satisfaction insights: {str(e)}"


# Define dropdown options
GENDER_OPTIONS = ["Male", "Female", "Other"]

PRODUCT_OPTIONS = [
    "Dell XPS", "LG Smart TV", "GoPro Hero", "Microsoft Office", "Autodesk AutoCAD",
    "iPhone 13", "Samsung Galaxy", "MacBook Pro", "iPad", "Sony PlayStation",
    "Nintendo Switch", "Adobe Creative Suite", "HP Laptop", "Canon Camera",
    "Surface Pro", "Google Pixel", "OnePlus", "Xiaomi Phone", "Tesla Model 3"
]

TICKET_TYPE_OPTIONS = [
    "Technical issue", "Billing inquiry", "Product inquiry", "Refund request",
    "Cancellation request", "Installation help", "Account issue", "Feature request",
    "Warranty claim", "Upgrade request"
]

TICKET_SUBJECT_OPTIONS = [
    "Product setup", "Network problem", "Software bug", "Hardware issue",
    "Battery life", "Refund request", "Login issues", "Performance problems",
    "Compatibility issue", "Update problems", "Audio/Video issues", "Shipping inquiry",
    "Payment failed", "Account locked", "Data recovery", "Screen issues"
]

TICKET_STATUS_OPTIONS = [
    "Open", "Closed", "Pending Customer Response", "In Progress", "Escalated", "Resolved"
]

PRIORITY_OPTIONS = ["Low", "Medium", "High", "Critical"]

CHANNEL_OPTIONS = ["Email", "Chat", "Phone", "Social media", "Web form"]

# Create Gradio Interface
with gr.Blocks(title="Customer Satisfaction Prediction System", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# Customer Satisfaction Prediction System")
    gr.Markdown("Upload your customer support dataset and predict satisfaction ratings using machine learning!")

    with gr.Tab("Data Upload & Analysis"):
        gr.Markdown("## Upload Your Customer Support Dataset")
        gr.Markdown("""
        ### Required CSV Columns:
        **Essential:** `Customer Satisfaction Rating` (1-5 scale)

        **Recommended columns:**
        - Customer Age, Customer Gender, Product Purchased
        - Ticket Type, Ticket Subject, Ticket Status, Ticket Priority
        - Ticket Channel, Date of Purchase, etc.
        """)

        with gr.Row():
            file_upload = gr.File(
                label="Upload CSV Dataset",
                file_types=[".csv"],
                type="filepath"
            )
            use_sample_checkbox = gr.Checkbox(
                label="Use Sample Data",
                value=False
            )

        analyze_btn = gr.Button("Load Data & Analyze", variant="primary", size="lg")

        analysis_output = gr.Textbox(
            label="Analysis Results",
            lines=20,
            placeholder="Upload a CSV file or check 'Use Sample Data' and click 'Load Data & Analyze' to see results..."
        )

        analyze_btn.click(
            fn=load_and_analyze_data,
            inputs=[file_upload, use_sample_checkbox],
            outputs=[analysis_output]
        )

        gr.Markdown("## Data Preview")
        preview_btn = gr.Button("Show Data Preview")
        data_preview = gr.DataFrame(
            label="First 10 rows of loaded data"
        )

        preview_btn.click(
            fn=show_data_preview,
            outputs=data_preview
        )

    with gr.Tab("Customer Distribution"):
        gr.Markdown("## Customer Demographics & Distribution Analysis")
        gr.Markdown("Comprehensive analysis of customer demographics including gender, age, and age group distributions.")

        gr.Markdown("""
        ### Available Distributions in this Page:
        - **Customer Gender Distribution**
        - **Customer Age Distribution**
        - **Customer Age Groups Distribution**
        - **Customer Demographics Insights**
        """)

        with gr.Row():
            gender_viz_btn = gr.Button("Gender Distribution", variant="primary")
            age_viz_btn = gr.Button("Age Distribution", variant="primary")
            age_groups_viz_btn = gr.Button("Age Groups", variant="primary")

        with gr.Row():
            with gr.Column():
                gender_plot = gr.Plot(label="Customer Gender Distribution")
            with gr.Column():
                age_plot = gr.Plot(label="Customer Age Distribution")

        with gr.Row():
            age_groups_plot = gr.Plot(label="Customers by Age Group with Satisfaction")

        gender_viz_btn.click(fn=create_customer_gender_viz, outputs=gender_plot)
        age_viz_btn.click(fn=create_customer_age_viz, outputs=age_plot)
        age_groups_viz_btn.click(fn=create_customer_age_groups_viz, outputs=age_groups_plot)

        gr.Markdown("## Customer Demographics Insights")
        customer_insights_btn = gr.Button("Generate Customer Insights", variant="secondary", size="lg")
        customer_insights_output = gr.Textbox(
            label="Customer Demographics Analysis",
            lines=20,
            placeholder="Click 'Generate Customer Insights' to see detailed customer demographics and age group analysis..."
        )

        customer_insights_btn.click(fn=generate_customer_insights, outputs=customer_insights_output)

    with gr.Tab("Tickets Related Distribution"):
        gr.Markdown("## Enhanced Ticekts Data Visualizations")
        gr.Markdown("Comprehensive visual analysis of your customer support data including ticket types and trends")

        gr.Markdown("""
        ### Available Distributions in this Page:
        - **Ticket Priority Level Distribution**
        - **Ticket Channel Distribution**
        - **Ticket Status Distribution**
        - **Ticket Type Distribution**
        - **Ticket Trends Over Time**
        - **Enhanced Data Insights**
        """)

        # Original visualizations
        with gr.Row():
            priority_viz_btn = gr.Button("Priority Level", variant="primary")
            channel_viz_btn = gr.Button("Channel Distribution", variant="primary")
            status_viz_btn = gr.Button("Status Distribution", variant="primary")

        with gr.Row():
            priority_plot = gr.Plot(label="Ticket Priority Level Distribution")

        with gr.Row():
            with gr.Column():
                channel_plot = gr.Plot(label="Ticket Channel Distribution")
            with gr.Column():
                status_plot = gr.Plot(label="Ticket Status Distribution")

        priority_viz_btn.click(fn=create_priority_viz, outputs=priority_plot)
        channel_viz_btn.click(fn=create_channel_viz, outputs=channel_plot)
        status_viz_btn.click(fn=create_status_viz, outputs=status_plot)

        # Added ticket-related visualizations
        gr.Markdown("## Ticket Type & Trends Analysis")

        # Ticket Type Analysis Section
        gr.Markdown("### Ticket Type Distribution")
        with gr.Row():
            ticket_type_btn = gr.Button("Analyze Ticket Types", variant="primary", size="lg")

        with gr.Row():
            ticket_type_plot = gr.Plot(label="Ticket Type Distribution with Satisfaction")

        ticket_type_analysis = gr.Textbox(
            label="Ticket Type Analysis Results",
            lines=15,
            placeholder="Click 'Analyze Ticket Types' to see distribution and satisfaction by ticket type..."
        )

        ticket_type_btn.click(
            fn=lambda: [create_ticket_type_viz(), analyze_ticket_types_interface()],
            outputs=[ticket_type_plot, ticket_type_analysis]
        )

        # Ticket Trends Analysis Section
        gr.Markdown("### Customer Support Ticket Trends Over Time")
        with gr.Row():
            ticket_trends_btn = gr.Button("Analyze Ticket Trends", variant="primary", size="lg")

        with gr.Row():
            ticket_trends_plot = gr.Plot(label="Ticket Trends Over Time Dashboard")

        ticket_trends_analysis = gr.Textbox(
            label="Ticket Trends Analysis Results",
            lines=20,
            placeholder="Click 'Analyze Ticket Trends' to see temporal trends and patterns..."
        )

        ticket_trends_btn.click(
            fn=lambda: [create_ticket_trends_viz(), analyze_ticket_trends_interface()],
            outputs=[ticket_trends_plot, ticket_trends_analysis]
        )

        gr.Markdown("## Data Insights")
        insights_btn = gr.Button("Generate Enhanced Insights", variant="secondary", size="lg")
        insights_output = gr.Textbox(
            label="Enhanced Data Insights",
            lines=15,
            placeholder="Click 'Generate Enhanced Insights' to see detailed statistics and analysis..."
        )

        insights_btn.click(fn=generate_insights, outputs=insights_output)

    with gr.Tab("Satisfaction Distribution"):
        gr.Markdown("## Customer Satisfaction Analysis")
        gr.Markdown("Detailed analysis of customer satisfaction patterns and distribution across different demographics.")

        gr.Markdown("""
        ### Available Distributions in this Page:
        - **Customer Satisfaction by Gender**
        - **Overall Satisfaction Distribution**
        - **Satisfaction Analytics & Insights**
        """)

        with gr.Row():
            satisfaction_gender_btn = gr.Button("Satisfaction by Gender", variant="primary")
            satisfaction_dist_btn = gr.Button("Satisfaction Distribution", variant="primary")

        with gr.Row():
            with gr.Column():
                satisfaction_gender_plot = gr.Plot(label="Customer Satisfaction by Gender")
            with gr.Column():
                satisfaction_distribution_plot = gr.Plot(label="Overall Satisfaction Distribution")

        satisfaction_gender_btn.click(fn=create_satisfaction_by_gender_viz, outputs=satisfaction_gender_plot)
        satisfaction_dist_btn.click(fn=create_satisfaction_distribution_viz, outputs=satisfaction_distribution_plot)

        gr.Markdown("## Satisfaction Insights & Analytics")
        satisfaction_insights_btn = gr.Button("Generate Satisfaction Insights", variant="secondary", size="lg")
        satisfaction_insights_output = gr.Textbox(
            label="Customer Satisfaction Analytics",
            lines=20,
            placeholder="Click 'Generate Satisfaction Insights' to see detailed satisfaction analysis and statistics..."
        )

        satisfaction_insights_btn.click(fn=generate_satisfaction_insights, outputs=satisfaction_insights_output)

    with gr.Tab("Top Product Distribution"):
        gr.Markdown("## Top 10 Most Purchased Products Analysis")
        gr.Markdown("Analyze the most popular products overall and by customer gender with detailed purchase statistics and satisfaction ratings.")

        gr.Markdown("""
        ### Available Distributions in this Page:
        - **Top Products - All Customers**
        - **Top Products - Female Customers**
        - **Top Products - Male Customers**
        - **Top Products - Other Gender**
        - **Detailed Product Analysis**
        """)

        gr.Markdown("### Filter by Customer Gender")
        with gr.Row():
            all_btn = gr.Button("All Customers", variant="primary")
            female_btn = gr.Button("Female", variant="secondary")
            male_btn = gr.Button("Male", variant="secondary")
            other_btn = gr.Button("Other", variant="secondary")

        with gr.Row():
            products_plot = gr.Plot(label="Top 10 Products by Purchase Volume")

        gr.Markdown("### Detailed Product Analysis")
        products_analysis = gr.Textbox(
            label="Top Products Analysis Results",
            lines=25,
            placeholder="Click on a gender filter button to see the top 10 most purchased products..."
        )

        all_btn.click(
            fn=lambda: [create_top_products_viz('all'), analyze_top_products_interface('all')],
            outputs=[products_plot, products_analysis]
        )
        female_btn.click(
            fn=lambda: [create_top_products_viz('Female'), analyze_top_products_interface('Female')],
            outputs=[products_plot, products_analysis]
        )
        male_btn.click(
            fn=lambda: [create_top_products_viz('Male'), analyze_top_products_interface('Male')],
            outputs=[products_plot, products_analysis]
        )
        other_btn.click(
            fn=lambda: [create_top_products_viz('Other'), analyze_top_products_interface('Other')],
            outputs=[products_plot, products_analysis]
        )

    with gr.Tab("Model Training"):
        gr.Markdown("## Train Customer Satisfaction Prediction Model")
        gr.Markdown("Train a Random Forest model on your uploaded dataset.")

        train_btn = gr.Button("Train Model", variant="primary", size="lg")
        training_output = gr.Textbox(
            label="Training Results",
            lines=20,
            placeholder="Upload data first, then click 'Train Model' to see results..."
        )

        train_btn.click(
            fn=train_model_interface,
            outputs=training_output
        )

    with gr.Tab("Make Predictions"):
        gr.Markdown("## Predict Customer Satisfaction")
        gr.Markdown("Select customer and ticket details from the dropdown menus to predict satisfaction rating.")

        with gr.Row():
            with gr.Column():
                age_input = gr.Number(
                    label="Customer Age",
                    value=35,
                    minimum=18,
                    maximum=100
                )
                gender_input = gr.Dropdown(
                    label="Gender",
                    choices=GENDER_OPTIONS,
                    value="Male"
                )
                product_input = gr.Dropdown(
                    label="Product Purchased",
                    choices=PRODUCT_OPTIONS,
                    value="iPhone 13"
                )
                ticket_type_input = gr.Dropdown(
                    label="Ticket Type",
                    choices=TICKET_TYPE_OPTIONS,
                    value="Technical issue"
                )

            with gr.Column():
                subject_input = gr.Dropdown(
                    label="Ticket Subject",
                    choices=TICKET_SUBJECT_OPTIONS,
                    value="Product setup"
                )
                status_input = gr.Dropdown(
                    label="Ticket Status",
                    choices=TICKET_STATUS_OPTIONS,
                    value="Open"
                )
                priority_input = gr.Dropdown(
                    label="Priority",
                    choices=PRIORITY_OPTIONS,
                    value="Medium"
                )
                channel_input = gr.Dropdown(
                    label="Channel",
                    choices=CHANNEL_OPTIONS,
                    value="Email"
                )

        predict_btn = gr.Button("Predict Satisfaction", variant="primary", size="lg")
        prediction_output = gr.Textbox(
            label="Prediction Result",
            lines=10,
            placeholder="Select the details from dropdowns and click 'Predict Satisfaction'..."
        )

        predict_btn.click(
            fn=predict_single,
            inputs=[age_input, gender_input, product_input, ticket_type_input,
                   subject_input, status_input, priority_input, channel_input],
            outputs=prediction_output
        )

# Launch the interface
if __name__ == "__main__":
    demo.launch(share=True, debug=True)

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://127082f41f0cde8466.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://127082f41f0cde8466.gradio.live
