Skip to content

Commit

Permalink
Add lost function check_estimated_shape (#101)
Browse files Browse the repository at this point in the history
* Add lost function check_estimated_shape

This function transforms 1D output of shape (n,) to a 2D array of shape
(n, 1).

* Add tests for check_estimated_shape
  • Loading branch information
Joanna Jędrzejewska-Szmek committed May 9, 2019
1 parent 8cf977a commit ee322b9
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
9 changes: 9 additions & 0 deletions kcsd/sKCSD_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,15 @@ def load_elpos(path):
raise Exception('Unknown electrode position file format.')
return ele_pos


def check_estimated_shape(to_estimate):
if len(to_estimate.shape) == 1:
estimated = np.ndarray((to_estimate.shape[0], 1))
estimated[:, 0] = to_estimate
return estimated
return to_estimate


def _bresenhamline_nslope(slope):
"""
Normalize slope for Bresenham's line algorithm.
Expand Down
20 changes: 20 additions & 0 deletions kcsd/tests/test_sKCSDutils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import print_function, division, absolute_import
from kcsd.sKCSD_utils import check_estimated_shape
import os
import unittest
import numpy as np

class testCheckEstimatedShape(unittest.TestCase):
def test_unchanged(self):
array = np.ones((1, 5))
out = check_estimated_shape(array)
self.assertEqual(array.shape, out.shape)

def test_changed(self):
array = np.ones((5, ))
out = check_estimated_shape(array)
self.assertEqual(out.shape, (5, 1))


if __name__ == '__main__':
unittest.main()

0 comments on commit ee322b9

Please sign in to comment.