Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix `simulation_export` function to handle DataFrame output correctly
- Fix `detailed_pandas_compiler` function to support new numpy versions
- Fix probability adjustements in `cpm.optimisation.minimise.LogLikelihood` method to ensure correct parameter estimates
- Fix NaN handling in `cpm.models.decision.Softmax`, and `cpm.models.decision.Sigmoid` due to infinities in the exponential function for out-of-bounds parameters

### Changed

Expand Down
2 changes: 1 addition & 1 deletion cpm/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.24.7"
__version__ = "0.24.8"
34 changes: 28 additions & 6 deletions cpm/models/decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class Softmax:
>>> softmax.compute()
array([0.30719589, 0.18632372, 0.50648039])
>>> softmax.choice() # This will randomly choose one of the actions based on the computed probabilities.
2
2
>>> Softmax(temperature=temperature, activations=activations).compute()
array([0.30719589, 0.18632372, 0.50648039])
"""
Expand All @@ -72,6 +72,15 @@ def __init__(self, temperature=None, xi=None, activations=None, **kwargs):
self.activations = activations.copy()
else:
self.activations = np.zeros(1)
## Throw error if activations contain missing values of infities
if np.isnan(self.activations).any():
raise ValueError(
"Activations contain NaN values. Please remove or impute missing values."
)
if np.isinf(self.activations).any():
raise ValueError(
"Activations contain infinite values. Please remove or impute infinite values."
)
self.policies = np.zeros(self.activations.shape[0])
self.shape = self.activations.shape
if len(self.shape) > 1:
Expand All @@ -81,7 +90,6 @@ def __init__(self, temperature=None, xi=None, activations=None, **kwargs):
"Activations should be a 1D array, but a 2D array was provided. "
"Flattening the activations to a 1D array."
)


self.__run__ = False

Expand All @@ -97,6 +105,14 @@ def compute(self):
np.exp(self.activations * self.temperature)
)
self.policies = output

if np.isnan(self.policies).any():
self.policies[np.isnan(self.policies)] = 1
self.policies /= self.policies.sum()
warnings.warn(
"NaN values found in policies. Replacing NaN values with 1 and normalising the policies to sum to 1."
)

self.__run__ = True
return self.policies

Expand Down Expand Up @@ -196,6 +212,7 @@ def __init__(self, temperature=None, activations=None, beta=0, **kwargs):
"Activations should be a 1D array, but a 2D array was provided. "
"Flattening the activations to a 1D array."
)

self.__run__ = False

def compute(self):
Expand All @@ -207,11 +224,16 @@ def compute(self):
output: ndarray
A 2D array of outputs computed using the sigmoid function.
"""
output = 1 / (
1
+ np.exp((self.activations - self.beta) * -self.temperature)
)
output = 1 / (1 + np.exp((self.activations - self.beta) * -self.temperature))
self.policies = output

if np.isnan(self.policies).any():
self.policies[np.isnan(self.policies)] = 1
self.policies /= self.policies.sum()
warnings.warn(
"NaN values found in policies. Replacing NaN values with 1 and normalising the policies to sum to 1."
)

self.__run__ = True
return output

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "cpm-toolbox"
version = "0.24.7"
version = "0.24.8"
description = "A fundamental scientific toolbox for computational psychiatry and psychology."
readme = "README.md"
authors = [
Expand Down
2 changes: 1 addition & 1 deletion test/applications/test_reinforcement_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_model_warnings(setup_data):
assert len(w) > 0, "Warnings not raised"
assert issubclass(w[-1].category, UserWarning), "Warning is not a UserWarning"
assert (
"NaN in policy with parameters: 0.5, 1000, \nand with policy: [nan 0.]\n"
"NaN values found in policies. Replacing NaN values with 1 and normalising the policies to sum to 1."
in str(w[-1].message)
), "Warning message mismatch"

Expand Down
22 changes: 20 additions & 2 deletions test/models/test_decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,28 @@ def test_softmax_choice():
activations = np.array([0.1, 0, 0.2])
softmax = Softmax(temperature=1, activations=activations)
choice = softmax.choice()
assert choice in [0, 1, 2], "The Softmax.choice output is not in the expected range."
assert choice in [
0,
1,
2,
], "The Softmax.choice output is not in the expected range."


def test_softmax_input_shape():
activations = np.array([[0.1, 0, 0.2], [0.3, 0.4, 0.5]])
softmax = Softmax(temperature=1, activations=activations)
assert len(softmax.shape) == 1, "The Softmax model should flatten 2D input arrays."


def test_softmax_nan_handling():
activations = np.array([0.1, np.nan, 0.2])
softmax = Softmax(temperature=1, activations=activations)
policies = softmax.compute()
assert not np.isnan(
policies
).any(), "The Softmax model should handle NaN values in activations."


def test_sigmoid():
expected = np.array([0.52497919, 0.5, 0.549834])
activations = np.array([0.1, 0, 0.2])
Expand All @@ -58,7 +72,11 @@ def test_sigmoid_choice():
activations = np.array([0.1, 0, 0.2])
sigmoid = Sigmoid(temperature=1, activations=activations)
choice = sigmoid.choice()
assert choice in [0, 1, 2], "The Sigmoid.choice output is not in the expected range."
assert choice in [
0,
1,
2,
], "The Sigmoid.choice output is not in the expected range."


def test_greedy_rule():
Expand Down