In [None]:
from __future__ import annotations

from typing import Union
from typing import Any, Callable
import jax
from math import prod

# third party
import flax
from jax import numpy as jnp
import numpy as np
from scipy.optimize import shgo
from syft.core.adp.data_subject import DataSubjectGroup
from syft.core.common.uid import UID
from flax.struct import dataclass

from syft.core.tensor.passthrough import PassthroughTensor, AcceptableSimpleType


class PhiTensor(PassthroughTensor):
    __slots__ = (
        "id"
        "child",
        "min_vals",
        "max_vals",
        "data_subject",
    )

    def reconstruct(self, state):
        return state[self.id]

    def __init__(
        self,
        child: np.ndarray,
        min_vals: np.ndarray,
        max_vals: np.ndarray,
        data_subject: UID = None,
    ) -> None:
        super().__init__(child)
        self.min_vals = min_vals
        self.max_vals = max_vals
        self.data_subject = data_subject if data_subject else UID()
        self.id = str(UID())

    def __add__(self, other: Union[GammaTensor, PhiTensor, AcceptableSimpleType]):
        if isinstance(other, PhiTensor):
            if self.data_subject == other.data_subject:
                return PhiTensor(
                    child = self.child + other.child,
                    min_vals = self.min_vals + other.min_vals,
                    max_vals = self.max_vals + other.max_vals,
                    data_subject=self.data_subject
                )
            else:
                return GammaTensor(
                    child = self.child + other.child,
                    sources = {
                        self.id: self,
                        other.id: other
                    },
                    data_subjects = DataSubjectGroup([self.data_subject, other.data_subject]),
                    func = lambda sources: jnp.add(sources[self.id], sources[other.id]),
                    is_linear=True,
                )

        if isinstance(other, GammaTensor):
            return other.__radd__(self)

        return PhiTensor(
            child=self.child + other,
            min_vals=self.min_vals,
            max_vals=self.max_vals,
            data_subject=self.data_subject
        )


@dataclass
class GammaTensor:
    child: jnp.array
    data_subjects: np.ndarray = flax.struct.field(pytree_node=False)
    func: Callable = flax.struct.field(pytree_node=False)

    is_linear: bool = False
    id: str = flax.struct.field(
        pytree_node=False, default_factory=lambda: UID()
    )
    sources: dict = flax.struct.field(pytree_node=False, default_factory=dict)

    def reconstruct(self, state):
        return self.func(state)
    
    
    def __add__(self, other: Any) -> GammaTensor:
        output_state = self.state.copy()

        if isinstance(other, PhiTensor):
            output_state[other.id] = other

        #merge data subjects and states
        child = self.child + other

        return GammaTensor(
            child=child,
            data_subjects=None,
            func=lambda state: jnp.add(self.reconstruct(state), other.reconstruct(state)),
            sources=output_state,
            is_linear=True
        )

    def __radd__(self, other):
        return self.__add__(other)


    def lipschitz_bound(self):
        def convert_array_to_dict_state(array_state, input_sizes):
            start_id = 0
            state = {}

            for id, shape in input_sizes.items():
                total_size = prod(shape)
                state[id] = np.reshape(array_state[start_id:start_id + total_size], shape)
                start_id += total_size

            return state


        def convert_state_to_bounds(input_sizes, input_states):
            bounds = []
            for id in input_sizes:
                bounds.extend(list(zip(input_states[id].min_vals.flatten(), input_states[id].max_vals.flatten())))
            return bounds

        grad_fn = jax.grad(jax.jit(lambda state: jnp.sum(self.func(state))))

        input_sizes = {tensor.id: tensor.shape for tensor in self.sources.values()}
        bounds = convert_state_to_bounds(input_sizes, self.sources)
        i = 0

        def search(array_state):
            nonlocal i
            print(i)
            i += 1
            dict_state = convert_array_to_dict_state(array_state, input_sizes)
            grads = grad_fn(dict_state)
            return -jnp.max(jnp.array(list(grads.values())))

        return -shgo(search, bounds=bounds, sampling_method="simplicial").fun

N = list(range(20))
phi_tensor_1 = PhiTensor(jnp.array(N), min_vals=jnp.array([1 for _ in N]), max_vals=jnp.array([1000 for _ in N]))
phi_tensor_2 = PhiTensor(jnp.array(N), min_vals=jnp.array([1 for _ in N]), max_vals=jnp.array([1000 for _ in N]))

gamma1 = phi_tensor_1 + phi_tensor_2
print(gamma1.lipschitz_bound())
