In [None]:
# Visualization Development - Market Research System v1.0
# File: notebooks/development/visualization_dev.ipynb
# Created: February 2022
# Purpose: Development of comprehensive visualization suite for Indian market analysis

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import yfinance as yf
from datetime import datetime, timedelta
import warnings
warnings.filterwarnings('ignore')

# Configure matplotlib and seaborn
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

class MarketVisualizer:
    """
    Comprehensive visualization suite for Indian stock market analysis
    Developed for Market Research System v1.0 (2022)
    """
    
    def __init__(self):
        self.colors = {
            'primary': '#1f77b4',
            'secondary': '#ff7f0e', 
            'success': '#2ca02c',
            'danger': '#d62728',
            'warning': '#ff7f0e',
            'info': '#17a2b8',
            'neutral': '#6c757d'
        }
        
    def candlestick_chart(self, data, title="Stock Price Analysis", indicators=None):
        """
        Create interactive candlestick chart with technical indicators
        """
        fig = make_subplots(
            rows=3, cols=1,
            shared_xaxes=True,
            vertical_spacing=0.03,
            subplot_titles=(title, 'Volume', 'Technical Indicators'),
            row_width=[0.2, 0.1, 0.1]
        )
        
        # Candlestick chart
        fig.add_trace(
            go.Candlestick(
                x=data.index,
                open=data['Open'],
                high=data['High'],
                low=data['Low'],
                close=data['Close'],
                name="Price",
                increasing_line_color='green',
                decreasing_line_color='red'
            ),
            row=1, col=1
        )
        
        # Add moving averages if provided
        if indicators and 'SMA_20' in indicators:
            fig.add_trace(
                go.Scatter(
                    x=data.index,
                    y=indicators['SMA_20'],
                    mode='lines',
                    name='SMA 20',
                    line=dict(color='orange', width=2)
                ),
                row=1, col=1
            )
            
        if indicators and 'SMA_50' in indicators:
            fig.add_trace(
                go.Scatter(
                    x=data.index,
                    y=indicators['SMA_50'],
                    mode='lines',
                    name='SMA 50',
                    line=dict(color='blue', width=2)
                ),
                row=1, col=1
            )
        
        # Volume chart
        colors = ['green' if close >= open else 'red' 
                 for close, open in zip(data['Close'], data['Open'])]
        
        fig.add_trace(
            go.Bar(
                x=data.index,
                y=data['Volume'],
                name='Volume',
                marker_color=colors,
                opacity=0.7
            ),
            row=2, col=1
        )
        
        # RSI if provided
        if indicators and 'RSI' in indicators:
            fig.add_trace(
                go.Scatter(
                    x=data.index,
                    y=indicators['RSI'],
                    mode='lines',
                    name='RSI',
                    line=dict(color='purple', width=2)
                ),
                row=3, col=1
            )
            
            # RSI levels
            fig.add_hline(y=70, line_dash="dash", line_color="red", 
                         annotation_text="Overbought", row=3, col=1)
            fig.add_hline(y=30, line_dash="dash", line_color="green", 
                         annotation_text="Oversold", row=3, col=1)
        
        # Update layout
        fig.update_layout(
            title=title,
            yaxis_title="Price (₹)",
            xaxis_title="Date",
            template="plotly_white",
            height=800,
            showlegend=True
        )
        
        return fig
    
    def sector_heatmap(self, sector_data):
        """
        Create sector performance heatmap
        """
        plt.figure(figsize=(14, 10))
        
        # Create pivot table for heatmap
        pivot_data = sector_data.pivot(index='Sector', columns='Metric', values='Value')
        
        # Create heatmap
        sns.heatmap(
            pivot_data,
            annot=True,
            fmt='.2f',
            cmap='RdYlGn',
            center=0,
            square=True,
            cbar_kws={"shrink": .8}
        )
        
        plt.title('Sector Performance Heatmap - Indian Market', 
                 fontsize=16, fontweight='bold', pad=20)
        plt.xlabel('Performance Metrics', fontsize=12)
        plt.ylabel('Sectors', fontsize=12)
        plt.xticks(rotation=45)
        plt.yticks(rotation=0)
        plt.tight_layout()
        
        return plt.gcf()
    
    def correlation_matrix(self, correlation_data, stocks):
        """
        Create correlation matrix visualization
        """
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
        
        # Correlation heatmap
        mask = np.triu(np.ones_like(correlation_data))
        sns.heatmap(
            correlation_data,
            mask=mask,
            annot=True,
            fmt='.2f',
            cmap='coolwarm',
            center=0,
            square=True,
            ax=ax1,
            cbar_kws={"shrink": .8}
        )
        ax1.set_title('Stock Correlation Matrix', fontsize=14, fontweight='bold')
        
        # Correlation network (simplified)
        # High correlation pairs
        high_corr_pairs = []
        for i in range(len(stocks)):
            for j in range(i+1, len(stocks)):
                corr_val = correlation_data.iloc[i, j]
                if abs(corr_val) > 0.7:  # High correlation threshold
                    high_corr_pairs.append((stocks[i], stocks[j], corr_val))
        
        # Plot high correlation pairs as bar chart
        if high_corr_pairs:
            pairs_labels = [f"{pair[0]}-{pair[1]}" for pair in high_corr_pairs]
            correlations = [pair[2] for pair in high_corr_pairs]
            
            bars = ax2.bar(range(len(pairs_labels)), correlations, 
                          color=['green' if x > 0 else 'red' for x in correlations])
            ax2.set_xticks(range(len(pairs_labels)))
            ax2.set_xticklabels(pairs_labels, rotation=45)
            ax2.set_title('High Correlation Pairs (>0.7)', fontsize=14, fontweight='bold')
            ax2.set_ylabel('Correlation Coefficient')
            ax2.grid(True, alpha=0.3)
            
            # Add value labels on bars
            for bar, val in zip(bars, correlations):
                height = bar.get_height()
                ax2.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                        f'{val:.2f}', ha='center', va='bottom')
        
        plt.tight_layout()
        return fig
    
    def portfolio_performance(self, portfolio_data):
        """
        Create portfolio performance dashboard
        """
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=('Cumulative Returns', 'Drawdown', 
                          'Monthly Returns Heatmap', 'Risk-Return Scatter'),
            specs=[[{"type": "scatter"}, {"type": "scatter"}],
                   [{"type": "heatmap"}, {"type": "scatter"}]]
        )
        
        # Cumulative returns
        fig.add_trace(
            go.Scatter(
                x=portfolio_data.index,
                y=portfolio_data['cumulative_returns'],
                mode='lines',
                name='Portfolio',
                line=dict(color='blue', width=3)
            ),
            row=1, col=1
        )
        
        # Drawdown
        fig.add_trace(
            go.Scatter(
                x=portfolio_data.index,
                y=portfolio_data['drawdown'],
                fill='tozeroy',
                mode='lines',
                name='Drawdown',
                line=dict(color='red'),
                fillcolor='rgba(255,0,0,0.3)'
            ),
            row=1, col=2
        )
        
        fig.update_layout(
            title="Portfolio Performance Dashboard",
            height=800,
            showlegend=True
        )
        
        return fig
    
    def technical_signals_plot(self, data, signals):
        """
        Plot price with buy/sell signals
        """
        fig, ax = plt.subplots(figsize=(15, 8))
        
        # Plot price
        ax.plot(data.index, data['Close'], label='Close Price', linewidth=2, color='black')
        
        # Plot buy signals
        buy_signals = signals[signals['Action'] == 'BUY']
        if not buy_signals.empty:
            ax.scatter(buy_signals.index, buy_signals['Price'], 
                      color='green', marker='^', s=100, label='Buy Signal', zorder=5)
        
        # Plot sell signals
        sell_signals = signals[signals['Action'] == 'SELL']
        if not sell_signals.empty:
            ax.scatter(sell_signals.index, sell_signals['Price'], 
                      color='red', marker='v', s=100, label='Sell Signal', zorder=5)
        
        ax.set_title('Trading Signals Analysis', fontsize=16, fontweight='bold')
        ax.set_xlabel('Date')
        ax.set_ylabel('Price (₹)')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        return fig
    
    def economic_indicators_dashboard(self, indicators_data):
        """
        Create economic indicators dashboard
        """
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=('GDP Growth', 'Inflation Rate', 
                          'Interest Rates', 'Market Indices'),
            specs=[[{"secondary_y": True}, {"secondary_y": True}],
                   [{"secondary_y": True}, {"secondary_y": True}]]
        )
        
        # This is a template - actual implementation would require real economic data
        # For now, we'll create sample visualizations
        
        dates = pd.date_range(start='2022-01-01', end='2022-12-31', freq='M')
        
        # GDP Growth (sample data)
        gdp_growth = np.random.normal(6.5, 1.5, len(dates))
        fig.add_trace(
            go.Scatter(x=dates, y=gdp_growth, mode='lines+markers', 
                      name='GDP Growth %', line=dict(color='green')),
            row=1, col=1
        )
        
        # Inflation Rate (sample data)
        inflation = np.random.normal(5.8, 1.2, len(dates))
        fig.add_trace(
            go.Scatter(x=dates, y=inflation, mode='lines+markers', 
                      name='Inflation %', line=dict(color='red')),
            row=1, col=2
        )
        
        fig.update_layout(
            title="Economic Indicators Dashboard - India",
            height=600,
            showlegend=True
        )
        
        return fig


def create_sample_data():
    """
    Create sample data for visualization testing
    """
    # Fetch real data for major Indian stocks
    stocks = ['RELIANCE.NS', 'TCS.NS', 'INFY.NS', 'HDFCBANK.NS', 'ICICIBANK.NS']
    stock_data = {}
    
    for stock in stocks:
        try:
            ticker = yf.Ticker(stock)
            data = ticker.history(period="1y")
            if len(data) > 0:
                stock_data[stock] = data
                print(f"✓ Fetched data for {stock}")
        except Exception as e:
            print(f"✗ Error fetching {stock}: {e}")
    
    return stock_data


def create_sector_sample_data():
    """
    Create sample sector performance data
    """
    sectors = ['Technology', 'Banking', 'Pharmaceuticals', 'Energy', 
               'Automotive', 'FMCG', 'Metals', 'Telecom']
    metrics = ['1M Return %', '3M Return %', '6M Return %', '1Y Return %', 'Volatility %']
    
    data = []
    for sector in sectors:
        for metric in metrics:
            if 'Return' in metric:
                value = np.random.normal(8, 15)  # Random returns
            else:  # Volatility
                value = np.random.normal(25, 10)  # Random volatility
            
            data.append({
                'Sector': sector,
                'Metric': metric,
                'Value': value
            })
    
    return pd.DataFrame(data)


def test_visualizations():
    """
    Test all visualization functions
    """
    print("=== Visualization Development - Market Research System v1.0 ===")
    print("Testing visualization suite...")
    print("=" * 70)
    
    visualizer = MarketVisualizer()
    
    # Get sample data
    stock_data = create_sample_data()
    sector_data = create_sector_sample_data()
    
    if stock_data:
        # Test candlestick chart
        first_stock = list(stock_data.keys())[0]
        data = stock_data[first_stock]
        
        print(f"Creating candlestick chart for {first_stock}...")
        
        # Calculate simple indicators for testing
        indicators = {
            'SMA_20': data['Close'].rolling(20).mean(),
            'SMA_50': data['Close'].rolling(50).mean(),
            'RSI': calculate_rsi(data['Close'])
        }
        
        candlestick_fig = visualizer.candlestick_chart(data, 
                                                      f"{first_stock} Technical Analysis", 
                                                      indicators)
        # Save as HTML
        candlestick_fig.write_html("reports/candlestick_analysis.html")
        print("✓ Candlestick chart created and saved")
        
        # Test correlation matrix
        if len(stock_data) >= 3:
            print("Creating correlation matrix...")
            returns_data = pd.DataFrame()
            stock_names = []
            
            for stock, data in list(stock_data.items())[:5]:  # Limit to 5 stocks
                returns = data['Close'].pct_change().dropna()
                returns_data[stock] = returns
                stock_names.append(stock)
            
            correlation_matrix = returns_data.corr()
            corr_fig = visualizer.correlation_matrix(correlation_matrix, stock_names)
            plt.savefig("reports/correlation_matrix.png", dpi=300, bbox_inches='tight')
            print("✓ Correlation matrix created and saved")
    
    # Test sector heatmap
    print("Creating sector heatmap...")
    sector_fig = visualizer.sector_heatmap(sector_data)
    plt.savefig("reports/sector_heatmap.png", dpi=300, bbox_inches='tight')
    print("✓ Sector heatmap created and saved")
    
    # Test economic indicators dashboard
    print("Creating economic indicators dashboard...")
    econ_fig = visualizer.economic_indicators_dashboard({})
    econ_fig.write_html("reports/economic_dashboard.html")
    print("✓ Economic indicators dashboard created and saved")
    
    print("\n=== Visualization Development Complete ===")
    print("All visualization functions tested successfully!")
    print("Charts saved in reports/ directory")


def calculate_rsi(prices, period=14):
    """
    Helper function to calculate RSI
    """
    delta = prices.diff()
    gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
    loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
    rs = gain / loss
    rsi = 100 - (100 / (1 + rs))
    return rsi


# Main execution
if __name__ == "__main__":
    test_visualizations()
    