In [1]:
from sympy import Q, ask, oo, I, S
import sympy as sp
import random
from ramanujantools import Position
import multiprocessing
from sympy.abc import n
import csv
import math
import pickle
from sympy.matrices.dense import Matrix
from itertools import product


def save_csv(data: list[tuple[Position, float]], filename: str) -> None:
    if not data:
        raise ValueError("No data to save")

    coord_names = sorted(data[0][0].to_dict().keys())

    with open(filename, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(coord_names + ["delta"])

        for pos, delta in data:
            d = pos.to_dict()
            row = [d[name] for name in coord_names] + [delta]
            writer.writerow(row)


def save_pickle(data, filename: str) -> None:
    with open(filename+'.pkl', "wb") as f:
        pickle.dump(data, f)



def normalized_depth(traj, depth):
    avgStepSize = sum(abs(value) for value in traj.values()) / len(traj)
    # return round(depth / avgStepSize)
    return depth

def delta(estemated, lim):
    error = sp.Abs(estemated - lim)
    denomenator = sp.denom(estemated)
    delta = -1 - sp.log(error) / sp.log(denomenator)

    return (delta).evalf()



def unit_steps(p: int, q: int):
    x_symbols = sp.symbols(f'x0:{p}')  # creates (x0, x1, ..., xp)
    y_symbols = sp.symbols(f'y0:{q}')  # creates (y0, y1, ..., yq)
    coords = list(x_symbols) + list(y_symbols)  # list of Symbols

    steps = []
    for v in coords:
        for s in (1, -1):
            step = {w: (s if w == v else 0) for w in coords}
            steps.append(Position(step))
    return steps


def next_steps(p,q,traj,oldPosList):
    steps = unit_steps(p, q)
    nextTrajList = []

    for step in steps:
        curPos = traj + step
        if (curPos not in oldPosList):
            nextTrajList.append(curPos)
    return nextTrajList


def trajDelta(init, trajectory, ccmf, p, q, expr, trueConst, depth, anchMat):
    x = sp.symbols(f'x:{p+1}')
    y = sp.symbols(f'y:{q+1}')
    c = sp.symbols(f'c:{p+1}')
    depth1 = normalized_depth(trajectory, depth)

    try:
        trajMat = ccmf.trajectory_matrix(trajectory, init)
        if anchMat == 0:
            walk = trajMat.walk({n: 1}, depth, {n: 0}).inv().T
            walk = walk / walk[0, 0]
            col1 =  walk.col(0)
        elif anchMat == Matrix([[1,0],[0,1]]):
            walk = trajMat.walk({n: 1}, depth, {n: 0})
            normilized = Matrix([[0, 1], [-2, 2]]) * walk
            col1 = [0,(normilized / normilized[1, 0]).col(0)[0]]
        else:
            walked = trajMat.walk({n: 1}, depth, {n: 0})
            walked = (walked.inv() * anchMat).T
            walked = walked / walked[0, 0]
            col1 = walked.col(0)
        subs_dict = {
            c[i]: col1[i]
            for i in range(1, min(len(col1), len(c)))  # safe bound
        }
        estemated = (expr).subs(subs_dict)
        lim = trueConst
        deltaCurr = delta(estemated, lim)

        if not(ask(Q.finite(deltaCurr))):
            deltaCurr = -2
    except:
         deltaCurr = -2
    finally:
        print(trajectory)
        print(deltaCurr)
        return {"trajectory": trajectory, "delta": deltaCurr}


def _err(exc):
    import traceback, sys
    traceback.print_exception(type(exc), exc, exc.__traceback__, file=sys.stderr)



def search_around_traj(init, nextTrajList, sCmf, p, q, expr, trueConst, depth,cores, anchMat=0):
  with multiprocessing.Pool(processes=cores) as pool:
    resu = [pool.apply_async(trajDelta, args=(init, pos, sCmf, p, q, expr, trueConst,depth, anchMat),error_callback=_err) for pos in nextTrajList]
    return_list = [res.get() for res in resu]
  return return_list


def update_old_list_neighs(oldPosList, neighs):
    listSize = 14*5
    retList = oldPosList+neighs
    listLen = len(retList)
    if(listLen>listSize):
        retList = retList[listLen-listSize:]
    return retList


def update_old_list(oldPosList, curTraj):
    listSize = 10
    retList = oldPosList+[curTraj]
    listLen = len(retList)
    if(listLen>listSize):
        retList = retList[1:]
    return retList


def get_temp(T0, k, type='linear'):
    if type == 'log':
        return T0/math.log(k + 1)
    elif type == 'linear':
        return T0/(k + 1)



def search_traj_sa(initPoint, traj, iterations, maxRes, ccmf , p, q, expr,
                   trueConst,depth,
                   cores,
                   T0=1,
                   Tmin=1e-3,
                   anchMat=0):
    curTraj = traj
    trajMul = 1
    T = T0
    iterLeft = iterations
    data = []
    curDelta = trajDelta(initPoint, traj, ccmf,p,q, expr, trueConst, depth,anchMat)['delta']
    bestRes = curDelta
    oldPosList = [traj]

    while iterLeft and T > Tmin:

        #curDelta = trajDelta(initPoint, curTraj, ccmf,p,q, expr, trueConst, depth)['delta']

        neighs = next_steps(p, q, curTraj, oldPosList)
        # print(neighs)

        neighsResults = search_around_traj(initPoint, neighs, ccmf, p, q, expr, trueConst, depth, cores,anchMat)
        data.extend(neighsResults)

        accepted = False
        oldPosList = update_old_list_neighs(oldPosList, neighs+[curTraj])
        neighsResults.sort(key=lambda d: d['delta'], reverse=True)
        newDelta = neighsResults[0]['delta']

        if newDelta >= curDelta:
            curTraj  = (neighsResults[0])['trajectory']
            curDelta = newDelta
            accepted = True
            iterLeft -= 1
            print("better delta at: ",trajMul*curTraj)
        else:
            # go to worse delta with probability
            prob = math.exp((newDelta - curDelta) / T )
            print(f"transition probability to {(neighsResults[0])['trajectory']} is {prob}")
            if random.random() < prob:
                curTraj  = (neighsResults[0])['trajectory']
                curDelta = newDelta
                iterLeft -= 1
                accepted = True
                print("jumped to: ",trajMul*curTraj)
            else:
                curTraj=curTraj*2



        if curDelta > bestRes:
            bestRes = curDelta

        # if not accepted:
        #     trajMul *= 2
        #     giveUp = trajMul > maxRes

        print("best delta", bestRes)
        print("cur delta", curDelta)

        T = get_temp(T0, iterations-iterLeft, type='linear')
        print(T)

    return data



def iter_positions_xy(p, q, n):
    xs = sp.symbols(f'x0:{p}') if p > 0 else []
    ys = sp.symbols(f'y0:{q}') if q > 0 else []
    keys = list(xs) + list(ys)

    rng = range(-n, n+1)
    for vals in product(rng, repeat=len(keys)):
        yield Position({p: q for p, q in zip(keys, vals)})




