In [None]:
import glob
import pickle as pkl
import h5py

import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt

from sparse_coder import SparseCoder
# from sparse_coder.prep_field_dataset import get_data_matrix

from context import utils
from utils.rf_plot import show_fields
plt.rcParams['image.cmap'] = 'gray'

In [None]:
files = glob.glob('output/vh_sparse_coder1_alpha_*_overcomp_*.pkl')
files = sorted(files)

files

In [None]:
path = files[-1]
path

In [None]:
with open(path, 'r') as f:
    out = pkl.load(f)
    D = out['D']
    print out['alpha']

In [None]:
idx = np.random.randint(D.shape[0], size=512)

fig, ax = plt.subplots(1, 1, figsize=(22, 16))
show_fields(D, fig=fig, ax=ax, normed=True)
ax.set_title('Dictionary with alpha={:.2f} n_sp = {}'.format(out['alpha'], out['n_sp']))
# plt.savefig('output/dict.pdf', dpi=300)

In [None]:
with h5py.File('data/final/new_extracted_patches1.h5') as f:
    data = f['white_patches'][0:1000]
    l_patch = f['l_patch'].value
#     data = data.reshape(1000, -1)

In [None]:
sc = SparseCoder.restore(data.reshape((-1, l_patch ** 2)), path, n_bat=1000)

In [None]:
cost_list = []
i_idx = sc.train(n_itr=1, eta=0, cost_list=cost_list, n_g_itr=200)

In [None]:
sc.plot_example(1, i_idx, 32)

In [None]:
D = sc.tc.get_dictionary()
A = sc.tc.get_coefficients()
Ih = np.dot(A, D)
I = sc.tc.t_DATA.get_value()[i_idx]

In [None]:
snrs = (I ** 2).sum(axis=1) / ((I - Ih) ** 2).sum(axis=1)

In [None]:
plt.hist(snrs, bins=50)
plt.yscale('log')

In [None]:
orig_idx = i_idx[np.where(snrs > 10)]

In [None]:
with h5py.File('sparse_coder/data/final/new_extracted_patches.h5') as f:
    data1 = np.zeros((len(orig_idx), 32, 32))
    for i in range(len(orig_idx)):
        data1[i] = f['white_patches'][orig_idx[i]]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(12, 10))
show_fields(data1.reshape(-1, 32 * 32), fig=fig, ax=ax)

In [None]:
img_idx = np.array([434, 569, 752, 114,  78,   2, 263, 720, 986,  86, 115, 307,
       385, 767, 959, 692, 399,  92, 886, 488, 100, 606, 209, 148, 646,
       600, 662, 533, 618, 860, 427, 115, 798, 826,  48, 724, 116, 569,
       307, 302, 232, 469, 688, 624, 134, 852, 665,  74, 876, 790,  60,
       246, 405, 549, 123, 938, 227, 829, 888, 438, 353, 992, 158, 685,
       843,  58, 288, 914, 289, 687, 246, 392, 443, 748,  66, 652, 328,
        47,  77, 375, 617, 468, 339, 429, 778, 141, 326, 240, 780, 400,
       951, 212,   4, 185, 671, 127, 305, 324], dtype='int32')

In [None]:
plt.imshow(data[img_idx[7]].reshape(32, 32))

In [None]:
def normalize(x, smin, smax):
    xmin, xmax = [getattr(x, func)(axis=(1, 2), keepdims=True) for func in ['min', 'max']]
    u = (x - xmin) / (xmax-xmin)
    return u * (smax - smin) + smin

In [None]:
data[img_idx].shape

In [None]:
normalize(
        data[img_idx], 
        -0.5, 0.5)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(16, 14))
show_fields(
    normalize(data[img_idx], -0.5, 0.5).reshape(-1, l_patch ** 2), fig=fig, ax=ax)


In [None]:
pcts = np.percentile(abs(data.ravel()), 99)

In [None]:
idx_ = np.where(abs(data) > 1)

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 4))
q = 400
cax = axes[0].imshow(data[q].reshape(20, 20), cmap=plt.cm.gray, interpolation='nearest')
fig.colorbar(cax, ax=axes[0])
cax = axes[1].imshow(np.clip(data[q], -1, 1).reshape(20, 20), 
                     cmap=plt.cm.gray, interpolation='nearest', vmin=-1, vmax=1)
fig.colorbar(cax, ax=axes[1])


In [None]:
_ = plt.hist(abs(data.ravel()), bins=50)
plt.yscale('log')

In [None]:
def normalize(x, smin, smax):
    xmin, xmax = [getattr(x, func)(axis=(1, 2), keepdims=True) for func in ['min', 'max']]
    u = (x - xmin) / (xmax-xmin)
    return u * (smax - smin) + smin

In [None]:
data_normed = normalize(data, -0.5, 0.5)

In [None]:
_ = plt.hist(data_normed.ravel(), bins=50)

In [None]:
np.random.randint(0, 100, size=10)

In [None]:
plt.hist(data_normed.std(axis=1))

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 4))
q = 10
cax = axes[0].imshow(data[q].reshape(20, 20), cmap=plt.cm.gray, interpolation='nearest')
fig.colorbar(cax, ax=axes[0])
cax = axes[1].imshow(np.clip(data_normed[q], -1, 1).reshape(20, 20), 
                     cmap=plt.cm.gray, interpolation='nearest', vmin=-0.5, vmax=0.5)
fig.colorbar(cax, ax=axes[1])


In [None]:
plt.hist(data.mean(axis=1), bins=50)