# 0) Initialization

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import time
import sampling, utils, param

In [None]:
# utils.print_basic_info() # Brief information about the environment

In [None]:
pd.set_option('display.max_columns', None)

# 1) One particular sampled dataset

In [None]:
mysample = sampling.MySample(param.sdict, param.s_seed, verbose = False)
mysample.GenDist()
mysample.GenSmpl()
mysample.SmplSummary()

In [None]:
mysample.VisDimOneTwo(smpl_type = 'original')
mysample.VisDimOneTwo(smpl_type = 'box_2')

## 1.1) Grid Search

In [None]:
mysample.OptInit(param.odict, param.o_seed)

In [None]:
mysample.verbose = True
mysample.GridSearch(param.gdict) # only works with 2-dimension (i.e. mysample.d=2) with non-zero k
mysample.verbose = False
mysample.VisGridSearch()

In [None]:
remember_k = mysample.k
utils.DrawHeatmaps([0,1,2], mysample, param.gdict)
mysample.k = remember_k

## 1.2) MMD calculation

In [None]:
mysample.OptInit(param.odict, param.o_seed)
print(mysample.optimizer, mysample.tAll, mysample.lamb, mysample.device, mysample.k)

In [None]:
mysample.verbose = True
mysample.OptSolve()
mysample.verbose = False

In [None]:
mysample.OptVis('history', start=20)
# mysample.OptVis('history', start=0)
# mysample.OptVis('all', start=20)

## 1.2.1) The last result of the optimization

In [None]:
mysample.OptNorm('last')
print(mysample.tAll, mysample.opt_norm['i'], mysample.opt_norm['IPM'], "\n") # check the value of 'i' is equal to 'mysample.tAll - 1'
mysample.OptMajor(verb_grid = True)

In [None]:
mysample.verbose = False
mysample.OptChooseOne()
print(mysample.opt_one)
mysample.verbose = False

## 1.2.2) The best result throughout the whole optimization process

In [None]:
mysample.OptNorm('best')
print(mysample.tAll, mysample.opt_norm['i'], mysample.opt_norm['IPM'], "\n")
mysample.OptMajor(verb_grid = True)

In [None]:
mysample.verbose = False
mysample.OptChooseOne()
print(mysample.opt_one)
mysample.verbose = False

# 2) Alternative and null hypothesis: multiple resampled datasets

In [None]:
altnull_repeat = sampling.MySample(param.sdict, param.s_seed, verbose = False)
altnull_repeat.OptInit(param.odict, param.o_seed)

In [None]:
start  = time.time()
altnull_repeat_output = altnull_repeat.AltNullRepeat(rep_sample = 5, rep_optim = 3, altnull_task = 'logonly')
# output_lnl = altnull_repeat.AltNullRepeat(rep_sample = 5, rep_optim = 3, altnull_task = 'lognolog')
print(time.time()-start)

In [None]:
print(altnull_repeat_output.shape)
# print(altnull_repeat_output.isna().astype(int).sum(axis = 0))
altnull_repeat_output[altnull_repeat_output['hypo'] == 'alt_hypo'].head(40)

In [None]:
utils.PlotAltNullRepeat(altnull_repeat, altnull_repeat_output, 'log', 12)

# 3) MMD values during the whole optimization

In [None]:
mmd_curve = sampling.MySample(param.sdict, param.s_seed, verbose = False)
mmd_curve.OptInit(param.odict, param.o_seed)

In [None]:
start  = time.time()
mmd_curve_output = mmd_curve.MMDCurveAltNullRepeat(rep_sample = 5, rep_optim = 3)
print(time.time()-start)

In [None]:
print(mmd_curve_output.shape)
mmd_curve_output[mmd_curve_output['log_nolog'] == 'log'].head(6)