In [None]:
%load_ext autoreload
%autoreload 2
%run ~/FKMC/notebooks/notebook_preamble.py
%matplotlib inline
np.seterr(under = 'ignore')
from time import time
from munch import munchify

In [None]:
import numpy as np
from time import time
from operator import mul
from functools import reduce
from itertools import count
from munch import Munch
from itertools import zip_longest
import logging

import re
from pathlib import Path

import scipy
from FKMC.general import index_histogram_array, sort_IPRs, smooth, shapes, normalise_IPR
from FKMC.stats import binned_error_estimate_multidim, product
from FKMC.import_funcs import allocate, copy, reshape, execute_script, ProgressReporter, shape_hints, timefmt, update_description

#variable classifications
N_dependent_size = set(['IPRs', 'eigenvals', 'state','accept_rates', 'classical_accept_rates', 'last_state', 'proposal_rates'])
per_step = set([ 'Fc', 'Ff', 'Mf_moments', 'Nc', 'Nf', 'eigenval_bins'])
per_run = set(['A', 'N_cumulants','time'])

from collections import defaultdict

def datafile_concat(datafiles, Ns):
    #datafiles is a list of lists where the outer is chain_ext and inner is Ns
    datafile = [Munch() for _ in Ns]
    names = ['IPRs', 'eigenvals', 'Mf_moments', 'eigenval_bins', 'time', 'accept_rates', 'proposal_rates']
    
    for name in names:
        shape = shape_hints(name)
        if 'MCstep' in shape:
            axis = shape.index('MCstep')
        for i, N in enumerate(Ns):
            if name == 'time':
                datafile[i][name] = np.sum([getattr(log[i], name) for log in datafiles])
            elif name == 'eigenval_bins':
                log = datafiles[0]
                datafile[i][name] = getattr(log[i], name)
            else:
                datafile[i][name] = np.concatenate([getattr(log[i], name) for log in datafiles], axis = axis)
    
    return datafile

def get_data_funcmap_chain_ext(this_run,
            functions = [],
            strict_chain_length = True,
            ):
    
    '''
    '''
    this_run = this_run.expanduser()
    logger.warning(f'looking in {this_run}')
    data = this_run / 'data'
    code = this_run / 'code'
    
    #get the batch params from the original script
    print(list(code.glob('*.py')))
    py_script = next(code.glob('*.py'))
    context = execute_script(py_script)
    batch_params = Munch(context.batch_params)
    structure_names = batch_params.structure_names
    structure_dims = tuple(d.size for d in batch_params.structure_dimensions)
    
    logger.debug(f'structure_names = {structure_names}')
    logger.debug(f'structure_dims = {structure_dims}')
    
    #calculate the epected number of jobs
    def name2id(n): return tuple(map(int,n.split('_')))
    
    datafiles = dict()
    task_ids = set()
    chain_ids = defaultdict(set)
    for f in data.glob('*.npz'):
        task_id, chain_id = name2id(f.stem)
        datafiles[(task_id, chain_id)] = f
        task_ids.add(task_id)
        chain_ids[task_id].add(chain_id)
    
    
    N_tasks = product(structure_dims)
    N_chains = min(max(c) for c in chain_ids.values())
    logger.debug(f'Expected number of tasks {N_tasks}')
    logger.debug(f'Measured number of tasks {len(task_ids)}')
    logger.debug(f'Measured number of chains {N_chains}')
    
    functions += [extract('time'), mean_over_MCMC('accept_rates'), mean_over_MCMC('proposal_rates')]
    
    if len(datafiles) == 0: 
        logger.error("NO DATA FILES FOUND");
        return
    
    #get stuff from an an example datafile
    d = Munch(np.load(next(iter(datafiles.values())), allow_pickle = True))
    Ns = d['Ns']
    parameters = d['parameters'][()]
    MCMC_params = d['MCMC_params'][()]
    
    logger.info(f'Logger keys: {list(d.keys())} \n')
    logger.info(f"MCMC_params keys: {list(MCMC_params.keys())} \n")
    
    original_N_steps = MCMC_params['N_steps']
    thin = MCMC_params['thin']
    N_steps = original_N_steps // thin
    
    logger.debug(list(zip(count(), structure_names, structure_dims)))

    possible_observables = [s for s in dir(d.logs[0]) if not s.startswith("_")]
    logger.info(f'available observables = {possible_observables}')
    
    logger.debug(f'Allocating space for the requested observables:')
    observables = Munch()
    for f in functions: f.allocate(observables, example_datafile = d, N_jobs = N_tasks)
    
    #copy extra info over, note that structure_names might appear as a key in d, but I just overwrite it for now
    observables.update({k : v[()] for k,v in d.items() if k != 'logs'})
    observables.structure_names = structure_names
    observables.structure_dims = structure_dims
    observables.batch_params = batch_params
    observables['hints'] = Munch() 
    
    for name, dim in zip(structure_names, batch_params.structure_dimensions):
        observables[name] = dim
    
    for task_id in range(N_tasks):
        print(task_id, end = ' ')
        datafile_list = np.empty(dtype = object, shape = N_chains)
        for chain_id in range(N_chains):
            if not (task_id, chain_id) in datafiles:
                if strict_chain_length: raise ValueError(f'{(task_id, chain_id)} is expected but missing!')
                break
            f = datafiles[(task_id, chain_id)]    
            datafile_list[chain_id] = np.load(f, allow_pickle = True)['logs']
        
        datafile = datafile_concat(datafile_list, Ns)
        
        #convert all those datafiles to one
        for f in functions: f.copy(observables, task_id, datafile)
    
    for f in functions:
        f.reshape(structure_dims, observables)
    
   
    logger.info('########################################################################\n')
    logger.info(f'Observables has keys: {observables.keys()}')
    
    o = observables = Munch(observables)
    
    infostring = \
    f"""
    Completed jobs:?
    MCMC Steps: {original_N_steps} with thinning = {thin} for {N_steps} recorded steps
    Burn in: {Munch(MCMC_params).N_burn_in}
    Structure_names: {dict(zip(structure_names, structure_dims))}
    Ns = {Ns}
    Runtimes: 
        Average: {timefmt(np.nanmean(o.time.sum(axis=0)))}
        Min: {timefmt(np.nanmin(o.time.sum(axis=0)))}
        Max: {timefmt(np.nanmax(o.time.sum(axis=0)))}
        Total: {timefmt(np.nansum(o.time))}
    """[1:]
    logger.info(infostring)
    update_description(this_run.stem, infostring)
    
    return observables

In [None]:
from FKMC.import_funcs import  mean_over_MCMC, IPRandDOS, extract


logging.basicConfig()
logger = logging.getLogger('local')
logger.setLevel(logging.DEBUG)

oSingle = get_data_funcmap_chain_ext(Path('~/HPC_data/IPR_DOR_U5_J5_above_below_Tc'),
            functions = [
                IPRandDOS(E_bins = np.linspace(-20, 20, 10000 + 1)), 
                mean_over_MCMC('Mf_moments'),
                            ],
            )

In [None]:
o = oSingle
o.keys()

In [None]:
from copy import deepcopy
o = oSmoothed = deepcopy(oSingle)


for i, N in zip(count(), o.Ns):
    scale = 0.1
    for name in ['DOS', 'IPR', 'dDOS', 'dIPR']:
        o[name][i] = smooth(o[name][i], scale)


f, axes = plt.subplots(2,1, sharex = True, figsize = (14,7))
T_select = 3

print(o.Ts)
for i, N in enumerate(o.Ns):
    axes[0].plot(o.E_bins[1:] / o.parameters.U, o.DOS[i, T_select, :], label = f'N = {N}')
    axes[1].plot(o.E_bins[1: ] / o.parameters.U, o.IPR[i, T_select, :],)

axes[0].set(xlim = (0, 1))
f.suptitle(f'T = {o.Ts[T_select]}, J = {o.parameters.J}, U = {o.parameters.U}, M**2 = {o.Mf_moments[-1, :, 2].mean():.2f}')

In [None]:
o.hints.Mf_moments, o.sigma_Mf_moments.shape

In [None]:
from FKMC.plotting import spread
o = oSingle
f, ax = plt.subplots()

for i, N in enumerate(o.Ns):
    spread(ax, o.Ts, o.Mf_moments[i, :, 2], o.sigma_Mf_moments[i, :, 2], alpha = 0.3, label = f'N = {N}')

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.INFO)

oTsweep = get_data_funcmap_chain_ext(Path('~/HPC_data/Tsweep_U5_J5'),
            functions = [
                #IPRandDOS(E_bins = np.linspace(-20, 20, 10000 + 1)), 
                mean_over_MCMC('Mf_moments'),
                            ],
            )

In [None]:
from FKMC.plotting import spread
o = oTsweep
f, axes = plt.subplots(1,2)

d = Munch()
d.Ns = o.Ns
d.X = o.Ts
d.M2 = o.Mf_moments[:, :, 2]
d.dM2 = o.sigma_Mf_moments[:, :, 2]
d.B =  o.Mf_moments[:, :, 4] / o.Mf_moments[:, :, 2]**2
d.dB = None

for i, N in enumerate(d.Ns):
    axes[0].plot(d.X, d.M2[i])
    spread(axes[0], d.X, d.M2[i], d.dM2[i], alpha = 0.3, label = f'N = {N}')
    
    axes[1].plot(d.X, d.B[i], label = f'N = {N}')
    
import pickle
with open('/home/tch14/HPC_data/pickled_data/binder_data.pickle', 'wb') as file: 
    pickle.dump(d.toDict(), file)


In [None]:
Munch(job_id = 11111).job_id