In [1]:
import numpy as np
from numpy.random import *
from scipy import ceil, complex64, float64, hamming, zeros
from matplotlib import pylab as plt
import xml.etree.ElementTree as et
import argparse
import time
from cnimf import CNIMF
from matplotlib import pylab as plt
%matplotlib inline

In [2]:
parser = argparse.ArgumentParser(description='structure detection test')
parser.add_argument('-s', '--seed_number', \
                    action='store', \
                    nargs='?', \
                    const=None, \
                    default=0, \
                    type=int, \
                    choices=None, \
                    help='seed_number', \
                    metavar=None)
parser.add_argument('-d', '--dat_dir', \
                    action='store', \
                    nargs='?', \
                    const=None, \
                    default='../dat/structure_detection', \
                    type=str, \
                    choices=None, \
                    help='Directory where npy files will be stored.', \
                    metavar=None)
parser.add_argument('-n', '--dat_npz', \
                    action='store', \
                    nargs='?', \
                    const=None, \
                    default='structure_result_3_2_'+time.strftime('%Y%m%d%H%M')+'.npz', \
                    type=str, \
                    choices=None, \
                    help='npz file name', \
                    metavar=None)


_StoreAction(option_strings=['-n', '--dat_npz'], dest='dat_npz', nargs='?', const=None, default='structure_result_3_2_201701221536.npz', type=<class 'str'>, choices=None, help='npz file name', metavar=None)

In [3]:
args = parser.parse_args([])

In [4]:
seed(args.seed_number)


In [5]:
n_tests = 50
n_criteria = 5
n_samples_list = [100, 200, 300, 400]
missing_rate_list = [0.2, 0.6]
data_dim = 12
convolution_width = 3
true_n_components_list = [2]
gamma_shape = 2.0
gamma_rate = 2.0
base_max = 10.0
best_n_bases = np.zeros([len(true_n_components_list), len(n_samples_list), n_tests, n_criteria])
completion_error = np.zeros([len(true_n_components_list), len(missing_rate_list), len(n_samples_list), n_tests, n_criteria])
accuracy = np.zeros([len(true_n_components_list), len(n_samples_list), n_criteria])

In [6]:
%pdb
import warnings
warnings.filterwarnings('error')
for i_n_components in range(len(true_n_components_list)):
    true_n_components = true_n_components_list[i_n_components]
    for i_n_samples in range(len(n_samples_list)):
        n_samples = n_samples_list[i_n_samples]
        print('n_samples', n_samples)
        for i_test in range(n_tests):
            print('i_test', i_test)
            true_activation = np.random.gamma(gamma_shape, 1.0 / gamma_rate, [n_samples, true_n_components])
            true_base = np.random.uniform(0.0, base_max, [convolution_width, true_n_components, data_dim])
            X = np.random.poisson(CNIMF.convolute(true_activation, true_base))
            arg_dict = dict(
                convolution_max = convolution_width,
                component_max = X.shape[1],
                true_width = convolution_width,
                gamma_shape = gamma_shape,
                gamma_rate = gamma_rate,
                base_max = base_max,
                convergence_threshold = 0.00001,
                loop_max = 1000)
            factorizer = CNIMF(**arg_dict)
            filtre = np.ones(X.shape)
            factorizer.fit(X, None, filtre)
            for i_criterion in range(n_criteria):
                best_n_bases[i_n_components, i_n_samples, i_test, i_criterion] = factorizer.criteria[i_criterion].best_structure[1]
            for i_missing_rate in range(len(missing_rate_list)):
                missing_rate = missing_rate_list[i_missing_rate]
                print(missing_rate)
                factorizer = CNIMF(**arg_dict)
                filtre = np.random.binomial(1, missing_rate, X.shape)
                factorizer.fit(X, None, filtre)
                for i_criterion in range(n_criteria):
                    completion_error[i_n_components, i_missing_rate, i_n_samples, i_test, i_criterion] = factorizer.criteria[i_criterion].completion_error/np.prod(X.shape)
                print(completion_error[i_n_components, i_missing_rate, i_n_samples, i_test, :])
    accuracy[i_n_components, :, :] = np.mean(best_n_bases[i_n_components, :, :, :]==true_n_components, axis = 1)

Automatic pdb calling has been turned ON
n_samples 100
i_test 0
0.2
[ 1.84991348  1.61879796  1.61879796  1.61879796  1.61879796]
0.6
[ 0.61878732  0.53122981  0.53122981  0.53122981  0.53122981]
i_test 1
0.2
[ 2.03992397  1.71032427  1.71032427  1.71032427  1.71032427]
0.6
[ 0.64423282  0.60486615  0.60486615  0.60486615  0.60486615]
i_test 2
0.2
[ 2.20511493  1.46298523  1.46298523  1.46298523  1.46298523]
0.6
[ 0.6410518   0.62218278  0.62218278  0.62218278  0.65029544]
i_test 3
0.2
[ 1.9221473   1.48609773  1.48609773  1.48609773  1.48609773]
0.6
[ 0.65001073  0.59141964  0.65893107  0.59141964  0.70572694]
i_test 4
0.2
[ 2.78011075  2.04175939  2.04175939  2.04175939  2.04175939]
0.6
[ 0.95453926  0.55783574  0.61459661  0.55783574  0.55783574]
i_test 5
0.2
[ 1.94421972  1.51590089  1.51590089  1.51590089  1.51590089]
0.6
[ 0.72821128  0.54248716  0.54248716  0.54248716  0.64082475]
i_test 6
0.2
[ 1.9606439   1.72590681  1.72590681  1.72590681  1.72590681]
0.6
[ 0.63980671  0.6202

In [7]:
np.savez(args.dat_dir + '/' + args.dat_npz,
         accuracy=accuracy,
         completion_error=completion_error,
         best_n_bases=best_n_bases,
         n_samples_list = n_samples_list,
         missing_rate_list = missing_rate_list, 
         true_n_components_list = true_n_components_list,
         convolution_width = convolution_width)

In [8]:
# npz = np.load(args.dat_dir + '/artificial_result_201701210117.npz')

In [9]:
accuracy

array([[[ 0.  ,  0.72,  0.52,  0.72,  0.98],
        [ 0.  ,  0.  ,  0.12,  0.68,  0.84],
        [ 0.  ,  0.  ,  0.02,  0.78,  0.86],
        [ 0.  ,  0.  ,  0.  ,  0.4 ,  0.72]]])

In [10]:
completion_error

array([[[[[ 1.84991348,  1.61879796,  1.61879796,  1.61879796,  1.61879796],
          [ 2.03992397,  1.71032427,  1.71032427,  1.71032427,  1.71032427],
          [ 2.20511493,  1.46298523,  1.46298523,  1.46298523,  1.46298523],
          ..., 
          [ 1.49854175,  1.25669695,  1.25669695,  1.25669695,  1.25669695],
          [ 1.64325071,  1.52595288,  1.52595288,  1.52595288,  1.52595288],
          [ 1.52899464,  1.25937433,  1.25937433,  1.25937433,  1.25937433]],

         [[ 1.61242577,  1.35933235,  1.35933235,  1.35933235,  1.35933235],
          [ 1.92311448,  1.92311448,  1.92311448,  1.30815628,  1.30815628],
          [ 1.52759449,  1.52759449,  1.12129044,  1.12129044,  1.12129044],
          ..., 
          [ 1.68533219,  1.68533219,  1.68533219,  1.01119702,  1.01119702],
          [ 1.49803414,  1.49803414,  1.49803414,  1.06583992,  1.06583992],
          [ 1.45712921,  1.27013032,  1.27013032,  1.27013032,  1.27013032]],

         [[ 1.50000114,  1.5195288 ,  1.