In [None]:
# Load model from pickle file
import random
import numpy as np
import matplotlib.pyplot as plt
import pickle
import copy
import sys
from micrograd.engine import Value
from micrograd.nn import Neuron, Layer, MLP
import cvxpy as cp

np.random.seed(1337)
random.seed(1337)


with open('model.pkl', 'rb') as f:
    loaded_model = pickle.load(f)
    
def computebounds(custom_x0, custom_x1, eps, model):
    input_with_bounds = [Value(custom_x0), Value(custom_x1)]
    input_with_bounds[0].lower = input_with_bounds[0].data - eps
    input_with_bounds[0].upper = input_with_bounds[0].data + eps
    input_with_bounds[1].lower = input_with_bounds[1].data - eps
    input_with_bounds[1].upper = input_with_bounds[1].data + eps

    score = model(input_with_bounds)
    score.ibp()
    return score

#new dic for if already split on / or relu

# Collect all ReLU nodes in the computation graph of the last score
def collect_relu_nodes(output_node):
    relu_nodes = []
    visited = set()
    def traverse(v):
        if v not in visited:
            visited.add(v)
            if v._op == 'ReLU':
                # Only add if lower and upper bounds have a sign change
                relu_input = list(v._prev)[0]
                lower = relu_input.lower
                upper = relu_input.upper
                if lower is not None and upper is not None and lower * upper < 0:
                    relu_nodes.append(v)
            for child in getattr(v, '_prev', []):
                traverse(child)
    traverse(output_node)
    return relu_nodes

  
#score = computebounds(2, 0, 0.1, loaded_model)     
#relu_nodes = collect_relu_nodes(score) - debugging


def branch_and_bound(score):
    
    # Collect ReLU nodes
    relu_nodes = collect_relu_nodes(score)
    
    # No ReLU nodes with sign change in bounds found
    if not relu_nodes:
        return score.lower, score.upper
    
    # Pick a ReLU node at random to branch on
    #chosen_relu = random.choice(relu_nodes)
    chosen_relu = relu_nodes[7]  # For deterministic behavior, use the first one
    relu_input = list(chosen_relu._prev)[0]
    
    
    # Branch 1: ReLU input >= 0
    score_branch1 = copy.deepcopy(score)
    relu_input_pos = find_corresponding_node(score_branch1, relu_input)
    relu_input_pos.lower = 0
    relu_input_pos.chosen_relu = True  # Mark as chosen for this branch
    score_branch1.ibp()
    
    # Check if the bounds are valid for the first branch via Planet relaxation
    check1_l, check1_u = planet_relaxation(score_branch1)
    #print(f"check1_l: {check1_l}, check1_u: {check1_u}")  # debugging
    if check1_l >=0:
        print(f"check1_l: {check1_l}, check1_u: {check1_u}")
        return score_branch1.lower, score_branch1.upper
    elif check1_u < 0:
        raise ValueError("Relaxation bounds are not valid.") 
    elif check1_l == 'inf' or check1_u == '-inf':
        return 'inf', '-inf'
    
    bounds1 = branch_and_bound(score_branch1)

#inf and -inf for planet infesibable


    # Branch 2: ReLU input <= 0
    score_branch2 = copy.deepcopy(score)
    relu_input_neg = find_corresponding_node(score_branch2, relu_input)
    relu_input_neg.upper = 0
    relu_input_neg.chosen_relu = True  # Mark as chosen for this branch
    score_branch2.ibp()
    
    # Check if the bounds are valid for the second branch via Planet relaxation
    check2_l, check2_u = planet_relaxation(score_branch2)
    if check2_l >= 0:
        return score_branch2.lower, score_branch2.upper
    elif check2_u < 0:
        raise ValueError("Relaxation bounds are not valid.")
    elif check2_l == 'inf' or check2_u == '-inf':
        return 'inf', '-inf'
    
    
    bounds2 = branch_and_bound(score_branch2)

    # Return global bounds
    print (f"bounds1: {bounds1}, bounds2: {bounds2}")  # debugging
    return min(bounds1[0], bounds2[0]), max(bounds1[1], bounds2[1])
   
def find_corresponding_node(new_score, old_node):
    visited = set()
    stack = [new_score]
    while stack:
        v = stack.pop()
        if v.id == old_node.id:
            return v
        visited.add(v)
        for child in v._prev:
            if child not in visited:
                stack.append(child)
    raise ValueError("Corresponding node not found")


def planet_relaxation(output: Value):
    env = {}  # maps Value nodes to cp.Variable or float
    constraints = []

    # Traverse in topological order
    for v in output.compute_graph():
        if len(v._prev) == 0:
            # Input node
            if (v.input):
                #alternavit to lower != data 
                #(v.lower == -0.1 and v.upper == 0.1) or (v.lower == 0.4 and v.upper == 0.6)
                var = cp.Variable()
                env[v] = var
                constraints += [
                    var >= v.lower,
                    var <= v.upper,
                ]
            else:
                # Constant/weight node
                #print(f"Assigning constant: {v}, value type: {type(v.data)}") #debugging
                env[v] = v.data
        else:
            # Operation node
            if v._op == "+":
                a, b = [env[p] for p in v._prev]
                var = cp.Variable()
                constraints.append(var == a + b)
                env[v] = var
            elif v._op == "*":
                # For PLANET, only allow multiplication by constant (affine layers)
                a, b = [env[p] for p in v._prev]
                #print(f"Multiplying types: {type(a)}, {type(b)}") #debugging
                
                if isinstance(a, (int, float)):
                    var = cp.Variable()
                    constraints.append(var == a * b)
                    env[v] = var
                elif isinstance(b, (int, float)):
                    var = cp.Variable()
                    constraints.append(var == b * a)
                    env[v] = var
                else:
                    #if var * var
                    raise NotImplementedError("PLANET relaxation only supports multiplication by constants.")
            
            elif v._op == "ReLU":
                inp = [env[p] for p in v._prev][0]
                var = cp.Variable()
                # Get input bounds for relaxation
                input_node = list(v._prev)[0]
                l = input_node.lower
                u = input_node.upper
                # Standard PLANET ReLU relaxation
                constraints += [
                    var >= 0,
                    var >= inp,
                    var <= (u / (u - l)) * (inp - l) if u > l else var <= 0,
                    var <= u if u > 0 else var <= 0,
                ]
                env[v] = var
            else:
                raise NotImplementedError(f"Operation {v.op} not supported in PLANET relaxation.")

    prob_lower = cp.Problem(cp.Minimize(env[output]), constraints)
    result_lower = prob_lower.solve()
    
    prob_upper = cp.Problem(cp.Maximize(env[output]), constraints)
    result_upper = prob_upper.solve()
    
    return result_lower , result_upper



score = computebounds(0, 0.5, 0.1, loaded_model)
global_lower, global_upper = branch_and_bound(score)
print(f"Global lower bound: {global_lower}, Global upper bound: {global_upper}")
