In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from collections import defaultdict
from faker import Faker
import random
import torch.nn.functional as F
from typing import Dict, List
import matplotlib.pyplot as plt


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

class BusinessDataset(Dataset):
    def __init__(self, X, y_revenue, y_risk, y_profitability, y_churn, y_suggestions, month_nums, quarter_nums):
        self.X = X if torch.is_tensor(X) else torch.FloatTensor(X)
        self.y_revenue = y_revenue if torch.is_tensor(y_revenue) else torch.FloatTensor(y_revenue)
        self.y_risk = y_risk if torch.is_tensor(y_risk) else torch.LongTensor(y_risk)
        self.y_profitability = y_profitability if torch.is_tensor(y_profitability) else torch.FloatTensor(y_profitability)
        self.y_churn = y_churn if torch.is_tensor(y_churn) else torch.FloatTensor(y_churn)
        self.y_suggestions = y_suggestions if torch.is_tensor(y_suggestions) else torch.LongTensor(y_suggestions)
        self.month_nums = month_nums if torch.is_tensor(month_nums) else torch.LongTensor(month_nums)
        self.quarter_nums = quarter_nums if torch.is_tensor(quarter_nums) else torch.LongTensor(quarter_nums)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return (
            self.X[idx],
            self.y_revenue[idx],
            self.y_risk[idx],
            self.y_profitability[idx],
            self.y_churn[idx],
            self.y_suggestions[idx],
            self.month_nums[idx],
            self.quarter_nums[idx]
        )


class EnhancedBusinessPredictor(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int = 128, num_risk_classes: int = 3, num_suggestion_classes: int = 6):
        super().__init__()
        self.hidden_dim = hidden_dim

        # Temporal embedding layers
        self.month_embedding = nn.Embedding(13, 16)  # 12 months + padding
        self.quarter_embedding = nn.Embedding(5, 8)   # 4 quarters + padding

        # Linear layer to match dimensions before LSTM
        self.input_proj = nn.Linear(input_dim + 24, hidden_dim)

        # LSTM for temporal dependencies
        self.lstm = nn.LSTM(
            input_size=hidden_dim,
            hidden_size=hidden_dim,
            num_layers=2,
            dropout=0.2,
            batch_first=True
        )

        # Task-specific layers with dropout
        self.dropout = nn.Dropout(0.3)

        # Revenue prediction
        self.revenue_layer = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 1)
        )

        # Risk assessment
        self.risk_layer = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, num_risk_classes)
        )

        # Profitability prediction
        self.profitability_layer = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 1)
        )

        # Churn prediction
        self.churn_layer = nn.Sequential(
            nn.Linear(hidden_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )

        # Action recommendation
        self.suggestion_layer = nn.Sequential(
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_suggestion_classes)
        )

    def forward(self, x, months, quarters):
        # Get temporal embeddings
        month_emb = self.month_embedding(months)
        quarter_emb = self.quarter_embedding(quarters)
        temporal_features = torch.cat([month_emb, quarter_emb], dim=-1)

        # Combine features
        combined = torch.cat([x, temporal_features], dim=-1)

        # Project to correct dimension
        projected = self.input_proj(combined)
        projected = projected.unsqueeze(1)  # Add sequence dimension

        # Process through LSTM
        lstm_out, _ = self.lstm(projected)

        # Get final hidden state
        features = lstm_out.squeeze(1)
        features = self.dropout(features)

        # Task-specific predictions
        return {
            'revenue': self.revenue_layer(features),
            'risk': self.risk_layer(features),
            'profitability': self.profitability_layer(features),
            'churn': self.churn_layer(features),
            'suggestions': self.suggestion_layer(features)
        }



class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Linear(in_channels, out_channels)
        self.conv2 = nn.Linear(out_channels, out_channels)
        self.norm1 = nn.LayerNorm(out_channels)
        self.norm2 = nn.LayerNorm(out_channels)
        self.dropout = nn.Dropout(0.2)

        # Projection shortcut if dimensions don't match
        self.shortcut = nn.Linear(in_channels, out_channels) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        identity = self.shortcut(x)

        x = self.conv1(x)
        x = self.norm1(x)
        x = F.gelu(x)  # GELU activation for better gradient flow
        x = self.dropout(x)

        x = self.conv2(x)
        x = self.norm2(x)
        x = self.dropout(x)

        x += identity
        return F.gelu(x)

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = x.unsqueeze(1)  # Add sequence length dimension
        attn_output, _ = self.attention(x, x, x)
        x = x + attn_output
        x = self.norm(x)
        return x.squeeze(1)

class TaskNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dims=[64, 32], dropout=0.2):
        super().__init__()
        layers = []
        prev_dim = input_dim

        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.LayerNorm(hidden_dim),
                nn.GELU(),
                nn.Dropout(dropout)
            ])
            prev_dim = hidden_dim

        layers.append(nn.Linear(prev_dim, output_dim))
        self.network = nn.Sequential(*layers)

    def forward(self, x):
        return self.network(x)

def generate_business_data(n_samples):
    np.random.seed(42)
    fake = Faker()

    def generate_pl_statement():
        revenue = np.random.randint(500000, 5000000)
        cogs = np.random.randint(100000, revenue)
        gross_profit = revenue - cogs
        operating_expenses = np.random.randint(100000, 1000000)
        operating_income = gross_profit - operating_expenses
        other_income = np.random.randint(10000, 100000)
        net_profit = operating_income + other_income
        return f"Revenue: {revenue}\nCOGS: {cogs}\nGross Profit: {gross_profit}\n" \
               f"Operating Expenses: {operating_expenses}\nOperating Income: {operating_income}\n" \
               f"Other Income: {other_income}\nNet Profit: {net_profit}"

    def generate_balance_sheet():
        current_assets = np.random.randint(500000, 2000000)
        fixed_assets = np.random.randint(1000000, 5000000)
        total_assets = current_assets + fixed_assets
        current_liabilities = np.random.randint(100000, 500000)
        long_term_liabilities = np.random.randint(200000, 1000000)
        total_liabilities = current_liabilities + long_term_liabilities
        equity = total_assets - total_liabilities
        return f"Current Assets: {current_assets}\nFixed Assets: {fixed_assets}\n" \
               f"Total Assets: {total_assets}\nCurrent Liabilities: {current_liabilities}\n" \
               f"Long-term Liabilities: {long_term_liabilities}\nTotal Liabilities: {total_liabilities}\n" \
               f"Equity: {equity}"

    def generate_recommendation_actions():
        actions = [
            "Expand market reach",
            "Optimize costs",
            "Invest in marketing",
            "Diversify products",
            "Improve customer service",
            "Increase R&D",
            "Enhance loyalty programs",
            "Cut overhead costs"
        ]
        return random.choice(actions)

    data = {
        'P&L_Statement': [generate_pl_statement() for _ in range(n_samples)],
        'Balance_Sheet': [generate_balance_sheet() for _ in range(n_samples)],
        'Monthly_Revenue_Turnover': [fake.month_name() for _ in range(n_samples)],
        'Revenue_Generated': np.random.randint(50000, 1000000, n_samples),
        'Customer_Turnover_Rate': np.random.choice(['Low', 'Medium', 'High'], n_samples),
        'Growth_Rate (%)': np.round(np.random.uniform(5, 25, n_samples), 2),
        'Average_Monthly_Expenses': np.random.randint(10000, 500000, n_samples),
        'Customer_Acquisition_Cost (₦)': np.random.randint(500, 5000, n_samples),
        'Lifetime_Value_of_Customer (₦)': np.random.randint(5000, 100000, n_samples),
        'Market_Size_Potential': np.random.randint(500, 50000, n_samples),
        'Risk_Assessment': np.random.choice(['Low', 'Medium', 'High'], n_samples),
        'Predicted_Revenue_Growth (%)': np.round(np.random.uniform(5, 30, n_samples), 2),
        'Profitability_Score': np.round(np.random.uniform(50, 100, n_samples), 2),
        'Churn_Rate (%)': np.round(np.random.uniform(5, 40, n_samples), 2),
        'Recommendation_Actions': [generate_recommendation_actions() for _ in range(n_samples)]
    }

    return pd.DataFrame(data)

def improved_preprocessing(df):
    # Enhanced feature engineering
    df['Revenue_Growth_Rate'] = df.groupby('Month_Num')['Revenue_Generated'].pct_change()
    df['Profit_Margin'] = df['Net_Profit'] / df['Revenue_Generated']
    df['Operating_Margin'] = df['Operating_Income'] / df['Revenue_Generated']
    df['Customer_Efficiency'] = df['Revenue_Generated'] / df['Market_Size_Potential']

    # Add cyclical encoding for months
    df['Month_Sin'] = np.sin(2 * np.pi * df['Month_Num'] / 12)
    df['Month_Cos'] = np.cos(2 * np.pi * df['Month_Num'] / 12)

    # Create interaction features
    df['CAC_CLV_Ratio'] = df['Customer_Acquisition_Cost (₦)'] / df['Lifetime_Value_of_Customer (₦)']
    df['Revenue_per_Expense'] = df['Revenue_Generated'] / df['Average_Monthly_Expenses']

    # Add rolling statistics
    df['Rolling_Avg_Revenue'] = df.groupby('Month_Num')['Revenue_Generated'].transform(
        lambda x: x.rolling(3, min_periods=1).mean()
    )
    if 'Month_Num' not in df or df['Revenue_Generated'].isnull().any():
      raise ValueError("Required columns are missing or contain NaNs.")


    return df


class ImprovedBusinessDataProcessor:
    def __init__(self):
        self.feature_scaler = StandardScaler()
        self.label_encoder_turnover = LabelEncoder()
        self.label_encoder_risk = LabelEncoder()
        self.label_encoder_suggestions = LabelEncoder()

    def process_data(self, df):
        df = df.copy()

        # Add month and quarter numbers
        month_map = {month: idx for idx, month in enumerate(pd.date_range('2024-01-01', '2024-12-31', freq='M').strftime('%B'), 1)}
        df['Month_Num'] = df['Monthly_Revenue_Turnover'].map(month_map)
        df['Quarter_Num'] = ((df['Month_Num'] - 1) // 3) + 1

        # Handle categorical variables
        df['Customer_Turnover_Rate'] = self.label_encoder_turnover.fit_transform(df['Customer_Turnover_Rate'])
        df['Risk_Assessment'] = self.label_encoder_risk.fit_transform(df['Risk_Assessment'])
        df['Recommendation_Actions'] = self.label_encoder_suggestions.fit_transform(df['Recommendation_Actions'])

        # Extract financial metrics from P&L Statement
        df['Net_Profit'] = df['P&L_Statement'].apply(lambda x: float(x.split('Net Profit: ')[1]))
        df['Operating_Income'] = df['P&L_Statement'].apply(lambda x: float(x.split('Operating Income: ')[1].split('\n')[0]))

        # Add derived features
        df['Customer_Lifetime_Value_Ratio'] = df['Lifetime_Value_of_Customer (₦)'] / df['Customer_Acquisition_Cost (₦)']
        df['Revenue_per_Market_Size'] = df['Revenue_Generated'] / df['Market_Size_Potential']
        df['Expense_Ratio'] = df['Average_Monthly_Expenses'] / df['Revenue_Generated']

        # Feature columns for scaling
        feature_columns = [
            'Revenue_Generated',
            'Growth_Rate (%)',
            'Average_Monthly_Expenses',
            'Customer_Acquisition_Cost (₦)',
            'Lifetime_Value_of_Customer (₦)',
            'Market_Size_Potential',
            'Customer_Lifetime_Value_Ratio',
            'Revenue_per_Market_Size',
            'Expense_Ratio',
            'Net_Profit',
            'Operating_Income'
        ]

        # Scale features
        X = self.feature_scaler.fit_transform(df[feature_columns])

        return (
            X,
            df['Predicted_Revenue_Growth (%)'].values,
            df['Risk_Assessment'].values,
            df['Profitability_Score'].values,
            df['Churn_Rate (%)'].values / 100,  # Convert to proportion
            df['Recommendation_Actions'].values,
            df['Month_Num'].values,
            df['Quarter_Num'].values
        )

def calculate_metrics(outputs, y_risk_batch, y_sug_batch):
    risk_preds = torch.argmax(outputs['risk'], dim=1)
    sug_preds = torch.argmax(outputs['suggestions'], dim=1)

    risk_acc = (risk_preds == y_risk_batch).float().mean()
    sug_acc = (sug_preds == y_sug_batch).float().mean()

    return risk_acc.item(), sug_acc.item()

def find_learning_rate(model, train_loader, device, start_lr=1e-7, end_lr=1, num_iterations=100):
    """
    Implementation of the learning rate range test with better error handling.
    """
    import copy
    init_state = copy.deepcopy(model.state_dict())

    lrs = []
    losses = []
    best_loss = float('inf')

    # Create optimizer with initial learning rate
    optimizer = torch.optim.AdamW(model.parameters(), lr=start_lr)

    # Calculate multiplication factor
    mult = (end_lr / start_lr) ** (1/num_iterations)

    # Loss functions
    regression_criterion = nn.MSELoss()
    classification_criterion = nn.CrossEntropyLoss()

    model.train()
    for batch_idx, batch in enumerate(train_loader):
        if batch_idx >= num_iterations:
            break

        X_batch, y_revenue_batch, y_risk_batch, y_prof_batch, y_churn_batch, y_sug_batch, month_nums, quarter_nums = [
            b.to(device) for b in batch
        ]

        optimizer.zero_grad()

        try:
            outputs = model(X_batch, month_nums, quarter_nums)

            # Calculate losses
            revenue_loss = regression_criterion(outputs['revenue'].squeeze(), y_revenue_batch)
            risk_loss = classification_criterion(outputs['risk'], y_risk_batch)
            prof_loss = regression_criterion(outputs['profitability'].squeeze(), y_prof_batch)
            churn_loss = regression_criterion(outputs['churn'].squeeze(), y_churn_batch)
            sug_loss = classification_criterion(outputs['suggestions'], y_sug_batch)

            total_loss = revenue_loss + risk_loss + prof_loss + churn_loss + sug_loss

            if not torch.isnan(total_loss) and not torch.isinf(total_loss):
                total_loss.backward()
                optimizer.step()

                lrs.append(optimizer.param_groups[0]['lr'])
                losses.append(total_loss.item())

                # Update learning rate
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= mult

                # Early stopping if loss explodes
                if total_loss.item() > 4 * best_loss:
                    break
                if total_loss.item() < best_loss:
                    best_loss = total_loss.item()
            else:
                break

        except RuntimeError as e:
            print(f"Error during LR finding: {e}")
            break

    # Restore initial model state
    model.load_state_dict(init_state)

    if len(losses) > 1:
        # Find the point of steepest descent
        loss_diff = np.diff(losses)
        optimal_idx = np.argmin(loss_diff)
        optimal_lr = lrs[optimal_idx] if optimal_idx < len(lrs) else lrs[-1]
        return optimal_lr / 10, lrs, losses
    else:
        return 1e-4, lrs, losses

def improved_train_model(model, train_loader, val_loader, device, num_epochs=100):
    """
    Enhanced training function with better optimization and monitoring
    Fixed loss handling to maintain computational graph for backpropagation
    """
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=3e-4,
        epochs=num_epochs, steps_per_epoch=len(train_loader)
    )

    # Loss functions with label smoothing for classification
    regression_criterion = nn.HuberLoss(delta=1.0)  # More robust than MSE
    classification_criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    best_val_loss = float('inf')
    best_model_state = None
    patience = 20
    patience_counter = 0

    def calculate_metrics(outputs, targets):
        revenue_batch, risk_batch, prof_batch, churn_batch, sug_batch = targets

        # Calculate accuracies (detached for metrics)
        risk_acc = (outputs['risk'].argmax(dim=1) == risk_batch).float().mean().item()
        churn_acc = ((outputs['churn'].squeeze(1) > 0.5) == churn_batch).float().mean().item()
        sug_acc = (outputs['suggestions'].argmax(dim=1) == sug_batch).float().mean().item()

        # Calculate losses (keeping computational graph)
        losses = {
            'revenue': regression_criterion(outputs['revenue'].squeeze(1), revenue_batch),
            'risk': classification_criterion(outputs['risk'], risk_batch),
            'profitability': regression_criterion(outputs['profitability'].squeeze(1), prof_batch),
            'churn': F.binary_cross_entropy(outputs['churn'].squeeze(1), churn_batch),
            'suggestions': classification_criterion(outputs['suggestions'], sug_batch)
        }

        # Store loss values for metrics
        loss_values = {k: v.item() for k, v in losses.items()}

        accuracies = {
            'risk': risk_acc,
            'churn': churn_acc,
            'suggestions': sug_acc
        }

        return losses, loss_values, accuracies

    for epoch in range(num_epochs):
        model.train()
        train_losses = []
        train_metrics = {
            'risk_acc': [], 'churn_acc': [], 'suggestions_acc': [],
            'revenue_loss': [], 'risk_loss': [], 'profitability_loss': [],
            'churn_loss': [], 'suggestions_loss': []
        }

        # Training phase
        for batch in train_loader:
            X_batch, y_revenue_batch, y_risk_batch, y_prof_batch, \
            y_churn_batch, y_sug_batch, month_nums, quarter_nums = [b.to(device) for b in batch]

            optimizer.zero_grad()
            outputs = model(X_batch, month_nums, quarter_nums)

            losses, loss_values, accuracies = calculate_metrics(
                outputs,
                (y_revenue_batch, y_risk_batch, y_prof_batch, y_churn_batch, y_sug_batch)
            )

            # Sum losses while maintaining computational graph
            total_loss = sum(losses.values())
            total_loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            # Record metrics using detached values
            train_losses.append(sum(loss_values.values()))
            for k, v in loss_values.items():
                train_metrics[f'{k}_loss'].append(v)
            for k, v in accuracies.items():
                train_metrics[f'{k}_acc'].append(v)

        scheduler.step()

        # Validation phase
        model.eval()
        val_losses = []
        val_metrics = {
            'risk_acc': [], 'churn_acc': [], 'suggestions_acc': [],
            'revenue_loss': [], 'risk_loss': [], 'profitability_loss': [],
            'churn_loss': [], 'suggestions_loss': []
        }

        with torch.no_grad():
            for batch in val_loader:
                X_batch, y_revenue_batch, y_risk_batch, y_prof_batch, \
                y_churn_batch, y_sug_batch, month_nums, quarter_nums = [b.to(device) for b in batch]

                outputs = model(X_batch, month_nums, quarter_nums)

                _, loss_values, accuracies = calculate_metrics(
                    outputs,
                    (y_revenue_batch, y_risk_batch, y_prof_batch, y_churn_batch, y_sug_batch)
                )

                val_losses.append(sum(loss_values.values()))

                # Record metrics
                for k, v in loss_values.items():
                    val_metrics[f'{k}_loss'].append(v)
                for k, v in accuracies.items():
                    val_metrics[f'{k}_acc'].append(v)

        # Calculate average metrics
        avg_train_loss = np.mean(train_losses)
        avg_val_loss = np.mean(val_losses)

        # Print epoch metrics
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("Training Metrics:")
        print(f"Total Loss: {avg_train_loss:.4f}")
        for k, v in train_metrics.items():
            print(f"{k}: {np.mean(v):.4f}")

        print("\nValidation Metrics:")
        print(f"Total Loss: {avg_val_loss:.4f}")
        for k, v in val_metrics.items():
            print(f"{k}: {np.mean(v):.4f}")

        # Model checkpointing
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_state = model.state_dict().copy()
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"\nEarly stopping at epoch {epoch+1}")
            break

    if best_model_state is not None:
        model.load_state_dict(best_model_state)

    return model, {'best_val_loss': best_val_loss}

if __name__ == "__main__":
    set_seed(42)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print("Generating data...")
    df = generate_business_data(80000)

    print("Processing data...")
    processor = ImprovedBusinessDataProcessor()
    try:
        X, y_revenue, y_risk, y_profitability, y_churn, y_suggestions, month_nums, quarter_nums = processor.process_data(df)
    except Exception as e:
        print(f"Error during data processing: {e}")
        exit(1)

    print("Splitting data...")
    X_train, X_val, y_revenue_train, y_revenue_val, \
    y_risk_train, y_risk_val, y_profitability_train, y_profitability_val, \
    y_churn_train, y_churn_val, y_suggestions_train, y_suggestions_val, \
    month_train, month_val, quarter_train, quarter_val = train_test_split(
        X, y_revenue, y_risk, y_profitability, y_churn, y_suggestions, month_nums, quarter_nums, test_size=0.2, random_state=42
    )

    # Create datasets and loaders
    train_dataset = BusinessDataset(
        X_train, y_revenue_train, y_risk_train, y_profitability_train, y_churn_train,
        y_suggestions_train, month_train, quarter_train
    )
    val_dataset = BusinessDataset(
        X_val, y_revenue_val, y_risk_val, y_profitability_val, y_churn_val,
        y_suggestions_val, month_val, quarter_val
    )

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    print("Initializing model...")
    input_dim = X.shape[1]
    num_risk_classes = len(np.unique(y_risk))
    num_suggestion_classes = len(np.unique(y_suggestions))

    model = EnhancedBusinessPredictor(
        input_dim=input_dim,
        num_risk_classes=num_risk_classes,
        num_suggestion_classes=num_suggestion_classes
    ).to(device)

    print("Finding optimal learning rate...")
    optimal_lr, lrs, losses = find_learning_rate(model, train_loader, device)
    print(f"Optimal learning rate found: {optimal_lr}")

    print("Training model...")
    trained_model, metrics = improved_train_model(
        model, train_loader, val_loader, device, num_epochs=100
    )

    print("\nTraining completed!")
    print("Best validation loss:", metrics['best_val_loss'])

    # Save the model
    # Save the model
    model_save_path = "improved_business_predictor.pth"
    torch.save({
        'model_state_dict': trained_model.state_dict(),
        'input_dim': input_dim,
        'num_risk_classes': num_risk_classes,
        'num_suggestion_classes': num_suggestion_classes
    }, model_save_path)
    print(f"Model saved at: {model_save_path}")


In [None]:
print(torch.__version__)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
from typing import Dict, List, Optional, Tuple
import matplotlib.pyplot as plt
import seaborn as sns


class TimeSeriesBusinessPredictor:
    def __init__(self, model_path: str, processor: ImprovedBusinessDataProcessor):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.processor = processor

        # Load the saved model with weights_only=True for security
        checkpoint = torch.load(model_path, map_location=self.device, weights_only=True)

        # Initialize the core prediction model
        self.model = EnhancedBusinessPredictor(
            input_dim=checkpoint['input_dim'],
            num_risk_classes=checkpoint['num_risk_classes'],
            num_suggestion_classes=checkpoint['num_suggestion_classes']
        ).to(self.device)

        # Load the trained weights
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()

        # Store the expected classes for validation
        self.risk_classes = checkpoint.get('risk_classes', ['Low', 'Medium', 'High'])
# Replace the original suggestion_classes with this expanded version
        self.suggestion_classes = [
            # Market Expansion & Growth
            'Expand into new geographic markets through targeted marketing campaigns',
            'Develop partnerships with complementary businesses for market penetration',
            'Launch new product lines based on customer feedback and market research',
            'Establish international presence through e-commerce platforms',
            'Create franchise opportunities for rapid market expansion',

            # Operational Optimization
            'Implement automated inventory management system to reduce costs',
            'Streamline supply chain through strategic supplier partnerships',
            'Adopt lean manufacturing principles to minimize waste',
            'Invest in employee training programs for improved productivity',
            'Upgrade technology infrastructure for better operational efficiency',

            # Customer Experience & Retention
            'Develop personalized customer loyalty program with tiered benefits',
            'Implement AI-powered customer service chatbot for 24/7 support',
            'Create customer feedback loops with regular surveys and focus groups',
            'Establish VIP customer program for high-value clients',
            'Launch customer education programs about product features and benefits',

            # Cost Management
            'Negotiate bulk purchase agreements with suppliers for better rates',
            'Optimize energy usage through smart building management systems',
            'Implement zero-based budgeting for all departments',
            'Outsource non-core business functions to reduce overhead',
            'Invest in preventive maintenance to reduce long-term costs',

            # Marketing & Brand Development
            'Launch targeted social media advertising campaigns',
            'Develop content marketing strategy with industry thought leadership',
            'Create referral program with incentives for existing customers',
            'Implement influencer marketing program in key markets',
            'Enhance brand visibility through community engagement events',

            # Digital Transformation
            'Develop mobile app for improved customer engagement',
            'Implement cloud-based solutions for remote work capability',
            'Create digital payment options for customer convenience',
            'Establish omnichannel presence for seamless customer experience',
            'Implement data analytics for better decision-making',

            # Product Innovation
            'Conduct research and development for product improvements',
            'Create sustainable/eco-friendly product alternatives',
            'Develop subscription-based service models',
            'Launch premium product line for high-end market segment',
            'Create bundled product offerings for increased value',

            # Financial Management
            'Implement dynamic pricing strategy based on market demand',
            'Develop alternative revenue streams through complementary services',
            'Optimize working capital through improved inventory management',
            'Establish strategic partnerships for shared resource utilization',
            'Create financial forecasting models for better planning',

            # Human Resources
            'Implement performance-based incentive programs',
            'Develop career advancement programs for employee retention',
            'Create flexible work arrangements for improved work-life balance',
            'Establish mentorship programs for knowledge transfer',
            'Implement employee wellness programs for improved productivity',

            # Risk Management
            'Develop business continuity plans for various scenarios',
            'Implement cybersecurity measures for data protection',
            'Create disaster recovery protocols for critical systems',
            'Establish quality control processes for consistent delivery',
            'Develop compliance monitoring systems for regulatory requirements'
        ]

        # Initialize prediction history
        self.prediction_history = []

    def _validate_prediction(self, pred_idx: int, valid_classes: List[str]) -> str:
        """Validate prediction index against known classes"""
        if 0 <= pred_idx < len(valid_classes):
            return valid_classes[pred_idx]
        return "Unknown"

    def _calculate_stability(self) -> float:
        """Calculate the stability of predictions over time."""
        if len(self.prediction_history) < 2:
            return 1.0  # If there's only one prediction, it's stable by default

        # Calculate percentage change in revenue and profitability from the previous prediction
        prev_pred = self.prediction_history[-2]
        curr_pred = self.prediction_history[-1]

        revenue_change = (curr_pred['revenue'] - prev_pred['revenue']) / prev_pred['revenue'] if prev_pred['revenue'] != 0 else 0
        profitability_change = (curr_pred['profitability'] - prev_pred['profitability']) / prev_pred['profitability'] if prev_pred['profitability'] != 0 else 0

        # Stability score based on the magnitude of changes (higher score is more stable)
        stability_score = 1 - (abs(revenue_change) + abs(profitability_change)) / 2

        return max(0, stability_score)  # Ensure stability is between 0 and 1


    def _calculate_confidence_metrics(self, outputs: Dict[str, torch.Tensor]) -> Dict[str, float]:
        """Calculate confidence scores for various predictions"""
        risk_probs = F.softmax(outputs['risk'], dim=1)
        action_probs = F.softmax(outputs['suggestions'], dim=1)

        # Use robust standard deviation calculation
        revenue_std = torch.std(outputs['revenue'], unbiased=False).item() if outputs['revenue'].numel() > 1 else 0.0

        return {
            'risk_confidence': risk_probs.max().item(),
            'action_confidence': action_probs.max().item(),
            'revenue_uncertainty': revenue_std,
            'prediction_stability': self._calculate_stability()
        }


    def _update_features(self, current_features: torch.Tensor, current_metrics: Dict[str, float], month: int) -> torch.Tensor:
        """Update features with new monthly metrics and add seasonal variation."""
        # Extract relevant features from current_metrics
        base_features = [
            current_metrics['revenue'],
            current_metrics['growth_rate'],
            current_metrics['expenses'],
            current_metrics['cac'],
            current_metrics['ltv'],
            current_metrics['market_size'],
            current_metrics['ltv'] / current_metrics['cac'],
            current_metrics['revenue'] / current_metrics['market_size'],
            current_metrics['expenses'] / current_metrics['revenue'],
            current_metrics['revenue'] * np.random.normal(0.25, 0.03),  # Variable profit margin
            current_metrics['revenue'] * np.random.normal(0.2, 0.02),   # Variable operating income
        ]

        # Add seasonal variation based on month
        seasonal_factor = 1.0 + 0.1 * np.sin(2 * np.pi * month / 12)  # Creates yearly cycle
        market_noise = np.random.normal(1.0, 0.05)  # Random market fluctuations

        # Apply seasonal and random variations
        new_features = [f * seasonal_factor * market_noise for f in base_features]

        # Scale new features using the same scaler used for training data
        new_features_scaled = self.processor.feature_scaler.transform([new_features])

        # Add random noise to prevent perfect convergence
        noise = torch.randn_like(current_features) * 0.01

        # Replace relevant feature values in the current_features tensor
        X_updated = current_features.clone()
        X_updated[0, :len(new_features_scaled[0])] = torch.tensor(new_features_scaled[0]).to(self.device).type(X_updated.dtype)
        X_updated += noise

        return X_updated

    def _calculate_risk_level(self,
                            outputs: Dict[str, torch.Tensor],
                            confidence: float,
                            threshold: float,
                            current_metrics: Dict[str, float]) -> Tuple[str, float]:
        """Calculate risk level using multiple factors"""
        risk_probs = F.softmax(outputs['risk'], dim=1)
        max_prob = risk_probs.max().item()

        # Calculate additional risk factors
        growth_risk = 1.0 if current_metrics['growth_rate'] < 0 else 0.0
        churn_risk = outputs['churn'].item()
        profitability = outputs['profitability'].item()

        # Combine risk factors
        risk_score = (
            0.4 * max_prob +
            0.2 * growth_risk +
            0.2 * churn_risk +
            0.2 * (1 - profitability/100)
        )

        # Determine risk level
        if confidence < threshold:
            return "Uncertain", risk_score
        elif risk_score < 0.3:
            return "Low", risk_score
        elif risk_score < 0.6:
            return "Medium", risk_score
        else:
            return "High", risk_score

    def _get_unique_recommendations(self, top_actions, num_recommendations, recent_actions):
        """Generate unique recommendations avoiding recent duplicates."""
        unique_recommendations = []
        unique_confidences = []

        for idx, (action_idx, confidence) in enumerate(zip(top_actions.indices[0], top_actions.values[0])):
            action_label = self._validate_prediction(action_idx.item(), self.suggestion_classes)

            # Avoid actions already recommended recently
            if action_label not in recent_actions and action_label not in unique_recommendations:
                unique_recommendations.append(action_label)
                unique_confidences.append(confidence.item())

            # Break loop if we have enough unique recommendations
            if len(unique_recommendations) >= num_recommendations:
                break

            # If not enough unique recommendations, allow some repeats with reduced confidence
            while len(unique_recommendations) < num_recommendations:
                fallback_idx = np.random.choice(top_actions.indices[0].cpu().numpy())
                fallback_label = self._validate_prediction(fallback_idx.item(), self.suggestion_classes)
                if fallback_label not in unique_recommendations:
                    unique_recommendations.append(fallback_label)
                    unique_confidences.append(0.5)  # Assign lower confidence for fallback actions

            return unique_recommendations, unique_confidences

    def predict_monthly_metrics(
        self,
        initial_data: pd.DataFrame,
        num_months: int = 12,
        confidence_threshold: float = 0.7,
        num_recommendations: int = 3  # Number of recommendations to return
    ) -> Tuple[Dict[str, List], List[Dict[str, float]]]:
        """Make predictions with multiple recommendations per period"""
        predictions = {
            'dates': [],
            'revenue': [],
            'revenue_growth': [],
            'risk_levels': [],
            'risk_scores': [],
            'profitability': [],
            'churn_probability': [],
            'recommended_actions': [],  # Now will contain lists of recommendations
            'action_confidences': [],   # Store confidence scores for each recommendation
        }

        confidence_metrics = []
        start_date = pd.Timestamp.now().replace(day=1)
        current_data = initial_data.copy()

        # Process initial data
        features, _, _, _, _, _, month_nums, quarter_nums = self.processor.process_data(current_data)
        X = torch.FloatTensor(features).to(self.device)

        # Initialize current_metrics before the loop.
        current_metrics = {
            'revenue': current_data['Revenue_Generated'].values[0],
            'growth_rate': current_data['Growth_Rate (%)'].values[0],
            'expenses': current_data['Average_Monthly_Expenses'].values[0],
            'cac': current_data['Customer_Acquisition_Cost (₦)'].values[0],
            'ltv': current_data['Lifetime_Value_of_Customer (₦)'].values[0],
            'market_size': current_data['Market_Size_Potential'].values[0]
        }

        for month in range(num_months):
            current_date = start_date + pd.DateOffset(months=month)
            current_month = current_date.month
            current_quarter = (current_month - 1) // 3 + 1

            # Prepare temporal features
            month_tensor = torch.LongTensor([current_month]).to(self.device)
            quarter_tensor = torch.LongTensor([current_quarter]).to(self.device)

            with torch.no_grad():
                # Get model outputs
                outputs = self.model(X, month_tensor, quarter_tensor)
                # print(f"outputs from model{outputs}")

                # Calculate confidence metrics
                confidence = self._calculate_confidence_metrics(outputs)

                # Compute action probabilities
                action_probs = F.softmax(outputs['suggestions'], dim=1)

                # Dynamically adjust k for torch.topk
                num_actions = action_probs.size(1)
                k = min(10, num_actions)

                # Generate recommendations with diversity constraints
                top_actions = torch.topk(action_probs, k, dim=1)  # Adjusted k
                recent_actions = predictions['recommended_actions'][-3:] if len(predictions['recommended_actions']) > 3 else []
                month_recommendations, month_confidences = self._get_unique_recommendations(
                    top_actions, num_recommendations, recent_actions
                )

            # Calculate revenue with seasonal variation
            base_growth = outputs['revenue'].item()
            seasonal_factor = 1.0 + 0.1 * np.sin(2 * np.pi * current_month / 12)
            market_noise = np.random.normal(1.0, 0.02)
            revenue_growth = base_growth * seasonal_factor * market_noise

            current_revenue = (
                current_data['Revenue_Generated'].values[0] if month == 0
                else predictions['revenue'][-1] * (1 + revenue_growth / 100)
            )

            # Calculate risk with improved assessment
            current_metrics = {
                'revenue': current_revenue,
                'growth_rate': revenue_growth,
                'expenses': current_data['Average_Monthly_Expenses'].values[0],
                'cac': current_data['Customer_Acquisition_Cost (₦)'].values[0],
                'ltv': current_data['Lifetime_Value_of_Customer (₦)'].values[0],
                'market_size': current_data['Market_Size_Potential'].values[0]
            }

            risk_label, risk_score = self._calculate_risk_level(
                outputs,
                confidence['risk_confidence'],
                confidence_threshold,
                current_metrics
            )

            # Store predictions with added variability
            predictions['dates'].append(current_date)
            predictions['revenue'].append(current_revenue)
            predictions['revenue_growth'].append(revenue_growth)
            predictions['risk_levels'].append(risk_label)
            predictions['risk_scores'].append(risk_score)
            predictions['profitability'].append(outputs['profitability'].item() * np.random.normal(1.0, 0.05))
            predictions['churn_probability'].append(outputs['churn'].item() * np.random.normal(1.0, 0.03))
            predictions['recommended_actions'].append(month_recommendations)
            predictions['action_confidences'].append(month_confidences)

            confidence_metrics.append(confidence)

            # Update features for next prediction
            X = self._update_features(X, current_metrics, current_month)

        # Return predictions after processing all months
        return predictions, confidence_metrics


if __name__ == "__main__":
    # Create sample current metrics
    current_metrics = pd.DataFrame({
        'P&L_Statement': ['Revenue: 1000000\nCOGS: 600000\nGross Profit: 400000\n' +
                         'Operating Expenses: 200000\nOperating Income: 200000\n' +
                         'Other Income: 50000\nNet Profit: 250000'],
        'Balance_Sheet': ['Current Assets: 800000\nFixed Assets: 2000000\n' +
                         'Total Assets: 2800000\nCurrent Liabilities: 300000\n' +
                         'Long-term Liabilities: 500000\nTotal Liabilities: 800000\n' +
                         'Equity: 2000000'],
        'Monthly_Revenue_Turnover': ['January'],
        'Revenue_Generated': [1000000],
        'Customer_Turnover_Rate': ['Low'],
        'Growth_Rate (%)': [15.0],
        'Average_Monthly_Expenses': [200000],
        'Customer_Acquisition_Cost (₦)': [2000],
        'Lifetime_Value_of_Customer (₦)': [20000],
        'Market_Size_Potential': [10000],
        'Risk_Assessment': ['Low'],
        'Predicted_Revenue_Growth (%)': [20.0],
        'Profitability_Score': [75.0],
        'Churn_Rate (%)': [10.0],
        'Recommendation_Actions': ['Expand market reach']
    })

    # Initialize predictor
    processor = ImprovedBusinessDataProcessor()
    predictor = TimeSeriesBusinessPredictor('improved_business_predictor.pth', processor)

    # Make predictions
    predictions, confidence_metrics = predictor.predict_monthly_metrics(
        initial_data=current_metrics,
        num_months=12,
        confidence_threshold=0.7
    )

    # print(predictions)

    # Print predictions
    print("\nPredictions for the next 12 months:")
    for i in range(len(predictions['dates'])):
        print(f"\nMonth {predictions['dates'][i].strftime('%Y-%m')}:")
        print(f"Revenue Growth: {predictions['revenue_growth'][i]:.2f}%")
        print(f"Predicted Revenue: ₦{predictions['revenue'][i]:,.2f}")
        print(f"Risk Level: {predictions['risk_levels'][i]}")
        print(f"Profitability Score: {predictions['profitability'][i]:.2f}")
        print(f"Churn Probability: {predictions['churn_probability'][i]*100:.2f}%")
        print(f"Recommended Action: {predictions['recommended_actions'][i]}")

    # Visualize predictions
    # predictor.visualize_predictions(predictions, confidence_metrics)