In [None]:
# !Convert*stats.py*w*sh*

In [None]:
import time
import statistics
import gc

import numpy

from memory_profiler import memory_usage

In [None]:
def relative_error(u, v):
    return numpy.linalg.norm(u - v) / numpy.linalg.norm(u)

In [None]:
general_merge_config = {
    "prob": "first",
    "title": "first",
    "name": "first",
    "size": "list",
    "time": "mean+-stdev~list",
    "setup": "mean+-stdev",
    "solve": "mean+-stdev",
    "memory": "mean+-stdev",
    "vars": "list",
    "iters": "list",
    "loss": "mean+-stdev",
    "check": "mean+-stdev~list",
    "neg": "mean+-stdev~list",
    "error_mu": "mean+-stdev~list",
    "error_nu": "mean+-stdev~list",
    "error_sx": "mean+-stdev",
    "error_objx": "mean+-stdev~list",
}

In [None]:
general_output_config = {
    "prob": ["Problem", "{}"],
    "title": ["Title", "{}"],
    "name": ["Function name", "{}"],
    "size": ["Problem size", "{}"],
    "time": ["Time", "{0[0]:.5f}+-{0[1]:.5f}~{0[2]}"],
    "setup": ["Setup time", "{0[0]:.5f}+-{0[1]:.5f}"],
    "solve": ["Solve time", "{0[0]:.5f}+-{0[1]:.5f}"],
    "memory": ["Memory usage", "{0[0]:.5f}+-{0[1]:.5f}"],
    "vars": ["Variables", "{}"],
    "iters": ["Average iterations", "{}"],
    "loss": ["Loss", "{0[0]:.7e}+-{0[1]:.7e}"],
    "check": ["Check loss", "{0[0]:.7e}+-{0[1]:.7e}~{0[2]}"],
    "neg": ["Negative part of solution", "{0[0]:.7e}+-{0[1]:.7e}~{0[2]}"],
    "error_mu": ["Normalized error of mu", "{0[0]:.7e}+-{0[1]:.7e}~{0[2]}"],
    "error_nu": ["Normalized error of nu", "{0[0]:.7e}+-{0[1]:.7e}~{0[2]}"],
    "error_sx": ["Relative error to known solution", "{0[0]:.7e}+-{0[1]:.7e}"],
    "error_objx": ["Relative error to known Wasserstein distance", "{0[0]:.7e}+-{0[1]:.7e}~{0[2]}"],
}

In [None]:
def merge_stats(stats, config):
    d = {}
    for k, m in config.items():
        if k not in stats[0].keys():
            continue
        if m == "mean":
            d[k] = statistics.mean(s[k] for s in stats)
        elif m == "stdev":
            d[k] = statistics.stdev(s[k] for s in stats)
        elif m == "mean+-stdev":
            if len(stats) == 1:
                d[k] = [statistics.mean(s[k] for s in stats), 0.]
            else:
                d[k] = [statistics.mean(s[k] for s in stats), statistics.stdev(s[k] for s in stats)]
        elif m == "mean+-stdev~list":
            if len(stats) == 1:
                d[k] = [statistics.mean(s[k] for s in stats), 0., [s[k] for s in stats]]
            else:
                d[k] = [statistics.mean(s[k] for s in stats), statistics.stdev(s[k] for s in stats), [s[k] for s in stats]]
        elif m == "first":
            d[k] = stats[0][k]
        elif m == "list":
            d[k] = [s[k] for s in stats]
        elif m == "set":
            d[k] = {s[k] for s in stats}
    return d

In [None]:
def format_output(res, config, log):
    for k, v in config.items():
        if k in res:
            n = v[0]
            rp = v[1].format(res[k])
            log("{0}: {1}".format(n, rp))

In [None]:
class Statistics(object):
    def __init__(
        self,
        probs=None,
        merge_config=general_merge_config,
        output_config=general_output_config,
        prob="",
        log=print,
    ):
        self.len = len(probs)
        self.probs = probs
        self.merge_config = merge_config
        self.output_config = output_config
        self.prob = prob
        self.log = log
        self.stats = []
        self.ress = []
    
    def set_sx(self, func, prog=False, *args, **kwargs):
        for i in range(self.len):
            if prog:
                self.log("Setting {0}/{1}".format(i, self.len))
            func(self.probs[i], *args, **kwargs)
            self.probs[i].set_sx()
            self.probs[i].clean()
    
    def set_objx(self, func, prog=False, *args, **kwargs):
        for i in range(self.len):
            if prog:
                self.log("Setting {0}/{1}".format(i, self.len))
            func(self.probs[i], *args, **kwargs)
            self.probs[i].set_objx()
            self.probs[i].clean()
    
    def test_piece(self, prob, func, title="", memory=False, clean=True, *args, **kwargs):
        m, n = prob.c.shape
        
        stat = None
        
        def run():
            nonlocal prob, stat
            prob, stat = func(prob, stat=True, *args, **kwargs)
        
        gc.collect()
        
        if memory:
            start_mem_list = memory_usage()
            start_mem = max(start_mem_list)
        
        start_time = time.time()
        
        if memory:
            peak_mem_list = memory_usage(run)
            peak_mem = max(peak_mem_list)
        else:
            run()
        
        end_time = time.time()
        elapsed_time = end_time - start_time
        
        check_loss = numpy.sum(prob.s * prob.c)
        neg = numpy.minimum(prob.s, 0.).sum()
        error_mu = numpy.linalg.norm(prob.s.sum(axis=1) - prob.mu, 1)
        error_nu = numpy.linalg.norm(prob.s.sum(axis=0) - prob.nu, 1)
        
        stat["prob"] = self.prob
        if title != "":
            stat["title"] = title
        stat["name"] = func.__name__
        stat["time"] = elapsed_time
        if memory:
            stat["memory"] = peak_mem - start_mem
        stat["check"] = check_loss
        stat["neg"] = neg
        stat["error_mu"] = error_mu
        stat["error_nu"] = error_nu
        if prob.sx is not None:
            stat["error_sx"] = relative_error(prob.sx, prob.s)
        if prob.objx is not None:
            stat["error_objx"] = abs(check_loss - prob.objx) / prob.objx
        
        if clean:
            prob.clean()
        return stat
    
    def test(self, func, title="", memory=False, prog=None, *args, **kwargs):
        ss = []
        for i in range(self.len):
            if prog:
                self.log("Testing {0}/{1}".format(i, self.len))
            s = self.test_piece(self.probs[i], func, title=title, memory=memory, *args, **kwargs)
            ss.append(s)
        r = merge_stats(ss, self.merge_config)
        self.stats.append(ss)
        self.ress.append(r)
    
    def clean_last(self):
        for i in range(self.len):
            self.probs[i].clean()
    
    def output_last(self):
        format_output(self.ress[-1], self.output_config, self.log)

In [None]:
# !ConvertEnd*

In [None]:
# !Convert*stats_test.py*w*sehx*

In [None]:
import mosek

# !Switch*
from handler import FigureHandler
from dataset import ot_2d_Caffarelli
# !SwitchCase*
# import font
# from handler import FigureHandler
# from dataset import ot_2d_Caffarelli
# from stats import *
# !SwitchEnd*

In [None]:
# !Switch*
fh = FigureHandler(redir=True)
# !SwitchCase*
# fh = FigureHandler(sav=True, disp=False, ext=".pgf", redir=True)
# !SwitchEnd*

In [None]:
import mosek

def mosek_set_model(p, task):
    m, n = p.c.shape
    
    inf = 0.
    
    task.appendvars(m*n)
    task.appendcons(m+n)
    
    task.putvarboundlist(
        range(m*n),
        [mosek.boundkey.lo]*(m*n),
        [0.]*(m*n),
        [inf]*(m*n)
    )
    
    for i in range(m):
        task.putarow(
            i,
            range(i*n, (i+1)*n),
            [1.]*n
        )
    task.putconboundlist(
        range(0, m),
        [mosek.boundkey.fx]*m,
        p.mu,
        p.mu
    )
    
    for i in range(n):
        task.putarow(
            i+m,
            range(i, i+m*n, n),
            [1.]*m
        )
    task.putconboundlist(
        range(m, m+n),
        [mosek.boundkey.fx]*n,
        p.nu,
        p.nu
    )
    
    task.putclist(range(m*n), p.c.reshape(m*n))
    
    task.putobjsense(mosek.objsense.minimize)

def solve_mosek_primal_simplex(
    p,
    log=None, stat=False,
    *args, **kwargs
):
    m, n = p.c.shape
    
    if stat:
        start_time = time.time()
    
    with mosek.Env() as env:
        env.set_Stream(mosek.streamtype.log, log)
        
        with env.Task() as task:
            task.set_Stream(mosek.streamtype.log, log)
            
            task.putintparam(mosek.iparam.optimizer, mosek.optimizertype.primal_simplex)
        
            mosek_set_model(p, task)
            
            if stat:
                end_time = time.time()
            
            task.optimize()
            
            xx = [0.] * (m*n)
            task.getxx(mosek.soltype.bas, xx)
            
            p.s = numpy.array(xx).reshape(m, n)
    
            if stat:
                setup_time = end_time - start_time
                s = {
                    "loss": task.getprimalobj(mosek.soltype.bas),
                    "vars": task.getintinf(mosek.iinfitem.opt_numvar),
                    "iters": task.getintinf(mosek.iinfitem.sim_primal_iter),
                    "setup": setup_time,
                    "solve": task.getdouinf(mosek.dinfitem.optimizer_time)
                }
                return p, s
            else:
                return p

In [None]:
stat = Statistics(
    probs=[
        ot_2d_Caffarelli(500, 500, 1)
    ],
    merge_config=general_merge_config,
    output_config=general_output_config,
    prob="Test problems",
    log=fh.write,
)

In [None]:
stat.set_sx(solve_mosek_primal_simplex, prog=fh.write)

In [None]:
stat.test(solve_mosek_primal_simplex, title="MOSEK, test", memory=True, prog=fh.write, clean=False)
stat.output_last()

In [None]:
fh.fast(stat.probs[0].plot_link, aspect="equal")
stat.clean_last()

In [None]:
# !ConvertEnd*