In [None]:
from __future__ import print_function

import math

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pytest import approx

from fastimgproto.bindings.sourcefind import (
    cpp_sourcefind,
    cpp_sourcefind_result_to_islandparams,
)
from fastimgproto.fixtures.image import (
    add_gaussian2d_to_image,
    gaussian_point_source,
)
from fastimgproto.fixtures.sourcefind import random_sources_on_grid
from fastimgproto.sourcefind.fit import Gaussian2dParams
from fastimgproto.sourcefind.image import SourceFindImage

from fastimgproto.fixtures.sourcefind import (
    generate_random_source_params,
    check_single_source_extraction_successful,
)


import logging
import datetime

logger = logging.getLogger(__name__) 

In [None]:
%matplotlib inline

# Plot image pixels in cartesian ordering (i.e. y-positive == upwards):
plt.rcParams['image.origin'] = 'lower'
# Make plots bigger
plt.rcParams['figure.figsize'] = 6,6

In [None]:
amplitude_range = (6., 42.)
semiminor_range = (1.2, 2.5)
axis_ratio_range = (1., 2.)
seed = 123456

In [None]:
image_size = 1024
n_sources = 64

In [None]:
detection_n_sigma=5.
analysis_n_sigma=3.
rms_est=1.

In [None]:
def generate_test_image(image_size, n_sources, seed=None):
    image = np.zeros((image_size, image_size), dtype=np.float_)
    sources = random_sources_on_grid(image_size,
                                     n_sources,
                                     amplitude_range=amplitude_range,
                                     semiminor_range=semiminor_range,
                                     axis_ratio_range=axis_ratio_range,
                                     seed=seed
                                     )
    for s in sources:
        add_gaussian2d_to_image(s, image)
    return image

In [None]:
image = generate_test_image(image_size=2048, n_sources=64, seed=seed)

In [None]:
import stp_python
# ??stp_python.source_find_wrapper

In [None]:
import gc


def time_ceres_sourcefinding(image):
    

    ceres_diffmethod = stp_python.CeresDiffMethod.AnalyticDiff
    ceres_solvertype = stp_python.CeresSolverType.TrustRegion_DenseQR
    median_method = stp_python.MedianMethod.ZEROMEDIAN
    
    pars = dict(image_data=np.asfortranarray(image),
                detection_n_sigma=detection_n_sigma,
                analysis_n_sigma=analysis_n_sigma,
                rms_est=rms_est,
                find_negative_sources=False,
                sigma_clip_iters=0,
                median_method=median_method,
                gaussian_fitting=False,
                generate_labelmap=False,
                ceres_diffmethod=ceres_diffmethod,
                ceres_solvertype=ceres_solvertype)
    
    start1 = datetime.datetime.now()
    stp_python.source_find_wrapper(**pars)
    end1 = datetime.datetime.now()

    # Force a garbage collection to try and cool down
    # the cache lines between benchmark runs
    gc.collect()
    
    pars['gaussian_fitting'] = True
    start2 = datetime.datetime.now()
    stp_python.source_find_wrapper(**pars)
    end2=datetime.datetime.now()

    extraction_duration=(end1 - start1).total_seconds()
    fitting_duration=(end2 - start2).total_seconds() - extraction_duration

    return extraction_duration, fitting_duration



In [None]:
time_ceres_sourcefinding(image)

In [None]:
import sep
def time_sep_extraction(image):
    start=datetime.datetime.now()
    objects = sep.extract(image, 
                          thresh=detection_n_sigma, err=rms_est,
                          clean=False,
                          filter_kernel=None,
                          deblend_cont=1.,
                          minarea=1,
                         )
    end = datetime.datetime.now()
    extraction_duration = (end - start).total_seconds()
    return extraction_duration
    

In [None]:
time_sep_extraction(image)

In [None]:
import sep
def time_numpy_fftshift(image):
    start=datetime.datetime.now()
    x = np.copy(np.fft.ifftshift(image))
    end = datetime.datetime.now()
    duration = (end - start).total_seconds()
    return duration

In [None]:
2**15

In [None]:
# %%time
small_sizes =  (2**9, 2**10, 2**11, 2**12)
large_sizes =  small_sizes+(2**13, int(2**13.5), 2**14)
v_large_sizes =  large_sizes + (int(2**14.5), 2**15)
# images_w_64_sources = ( generate_test_image(image_size=sz, n_sources=64)
#                       for sz in (512, 1024, 2048, 4096, 8192)]
def image_set(image_sizes):
    return (generate_test_image(sz, n_sources=n_sources) for sz in image_sizes)

In [None]:
# Megabytes per copy of largest image size:
mbyte = 2**20
list(image_set(v_large_sizes))[-1].nbytes / mbyte

In [None]:
time_ceres_sourcefinding(generate_test_image(v_large_sizes[-1], n_sources=n_sources))

In [None]:
time_sep_extraction(generate_test_image(large_sizes[-1], n_sources=n_sources))

In [None]:
time_numpy_fftshift(time_sep_extraction(generate_test_image(large_sizes[-1], n_sources=n_sources)))

In [None]:
def run_benchmark(bench, images):
    d = {}
    for img in images:
        d[len(img)] = bench(img)
    return d    

In [None]:
ceres_results = run_benchmark(time_ceres_sourcefinding, image_set(v_large_sizes))
ceres_results

In [None]:
sep_results = run_benchmark(time_sep_extraction, image_set(large_sizes))
sep_results

In [None]:
import collections


def results_to_dataframe(results_dict):
    df = pd.DataFrame.from_dict(data=results_dict, orient='index')
    df.sort_index(inplace=True)
    df.index.name = 'Image size [pix/side]'
    v0 = results_dict.values()[0]
    if isinstance(v0, collections.Iterable):
        df.columns = ['extraction', 'fitting']
    else:
        df.columns = ['extraction']
    return df

In [None]:
ticklabels = np.round(np.log2(v_large_sizes),decimals=1)
simple_ticklabels = []
for tl in ticklabels:
    simple_ticklabels.append(int(tl) if tl==int(tl) else tl)
simple_ticklabels = ["$2^{{{}}}$".format(power) for power in simple_ticklabels]
simple_ticklabels

In [None]:
ceres_df = results_to_dataframe(ceres_results)
sep_df = results_to_dataframe(sep_results)

ceres_df['total'] = ceres_df.extraction + ceres_df.fitting

ax = plt.gca()

ceres_df.plot(y='extraction',ax=ax, label='Ceres (ex.)')
ceres_df.plot(y='total',ax=ax, label='Ceres (ex.+fit)')
# ceres.plot(y='fitting',ax=ax, label='ceres_fit')
sep_df.plot(y='extraction',ax=ax, label='SEP (ex.)')
ax.set_ylabel('Time [s]')
for sz in v_large_sizes:
    ax.axvline(sz, ls=':', alpha=0.3)

xax = ax.get_xaxis()
xax.set_ticks(v_large_sizes[1:])
# xax.set_ticklabels(large_sizes, rotation=45)
xax.set_ticklabels(simple_ticklabels[1:], rotation=30)

plt.savefig('extraction_64_sources_cpp.pdf')