In [4]:
import numpy as np
import pandas as pd
from typing import List, Dict, Tuple, Optional, Any, Set
from collections import defaultdict, deque
import json
import math
from scipy.special import logsumexp
import warnings
from datetime import datetime
from zoneinfo import ZoneInfo

warnings.filterwarnings('ignore')

# Load new price datasets
wholesale_df = pd.read_csv('wholesale historical data.csv')
market_df = pd.read_csv('vegetable_prices_pruned_features.csv')

# PRICE LOOKUP FUNCTION (updated for real datasets)
def get_real_price(veg_name, month=None, year=None):
    if month is None or year is None:
        now = datetime.now(ZoneInfo("Asia/Colombo"))
        month = now.month
        year = now.year
    
    veg_name = veg_name.split(',')[0].upper() if ',' in veg_name else veg_name.upper()
    
    # Try Dambulla wholesale first (primary source)
    wholesale_match = wholesale_df[
        (wholesale_df['Vegetable_Name'].str.upper() == veg_name) &
        (wholesale_df['Month'] == month) & (wholesale_df['ISO_Year'] == year)
    ]
    if not wholesale_match.empty:
        return wholesale_match['Avg_Weekly_Price'].mean()
    
    # Fallback to Colombo market prices
    market_match = market_df[
        (market_df['Vegetable'].str.upper() == veg_name) &
        (market_df['Month'] == month) & (market_df['Year'] == year)
    ]
    if not market_match.empty:
        return market_match['Weekly_Price'].mean()
    
    # Final fallback to CotD base (approx daily)
    return 905.0 / 7

# ============================================================================
# 1. PROFESSIONAL SPN IMPLEMENTATION WITH ACADEMIC CORRECTNESS
# ============================================================================
class SPNNode:
    """Base class for all SPN nodes with proper scope management"""
    def __init__(self, scope: List[int]):
        self.scope = sorted(scope)
        self.value = 0.0
        self.gradient = 0.0
        self.parents: List[SPNNode] = []
        self.id = id(self)
       
    def forward(self, X: np.ndarray, log_space: bool = True) -> float:
        raise NotImplementedError
       
    def backward(self, grad: float):
        raise NotImplementedError
       
    def sample(self, evidence: Optional[Dict[int, float]] = None) -> np.ndarray:
        raise NotImplementedError
       
    def marginals(self, X: np.ndarray) -> np.ndarray:
        raise NotImplementedError
       
    def mpe(self, evidence: Dict[int, float]) -> np.ndarray:
        raise NotImplementedError

class LeafNode(SPNNode):
    """Leaf node with proper EM parameter learning"""
    def __init__(self, scope: List[int], dist_type: str = 'gaussian'):
        assert len(scope) == 1, "Leaf nodes must have exactly 1 variable in scope"
        super().__init__(scope)
        self.dist_type = dist_type
        self.mean = 0.0
        self.std = 1.0
        self.var_idx = scope[0]
       
        self.sum_resp = 0.0
        self.sum_x = 0.0
        self.sum_x2 = 0.0
       
    def forward(self, X: np.ndarray, log_space: bool = True) -> float:
        x_val = X[self.var_idx]
        if np.isnan(x_val):
            return 0.0
       
        if self.dist_type == 'gaussian':
            z = (x_val - self.mean) / max(self.std, 1e-8)
            log_pdf = -0.5 * (z ** 2 + np.log(2 * np.pi * self.std ** 2 + 1e-8))
            return log_pdf if log_space else np.exp(log_pdf)
        else:
            return 0.0
           
    def collect_statistics(self, X: np.ndarray, responsibility: float):
        x_val = X[self.var_idx]
        if not np.isnan(x_val):
            self.sum_resp += responsibility
            self.sum_x += responsibility * x_val
            self.sum_x2 += responsibility * x_val * x_val
           
    def update_parameters(self):
        if self.sum_resp > 1e-8 and self.dist_type == 'gaussian':
            self.mean = self.sum_x / self.sum_resp
            var = max(self.sum_x2 / self.sum_resp - self.mean ** 2, 1e-8)
            self.std = np.sqrt(var)
        self.sum_resp = self.sum_x = self.sum_x2 = 0.0
       
    def sample(self, evidence: Optional[Dict[int, float]] = None) -> np.ndarray:
        if evidence and self.var_idx in evidence:
            return np.array([evidence[self.var_idx]])
        return np.random.normal(self.mean, self.std, 1)
           
    def mpe(self, evidence: Dict[int, float]) -> np.ndarray:
        return np.array([self.mean])

class ProductNode(SPNNode):
    def __init__(self, children: List[SPNNode]):
        all_scopes = []
        for child in children:
            all_scopes.extend(child.scope)
            child.parents.append(self)
           
        if len(all_scopes) != len(set(all_scopes)):
            raise ValueError("Product node children must have DISJOINT scopes!")
           
        super().__init__(sorted(set(all_scopes)))
        self.children = children
       
    def forward(self, X: np.ndarray, log_space: bool = True) -> float:
        log_prob = 0.0
        for child in self.children:
            log_prob += child.forward(X, log_space)
        return log_prob
       
    def sample(self, evidence=None) -> np.ndarray:
        sample_dict = {}
        for child in self.children:
            child_sample = child.sample(evidence)
            for i, var_idx in enumerate(child.scope):
                if var_idx not in sample_dict:
                    sample_dict[var_idx] = child_sample[i] if len(child_sample) > i else child_sample
        result = np.zeros(len(self.scope))
        for i, var_idx in enumerate(self.scope):
            result[i] = sample_dict.get(var_idx, np.nan)
        return result
       
    def mpe(self, evidence) -> np.ndarray:
        sample_dict = {}
        for child in self.children:
            child_mpe = child.mpe(evidence)
            for i, var_idx in enumerate(child.scope):
                if var_idx not in sample_dict:
                    sample_dict[var_idx] = child_mpe[i] if len(child_mpe) > i else child_mpe
        result = np.zeros(len(self.scope))
        for i, var_idx in enumerate(self.scope):
            result[i] = sample_dict.get(var_idx, np.nan)
        return result

class SumNode(SPNNode):
    def __init__(self, children: List[SPNNode], weights: Optional[np.ndarray] = None):
        first_scope = children[0].scope
        for child in children[1:]:
            if child.scope != first_scope:
                raise ValueError("Sum node children must have SAME scope!")
               
        super().__init__(first_scope)
        self.children = children
        self.n_children = len(children)
       
        if weights is None:
            self.weights = np.ones(self.n_children) / self.n_children
        else:
            self.weights = self._normalize_weights(weights)
           
        self.child_log_probs = np.zeros(self.n_children)
        self.responsibilities = np.zeros(self.n_children)
       
        for child in children:
            child.parents.append(self)
           
    def _normalize_weights(self, weights):
        weights = np.maximum(weights, 0)
        total = weights.sum()
        return weights / total if total > 1e-10 else np.ones_like(weights) / len(weights)
       
    def forward(self, X: np.ndarray, log_space: bool = True) -> float:
        for i, child in enumerate(self.children):
            self.child_log_probs[i] = child.forward(X, True)
           
        max_log = np.max(self.child_log_probs)
        shifted = self.child_log_probs - max_log
        exp_probs = np.exp(shifted)
        weighted = np.dot(self.weights, exp_probs)
        log_weighted = np.log(weighted + 1e-8) + max_log
        return log_weighted if log_space else np.exp(log_weighted)
       
    def compute_responsibilities(self):
        max_log = np.max(self.child_log_probs)
        shifted = self.child_log_probs - max_log
        exp_probs = np.exp(shifted)
        weighted = self.weights * exp_probs
        total = weighted.sum() + 1e-8
        self.responsibilities = weighted / total
        return self.responsibilities
       
    def backward(self, grad: float):
        responsibilities = self.compute_responsibilities()
        self.weights = self._normalize_weights(responsibilities)
        for i, child in enumerate(self.children):
            child.backward(grad * responsibilities[i])
           
    def collect_statistics(self, X: np.ndarray, parent_resp: float):
        responsibilities = self.compute_responsibilities() * parent_resp
        for i, child in enumerate(self.children):
            if hasattr(child, 'collect_statistics'):
                child.collect_statistics(X, responsibilities[i])
            elif hasattr(child, 'children'):
                child.collect_statistics(X, responsibilities[i])
               
    def sample(self, evidence=None) -> np.ndarray:
        safe_weights = self._normalize_weights(self.weights)
        child_idx = np.random.choice(self.n_children, p=safe_weights)
        return self.children[child_idx].sample(evidence)
       
    def mpe(self, evidence) -> np.ndarray:
        child_idx = np.argmax(self.weights)
        return self.children[child_idx].mpe(evidence)

# ============================================================================
# 4. COMPLETE NUTRITION SPN WITH ALL INFERENCE CAPABILITIES
# ============================================================================
class NutritionSPN:
    def __init__(self, n_nutrients: int = 6):
        self.n_nutrients = n_nutrients
        self.root = None
        self.nutrient_names = ['Energy (kcal)', 'Protein (g)', 'Fiber (g)',
                              'Iron (mg)', 'Potassium (mg)', 'Vitamin C (mg)']
        self.node_count = 0
        self.leaf_distributions = {}
       
    def build_handcrafted(self, X: np.ndarray) -> SumNode:
        print("Building handcrafted SPN for nutrition...")
        leaves = []
        for i in range(self.n_nutrients):
            leaf = LeafNode([i], 'gaussian')
            data = X[:, i]
            valid = data[~np.isnan(data)]
            leaf.mean = np.mean(valid) if len(valid) > 0 else 0.0
            leaf.std = max(np.std(valid), 1e-8) if len(valid) > 0 else 1.0
            leaves.append(leaf)
            self.leaf_distributions[i] = {'mean': leaf.mean, 'std': leaf.std}
       
        product1 = ProductNode(leaves.copy())
        product2 = ProductNode([
            ProductNode([leaves[0], leaves[1]]),
            ProductNode([leaves[2], leaves[3]]),
            ProductNode([leaves[4], leaves[5]])
        ]) if self.n_nutrients >= 6 else product1
       
        root = SumNode([product1, product2], weights=np.array([0.6, 0.4]))
        self.root = root
        self._count_nodes()
        print(f"Built SPN with {self.node_count} nodes")
        return root
       
    def learn_from_data(self, X: np.ndarray, **kwargs) -> SumNode:
        print("Learning SPN structure...")
        self.root = SPNBuilder.learn_spn(X, **kwargs)
        self._count_nodes()
        print(f"Learned SPN has {self.node_count} nodes")
       
        trainer = SPNTrainer(self.root)
        self.root = trainer.train(X, epochs=10, verbose=True)
        self._collect_leaf_distributions()
        return self.root
       
    def _count_nodes(self):
        self.node_count = 0
        if self.root:
            self._recursive_count(self.root)
           
    def _recursive_count(self, node):
        self.node_count += 1
        if hasattr(node, 'children'):
            for child in node.children:
                self._recursive_count(child)
               
    def _collect_leaf_distributions(self):
        if not self.root:
            return
        stack = [self.root]
        while stack:
            node = stack.pop()
            if isinstance(node, LeafNode):
                self.leaf_distributions[node.var_idx] = {'mean': node.mean, 'std': node.std}
            if hasattr(node, 'children'):
                stack.extend(node.children)
   
    def log_likelihood(self, X: np.ndarray) -> float:
        if not self.root:
            raise ValueError("SPN not trained")
        return self.root.forward(X, log_space=True)
       
    def probability(self, X: np.ndarray) -> float:
        return np.exp(self.log_likelihood(X))
       
    def sample(self, n_samples: int = 1, evidence=None) -> np.ndarray:
        if not self.root:
            raise ValueError("SPN not trained")
        samples = []
        for _ in range(n_samples):
            samples.append(self.root.sample(evidence))
        return np.array(samples)

# ============================================================================
# 5. VEGETABLE RECOMMENDER WITH PROFESSIONAL SPN + REAL PRICES
# ============================================================================
class ProfessionalVegetableRecommender:
    def __init__(self):
        self.spn = None
        self.vegetable_data = None
        self.nutrient_names = ['Energy (kcal)', 'Protein (g)', 'Fiber (g)',
                              'Iron (mg)', 'Potassium (mg)', 'Vitamin C (mg)']
       
    def load_and_prepare_data(self, veg_csv_path: str) -> np.ndarray:
        veggies = pd.read_csv(veg_csv_path)
        nutrient_cols = ['Energ_Kcal', 'Protein_(g)', 'Fiber_TD_(g)',
                        'Iron_(mg)', 'Potassium_(mg)', 'Vit_C_(mg)']
        available = [c for c in nutrient_cols if c in veggies.columns]
        
        veg_clean = veggies[['Shrt_Desc'] + available].copy()
        for col in available:
            veg_clean[col] = pd.to_numeric(veg_clean[col], errors='coerce').fillna(0)
            
        self.vegetable_data = veg_clean
        
        nutrient_values = veg_clean[available].values
        means = np.nanmean(nutrient_values, axis=0)
        stds = np.nanstd(nutrient_values, axis=0) + 1e-8
        normalized = (nutrient_values - means) / stds
        
        print(f"Loaded {len(veg_clean)} vegetables with {len(available)} nutrients")
        return normalized
       
    def train_spn(self, X: np.ndarray, method: str = 'handcrafted'):
        self.spn = NutritionSPN(n_nutrients=X.shape[1])
        
        if method == 'handcrafted':
            print("Training handcrafted SPN...")
            self.spn.build_handcrafted(X)
        else:
            print("Training learned SPN...")
            self.spn.learn_from_data(X)
       
    def recommend(self, user_profile: Dict[str, float], n_recommendations: int = 5):
        if self.spn is None or self.vegetable_data is None:
            raise ValueError("SPN not trained or data not loaded")
       
        evidence = {}
        for nutrient, value in user_profile.items():
            if nutrient in self.nutrient_names:
                idx = self.nutrient_names.index(nutrient)
                evidence[idx] = value
       
        nutrient_values = self.vegetable_data.iloc[:, 1:].values
        veg_names = self.vegetable_data['Shrt_Desc'].values
       
        recommendations = []
        for i, veg_nutrients in enumerate(nutrient_values):
            x = veg_nutrients.copy()
            for var_idx, val in evidence.items():
                x[var_idx] = val
               
            log_prob = self.spn.log_likelihood(x)
            prob = np.exp(log_prob)
            
            # Get real price
            veg_name = veg_names[i]
            real_price = get_real_price(veg_name, datetime.now().month, datetime.now().year)
            
            recommendations.append({
                'vegetable': veg_name,
                'probability': float(prob),
                'estimated_price_per_unit': float(real_price),
                'nutrients': {self.nutrient_names[j]: float(veg_nutrients[j]) for j in range(len(veg_nutrients))}
            })
       
        recommendations.sort(key=lambda x: x['probability'], reverse=True)
        return recommendations[:n_recommendations]
       
    def explain_recommendation(self, vegetable_name: str) -> Dict:
        if self.spn is None or self.vegetable_data is None:
            raise ValueError("SPN not trained or data not loaded")
       
        veg_row = self.vegetable_data[self.vegetable_data['Shrt_Desc'] == vegetable_name]
        if veg_row.empty:
            return {"error": "Vegetable not found"}
       
        nutrients = veg_row.iloc[0, 1:].values
        log_prob = self.spn.log_likelihood(nutrients)
        prob = np.exp(log_prob)
        
        real_price = get_real_price(vegetable_name, datetime.now().month, datetime.now().year)
        
        explanation = {
            'vegetable': vegetable_name,
            'log_likelihood': float(log_prob),
            'probability': float(prob),
            'estimated_price_per_unit': float(real_price),
            'spn_confidence': float(prob)
        }
       
        return explanation

# ============================================================================
# DEMONSTRATION & EXECUTION
# ============================================================================
if __name__ == "__main__":
    print("\n" + "="*70)
    print("PROFESSIONAL SPN IMPLEMENTATION WITH REAL PRICE INTEGRATION")
    print("Using Dambulla wholesale & Colombo market historical data")
    print("="*70 + "\n")
   
    recommender = ProfessionalVegetableRecommender()
   
    # Load and train
    print("1. LOADING AND PREPARING DATA")
    X = recommender.load_and_prepare_data('vegetables_USDA.csv')
    recommender.train_spn(X, method='handcrafted')
   
    # Example user profile
    print("\n2. GENERATING RECOMMENDATIONS")
    user_profile = {
        'Energy (kcal)': 80.0,
        'Protein (g)': 2.5,
        'Fiber (g)': 3.5,
        'Iron (mg)': 1.8,
        'Potassium (mg)': 350.0,
        'Vitamin C (mg)': 40.0
    }
   
    recommendations = recommender.recommend(user_profile, n_recommendations=3)
   
    print(f"User Profile: {user_profile}")
    print("\nTop Recommendations (with real prices):")
    for i, rec in enumerate(recommendations, 1):
        print(f"{i}. {rec['vegetable']}")
        print(f"   Probability: {rec['probability']:.4f}")
        print(f"   Est. Price per unit: LKR {rec['estimated_price_per_unit']:.2f}")
   
    # Explain top recommendation
    if recommendations:
        print("\n3. SPN EXPLANATION FOR TOP RECOMMENDATION")
        explanation = recommender.explain_recommendation(recommendations[0]['vegetable'])
        print(f"Vegetable: {explanation['vegetable']}")
        print(f"SPN Probability: {explanation['probability']:.4f}")
        print(f"Estimated Price: LKR {explanation['estimated_price_per_unit']:.2f}")
        print(f"SPN Confidence: {explanation['spn_confidence']:.4f}")
   
    print("\n" + "="*70)
    print("SPN DEMO COMPLETE – REAL PRICES INTEGRATED")
    print("="*70)


PROFESSIONAL SPN IMPLEMENTATION WITH REAL PRICE INTEGRATION
Using Dambulla wholesale & Colombo market historical data

1. LOADING AND PREPARING DATA
Loaded 166 vegetables with 6 nutrients
Training handcrafted SPN...
Building handcrafted SPN for nutrition...
Built SPN with 18 nodes

2. GENERATING RECOMMENDATIONS
User Profile: {'Energy (kcal)': 80.0, 'Protein (g)': 2.5, 'Fiber (g)': 3.5, 'Iron (mg)': 1.8, 'Potassium (mg)': 350.0, 'Vitamin C (mg)': 40.0}

Top Recommendations (with real prices):
1. ACEROLA JUICE,RAW
   Probability: 0.0000
   Est. Price per unit: LKR 129.29
2. ROSELLE,RAW
   Probability: 0.0000
   Est. Price per unit: LKR 129.29
3. AMARANTH LEAVES,RAW
   Probability: 0.0000
   Est. Price per unit: LKR 129.29

3. SPN EXPLANATION FOR TOP RECOMMENDATION
Vegetable: ACEROLA JUICE,RAW
SPN Probability: 0.0000
Estimated Price: LKR 129.29
SPN Confidence: 0.0000

SPN DEMO COMPLETE – REAL PRICES INTEGRATED
