Skip to content

Commit

Permalink
Merge pull request #79 from aburrell/memory_leak
Browse files Browse the repository at this point in the history
Fixed memory leak
  • Loading branch information
aburrell committed Jan 24, 2023
2 parents 0057dbe + 50a1bfb commit e182e2e
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 57 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Changelog
* Updated unit tests to current pytest standards
* Updated links in the documentation
* Improved the documentation style and added docstring tests
* Fixed memory leak in the array C wrappers

2.6.2 (2020-01-13)
------------------
Expand Down
32 changes: 19 additions & 13 deletions aacgmv2/aacgmv2module.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,13 @@ static PyObject *aacgm_v2_setdatetime(PyObject *self, PyObject *args)

static PyObject *aacgm_v2_convert_arr(PyObject *self, PyObject *args)
{
int i, code, err;
int code, err;

long int in_num;
long int i, in_num;

double in_lat, in_lon, in_h, out_lat, out_lon, out_r;

PyObject *latIn, *lonIn, *hIn, *latOut, *lonOut, *rOut, *badOut, *allOut;
PyObject *badInt, *badFloat;

/* Parse the input as a tuple */
if(!PyArg_ParseTuple(args, "O!O!O!i", &PyList_Type, &latIn, &PyList_Type,
Expand All @@ -73,31 +72,32 @@ static PyObject *aacgm_v2_convert_arr(PyObject *self, PyObject *args)
lonOut = PyList_New(in_num);
rOut = PyList_New(in_num);
badOut = PyList_New(in_num);
badInt = PyLong_FromLong((int long)(-1));
badFloat = PyFloat_FromDouble(-666.0);

/* Cycle through all of the inputs */
for(i=0; i<in_num; i++)
{
/* Read in the input */
/* Read in the input and convert to doubles. GetItem are BORROWED */
in_lat = PyFloat_AsDouble(PyList_GetItem(latIn, i));
in_lon = PyFloat_AsDouble(PyList_GetItem(lonIn, i));
in_h = PyFloat_AsDouble(PyList_GetItem(hIn, i));

/* Call the AACGM routine */
err = AACGM_v2_Convert(in_lat, in_lon, in_h, &out_lat, &out_lon,
&out_r, code);

/* Set the output */
if(err < 0)
{
/* Python 3.7+ raises a SystemError when passing on inf */
/* Python 3.7+ raises a SystemError when passing on inf. */
/* SetItem STEALS the references that are added. */
PyList_SetItem(badOut, i, PyLong_FromLong((int long)i));
PyList_SetItem(latOut, i, badFloat);
PyList_SetItem(lonOut, i, badFloat);
PyList_SetItem(rOut, i, badFloat);
PyList_SetItem(latOut, i, PyFloat_FromDouble(-666.0));
PyList_SetItem(lonOut, i, PyFloat_FromDouble(-666.0));
PyList_SetItem(rOut, i, PyFloat_FromDouble(-666.0));
}
else
{
PyList_SetItem(badOut, i, badInt);
PyList_SetItem(badOut, i, PyLong_FromLong((int long)(-1)));
PyList_SetItem(latOut, i, PyFloat_FromDouble(out_lat));
PyList_SetItem(lonOut, i, PyFloat_FromDouble(out_lon));
PyList_SetItem(rOut, i, PyFloat_FromDouble(out_r));
Expand All @@ -106,7 +106,13 @@ static PyObject *aacgm_v2_convert_arr(PyObject *self, PyObject *args)

/* Set the output tuple */
allOut = PyTuple_Pack(4, latOut, lonOut, rOut, badOut);


/* Free memory for the local-only (not input) variables */
Py_DECREF(latOut);
Py_DECREF(lonOut);
Py_DECREF(rOut);
Py_DECREF(badOut);

return allOut;
}

Expand Down
43 changes: 35 additions & 8 deletions aacgmv2/tests/test_c_aacgmv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def setup_method(self):
self.mlat = None
self.mlon = None
self.rshell = None
self.bad_ind = None
self.mlt = None
self.lat_in = [45.5, 60]
self.lon_in = [-23.5, 0]
Expand All @@ -33,7 +34,7 @@ def teardown_method(self):
"""Run after every method to clean up previous testing."""
del self.date_args, self.long_date, self.mlat, self.mlon, self.mlt
del self.lat_in, self.lon_in, self.alt_in, self.lat_comp, self.lon_comp
del self.r_comp, self.code
del self.r_comp, self.code, self.bad_ind

@pytest.mark.parametrize('mattr,val', [(aacgmv2._aacgmv2.G2A, 0),
(aacgmv2._aacgmv2.A2G, 1),
Expand Down Expand Up @@ -70,9 +71,12 @@ def test_set_datetime(self, idate):
def test_fail_set_datetime(self):
"""Test unsuccessful set_datetime."""
self.long_date[0] = 1013
with pytest.raises(RuntimeError):
with pytest.raises(RuntimeError) as rerr:
aacgmv2._aacgmv2.set_datetime(*self.long_date)

if str(rerr).find("AACGM_v2_SetDateTime returned error code -1") < 0:
raise AssertionError('unknown error message: {:}'.format(str(rerr)))

@pytest.mark.parametrize('idate,ckey', [(0, 'G2A'), (1, 'G2A'),
(0, 'A2G'), (1, 'A2G'),
(0, 'TG2A'), (1, 'TG2A'),
Expand Down Expand Up @@ -113,9 +117,9 @@ def test_convert_arr(self, ckey):
"""
aacgmv2._aacgmv2.set_datetime(*self.date_args[0])
(self.mlat, self.mlon, self.rshell,
bad_ind) = aacgmv2._aacgmv2.convert_arr(self.lat_in, self.lon_in,
self.alt_in,
self.code[ckey])
self.bad_ind) = aacgmv2._aacgmv2.convert_arr(self.lat_in, self.lon_in,
self.alt_in,
self.code[ckey])

np.testing.assert_equal(len(self.mlat), len(self.lat_in))
np.testing.assert_almost_equal(self.mlat[0], self.lat_comp[ckey][0],
Expand All @@ -124,22 +128,28 @@ def test_convert_arr(self, ckey):
decimal=4)
np.testing.assert_almost_equal(self.rshell[0], self.r_comp[ckey][0],
decimal=4)
np.testing.assert_equal(bad_ind[0], -1)
np.testing.assert_equal(self.bad_ind[0], -1)

def test_forbidden(self):
"""Test convert failure."""
self.lat_in[0] = 7
with pytest.raises(RuntimeError):
with pytest.raises(RuntimeError) as rerr:
aacgmv2._aacgmv2.convert(self.lat_in[0], self.lon_in[0], 0,
aacgmv2._aacgmv2.G2A)

if str(rerr).find("AACGM_v2_Convert returned error code -1") < 0:
raise AssertionError('unknown error message: {:}'.format(str(rerr)))

def test_convert_high_denied(self):
"""Test for failure when converting to high alt geod to mag coords."""
aacgmv2._aacgmv2.set_datetime(*self.date_args[0])
with pytest.raises(RuntimeError):
with pytest.raises(RuntimeError) as rerr:
aacgmv2._aacgmv2.convert(self.lat_in[0], self.lon_in[0], 5500,
aacgmv2._aacgmv2.G2A)

if str(rerr).find("AACGM_v2_Convert returned error code -4") < 0:
raise AssertionError('unknown error message: {:}'.format(str(rerr)))

@pytest.mark.parametrize('code,lat_comp,lon_comp,r_comp',
[(aacgmv2._aacgmv2.G2A + aacgmv2._aacgmv2.TRACE,
59.9753, 57.7294, 1.8626),
Expand Down Expand Up @@ -229,6 +239,15 @@ def test_inv_mlt_convert(self, marg, mlt_comp):
self.mlon = aacgmv2._aacgmv2.inv_mlt_convert(*self.long_date)
np.testing.assert_almost_equal(self.mlon, mlt_comp, decimal=4)

def test_inv_mlt_convert_arr(self):
"""Test array MLT inversion."""
self.date_args = [[ldate for j in range(3)] for ldate in self.long_date]
self.mlt = [12.0, 25.0, -1.0]
self.lon_in = [-153.6033, 41.3967, 11.3967]
self.mlon = aacgmv2._aacgmv2.inv_mlt_convert_arr(*self.date_args,
self.mlt)
np.testing.assert_almost_equal(self.mlon, self.lon_in, decimal=4)

@pytest.mark.parametrize('marg,mlt_comp',
[(12.0, -153.6033), (25.0, 41.3967),
(-1.0, 11.3967)])
Expand Down Expand Up @@ -273,6 +292,14 @@ def test_mlt_convert(self, marg, mlt_comp):
self.mlt = aacgmv2._aacgmv2.mlt_convert(*mlt_args)
np.testing.assert_almost_equal(self.mlt, mlt_comp, decimal=4)

def test_mlt_convert_arr(self):
"""Test array MLT conversion."""
self.date_args = [[ldate for j in range(3)] for ldate in self.long_date]
self.mlon = [-153.6033, 41.3967, 11.3967]
self.lon_in = [12.0, 1.0, 23.0]
self.mlt = aacgmv2._aacgmv2.mlt_convert_arr(*self.date_args, self.mlon)
np.testing.assert_almost_equal(self.mlt, self.lon_in, decimal=4)

@pytest.mark.parametrize('marg,mlt_comp',
[(270.0, 16.2402), (80.0, 3.5736),
(-90.0, 16.2402)])
Expand Down
106 changes: 80 additions & 26 deletions aacgmv2/tests/test_py_aacgmv2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import datetime as dt
from io import StringIO
import logging
import numpy as np
import os
Expand Down Expand Up @@ -153,6 +152,23 @@ def test_convert_latlon_failure(self, in_rep, in_irep, msg):
class TestConvertLatLonArr(TestConvertArray):
"""Unit tests for Lat/Lon array conversion."""

def test_convert_latlon_arr_large(self):
"""Test array latlon conversion for a large array."""
# Update input
self.lat_in = np.full(shape=(6000,), fill_value=self.lat_in[0])
self.lon_in = np.full(shape=(6000,), fill_value=self.lon_in[0])

# Update expected output
self.ref[0] = np.full(shape=(6000,), fill_value=self.ref[0][0])
self.ref[1] = np.full(shape=(6000,), fill_value=self.ref[1][0])
self.ref[2] = np.full(shape=(6000,), fill_value=1.0457)

# Run the conversion and evaluate
self.out = aacgmv2.convert_latlon_arr(self.lat_in, self.lon_in,
self.alt_in[0], self.dtime,
self.method)
self.evaluate_output()

def test_convert_latlon_arr_single_val(self):
"""Test array latlon conversion for a single value."""
self.out = aacgmv2.convert_latlon_arr(self.lat_in[0], self.lon_in[0],
Expand Down Expand Up @@ -388,6 +404,24 @@ def test_get_aacgm_coord_raise_value_error(self, in_index, value):
class TestGetAACGMCoordArr(TestConvertArray):
"""Unit tests for AACGM coordinate array conversion."""

def test_get_aacgm_coord_arr_large(self):
"""Test array conversion for a large array."""
# Update input
self.lat_in = np.full(shape=(6000,), fill_value=self.lat_in[0])
self.lon_in = np.full(shape=(6000,), fill_value=self.lon_in[0])
self.alt_in = np.full(shape=(6000,), fill_value=self.alt_in[0])

# Update expected output
self.ref[0] = np.full(shape=(6000,), fill_value=self.ref[0][0])
self.ref[1] = np.full(shape=(6000,), fill_value=self.ref[1][0])
self.ref[2] = np.full(shape=(6000,), fill_value=self.ref[2][0])

# Run the conversion and evaluate
self.out = aacgmv2.get_aacgm_coord_arr(self.lat_in, self.lon_in,
self.alt_in, self.dtime,
self.method)
self.evaluate_output()

def test_get_aacgm_coord_arr_single_val(self):
"""Test array AACGMV2 calculation for a single value."""
self.out = aacgmv2.get_aacgm_coord_arr(self.lat_in[0], self.lon_in[0],
Expand Down Expand Up @@ -818,50 +852,70 @@ def setup_method(self):
"""Create a clean test environment."""
self.lwarn = ""
self.lout = ""
self.log_capture = StringIO()
aacgmv2.logger.addHandler(logging.StreamHandler(self.log_capture))
aacgmv2.logger.setLevel(logging.INFO)
self.log_name = "aacgmv2_logger"

def teardown_method(self):
"""Clean up the test envrionment."""
self.log_capture.close()
del self.lwarn, self.lout, self.log_capture
del self.lwarn, self.lout, self.log_name

def test_warning_below_ground(self):
def eval_logger_message(self):
"""Evaluate the logger message."""
if self.lout.find(self.lwarn) < 0:
raise AssertionError(
"unknown logger message: {:} not in {:}".format(self.lwarn,
self.lout))

def test_warning_below_ground(self, caplog):
"""Test that a warning is issued if height < 0 for height test."""
self.lwarn = "conversion not intended for altitudes < 0 km"

aacgmv2.wrapper.test_height(-1, 0)
self.lout = self.log_capture.getvalue()
if self.lout.find(self.lwarn) < 0:
raise AssertionError()
with caplog.at_level(logging.WARNING, logger=self.log_name):
aacgmv2.wrapper.test_height(-1, 0)

self.lout = caplog.text
self.eval_logger_message()

def test_warning_magnetosphere(self):
def test_warning_magnetosphere(self, caplog):
"""Test that a warning is issued if altitude is very high."""
self.lwarn = "coordinates are not intended for the magnetosphere"

aacgmv2.wrapper.test_height(70000, aacgmv2._aacgmv2.TRACE)
self.lout = self.log_capture.getvalue()
if self.lout.find(self.lwarn) < 0:
raise AssertionError()
with caplog.at_level(logging.ERROR, logger=self.log_name):
aacgmv2.wrapper.test_height(70000, aacgmv2._aacgmv2.TRACE)

self.lout = caplog.text
self.eval_logger_message()

def test_warning_high_coeff(self):
def test_warning_high_coeff(self, caplog):
"""Test that a warning is issued if altitude is very high."""
self.lwarn = "must either use field-line tracing (trace=True"

aacgmv2.wrapper.test_height(3000, 0)
self.lout = self.log_capture.getvalue()
if self.lout.find(self.lwarn) < 0:
raise AssertionError()
with caplog.at_level(logging.ERROR, logger=self.log_name):
aacgmv2.wrapper.test_height(3000, 0)

def test_warning_single_loc_in_arr(self):
self.lout = caplog.text
self.eval_logger_message()

def test_warning_single_loc_in_arr(self, caplog):
"""Test that user is warned they should be using simpler routine."""
self.lwarn = "for a single location, consider using"

aacgmv2.convert_latlon_arr(60, 0, 300, dt.datetime(2015, 1, 1, 0, 0, 0))
self.lout = self.log_capture.getvalue()
if self.lout.find(self.lwarn) < 0:
raise AssertionError()
with caplog.at_level(logging.INFO, logger=self.log_name):
aacgmv2.convert_latlon_arr(60, 0, 300,
dt.datetime(2015, 1, 1, 0, 0, 0))

self.lout = caplog.text
self.eval_logger_message()

def test_warning_equator(self, caplog):
"""Test that user is warned about undefined coordinates."""
self.lwarn = "unable to perform conversion at"

with caplog.at_level(logging.WARNING, logger=self.log_name):
aacgmv2.convert_latlon(10.0, 10.0, 300,
dt.datetime(2015, 1, 1, 0, 0, 0))

self.lout = caplog.text
self.eval_logger_message()


class TestTimeReturns(object):
Expand Down
11 changes: 5 additions & 6 deletions aacgmv2/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import datetime as dt
import numpy as np
import os
import sys

import aacgmv2
import aacgmv2._aacgmv2 as c_aacgmv2
Expand Down Expand Up @@ -240,11 +239,11 @@ def convert_latlon(in_lat, in_lon, height, dtime, method_code="G2A"):
try:
lat_out, lon_out, r_out = c_aacgmv2.convert(in_lat, in_lon, height,
bit_code)
except Exception:
err = sys.exc_info()[0]
estr = "unable to perform conversion at {:.1f},".format(in_lat)
estr = "{:s}{:.1f} {:.1f} km, {:} ".format(estr, in_lon, height, dtime)
estr = "{:s}using method {:}: {:}".format(estr, bit_code, err)
except Exception as err:
estr = "".join(["unable to perform conversion at {:.1f}".format(in_lat),
", {:.1f} {:.1f} km, {:}".format(in_lon, height, dtime),
" using method {:} <{:}>. Recall".format(bit_code, err),
" that AACGMV2 is undefined near the equator."])
aacgmv2.logger.warning(estr)
pass

Expand Down

0 comments on commit e182e2e

Please sign in to comment.