In [1]:
from pathlib import Path
import pickle
import gzip
from matplotlib import pyplot as plt
import seaborn as sns
import polars as plrs
import numpy as np
plrs.Config.set_tbl_rows(200)
plrs.Config.set_tbl_cols(200)

%matplotlib inline

In [2]:
# code from rundemc
from matplotlib.ticker import MaxNLocator
import pylab as pl
import os
import sys
from scipy.stats import gaussian_kde, pearsonr
import scipy as sp
import scipy.sparse
import scipy.signal


"""
A faster gaussian kernel density estimate (KDE).
Intended for computing the KDE on a regular grid (different use case than
scipy's original scipy.stats.kde.gaussian_kde()).
-Joe Kington
"""
__license__ = 'MIT License <http://www.opensource.org/licenses/mit-license.php>'


def fast_2d_kde(x, y, gridsize=(200, 200), extents=None,
                nocorrelation=False, weights=None):
    """
    Performs a gaussian kernel density estimate over a regular grid using a
    convolution of the gaussian kernel with a 2D histogram of the data.

    This function is typically several orders of magnitude faster than
    scipy.stats.kde.gaussian_kde for large (>1e7) numbers of points and
    produces an essentially identical result.

    Input:
        x: The x-coords of the input data points
        y: The y-coords of the input data points
        gridsize: (default: 200x200) A (nx,ny) tuple of the size of the output
            grid
        extents: (default: extent of input data) A (xmin, xmax, ymin, ymax)
            tuple of the extents of output grid
        nocorrelation: (default: False) If True, the correlation between the
            x and y coords will be ignored when preforming the KDE.
        weights: (default: None) An array of the same shape as x & y that
            weighs each sample (x_i, y_i) by each value in weights (w_i).
            Defaults to an array of ones the same size as x & y.
    Output:
        A gridded 2D kernel density estimate of the input points.
    """
    #---- Setup --------------------------------------------------------------
    x, y = np.asarray(x), np.asarray(y)
    x, y = np.squeeze(x), np.squeeze(y)

    if x.size != y.size:
        raise ValueError('Input x & y arrays must be the same size!')

    nx, ny = gridsize
    n = x.size

    if weights is None:
        # Default: Weight all points equally
        weights = np.ones(n)
    else:
        weights = np.squeeze(np.asarray(weights))
        if weights.size != x.size:
            raise ValueError('Input weights must be an array of the same size'
                             ' as input x & y arrays!')

    # Default extents are the extent of the data
    if extents is None:
        xmin, xmax = x.min(), x.max()
        ymin, ymax = y.min(), y.max()
    else:
        xmin, xmax, ymin, ymax = list(map(float, extents))
    dx = (xmax - xmin) / (nx - 1)
    dy = (ymax - ymin) / (ny - 1)

    #---- Preliminary Calculations -------------------------------------------

    # First convert x & y over to pixel coordinates
    # (Avoiding np.digitize due to excessive memory usage!)
    xyi = np.vstack((x, y)).T
    xyi -= [xmin, ymin]
    xyi /= [dx, dy]
    xyi = np.floor(xyi, xyi).T

    # Next, make a 2D histogram of x & y
    # Avoiding np.histogram2d due to excessive memory usage with many points
    grid = sp.sparse.coo_matrix((weights, xyi), shape=(nx, ny)).toarray()

    # Calculate the covariance matrix (in pixel coords)
    cov = np.cov(xyi)

    if nocorrelation:
        cov[1, 0] = 0
        cov[0, 1] = 0

    # Scaling factor for bandwidth
    scotts_factor = np.power(n, -1.0 / 6)  # For 2D

    #---- Make the gaussian kernel -------------------------------------------

    # First, determine how big the kernel needs to be
    std_devs = np.sqrt(np.diag(cov))
    kern_nx, kern_ny = np.round(scotts_factor * 2 * np.pi * std_devs)

    # Determine the bandwidth to use for the gaussian kernel
    inv_cov = np.linalg.inv(cov * scotts_factor**2)

    # x & y (pixel) coords of the kernel grid, with <x,y> = <0,0> in center
    xx = np.arange(kern_nx, dtype=float) - kern_nx / 2.0
    yy = np.arange(kern_ny, dtype=float) - kern_ny / 2.0
    xx, yy = np.meshgrid(xx, yy)

    # Then evaluate the gaussian function on the kernel grid
    kernel = np.vstack((xx.flatten(), yy.flatten()))
    kernel = np.dot(inv_cov, kernel) * kernel
    kernel = np.sum(kernel, axis=0) / 2.0
    kernel = np.exp(-kernel)
    kernel = kernel.reshape((int(kern_ny), int(kern_nx)))

    #---- Produce the kernel density estimate --------------------------------

    # Convolve the gaussian kernel with the 2D histogram, producing a gaussian
    # kernel density estimate on a regular grid
    grid = sp.signal.convolve2d(grid, kernel, mode='same', boundary='fill').T

    # Normalization factor to divide result by so that units are in the same
    # units as scipy.stats.kde.gaussian_kde's output.
    norm_factor = 2 * np.pi * cov * scotts_factor**2
    norm_factor = np.linalg.det(norm_factor)
    norm_factor = n * dx * dy * np.sqrt(norm_factor)

    # Normalize the result
    grid /= norm_factor

    return np.flipud(grid)

def kdensity(x, extrema=None, kernel="gaussian",
             binwidth=None, nbins=512, weights=None,
             # bw="nrd0",
             adjust=1.0, cut=3, xx=None):
    """Calculate kernel density via FFT.
    """
    # function (x, bw = "nrd0", adjust = 1, kernel = c("gaussian",
    #     "epanechnikov", "rectangular", "triangular", "biweight",
    #     "cosine", "optcosine"), weights = NULL, window = kernel,
    #     width, give.Rkern = FALSE, n = 512, from, to, cut = 3, na.rm = FALSE,
    #     ...)
    # {
    #     if (length(list(...)))
    #         warning("non-matched further arguments are disregarded")
    #     if (!missing(window) && missing(kernel))
    #         kernel <- window
    #     kernel <- match.arg(kernel)
    #     if (give.Rkern)
    #         return(switch(kernel, gaussian = 1/(2 * sqrt(pi)), rectangular = sqrt(3)/6,
    #             triangular = sqrt(6)/9, epanechnikov = 3/(5 * sqrt(5)),
    #             biweight = 5 * sqrt(7)/49, cosine = 3/4 * sqrt(1/3 -
    #                 2/pi^2), optcosine = sqrt(1 - 8/pi^2) * pi^2/16))
    #     if (!is.numeric(x))
    #         stop("argument 'x' must be numeric")
    #     name <- deparse(substitute(x))
    #     x <- as.vector(x)
    #     x.na <- is.na(x)
    #     if (any(x.na)) {
    #         if (na.rm)
    #             x <- x[!x.na]
    #         else stop("'x' contains missing values")
    #     }
    #     N <- nx <- length(x)
    N = len(x)
    nx = len(x)
    #     x.finite <- is.finite(x)
    #     if (any(!x.finite)) {
    #         x <- x[x.finite]
    #         nx <- length(x)
    #     }
    #     if (is.null(weights)) {
    #         weights <- rep.int(1/nx, nx)
    #         totMass <- nx/N
    #     }
    #     else {
    #         if (length(weights) != N)
    #             stop("'x' and 'weights' have unequal length")
    #         if (!all(is.finite(weights)))
    #             stop("'weights' must all be finite")
    #         if (any(weights < 0))
    #             stop("'weights' must not be negative")
    #         wsum <- sum(weights)
    #         if (any(!x.finite)) {
    #             weights <- weights[x.finite]
    #             totMass <- sum(weights)/wsum
    #         }
    #         else totMass <- 1
    #         if (!isTRUE(all.equal(1, wsum)))
    #             warning("sum(weights) != 1  -- will not get true density")
    #     }
    if weights is None:
        weights = np.ones(nx) / float(nx)
        totMass = nx / float(N)
    else:
        totMass = 1.0

    #     n.user <- n
    #     n <- max(n, 512)
    #     if (n > 512)
    #         n <- 2^ceiling(log2(n))
    nbins_user = nbins
    nbins = max(nbins, 512)
    if nbins > 512:
        nbins = int(2**np.ceil(np.log2(nbins)))
    #     if (missing(bw) && !missing(width)) {
    #         if (is.numeric(width)) {
    #             fac <- switch(kernel, gaussian = 4, rectangular = 2 *
    #                 sqrt(3), triangular = 2 * sqrt(6), epanechnikov = 2 *
    #                 sqrt(5), biweight = 2 * sqrt(7), cosine = 2/sqrt(1/3 -
    #                 2/pi^2), optcosine = 2/sqrt(1 - 8/pi^2))
    #             bw <- width/fac
    #         }
    #         if (is.character(width))
    #             bw <- width
    #     }
    #     if (is.character(bw)) {
    #         if (nx < 2)
    #             stop("need at least 2 points to select a bandwidth automatically")
    #         bw <- switch(tolower(bw), nrd0 = bw.nrd0(x), nrd = bw.nrd(x),
    #             ucv = bw.ucv(x), bcv = bw.bcv(x), sj = , `sj-ste` = bw.SJ(x,
    #                 method = "ste"), `sj-dpi` = bw.SJ(x, method = "dpi"),
    #             stop("unknown bandwidth rule"))
    #     }
    bw = nrd0(x)
    #     if (!is.finite(bw))
    #         stop("non-finite 'bw'")
    #     bw <- adjust * bw
    bw *= adjust

    # for some reason I have to multiply bw by 2
    #bw *= 2.

    #     if (bw <= 0)
    #         stop("'bw' is not positive.")
    #     if (missing(from))
    #         from <- min(x) - cut * bw
    #     if (missing(to))
    #         to <- max(x) + cut * bw
    if extrema is None:
        extrema = (np.min(x) - cut * bw, np.max(x) + cut * bw)
    #     if (!is.finite(from))
    #         stop("non-finite 'from'")
    #     if (!is.finite(to))
    #         stop("non-finite 'to'")
    #     lo <- from - 4 * bw
    #     up <- to + 4 * bw
    lo = np.squeeze(extrema[0] - 4 * bw)
    up = np.squeeze(extrema[1] + 4 * bw)
    # print extrema,lo,up
    #     y <- .Call(C_BinDist, x, weights, lo, up, n) * totMass
    #y = np.histogram(x, nbins=nbins, weights=weights, range=(lo,up))*totMass
    xi = x.copy()
    xi -= lo
    xi /= (up - lo) / (nbins - 1)
    xi = np.floor(xi)

    # Next, make a histogram of x
    # Avoiding np.histogram2d due to excessive memory usage with many points
    y = sp.sparse.coo_matrix((weights, np.vstack([xi, np.zeros_like(xi)])),
                             shape=(nbins, 1)).toarray()[:, 0] * totMass

    #     kords <- seq.int(0, 2 * (up - lo), length.out = 2L * n)
    kords = np.linspace(0, 2 * (up - lo), 2 * nbins)
    #     kords[(n + 2):(2 * n)] <- -kords[n:2]
    kords[nbins + 1:] = -kords[nbins - 1:0:-1]
    #     kords <- switch(kernel, gaussian = dnorm(kords, sd = bw),
    #         rectangular = {
    #             a <- bw * sqrt(3)
    #             ifelse(abs(kords) < a, 0.5/a, 0)
    #         }, triangular = {
    #             a <- bw * sqrt(6)
    #             ax <- abs(kords)
    #             ifelse(ax < a, (1 - ax/a)/a, 0)
    #         }, epanechnikov = {
    #             a <- bw * sqrt(5)
    #             ax <- abs(kords)
    #             ifelse(ax < a, 3/4 * (1 - (ax/a)^2)/a, 0)
    #         }, biweight = {
    #             a <- bw * sqrt(7)
    #             ax <- abs(kords)
    #             ifelse(ax < a, 15/16 * (1 - (ax/a)^2)^2/a, 0)
    #         }, cosine = {
    #             a <- bw/sqrt(1/3 - 2/pi^2)
    #             ifelse(abs(kords) < a, (1 + cos(pi * kords/a))/(2 *
    #                 a), 0)
    #         }, optcosine = {
    #             a <- bw/sqrt(1 - 8/pi^2)
    #             ifelse(abs(kords) < a, pi/4 * cos(pi * kords/(2 *
    #                 a))/a, 0)
    #         })

    # NOTE: bw is doubled here to match doubled width of kernel
    bw2 = bw * 2.
    if kernel == 'gaussian':
        kords = normal(std=bw2).pdf(kords)
    elif kernel == 'epanechnikov':
        a = bw2 * np.sqrt(5)
        ax = np.abs(kords)
        ind = ax < a
        ax[ind] = .75 * (1 - (ax[ind] / a)**2) / a
        ax[~ind] = 0.0
        kords = ax
    else:
        raise ValueError("Unknown kernel type.")

    # squeeze to ensure 1d
    kords = np.squeeze(kords)
    
    #     kords <- fft(fft(y) * Conj(fft(kords)), inverse = TRUE)
    kords = np.fft.ifft(np.concatenate([np.fft.fft(y)] * 2) *
                        np.conj(np.fft.fft(kords)))
    #     kords <- pmax.int(0, Re(kords)[1L:n]/length(y))
    #kords = (np.real(kords)[:nbins]/float(len(y))).clip(0,np.inf)
    #kords = (np.real(kords)[::2]/float(len(y))).clip(0,np.inf)
    #kords = (np.real(kords)/float(len(y))).clip(0,np.inf)
    #kords = (np.real(kords)).clip(0,np.inf)*2.
    kords = (np.real(kords)[::2]).clip(0, np.inf) * 2.
    #     xords <- seq.int(lo, up, length.out = n)
    xords = np.linspace(lo, up, nbins)
    #     x <- seq.int(from, to, length.out = n.user)
    if xx is None:
        xx = np.linspace(extrema[0], extrema[1], nbins_user)
    pdf = np.interp(xx, xords, kords)
    #     structure(list(x = x, y = approx(xords, kords, x)$y, bw = bw,
    #         n = N, call = match.call(), data.name = name, has.na = FALSE),
    #         class = "density")
    # }

    return pdf, xx

def joint_plot(particles, weights, burnin=50, names=None, legend=False, add_best=True,
               border=.1, sep=0.0, rot=None, fig=None, nxticks=5, nyticks=5,
               take_log=False, grid=False, bold_sig=False, corr_size=(12, 18),
               do_scatter=False, num_contours=25, cmap=None):
    # get the fig
    if fig is None:
        fig = pl.gcf()

    if cmap is None:
        cmap = pl.get_cmap('gist_earth_r')
        cmap._init()
        cmap._lut[:5, :] = 1.0
        cmap._lut[:20, -1] = np.linspace(0, 1, 20)

    # get num grid
    n_p = particles.shape[-1]

    # calc width and height
    width = height = (1.0 - (2. * border) - ((n_p - 1) * sep)) / n_p

    # get best indiv
    best_ind = weights[burnin:].ravel().argmax()
    indiv = [particles[burnin:, :, i].ravel()[best_ind]
             for i in range(particles.shape[-1])]

    # set holder for axes
    ax = np.zeros((n_p, n_p), dtype=object)
    for i in range(n_p):
        for j in range(i + 1, n_p):
            # create the axis (start at top right)
            left = border + (j * width) + (j * sep)
            #bottom = 1 - (border + ((i+1)*height) + ((i+1)*sep))
            bottom = 1 - (border + (i * height) + (i * sep) + height)
            sharex = sharey = None
            if i > 0:
                if False:  # j==0:
                    if i > 1:
                        sharex = ax[1, j]
                else:
                    sharex = ax[0, j]
            if j > i + 1:
                if False:  # i==0:
                    if j > 1:
                        sharey = ax[i, 1]
                else:
                    sharey = ax[i, j - 1]
            ax[i, j] = fig.add_axes((left, bottom, width, height),
                                    sharex=sharex, sharey=sharey)
            a = ax[i, j]

            # clear it
            a.cla()

            # show joint
            if take_log:
                w = np.log(weights[burnin:, :].ravel())
            else:
                w = weights[burnin:, :].ravel()
            if num_contours > 0:
                # determine extents
                cx = particles[burnin:, :, i].ravel()
                xdiff = cx.max() - cx.min()
                cy = particles[burnin:, :, j].ravel()
                ydiff = cy.max() - cy.min()
                gap = .025
                extents = (cy.min() - gap * ydiff, cy.max() + gap * ydiff,
                           cx.min() - gap * xdiff, cx.max() + gap * xdiff)

                grd = fast_2d_kde(cy,
                                  cx,
                                  gridsize=(200, 200),
                                  extents=extents,
                                  nocorrelation=False,
                                  weights=None)
                # draw the contour
                a.contourf(np.linspace(extents[0], extents[1], 200),
                           np.linspace(extents[2], extents[3], 200),
                           np.flipud(grd), num_contours,
                           # cmap=pl.get_cmap('GnBu')
                           cmap=pl.get_cmap('gist_earth_r')
                           )
                a.axis('tight')

            if do_scatter:
                pts = a.scatter(particles[burnin:, :, j].ravel(),
                                particles[burnin:, :, i].ravel(),
                                c=w,
                                # cmap=pl.get_cmap('gist_earth'),
                                s=20,
                                linewidth=0,
                                alpha=.1)
            if legend:
                cb = pl.colorbar(alpha=1.0)
                cb.set_alpha(1.0)
                cb.set_label('log(Weight)')
                cb.draw_all()
            if add_best:
                a.plot(indiv[j], indiv[i], 'rx',
                       markersize=10, markeredgewidth=3)

            # set the n-ticks
            a.xaxis.set_major_locator(MaxNLocator(nxticks))
            a.yaxis.set_major_locator(MaxNLocator(nyticks))

            # set the tick loc
            a.xaxis.set_ticks_position('top')
            a.yaxis.set_ticks_position('right')

            # turn on the grid if wanted
            if grid:
                a.grid('on')

            # clean labels
            if i != 0:
                for label in a.get_xticklabels():
                    label.set_visible(False)
            elif not rot is None:
                for label in a.get_xticklabels():
                    label.set_rotation(rot)
                    # label.set_horizontalalignment('right')
            if j != n_p - 1:
                for label in a.get_yticklabels():
                    label.set_visible(False)

    # add text after scatter plot so we don't mess up ranges
    for i in range(n_p):
        # create the axis (start at top right)
        left = border + (i * width) + (i * sep)
        #bottom = 1 - (border + ((i+1)*height) + ((i+1)*sep))
        bottom = 1 - (border + (i * height) + (i * sep) + height)
        sharex = sharey = None
        ax[i, i] = fig.add_axes((left, bottom, width, height),
                                sharex=sharex, sharey=sharey)
        a = ax[i, i]

        # place the labels
        a.text(0.5, 0.5, names[i],
               horizontalalignment='center',
               verticalalignment='center',
               transform=a.transAxes,
               fontsize=24, fontweight='bold')

        # remove all ticks
        a.set_xticks([])
        a.set_yticks([])

    # now the bottom triangle
    for i in range(n_p):
        for j in range(0, i):
            # create the axis (start at top right)
            left = border + (j * width) + (j * sep)
            #bottom = 1 - (border + ((i+1)*height) + ((i+1)*sep))
            bottom = 1 - (border + (i * height) + (i * sep) + height)
            sharex = sharey = None
            ax[i, j] = fig.add_axes((left, bottom, width, height),
                                    sharex=sharex, sharey=sharey)
            a = ax[i, j]

            # add in correlation plot
            # cc = np.corrcoef(particles[burnin:,:,j].ravel(),
            #                  particles[burnin:,:,i].ravel())[0,1]
            cc, pp = pearsonr(particles[burnin:, :, j].ravel(),
                              particles[burnin:, :, i].ravel())
            a = ax[i, j]
            fs = (corr_size[1] - corr_size[0]) * np.abs(cc) + corr_size[0]
            txt = a.text(0.5, 0.5, '%1.2f' % cc,
                         horizontalalignment='center',
                         verticalalignment='center',
                         transform=a.transAxes,
                         fontsize=fs)
            # bold if sig
            if bold_sig and pp <= .05:
                txt.set_fontweight('bold')

            # remove all ticks
            a.set_xticks([])
            a.set_yticks([])

    # pl.tight_layout()
    pl.show()

    return ax


def load_results(filename):
    """
    Load in a simulation that was saved to a pickle.gz.
    """
    gf = gzip.open(filename, 'rb')
    res = pickle.loads(gf.read(), encoding='latin1')
    gf.close()
    return res



In [3]:
flkr_res_dir = Path('/data/MLDSST/nielsond/cogmood/data/20250611_pilot/model_res/')
burnin = 400

In [None]:
flkr_reses = []
for flkr_res in flkr_res_dir.glob('flanker_*.tgz'):
    res = load_results(flkr_res)
    best_ind = res['weights'][burnin:].argmax()
    indiv = [res['particles'][burnin:, :, i].ravel()[best_ind]
            for i in range(res['particles'].shape[-1])]
    best_ps = {p:v for p, v in zip(res['param_names'], indiv)}
    best_ps['sub_id'] = flkr_res.parts[-1].split('.')[0].split('-')[1]
    best_ps['weight'] = res['weights'][burnin:].max()
    flkr_reses.append(best_ps)
    fig = plt.figure(figsize=(10,7.5))
    fig.suptitle(f'sub_id: {flkr_res.parts[-1].split('.')[0].split('-')[1]}')
    joint_plot(res['particles'], res['weights'], burnin=burnin, names=res['param_display_names'], sep=0, fig=fig)
    
flkr_reses = plrs.DataFrame(flkr_reses)

In [None]:
flkr_reses

r,p,sd0,K,L,thresh,alpha,t0,sub_id,weight
f64,f64,f64,f64,f64,f64,f64,f64,str,f64
1.321696,1.817572,3.860363,0.032474,0.184214,7.712371,21.452495,0.360169,"""11nuj5ty67ojohm39cmzbt23""",19.116851
0.230198,3.153311,3.889936,0.358846,0.555873,3.138032,7.059636,0.522988,"""2upuqdbw3wdpk3q43x89zysp""",52.505726
0.424314,3.052498,3.824852,0.37838,0.716797,2.744588,5.088501,0.615051,"""48juqsgxp4m2o7797zvjxln9""",69.237476
0.55551,2.651472,3.119924,0.153346,0.325574,3.457636,5.243781,0.470854,"""60pixcark57tgonq4abwctvs""",55.5852
1.727149,2.160918,6.996961,0.184604,0.68839,4.756761,11.410604,0.564806,"""81987885tpc29718g2d8evdm""",46.119518
2.303885,1.629935,9.771818,0.108684,0.085128,4.267536,9.286047,0.459661,"""h3q7g3g6za07rl9qnhd87hoq""",12.135177
1.491918,2.384848,5.857088,0.098931,0.181047,7.803392,9.143323,0.314705,"""hvann18ezp9i2kq8bvqivehs""",33.491011
3.345413,1.315552,10.157574,0.852177,0.190372,2.132598,7.289991,0.540351,"""l8eyqget2wsecwew6bwabn1h""",-25.254624
0.942578,4.015424,5.538032,0.403551,0.606488,6.87456,6.036752,0.273175,"""mglomvxjfi6gya3jmrt7o09w""",58.652991
0.340607,3.453391,5.747746,0.420929,0.495114,6.877798,7.583932,0.294129,"""mjff7puqxr95bh6d945ru7z2""",44.022223
