In [None]:
"""
Foreign Trade Analysis - Kenya's International Trade Performance
===============================================================

Comprehensive analysis of Kenya's foreign trade patterns, including:
- Export and import trends
- Trade balance analysis  
- Trading partner analysis
- Product/commodity analysis
- Trade policy impact assessment
"""

import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Set style for better visualizations
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

class ForeignTradeAnalyzer:
    """Advanced Foreign Trade Analysis Engine"""
    
    def __init__(self):
        self.trade_data = {}
        self.export_data = {}
        self.import_data = {}
        
    def load_trade_data(self, data_path="data/raw/"):
        """Load all trade-related datasets"""
        
        # Trade summary data
        try:
            self.trade_summary = pd.read_csv(f"{data_path}Foreign Trade Summary (Ksh Million).csv", skiprows=2)
            print("✅ Loaded Foreign Trade Summary")
        except:
            print("❌ Could not load Foreign Trade Summary")
            
        # Export data by regions
        try:
            self.exports_africa = pd.read_csv(f"{data_path}Value of Exports to Selected African Countries (Ksh Million).csv", skiprows=2)
            self.exports_row = pd.read_csv(f"{data_path}Value of Exports to Selected Rest of World Countries (Ksh Million).csv", skiprows=2)
            self.exports_domestic = pd.read_csv(f"{data_path}Value of Selected Domestic Exports (Ksh Million).csv", skiprows=2)
            print("✅ Loaded Export Data")
        except:
            print("❌ Could not load Export Data")
            
        # Import data by regions
        try:
            self.imports_africa = pd.read_csv(f"{data_path}Value of Direct Imports from Selected African Countries (Ksh. Million).xlsx")
            self.imports_row = pd.read_csv(f"{data_path}Value of Direct Imports from Selected Rest of World Countries  (Kshs. Millions).csv", skiprows=2)
            print("✅ Loaded Import Data")
        except:
            print("❌ Could not load Import Data")
            
        # Export volumes and prices
        try:
            self.export_volumes = pd.read_csv(f"{data_path}Principal Exports Volume, Value and Unit Prices (Ksh Million).csv", skiprows=2)
            print("✅ Loaded Export Volumes and Prices")
        except:
            print("❌ Could not load Export Volumes and Prices")
    
    def clean_trade_data(self):
        """Clean and standardize trade data"""
        
        # Clean trade summary
        if hasattr(self, 'trade_summary'):
            # Remove rows with all NaN values
            self.trade_summary = self.trade_summary.dropna(how='all')
            
            # Convert numeric columns
            numeric_cols = self.trade_summary.select_dtypes(include=[object]).columns
            for col in numeric_cols[1:]:  # Skip first column (usually time/period)
                self.trade_summary[col] = pd.to_numeric(self.trade_summary[col].astype(str).str.replace(',', ''), errors='coerce')
        
        return self
    
    def analyze_trade_balance(self):
        """Analyze Kenya's trade balance trends"""
        
        if not hasattr(self, 'trade_summary'):
            print("Trade summary data not available")
            return None
            
        # Calculate trade balance
        df = self.trade_summary.copy()
        
        # Identify export and import columns
        export_cols = [col for col in df.columns if 'export' in col.lower() or 'domestic export' in col.lower()]
        import_cols = [col for col in df.columns if 'import' in col.lower() or 're-export' in col.lower()]
        
        if export_cols and import_cols:
            df['Total_Exports'] = df[export_cols].sum(axis=1, skipna=True)
            df['Total_Imports'] = df[import_cols].sum(axis=1, skipna=True)
            df['Trade_Balance'] = df['Total_Exports'] - df['Total_Imports']
            df['Trade_Balance_Ratio'] = df['Total_Exports'] / df['Total_Imports']
        
        return df
    
    def analyze_trading_partners(self):
        """Analyze Kenya's key trading partners"""
        
        partners_analysis = {}
        
        # Africa exports analysis
        if hasattr(self, 'exports_africa'):
            africa_total = self.exports_africa.select_dtypes(include=[np.number]).sum().sort_values(ascending=False)
            partners_analysis['top_africa_export_destinations'] = africa_total.head(10)
        
        # Rest of World exports analysis  
        if hasattr(self, 'exports_row'):
            row_total = self.exports_row.select_dtypes(include=[np.number]).sum().sort_values(ascending=False)
            partners_analysis['top_row_export_destinations'] = row_total.head(10)
        
        # Africa imports analysis
        if hasattr(self, 'imports_africa'):
            africa_imports = self.imports_africa.select_dtypes(include=[np.number]).sum().sort_values(ascending=False)
            partners_analysis['top_africa_import_sources'] = africa_imports.head(10)
        
        # Rest of World imports analysis
        if hasattr(self, 'imports_row'):
            row_imports = self.imports_row.select_dtypes(include=[np.number]).sum().sort_values(ascending=False)
            partners_analysis['top_row_import_sources'] = row_imports.head(10)
        
        return partners_analysis
    
    def analyze_export_commodities(self):
        """Analyze export commodity performance"""
        
        if not hasattr(self, 'export_volumes'):
            print("Export volumes data not available")
            return None
        
        commodity_analysis = {}
        df = self.export_volumes.copy()
        
        # Get commodity columns (excluding time columns)
        commodity_cols = [col for col in df.columns if not any(x in col.lower() for x in ['period', 'year', 'month', 'date'])]
        
        # Calculate total values and trends for each commodity
        for col in commodity_cols:
            if df[col].dtype in ['int64', 'float64']:
                commodity_analysis[col] = {
                    'total_value': df[col].sum(),
                    'average_value': df[col].mean(),
                    'volatility': df[col].std(),
                    'trend': np.polyfit(range(len(df[col].dropna())), df[col].dropna(), 1)[0]
                }
        
        return commodity_analysis
    
    def create_trade_visualizations(self):
        """Create comprehensive trade visualizations"""
        
        visualizations = {}
        
        # 1. Trade Balance Trend
        trade_balance_df = self.analyze_trade_balance()
        if trade_balance_df is not None and 'Trade_Balance' in trade_balance_df.columns:
            fig_balance = go.Figure()
            fig_balance.add_trace(go.Scatter(
                x=range(len(trade_balance_df)),
                y=trade_balance_df['Trade_Balance'],
                mode='lines+markers',
                name='Trade Balance',
                line=dict(color='blue', width=3)
            ))
            fig_balance.update_layout(
                title='Kenya Trade Balance Trend',
                xaxis_title='Period',
                yaxis_title='Trade Balance (Ksh Millions)',
                template='plotly_white'
            )
            visualizations['trade_balance'] = fig_balance
        
        # 2. Trading Partners Analysis
        partners = self.analyze_trading_partners()
        
        if 'top_africa_export_destinations' in partners:
            fig_africa_exports = px.bar(
                x=partners['top_africa_export_destinations'].values,
                y=partners['top_africa_export_destinations'].index,
                orientation='h',
                title='Top African Export Destinations',
                labels={'x': 'Export Value (Ksh Millions)', 'y': 'Country'}
            )
            fig_africa_exports.update_layout(height=500)
            visualizations['africa_exports'] = fig_africa_exports
        
        if 'top_row_export_destinations' in partners:
            fig_row_exports = px.bar(
                x=partners['top_row_export_destinations'].values,
                y=partners['top_row_export_destinations'].index,
                orientation='h',
                title='Top Rest of World Export Destinations',
                labels={'x': 'Export Value (Ksh Millions)', 'y': 'Country'}
            )
            fig_row_exports.update_layout(height=500)
            visualizations['row_exports'] = fig_row_exports
        
        # 3. Export Commodities Performance
        commodities = self.analyze_export_commodities()
        if commodities:
            commodity_values = {k: v['total_value'] for k, v in commodities.items()}
            top_commodities = dict(sorted(commodity_values.items(), key=lambda x: x[1], reverse=True)[:10])
            
            fig_commodities = px.pie(
                values=list(top_commodities.values()),
                names=list(top_commodities.keys()),
                title='Top Export Commodities by Value'
            )
            visualizations['commodities'] = fig_commodities
        
        return visualizations
    
    def generate_trade_insights(self):
        """Generate key insights from trade analysis"""
        
        insights = []
        
        # Trade balance insights
        trade_balance_df = self.analyze_trade_balance()
        if trade_balance_df is not None and 'Trade_Balance' in trade_balance_df.columns:
            avg_balance = trade_balance_df['Trade_Balance'].mean()
            if avg_balance > 0:
                insights.append(f"Kenya maintains a positive trade balance with an average surplus of Ksh {avg_balance:.1f} million")
            else:
                insights.append(f"Kenya has a trade deficit with an average deficit of Ksh {abs(avg_balance):.1f} million")
            
            # Recent trend
            recent_trend = trade_balance_df['Trade_Balance'].tail(12).mean() - trade_balance_df['Trade_Balance'].head(12).mean()
            if recent_trend > 0:
                insights.append("Trade balance has improved in recent periods")
            else:
                insights.append("Trade balance has deteriorated in recent periods")
        
        # Trading partners insights
        partners = self.analyze_trading_partners()
        if 'top_africa_export_destinations' in partners:
            top_africa_partner = partners['top_africa_export_destinations'].index[0]
            insights.append(f"Top African export destination: {top_africa_partner}")
        
        if 'top_row_export_destinations' in partners:
            top_row_partner = partners['top_row_export_destinations'].index[0]
            insights.append(f"Top Rest of World export destination: {top_row_partner}")
        
        # Export diversification
        commodities = self.analyze_export_commodities()
        if commodities:
            commodity_values = [v['total_value'] for v in commodities.values()]
            # Calculate Herfindahl index for concentration
            total_value = sum(commodity_values)
            shares = [v/total_value for v in commodity_values]
            herfindahl = sum([s**2 for s in shares])
            
            if herfindahl > 0.25:
                insights.append("Export portfolio is highly concentrated - diversification recommended")
            elif herfindahl > 0.15:
                insights.append("Export portfolio shows moderate concentration")
            else:
                insights.append("Export portfolio is well diversified")
        
        return insights

# Initialize analyzer
analyzer = ForeignTradeAnalyzer()

print("🌍 Foreign Trade Analysis Notebook")
print("=" * 50)
print("This notebook provides comprehensive analysis of Kenya's foreign trade performance")
print("including exports, imports, trading partners, and commodity analysis.")
print()

# Load and analyze data
print("📊 Loading trade data...")
analyzer.load_trade_data()
analyzer.clean_trade_data()

print("\n🔍 Generating trade insights...")
insights = analyzer.generate_trade_insights()
for i, insight in enumerate(insights, 1):
    print(f"{i}. {insight}")

print("\n📈 Creating visualizations...")
visualizations = analyzer.create_trade_visualizations()
print(f"Generated {len(visualizations)} visualizations")

print("\n✅ Foreign Trade Analysis Complete!")
print("Key outputs:")
print("- Trade balance analysis")
print("- Trading partner rankings")
print("- Export commodity performance")
print("- Interactive visualizations")
print("- Strategic insights and recommendations")