diff --git a/atintegrators/atelem.c b/atintegrators/atelem.c index 5822e1abb..6fde04348 100755 --- a/atintegrators/atelem.c +++ b/atintegrators/atelem.c @@ -193,16 +193,18 @@ static long atGetLong(const PyObject *element, const char *name) { const PyObject *attr = PyObject_GetAttrString((PyObject *)element, name); if (!attr) return 0L; + long l = PyLong_AsLong((PyObject *)attr); Py_DECREF(attr); - return PyLong_AsLong((PyObject *)attr); + return l; } static double atGetDouble(const PyObject *element, const char *name) { const PyObject *attr = PyObject_GetAttrString((PyObject *)element, name); if (!attr) return 0.0; + double d = PyFloat_AsDouble((PyObject *)attr); Py_DECREF(attr); - return PyFloat_AsDouble((PyObject *)attr); + return d; } static long atGetOptionalLong(const PyObject *element, const char *name, long default_value) diff --git a/docs/conf.py b/docs/conf.py index ca711bc22..a510e9ab4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -40,8 +40,9 @@ 'sphinx.ext.intersphinx', 'sphinx.ext.githubpages', 'sphinx.ext.viewcode', - 'myst_parser', + 'myst_nb', 'sphinx_copybutton', + 'sphinx_design', ] intersphinx_mapping = {'python': ('https://docs.python.org/3', None), @@ -56,7 +57,7 @@ # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ["README.rst", "**/*.so"] +exclude_patterns = ["README.rst", "**/*.so", "_build/*"] rst_prolog = """ .. role:: pycode(code) :language: python @@ -92,6 +93,8 @@ "deflist" ] myst_heading_anchors = 3 +nb_execution_mode = "auto" +nb_execution_allow_errors = True # -- Options for HTML output ------------------------------------------------- diff --git a/docs/p/index.rst b/docs/p/index.rst index 2b2a79e97..1e9efcaf3 100644 --- a/docs/p/index.rst +++ b/docs/p/index.rst @@ -26,6 +26,7 @@ Sub-packages howto/Installation howto/Primer + notebooks/variables .. toctree:: :maxdepth: 2 diff --git a/docs/p/notebooks/variables.ipynb b/docs/p/notebooks/variables.ipynb new file mode 100644 index 000000000..e9b6fb05a --- /dev/null +++ b/docs/p/notebooks/variables.ipynb @@ -0,0 +1,942 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "bc479f4a-a609-468f-a430-d71ad22b5cf2", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import at\n", + "import sys\n", + "from importlib.resources import files, as_file\n", + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d993e922-806d-42fc-a793-51e8c5e82995", + "metadata": {}, + "outputs": [], + "source": [ + "from at.future import VariableBase, ElementVariable, RefptsVariable, CustomVariable" + ] + }, + { + "cell_type": "markdown", + "id": "ba73f6c0-aa60-4ced-8158-dfddcade4bd9", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "# Variables\n", + "\n", + "Variables are **references** to any scalar quantity. Predefined classes are available\n", + "for accessing any scalar attribute of an element, or any item of an array attribute.\n", + "\n", + "Any other quantity may be accessed by either subclassing the {py:class}`~.variables.VariableBase`\n", + "abstract base class, or by using a {py:class}`~.variables.CustomVariable`.\n", + "\n", + "## {py:class}`~.lattice_variables.ElementVariable`\n", + "\n", + "An {py:class}`~.lattice_variables.ElementVariable` refers to a single attribute (or item of an array attribute) of one or several {py:class}`.Element` objects.\n", + "\n", + "We now create a variable pointing to the length of a QF1 magnet:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d42d5dc7-f40a-4d49-be7b-3fcf5e1e18ea", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "99fe5abf-a4d5-48c7-b99a-fb8fd8208699", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Quadrupole:\n", + "\tFamName : QF1\n", + "\tLength : 0.5\n", + "\tPassMethod : StrMPoleSymplectic4Pass\n", + "\tNumIntSteps : 10\n", + "\tMaxOrder : 1\n", + "\tPolynomA : [0. 0.]\n", + "\tPolynomB : [0. 2.1]\n", + "\tK : 2.1\n" + ] + } + ], + "source": [ + "qf1 = at.Quadrupole(\"QF1\", 0.5, 2.1)\n", + "print(qf1)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "e07cfef9-b821-4924-8032-112801723d53", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "lf1: ElementVariable(0.5, name='lf1')\n", + "0.5\n" + ] + } + ], + "source": [ + "lf1 = ElementVariable(qf1, \"Length\", name=\"lf1\")\n", + "print(f\"lf1: {lf1}\")\n", + "print(lf1.value)" + ] + }, + { + "cell_type": "markdown", + "id": "b1271329-08be-4655-8884-77d7eec67558", + "metadata": {}, + "source": [ + "and another variable pointing to the strength of the same magnet:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "db5c3831-467a-468b-aca7-648b45c90887", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "kf1: ElementVariable(2.1, name='kf1')\n", + "2.1\n" + ] + } + ], + "source": [ + "kf1 = ElementVariable(qf1, \"PolynomB\", index=1, name=\"kf1\")\n", + "print(\"kf1:\", kf1)\n", + "print(kf1.value)" + ] + }, + { + "cell_type": "markdown", + "id": "0a3bf50d-8a18-4372-9d38-d5143da1807c", + "metadata": {}, + "source": [ + "We can check which elements are concerned by the `kf1` variable. The element container is a set, so that no element may appear twice:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "5fbe7aa7-3264-4de8-a8d7-439b9b7ace2f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{Quadrupole('QF1', 0.5, 2.1)}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "kf1.elements" + ] + }, + { + "cell_type": "markdown", + "id": "5dc6ea71-8d0a-4b8e-b00f-bebd0b09a874", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "We can now change the strength of QF1 magnets and check again:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d74f4a63-bacf-4ee4-b42a-de75bfbea193", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Quadrupole:\n", + "\tFamName : QF1\n", + "\tLength : 0.5\n", + "\tPassMethod : StrMPoleSymplectic4Pass\n", + "\tNumIntSteps : 10\n", + "\tMaxOrder : 1\n", + "\tPolynomA : [0. 0.]\n", + "\tPolynomB : [0. 2.5]\n", + "\tK : 2.5\n" + ] + } + ], + "source": [ + "kf1.set(2.5)\n", + "print(qf1)" + ] + }, + { + "cell_type": "markdown", + "id": "baed19af-fd43-44d3-a5df-965eed43f0eb", + "metadata": {}, + "source": [ + "We can look at the history of `kf1` values" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "413145df-30e5-4601-b05a-df2b042028ff", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[2.1, 2.5]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "kf1.history" + ] + }, + { + "cell_type": "markdown", + "id": "c9b693f9-a04e-4cb6-bbd2-37de0733ccb8", + "metadata": {}, + "source": [ + "And revert to the initial or previous values:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "70a1bca4-af2d-49b9-8462-989d47ad1efe", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Quadrupole:\n", + "\tFamName : QF1\n", + "\tLength : 0.5\n", + "\tPassMethod : StrMPoleSymplectic4Pass\n", + "\tNumIntSteps : 10\n", + "\tMaxOrder : 1\n", + "\tPolynomA : [0. 0.]\n", + "\tPolynomB : [0. 2.1]\n", + "\tK : 2.1\n" + ] + } + ], + "source": [ + "kf1.set_previous()\n", + "print(qf1)" + ] + }, + { + "cell_type": "markdown", + "id": "c650be51-228a-4ee0-ac73-d741601f992b", + "metadata": {}, + "source": [ + "An {py:class}`~.lattice_variables.ElementVariable` is linked to Elements. It will apply wherever the element appears but it will not follow any copy of the element, neither shallow nor deep. So if we make a copy of QF1:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "314c398c-fbbc-43ce-abcc-76ba48cf3c03", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "qf1: 2.1\n", + "qf2: 2.1\n" + ] + } + ], + "source": [ + "qf2 = qf1.deepcopy()\n", + "print(f\"qf1: {qf1.PolynomB[1]}\")\n", + "print(f\"qf2: {qf2.PolynomB[1]}\")" + ] + }, + { + "cell_type": "markdown", + "id": "5ea8a7fc-3bb6-4d18-8656-a8f45755c932", + "metadata": {}, + "source": [ + "and modify the `kf1` variable:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "8458ab19-3be1-427c-a55d-0ccb029c9f26", + "metadata": {}, + "outputs": [], + "source": [ + "kf1.set(2.6)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "f4c1e653-eeaf-4fd5-be08-2a488bd7df9d", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "qf1: 2.6\n", + "qf2: 2.1\n" + ] + } + ], + "source": [ + "print(f\"qf1: {qf1.PolynomB[1]}\")\n", + "print(f\"qf2: {qf2.PolynomB[1]}\")" + ] + }, + { + "cell_type": "markdown", + "id": "e2c5b959-448a-425f-bf54-f33b8d189330", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "The copy of QF1 in is not affected.\n", + "\n", + "One can set upper and lower bounds on a variable. Trying to set a value out of the bounds will raise a {py:obj}`ValueError`. The default is (-{py:obj}`numpy.inf`, {py:obj}`numpy.inf`)." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "738f6c51-2968-4d77-8599-261e7998d52b", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "lfbound = ElementVariable(qf1, \"Length\", bounds=(0.45, 0.55))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "ba4e728e-fa3e-4d71-9dcb-658d240fd61c", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [ + "raises-exception" + ] + }, + "outputs": [ + { + "ename": "ValueError", + "evalue": "set value must be in (0.45, 0.55)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[14], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mlfbound\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0.2\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/dev/libraries/at/pyat/at/lattice/variables.py:202\u001b[0m, in \u001b[0;36mVariableBase.set\u001b[0;34m(self, value, ring)\u001b[0m\n\u001b[1;32m 194\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Set the variable value\u001b[39;00m\n\u001b[1;32m 195\u001b[0m \n\u001b[1;32m 196\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[38;5;124;03m may be necessary to set the variable.\u001b[39;00m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 201\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m value \u001b[38;5;241m<\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbounds[\u001b[38;5;241m0\u001b[39m] \u001b[38;5;129;01mor\u001b[39;00m value \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbounds[\u001b[38;5;241m1\u001b[39m]:\n\u001b[0;32m--> 202\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mset value must be in \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbounds\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 203\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_setfun(value, ring\u001b[38;5;241m=\u001b[39mring)\n\u001b[1;32m 204\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m np\u001b[38;5;241m.\u001b[39misnan(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_initial):\n", + "\u001b[0;31mValueError\u001b[0m: set value must be in (0.45, 0.55)" + ] + } + ], + "source": [ + "lfbound.set(0.2)" + ] + }, + { + "cell_type": "markdown", + "id": "ef061e87-74e7-446d-b93c-a5c84269ce0c", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "Variables also accept a *delta* keyword argument. Its value is used as the initial step in matching, and in the {py:meth}`~.variables.VariableBase.step_up` and {py:meth}`~.variables.VariableBase.step_down` methods." + ] + }, + { + "cell_type": "markdown", + "id": "7ee77711-5ae6-4e9c-8fa6-78dcbf8d21ca", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "## {py:class}`.RefptsVariable`\n", + "\n", + "An {py:class}`.RefptsVariable` is similar to an {py:class}`~.lattice_variables.ElementVariable` but it is not associated with an {py:class}`~.Element`\n", + "itself, but with its location in a Lattice. So it will act on any lattice with the same elements.\n", + "\n", + "But it needs a *ring* keyword in its *set* and *get* methods, to identify the selected lattice.\n", + "\n", + "Let's load a test ring and make a copy of it:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "10bf3120-0126-484f-8704-1b8ee02eabe0", + "metadata": {}, + "outputs": [], + "source": [ + "fname = \"hmba.mat\"\n", + "with as_file(files(\"machine_data\") / fname) as path:\n", + " ring = at.load_lattice(path)\n", + "newring = ring.deepcopy()" + ] + }, + { + "cell_type": "markdown", + "id": "30f5880c-5473-4452-8d00-d4a0853e1562", + "metadata": {}, + "source": [ + "and create a {py:class}`.RefptsVariable`" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "8cd07413-331c-4b93-b160-cb896c23cc1e", + "metadata": {}, + "outputs": [], + "source": [ + "kf2 = RefptsVariable(\"QF1[AE]\", \"PolynomB\", index=1, name=\"kf2\")" + ] + }, + { + "cell_type": "markdown", + "id": "fa01f10a-02d5-49f8-b00d-92651227db35", + "metadata": {}, + "source": [ + "We can now use this variable on the two rings:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "8f7b53ee-1d3a-4591-afa5-ec11fefcef8a", + "metadata": {}, + "outputs": [], + "source": [ + "kf2.set(2.55, ring=ring)\n", + "kf2.set(2.45, ring=newring)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "a2c16507-5153-4ff0-a27c-5a292c3b78f4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " ring: 2.55\n", + "newring: 2.45\n" + ] + } + ], + "source": [ + "print(f\" ring: {ring[5].PolynomB[1]}\")\n", + "print(f\"newring: {newring[5].PolynomB[1]}\")" + ] + }, + { + "cell_type": "markdown", + "id": "204d24e6-1e9b-4838-a950-3a8e1df5cac4", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "## Custom variables\n", + "Custom variables allow access to almost any quantity in AT. This can be achieved either by subclassing the {py:class}`~.variables.VariableBase` abstract base class, or by using a {py:class}`~.variables.CustomVariable`.\n", + "\n", + "We will take 2 examples:\n", + "\n", + "1. A variable accessing the *DPStep* parameter used in chromaticity computations. It does not look like a very\n", + " useful variable, it's for demonstration purpose,\n", + "2. A variable accessing the energy of a given lattice\n", + "\n", + "### Using the {py:class}`~.variables.CustomVariable`\n", + "\n", + "Using a {py:class}`~.variables.CustomVariable` makes it very easy to define simple variables: we just need\n", + "to define two functions for the \"get\" and \"set\" actions, and give them to the {py:class}`~.variables.CustomVariable` constructor.\n", + "\n", + "#### Example 1\n", + "\n", + "We define 2 functions for setting and getting the variable value:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "55a35ed2-a93b-4611-8a57-d7bce36a39f3", + "metadata": {}, + "outputs": [], + "source": [ + "def setvar1(value, ring=None):\n", + " at.DConstant.DPStep = value\n", + "\n", + "\n", + "def getvar1(ring=None):\n", + " return at.DConstant.DPStep" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "347abd5d-bfc9-469c-aae8-715fdfa11009", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3e-06\n" + ] + } + ], + "source": [ + "dpstep_var = CustomVariable(setvar1, getvar1, bounds=(1.0e-12, 0.1))\n", + "print(dpstep_var.value)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "f563564d-a862-4a11-9e9d-2e8864fa082d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.0002\n" + ] + } + ], + "source": [ + "dpstep_var.value = 2.0e-4\n", + "print(at.DConstant.DPStep)" + ] + }, + { + "cell_type": "markdown", + "id": "a52236a1-bcc2-4040-811e-4a1aadc11a42", + "metadata": {}, + "source": [ + "#### Example 2\n", + "\n", + "We can give to the {py:class}`~.variables.CustomVariable` constructor any positional or keyword argument\n", + "necessary for the *set* and *get* functions. Here we will send the lattice as a positional argument:" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "8b5235c9-089d-46d2-a761-1044b445e583", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def setvar2(value, lattice, ring=None):\n", + " lattice.energy = value\n", + "\n", + "\n", + "def getvar2(lattice, ring=None):\n", + " return lattice.energy\n", + "\n", + "\n", + "energy_var = CustomVariable(setvar2, getvar2, newring)" + ] + }, + { + "cell_type": "markdown", + "id": "24d9b25f-d712-4056-800f-7f9bca6b2749", + "metadata": {}, + "source": [ + "Here, the *newring* positional argument given to the variable constructor is available as a positional argument\n", + "in both the *set* and *get* functions." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "957c8990-d5e8-435d-959d-31109ff7cd17", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6000000000.0\n" + ] + } + ], + "source": [ + "print(energy_var.value)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "3db01f54-db31-4042-b8b3-ac3289591ccb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6100000000.0\n" + ] + } + ], + "source": [ + "energy_var.value = 6.1e9\n", + "print(energy_var.value)" + ] + }, + { + "cell_type": "markdown", + "id": "70085e74-406b-450b-844b-5880a7847610", + "metadata": {}, + "source": [ + "We can look at the history of the variable" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "786e4424-8840-490c-8ac0-1db350c0ef00", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[6000000000.0, 6100000000.0]" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "energy_var.history" + ] + }, + { + "cell_type": "markdown", + "id": "4f158c03-f238-43c8-912f-8bbcdc98efb0", + "metadata": {}, + "source": [ + "and go back to the initial value" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "9df21be4-0cf4-4499-aa0f-0fd7014f8934", + "metadata": {}, + "outputs": [], + "source": [ + "energy_var.reset()" + ] + }, + { + "cell_type": "markdown", + "id": "0dc64ddf-26fd-4c19-a76b-b55f79b89717", + "metadata": {}, + "source": [ + "### By derivation of the {py:class}`~.variables.VariableBase` class\n", + "\n", + "The derivation of {py:class}`~.variables.VariableBase` allows more control on the created variable by using\n", + "the class constuctor and its arguments to setup the variable.\n", + "\n", + "We will write a new variable class based on {py:class}`~.variables.VariableBase` abstract base class. The main task is to implement the `_setfun` and `_getfun` abstract methods.\n", + "\n", + "#### Example 1" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "7474764c-88ee-49df-9f85-b9e7f04aefbb", + "metadata": {}, + "outputs": [], + "source": [ + "class DPStepVariable(VariableBase):\n", + "\n", + " def _setfun(self, value, ring=None):\n", + " at.DConstant.DPStep = value\n", + "\n", + " def _getfun(self, ring=None):\n", + " return at.DConstant.DPStep" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "d838c699-8f99-4eb7-8784-acd503731888", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.0002\n" + ] + } + ], + "source": [ + "dpstep_var = DPStepVariable()\n", + "print(dpstep_var.value)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "13fcbedf-f2e8-46f5-ab01-5a8c80608c76", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3e-06\n" + ] + } + ], + "source": [ + "dpstep_var.value = 3.0e-6\n", + "print(dpstep_var.value)" + ] + }, + { + "cell_type": "markdown", + "id": "03d34984-79a4-4736-b902-7b77bdbd5a89", + "metadata": {}, + "source": [ + "#### Example 2\n", + "\n", + "Here we will store the lattice as an instance variable in the class constructor:" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "547bdb0a-c1ea-4861-88d3-407329478391", + "metadata": {}, + "outputs": [], + "source": [ + "class EnergyVariable(VariableBase):\n", + " def __init__(self, lattice, *args, **kwargs):\n", + " # Store the lattice\n", + " self.lattice = lattice\n", + " # Initialise the parent class\n", + " super().__init__(*args, **kwargs)\n", + "\n", + " def _setfun(self, value, ring=None):\n", + " self.lattice.energy = value\n", + "\n", + " def _getfun(self, ring=None):\n", + " return self.lattice.energy" + ] + }, + { + "cell_type": "markdown", + "id": "4195859c-3455-46ae-a443-f905926c7517", + "metadata": {}, + "source": [ + "We construct the variable:" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "2c61510b-3fab-4d1c-8dd3-5a0dc9e8659f", + "metadata": {}, + "outputs": [], + "source": [ + "energy_var = EnergyVariable(ring)" + ] + }, + { + "cell_type": "markdown", + "id": "bed78285-ae04-426a-9c04-b3796d80533b", + "metadata": {}, + "source": [ + "Look at the initial state:" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "2f75312e-c088-43fb-b7f8-d0179363167d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6000000000.0\n" + ] + } + ], + "source": [ + "print(energy_var.value)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "8fb40e51-9c5d-4f81-abfe-64e577a74656", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "6100000000.0\n" + ] + } + ], + "source": [ + "energy_var.value = 6.1e9\n", + "print(energy_var.value)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "54d43d08-0480-43bc-a260-631cf44f800a", + "metadata": {}, + "outputs": [], + "source": [ + "energy_var.reset()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyat/at/future.py b/pyat/at/future.py new file mode 100644 index 000000000..4dbf0b4ab --- /dev/null +++ b/pyat/at/future.py @@ -0,0 +1,2 @@ +from .lattice.variables import * +from .lattice.lattice_variables import * diff --git a/pyat/at/lattice/__init__.py b/pyat/at/lattice/__init__.py index 31c123c5a..527b212d2 100644 --- a/pyat/at/lattice/__init__.py +++ b/pyat/at/lattice/__init__.py @@ -9,11 +9,14 @@ from .axisdef import * from .options import DConstant, random from .particle_object import Particle +# from .variables import * +from .variables import VariableList from .elements import * from .rectangular_bend import * from .idtable_element import InsertionDeviceKickMap from .utils import * from .lattice_object import * +# from .lattice_variables import * from .cavity_access import * from .variable_elements import * from .deprecated import * diff --git a/pyat/at/lattice/axisdef.py b/pyat/at/lattice/axisdef.py index ae23141db..6f51608cf 100644 --- a/pyat/at/lattice/axisdef.py +++ b/pyat/at/lattice/axisdef.py @@ -1,6 +1,8 @@ """Helper functions for axis and plane descriptions""" + from __future__ import annotations from typing import Optional, Union + # For sys.version_info.minor < 9: from typing import Tuple @@ -16,31 +18,31 @@ ct=dict(index=5, label=r"$\beta c \tau$", unit=" [m]"), ) for xk, xv in [it for it in _axis_def.items()]: - xv['code'] = xk - _axis_def[xv['index']] = xv + xv["code"] = xk + _axis_def[xv["index"]] = xv _axis_def[xk.upper()] = xv -_axis_def['delta'] = _axis_def['dp'] -_axis_def['xp'] = _axis_def['px'] # For backward compatibility -_axis_def['yp'] = _axis_def['py'] # For backward compatibility -_axis_def['s'] = _axis_def['ct'] -_axis_def['S'] = _axis_def['ct'] -_axis_def[None] = dict(index=slice(None), label="", unit="", code=":") +_axis_def["delta"] = _axis_def["dp"] +_axis_def["xp"] = _axis_def["px"] # For backward compatibility +_axis_def["yp"] = _axis_def["py"] # For backward compatibility +_axis_def["s"] = _axis_def["ct"] +_axis_def["S"] = _axis_def["ct"] +_axis_def[None] = dict(index=None, label="", unit="", code=":") _axis_def[Ellipsis] = dict(index=Ellipsis, label="", unit="", code="...") _plane_def = dict( x=dict(index=0, label="x", unit=" [m]"), y=dict(index=1, label="y", unit=" [m]"), - z=dict(index=2, label="z", unit="") + z=dict(index=2, label="z", unit=""), ) for xk, xv in [it for it in _plane_def.items()]: - xv['code'] = xk - _plane_def[xv['index']] = xv + xv["code"] = xk + _plane_def[xv["index"]] = xv _plane_def[xk.upper()] = xv -_plane_def['h'] = _plane_def['x'] -_plane_def['v'] = _plane_def['y'] -_plane_def['H'] = _plane_def['x'] -_plane_def['V'] = _plane_def['y'] -_plane_def[None] = dict(index=slice(None), label="", unit="", code=":") +_plane_def["h"] = _plane_def["x"] +_plane_def["v"] = _plane_def["y"] +_plane_def["H"] = _plane_def["x"] +_plane_def["V"] = _plane_def["y"] +_plane_def[None] = dict(index=None, label="", unit="", code=":") _plane_def[Ellipsis] = dict(index=Ellipsis, label="", unit="", code="...") diff --git a/pyat/at/lattice/elements.py b/pyat/at/lattice/elements.py index 5d242de3f..4dbfff9bd 100644 --- a/pyat/at/lattice/elements.py +++ b/pyat/at/lattice/elements.py @@ -5,15 +5,21 @@ appropriate attributes. If a different PassMethod is set, it is the caller's responsibility to ensure that the appropriate attributes are present. """ + from __future__ import annotations + import abc import re -import numpy -from copy import copy, deepcopy from abc import ABC from collections.abc import Generator, Iterable +from copy import copy, deepcopy from typing import Any, Optional +import numpy + +# noinspection PyProtectedMember +from .variables import _nop + def _array(value, shape=(-1,), dtype=numpy.float64): # Ensure proper ordering(F) and alignment(A) for "C" access in integrators @@ -25,8 +31,17 @@ def _array66(value): return _array(value, shape=(6, 6)) -def _nop(value): - return value +def _float(value) -> float: + return float(value) + + +def _int(value, vmin: Optional[int] = None, vmax: Optional[int] = None) -> int: + intv = int(value) + if vmin is not None and intv < vmin: + raise ValueError(f"Value must be greater of equal to {vmin}") + if vmax is not None and intv > vmax: + raise ValueError(f"Value must be smaller of equal to {vmax}") + return intv class LongtMotion(ABC): @@ -42,6 +57,7 @@ class LongtMotion(ABC): * ``set_longt_motion(self, enable, new_pass=None, copy=False, **kwargs)`` must enable or disable longitudinal motion. """ + @abc.abstractmethod def _get_longt_motion(self): return False @@ -103,7 +119,8 @@ class _DictLongtMotion(LongtMotion): Defines a class such that :py:meth:`set_longt_motion` will select ``'IdentityPass'`` or ``'IdentityPass'``. - """ + """ + def _get_longt_motion(self): return self.PassMethod != self.default_pass[False] @@ -161,16 +178,20 @@ def set_longt_motion(self, enable, new_pass=None, copy=False, **kwargs): if new_pass is None or new_pass == self.PassMethod: return self if copy else None if enable: + def setpass(el): el.PassMethod = new_pass el.Energy = kwargs['energy'] + else: + def setpass(el): el.PassMethod = new_pass try: del el.Energy except AttributeError: pass + if copy: newelem = deepcopy(self) setpass(newelem) @@ -240,7 +261,7 @@ class Element(object): """Base class for AT elements""" _BUILD_ATTRIBUTES = ['FamName'] - _conversions = dict(FamName=str, PassMethod=str, Length=float, + _conversions = dict(FamName=str, PassMethod=str, Length=_float, R1=_array66, R2=_array66, T1=lambda v: _array(v, (6,)), T2=lambda v: _array(v, (6,)), @@ -248,9 +269,9 @@ class Element(object): EApertures=lambda v: _array(v, (2,)), KickAngle=lambda v: _array(v, (2,)), PolynomB=_array, PolynomA=_array, - BendingAngle=float, - MaxOrder=int, NumIntSteps=int, - Energy=float, + BendingAngle=_float, + MaxOrder=_int, NumIntSteps=lambda v: _int(v, vmin=0), + Energy=_float, ) _entrance_fields = ['T1', 'R1'] @@ -272,30 +293,35 @@ def __init__(self, family_name: str, **kwargs): def __setattr__(self, key, value): try: - super(Element, self).__setattr__( - key, self._conversions.get(key, _nop)(value)) + value = self._conversions.get(key, _nop)(value) except Exception as exc: exc.args = ('In element {0}, parameter {1}: {2}'.format( self.FamName, key, exc),) raise + else: + super(Element, self).__setattr__(key, value) def __str__(self): - first3 = ['FamName', 'Length', 'PassMethod'] + first3 = ["FamName", "Length", "PassMethod"] + # Get values and parameter objects attrs = dict(self.items()) - keywords = ['\t{0} : {1!s}'.format(k, attrs.pop(k)) for k in first3] - keywords += ['\t{0} : {1!s}'.format(k, v) for k, v in attrs.items()] - return '\n'.join((type(self).__name__ + ':', '\n'.join(keywords))) + keywords = [f"\t{k} : {attrs.pop(k)!s}" for k in first3] + keywords += [f"\t{k} : {v!s}" for k, v in attrs.items()] + return "\n".join((type(self).__name__ + ":", "\n".join(keywords))) def __repr__(self): - attrs = dict(self.items()) - arguments = [attrs.pop(k, getattr(self, k)) for k in - self._BUILD_ATTRIBUTES] + # Get values only, even for parameters + attrs = dict((k, getattr(self, k)) for k, v in self.items()) + arguments = [attrs.pop(k) for k in self._BUILD_ATTRIBUTES] defelem = self.__class__(*arguments) - keywords = ['{0!r}'.format(arg) for arg in arguments] - keywords += ['{0}={1!r}'.format(k, v) for k, v in sorted(attrs.items()) - if not numpy.array_equal(v, getattr(defelem, k, None))] - args = re.sub(r'\n\s*', ' ', ', '.join(keywords)) - return '{0}({1})'.format(self.__class__.__name__, args) + keywords = [f"{v!r}" for v in arguments] + keywords += [ + f"{k}={v!r}" + for k, v in sorted(attrs.items()) + if not numpy.array_equal(v, getattr(defelem, k, None)) + ] + args = re.sub(r"\n\s*", " ", ", ".join(keywords)) + return "{0}({1})".format(self.__class__.__name__, args) def equals(self, other) -> bool: """Whether an element is equivalent to another. @@ -326,10 +352,12 @@ def divide(self, frac) -> list[Element]: def swap_faces(self, copy=False): """Swap the faces of an element, alignment errors are ignored""" + def swapattr(element, attro, attri): val = getattr(element, attri) delattr(element, attri) return attro, val + if copy: el = self.copy() else: @@ -360,7 +388,7 @@ def update(self, *args, **kwargs): Update the element attributes with the given arguments """ attrs = dict(*args, **kwargs) - for (key, value) in attrs.items(): + for key, value in attrs.items(): setattr(self, key, value) def copy(self) -> Element: @@ -373,8 +401,8 @@ def deepcopy(self) -> Element: def items(self) -> Generator[tuple[str, Any], None, None]: """Iterates through the data members""" - for k, v in vars(self).items(): - yield k, v + # Properties may be added by overloading this method + yield from vars(self).items() def is_compatible(self, other: Element) -> bool: """Checks if another :py:class:`Element` can be merged""" @@ -384,8 +412,7 @@ def merge(self, other) -> None: """Merge another element""" if not self.is_compatible(other): badname = getattr(other, 'FamName', type(other)) - raise TypeError('Cannot merge {0} and {1}'.format(self.FamName, - badname)) + raise TypeError("Cannot merge {0} and {1}".format(self.FamName, badname)) # noinspection PyMethodMayBeStatic def _get_longt_motion(self): @@ -407,8 +434,8 @@ def is_collective(self) -> bool: class LongElement(Element): - """Base class for long elements - """ + """Base class for long elements""" + _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ['Length'] def __init__(self, family_name: str, length: float, *args, **kwargs): @@ -442,8 +469,7 @@ def popattr(element, attr): # Remove entrance and exit attributes fin = dict(popattr(el, key) for key in vars(self) if key in self._entrance_fields) - fout = dict(popattr(el, key) for key in vars(self) if - key in self._exit_fields) + fout = dict(popattr(el, key) for key in vars(self) if key in self._exit_fields) # Split element element_list = [el._part(f, numpy.sum(frac)) for f in frac] # Restore entrance and exit attributes @@ -516,6 +542,7 @@ def means(self): class SliceMoments(Element): """Element to compute slices mean and std""" + _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ['nslice'] _conversions = dict(Element._conversions, nslice=int) @@ -534,8 +561,7 @@ def __init__(self, family_name: str, nslice: int, **kwargs): kwargs.setdefault('PassMethod', 'SliceMomentsPass') self._startturn = kwargs.pop('startturn', 0) self._endturn = kwargs.pop('endturn', 1) - super(SliceMoments, self).__init__(family_name, nslice=nslice, - **kwargs) + super(SliceMoments, self).__init__(family_name, nslice=nslice, **kwargs) self._nbunch = 1 self.startturn = self._startturn self.endturn = self._endturn @@ -550,45 +576,33 @@ def set_buffers(self, nturns, nbunch): self.endturn = min(self.endturn, nturns) self._dturns = self.endturn - self.startturn self._nbunch = nbunch - self._stds = numpy.zeros((3, nbunch*self.nslice, self._dturns), - order='F') - self._means = numpy.zeros((3, nbunch*self.nslice, self._dturns), - order='F') - self._spos = numpy.zeros((nbunch*self.nslice, self._dturns), - order='F') - self._weights = numpy.zeros((nbunch*self.nslice, self._dturns), - order='F') + self._stds = numpy.zeros((3, nbunch*self.nslice, self._dturns), order="F") + self._means = numpy.zeros((3, nbunch*self.nslice, self._dturns), order="F") + self._spos = numpy.zeros((nbunch*self.nslice, self._dturns), order="F") + self._weights = numpy.zeros((nbunch*self.nslice, self._dturns), order="F") @property def stds(self): """Slices x,y,dp standard deviation""" - return self._stds.reshape((3, self._nbunch, - self.nslice, - self._dturns)) + return self._stds.reshape((3, self._nbunch, self.nslice, self._dturns)) @property def means(self): """Slices x,y,dp center of mass""" - return self._means.reshape((3, self._nbunch, - self.nslice, - self._dturns)) + return self._means.reshape((3, self._nbunch, self.nslice, self._dturns)) @property def spos(self): """Slices s position""" - return self._spos.reshape((self._nbunch, - self.nslice, - self._dturns)) + return self._spos.reshape((self._nbunch, self.nslice, self._dturns)) @property def weights(self): """Slices weights in mA if beam current >0, - otherwise fraction of total number of - particles in the bunch + otherwise fraction of total number of + particles in the bunch """ - return self._weights.reshape((self._nbunch, - self.nslice, - self._dturns)) + return self._weights.reshape((self._nbunch, self.nslice, self._dturns)) @property def startturn(self): @@ -619,6 +633,7 @@ def endturn(self, value): class Aperture(Element): """Aperture element""" + _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ['Limits'] _conversions = dict(Element._conversions, Limits=lambda v: _array(v, (4,))) @@ -697,6 +712,7 @@ def insert(self, class Collimator(Drift): """Collimator element""" + _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['RApertures'] def __init__(self, family_name: str, length: float, limits, **kwargs): @@ -715,8 +731,8 @@ def __init__(self, family_name: str, length: float, limits, **kwargs): class ThinMultipole(Element): """Thin multipole element""" - _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ['PolynomA', - 'PolynomB'] + + _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ["PolynomA", "PolynomB"] def __init__(self, family_name: str, poly_a, poly_b, **kwargs): """ @@ -744,10 +760,13 @@ def lengthen(poly, dl): else: return poly - # Remove MaxOrder, PolynomA and PolynomB - poly_a, len_a, ord_a = getpol(_array(kwargs.pop('PolynomA', poly_a))) - poly_b, len_b, ord_b = getpol(_array(kwargs.pop('PolynomB', poly_b))) + # PolynomA and PolynomB and convert to ParamArray + prmpola = self._conversions["PolynomA"](kwargs.pop("PolynomA", poly_a)) + prmpolb = self._conversions["PolynomB"](kwargs.pop("PolynomB", poly_b)) + poly_a, len_a, ord_a = getpol(prmpola) + poly_b, len_b, ord_b = getpol(prmpolb) deforder = max(getattr(self, 'DefaultOrder', 0), ord_a, ord_b) + # Remove MaxOrder maxorder = kwargs.pop('MaxOrder', deforder) kwargs.setdefault('PassMethod', 'ThinMPolePass') super(ThinMultipole, self).__init__(family_name, **kwargs) @@ -755,36 +774,32 @@ def lengthen(poly, dl): super(ThinMultipole, self).__setattr__('MaxOrder', maxorder) # Adjust polynom lengths and set them len_ab = max(self.MaxOrder + 1, len_a, len_b) - self.PolynomA = lengthen(poly_a, len_ab - len_a) - self.PolynomB = lengthen(poly_b, len_ab - len_b) + self.PolynomA = lengthen(prmpola, len_ab - len_a) + self.PolynomB = lengthen(prmpolb, len_ab - len_b) def __setattr__(self, key, value): """Check the compatibility of MaxOrder, PolynomA and PolynomB""" polys = ('PolynomA', 'PolynomB') if key in polys: - value = _array(value) - lmin = getattr(self, 'MaxOrder') + lmin = self.MaxOrder if not len(value) > lmin: raise ValueError( 'Length of {0} must be larger than {1}'.format(key, lmin)) elif key == 'MaxOrder': - value = int(value) + intval = int(value) lmax = min(len(getattr(self, k)) for k in polys) - if not value < lmax: - raise ValueError( - 'MaxOrder must be smaller than {0}'.format(lmax)) - + if not intval < lmax: + raise ValueError("MaxOrder must be smaller than {0}".format(lmax)) super(ThinMultipole, self).__setattr__(key, value) class Multipole(_Radiative, LongElement, ThinMultipole): """Multipole element""" - _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['PolynomA', - 'PolynomB'] + + _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ["PolynomA", "PolynomB"] _conversions = dict(ThinMultipole._conversions, K=float, H=float) - def __init__(self, family_name: str, length: float, poly_a, poly_b, - **kwargs): + def __init__(self, family_name: str, length: float, poly_a, poly_b, **kwargs): """ Args: family_name: Name of the element @@ -804,12 +819,10 @@ def __init__(self, family_name: str, length: float, poly_a, poly_b, """ kwargs.setdefault('PassMethod', 'StrMPoleSymplectic4Pass') kwargs.setdefault('NumIntSteps', 10) - super(Multipole, self).__init__(family_name, length, - poly_a, poly_b, **kwargs) + super(Multipole, self).__init__(family_name, length, poly_a, poly_b, **kwargs) def is_compatible(self, other) -> bool: - if super().is_compatible(other) and \ - self.MaxOrder == other.MaxOrder: + if super().is_compatible(other) and self.MaxOrder == other.MaxOrder: for i in range(self.MaxOrder + 1): if self.PolynomB[i] != other.PolynomB[i]: return False @@ -823,7 +836,8 @@ def is_compatible(self, other) -> bool: @property def K(self) -> float: """Focusing strength [mˆ-2]""" - return 0.0 if len(self.PolynomB) < 2 else self.PolynomB[1] + arr = self.PolynomB + return 0.0 if len(arr) < 2 else arr[1] # noinspection PyPep8Naming @K.setter @@ -834,7 +848,8 @@ def K(self, strength: float): @property def H(self) -> float: """Sextupolar strength [mˆ-3]""" - return 0.0 if len(self.PolynomB) < 3 else self.PolynomB[2] + arr = self.PolynomB + return 0.0 if len(arr) < 3 else arr[2] # noinspection PyPep8Naming @H.setter @@ -907,16 +922,15 @@ def __init__(self, family_name: str, length: float, Default PassMethod: :ref:`BndMPoleSymplectic4Pass` """ - poly_b = kwargs.pop('PolynomB', numpy.array([0, k])) kwargs.setdefault('BendingAngle', bending_angle) kwargs.setdefault('EntranceAngle', 0.0) kwargs.setdefault('ExitAngle', 0.0) kwargs.setdefault('PassMethod', 'BndMPoleSymplectic4Pass') - super(Dipole, self).__init__(family_name, length, [], poly_b, **kwargs) + super(Dipole, self).__init__(family_name, length, [], [0.0, k], **kwargs) - def items(self) -> Generator[tuple, None, None]: + def items(self) -> Generator[tuple[str, Any], None, None]: yield from super().items() - yield 'K', self.K + yield "K", vars(self)["PolynomB"][1] def _part(self, fr, sumfr): pp = super(Dipole, self)._part(fr, sumfr) @@ -929,9 +943,9 @@ def is_compatible(self, other) -> bool: def invrho(dip: Dipole): return dip.BendingAngle / dip.Length - return (super().is_compatible(other) and - self.ExitAngle == -other.EntranceAngle and - abs(invrho(self) - invrho(other)) <= 1.e-6) + return (super().is_compatible(other) + and self.ExitAngle == -other.EntranceAngle + and abs(invrho(self) - invrho(other)) <= 1.e-6) def merge(self, other) -> None: super().merge(other) @@ -946,6 +960,7 @@ def merge(self, other) -> None: class Quadrupole(Radiative, Multipole): """Quadrupole element""" + _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['K'] _conversions = dict(Multipole._conversions, FringeQuadEntrance=int, FringeQuadExit=int) @@ -983,18 +998,17 @@ def __init__(self, family_name: str, length: float, Default PassMethod: ``StrMPoleSymplectic4Pass`` """ - poly_b = kwargs.pop('PolynomB', numpy.array([0, k])) - kwargs.setdefault('PassMethod', 'StrMPoleSymplectic4Pass') - super(Quadrupole, self).__init__(family_name, length, [], poly_b, - **kwargs) + kwargs.setdefault("PassMethod", "StrMPoleSymplectic4Pass") + super(Quadrupole, self).__init__(family_name, length, [], [0.0, k], **kwargs) - def items(self) -> Generator[tuple, None, None]: + def items(self) -> Generator[tuple[str, Any], None, None]: yield from super().items() - yield 'K', self.K + yield "K", vars(self)["PolynomB"][1] class Sextupole(Multipole): """Sextupole element""" + _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['H'] DefaultOrder = 2 @@ -1018,14 +1032,18 @@ def __init__(self, family_name: str, length: float, Default PassMethod: ``StrMPoleSymplectic4Pass`` """ - poly_b = kwargs.pop('PolynomB', [0, 0, h]) - kwargs.setdefault('PassMethod', 'StrMPoleSymplectic4Pass') - super(Sextupole, self).__init__(family_name, length, [], poly_b, + kwargs.setdefault("PassMethod", "StrMPoleSymplectic4Pass") + super(Sextupole, self).__init__(family_name, length, [], [0.0, 0.0, h], **kwargs) + def items(self) -> Generator[tuple[str, Any], None, None]: + yield from super().items() + yield "H", vars(self)["PolynomB"][2] + class Octupole(Multipole): """Octupole element, with no changes from multipole at present""" + _BUILD_ATTRIBUTES = Multipole._BUILD_ATTRIBUTES DefaultOrder = 3 @@ -1033,6 +1051,7 @@ class Octupole(Multipole): class RFCavity(LongtMotion, LongElement): """RF cavity element""" + _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['Voltage', 'Frequency', 'HarmNumber', @@ -1073,9 +1092,9 @@ def _part(self, fr, sumfr): return pp def is_compatible(self, other) -> bool: - return (super().is_compatible(other) and - self.Frequency == other.Frequency and - self.TimeLag == other.TimeLag) + return (super().is_compatible(other) + and self.Frequency == other.Frequency + and self.TimeLag == other.TimeLag) def merge(self, other) -> None: super().merge(other) @@ -1094,6 +1113,7 @@ def set_longt_motion(self, enable, new_pass=None, **kwargs): class M66(Element): """Linear (6, 6) transfer matrix""" + _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES + ["M66"] _conversions = dict(Element._conversions, M66=_array66) @@ -1104,7 +1124,7 @@ def __init__(self, family_name: str, m66=None, **kwargs): m66: Transfer matrix. Default: Identity matrix Default PassMethod: ``Matrix66Pass`` - """ + """ if m66 is None: m66 = numpy.identity(6) kwargs.setdefault('PassMethod', 'Matrix66Pass') @@ -1120,6 +1140,7 @@ class SimpleQuantDiff(_DictLongtMotion, Element): Note: The damping times are needed to compute the correct kick for the emittance. Radiation damping is NOT applied. """ + _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES default_pass = {False: 'IdentityPass', True: 'SimpleQuantDiffPass'} @@ -1144,8 +1165,8 @@ def __init__(self, family_name: str, betax: float = 1.0, tauz: Longitudinal damping time [turns] Default PassMethod: ``SimpleQuantDiffPass`` - """ - kwargs.setdefault('PassMethod', self.default_pass[True]) + """ + kwargs.setdefault("PassMethod", self.default_pass[True]) assert taux >= 0.0, 'taux must be greater than or equal to 0' self.taux = taux @@ -1178,6 +1199,7 @@ def __init__(self, family_name: str, betax: float = 1.0, class SimpleRadiation(_DictLongtMotion, Radiative, Element): """Simple radiation damping and energy loss""" + _BUILD_ATTRIBUTES = Element._BUILD_ATTRIBUTES _conversions = dict(Element._conversions, U0=float, damp_mat_diag=lambda v: _array(v, shape=(6,))) @@ -1199,7 +1221,7 @@ def __init__(self, family_name: str, U0: Energy loss per turn [eV] Default PassMethod: ``SimpleRadiationPass`` - """ + """ assert taux >= 0.0, 'taux must be greater than or equal to 0' if taux == 0.0: dampx = 1 @@ -1228,6 +1250,7 @@ def __init__(self, family_name: str, class Corrector(LongElement): """Corrector element""" + _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['KickAngle'] def __init__(self, family_name: str, length: float, kick_angle, **kwargs): @@ -1253,6 +1276,7 @@ class Wiggler(Radiative, LongElement): See atwiggler.m """ + _BUILD_ATTRIBUTES = LongElement._BUILD_ATTRIBUTES + ['Lw', 'Bmax', 'Energy'] _conversions = dict(Element._conversions, Lw=float, Bmax=float, @@ -1295,14 +1319,12 @@ def __init__(self, family_name: str, length: float, wiggle_period: float, for i, b in enumerate(self.By.T): dk = abs(b[3] ** 2 - b[4] ** 2 - b[2] ** 2) / abs(b[4]) if dk > 1e-6: - raise ValueError("Wiggler(H): kx^2 + kz^2 -ky^2 !=0, i = " - "{0}".format(i)) + raise ValueError("Wiggler(H): kx^2 + kz^2 -ky^2 !=0, i = {0}".format(i)) for i, b in enumerate(self.Bx.T): dk = abs(b[2] ** 2 - b[4] ** 2 - b[3] ** 2) / abs(b[4]) if dk > 1e-6: - raise ValueError("Wiggler(V): ky^2 + kz^2 -kx^2 !=0, i = " - "{0}".format(i)) + raise ValueError("Wiggler(V): ky^2 + kz^2 -kx^2 !=0, i = {0}".format(i)) self.NHharm = self.By.shape[1] self.NVharm = self.Bx.shape[1] diff --git a/pyat/at/lattice/lattice_object.py b/pyat/at/lattice/lattice_object.py index 45b972617..357df9d91 100644 --- a/pyat/at/lattice/lattice_object.py +++ b/pyat/at/lattice/lattice_object.py @@ -33,7 +33,7 @@ from .elements import Element from .particle_object import Particle from .utils import AtError, AtWarning, Refpts -from .utils import get_s_pos, get_elements,get_value_refpts, set_value_refpts +from .utils import get_s_pos, get_elements, get_value_refpts, set_value_refpts # noinspection PyProtectedMember from .utils import get_uint32_index, get_bool_index, _refcount, Uint32Refpts from .utils import refpts_iterator, checktype, set_shift, set_tilt, get_geometry @@ -296,11 +296,11 @@ def _addition_filter(self, elems: Iterable[Element], copy_elements=False): if cavities and not hasattr(self, '_cell_harmnumber'): cavities.sort(key=lambda el: el.Frequency) try: - self._cell_harmnumber = getattr(cavities[0], 'HarmNumber') + self._cell_harmnumber = cavities[0].HarmNumber except AttributeError: length += self.get_s_pos(len(self))[0] rev = self.beta * clight / length - frequency = getattr(cavities[0], 'Frequency') + frequency = cavities[0].Frequency self._cell_harmnumber = int(round(frequency / rev)) self._radiation |= params.pop('_radiation') @@ -314,13 +314,13 @@ def insert(self, idx: SupportsIndex, elem: Element, copy_elements=False): If :py:obj:`True` a deep copy of elem is used. """ - # noinspection PyUnusedLocal # scan the new element to update it - elist = list(self._addition_filter([elem], + elist = list(self._addition_filter([elem], # noqa: F841 copy_elements=copy_elements)) super().insert(idx, elem) def extend(self, elems: Iterable[Element], copy_elements=False): + # noinspection PyUnresolvedReferences r"""This method adds all the elements of `elems` to the end of the lattice. The behavior is the same as for a :py:obj:`list` @@ -343,6 +343,7 @@ def extend(self, elems: Iterable[Element], copy_elements=False): super().extend(elems) def append(self, elem: Element, copy_elements=False): + # noinspection PyUnresolvedReferences r"""This method overwrites the inherited method :py:meth:`list.append()`, its behavior is changed, it accepts only AT lattice elements :py:obj:`Element` as input argument. @@ -361,6 +362,7 @@ def append(self, elem: Element, copy_elements=False): self.extend([elem], copy_elements=copy_elements) def repeat(self, n: int, copy_elements: bool = True): + # noinspection SpellCheckingInspection,PyUnresolvedReferences,PyRedeclaration r"""This method allows to repeat the lattice `n` times. If `n` does not divide `ring.periodicity`, the new ring periodicity is set to 1, otherwise it is set to @@ -405,6 +407,7 @@ def copy_fun(elem, copy): def concatenate(self, *lattices: Iterable[Element], copy_elements=False, copy=False): + # noinspection PyUnresolvedReferences,SpellCheckingInspection,PyRedeclaration """Concatenate several `Iterable[Element]` with the lattice Equivalent syntaxes: @@ -439,6 +442,7 @@ def concatenate(self, *lattices: Iterable[Element], return lattice if copy else None def reverse(self, copy=False): + # noinspection PyUnresolvedReferences r"""Reverse the order of the lattice and swapt the faces of elements. Alignment errors are not swapped @@ -516,7 +520,7 @@ def copy(self) -> Lattice: def deepcopy(self) -> Lattice: """Returns a deep copy of the lattice""" return copy.deepcopy(self) - + def slice_elements(self, refpts: Refpts, slices: int = 1) -> Lattice: """Create a new lattice by slicing the elements at refpts @@ -538,7 +542,7 @@ def slice_generator(_): else: yield el - return Lattice(slice_generator, iterator=self.attrs_filter) + return Lattice(slice_generator, iterator=self.attrs_filter) def slice(self, size: Optional[float] = None, slices: Optional[int] = 1) \ -> Lattice: @@ -635,8 +639,8 @@ def energy(self) -> float: def energy(self, energy: float): # Set the Energy attribute of radiating elements for elem in self: - if (isinstance(elem, (elt.RFCavity, elt.Wiggler)) or - elem.PassMethod.endswith('RadPass')): + if (isinstance(elem, (elt.RFCavity, elt.Wiggler)) + or elem.PassMethod.endswith('RadPass')): elem.Energy = energy # Set the energy attribute of the Lattice # Use a numpy scalar to allow division by zero @@ -1471,7 +1475,7 @@ def params_filter(params, elem_filter: Filter, *args) -> Generator[Element, None cavities = [] cell_length = 0 - for idx, elem in enumerate(elem_filter(params, *args)): + for elem in elem_filter(params, *args): if isinstance(elem, elt.RFCavity): cavities.append(elem) elif hasattr(elem, 'Energy'): diff --git a/pyat/at/lattice/lattice_variables.py b/pyat/at/lattice/lattice_variables.py new file mode 100644 index 000000000..aa94069b0 --- /dev/null +++ b/pyat/at/lattice/lattice_variables.py @@ -0,0 +1,137 @@ +"""Variables are **references** to scalar attributes of lattice elements. There are 2 +kinds of element variables: + +- an :py:class:`ElementVariable` is associated to an element object, and acts on all + occurences of this object. But it will not affect any copy, neither shallow nor deep, + of the original object, +- a :py:class:`RefptsVariable` is not associated to an element object, but to an element + location in a :py:class:`.Lattice`. It acts on any copy of the initial lattice. A + *ring* argument must be provided to the *set* and *get* methods to identify the + lattice, which may be a possibly modified copy of the original lattice +""" + +from __future__ import annotations + +__all__ = ["RefptsVariable", "ElementVariable"] + +from collections.abc import Sequence +from typing import Union, Optional + +import numpy as np + +from .elements import Element +from .lattice_object import Lattice +from .utils import Refpts, getval, setval +from .variables import VariableBase + + +class RefptsVariable(VariableBase): + r"""A reference to a scalar attribute of :py:class:`.Lattice` elements. + + It can refer to: + + * a scalar attribute or + * an element of an array attribute + + of one or several :py:class:`.Element`\ s of a lattice. + + A :py:class:`RefptsVariable` is not associated to element objets themselves, but + to the location of these elements in a lattice. So a :py:class:`RefptsVariable` + will act equally on any copy of the initial ring. + As a consequence, a *ring* keyword argument (:py:class:`.Lattice` object) must be + supplied for getting or setting the variable. + """ + + def __init__( + self, refpts: Refpts, attrname: str, index: Optional[int] = None, **kwargs + ): + r""" + Parameters: + refpts: Location of variable :py:class:`.Element`\ s + attrname: Attribute name + index: Index in the attribute array. Use :py:obj:`None` for + scalar attributes + + Keyword Args: + name (str): Name of the Variable. Default: ``''`` + bounds (tuple[float, float]): Lower and upper bounds of the + variable value. Default: (-inf, inf) + delta (float): Step. Default: 1.0 + ring (Lattice): If specified, it is used to get and store the initial + value of the variable. Otherwise, the initial value is set to None + """ + self._getf = getval(attrname, index=index) + self._setf = setval(attrname, index=index) + self.refpts = refpts + super().__init__(**kwargs) + + def _setfun(self, value: float, ring: Lattice = None): + if ring is None: + raise ValueError("Can't set values if ring is None") + for elem in ring.select(self.refpts): + self._setf(elem, value) + + def _getfun(self, ring: Lattice = None) -> float: + if ring is None: + raise ValueError("Can't get values if ring is None") + values = np.array([self._getf(elem) for elem in ring.select(self.refpts)]) + return np.average(values) + + +class ElementVariable(VariableBase): + r"""A reference to a scalar attribute of :py:class:`.Lattice` elements. + + It can refer to: + + * a scalar attribute or + * an element of an array attribute + + of one or several :py:class:`.Element`\ s of a lattice. + + An :py:class:`ElementVariable` is associated to an element object, and acts on all + occurrences of this object. But it will not affect any copy, neither shallow nor + deep, of the original object. + """ + + def __init__( + self, + elements: Union[Element, Sequence[Element]], + attrname: str, + index: Optional[int] = None, + **kwargs, + ): + r""" + Parameters: + elements: :py:class:`.Element` or Sequence[Element] whose + attribute is varied + attrname: Attribute name + index: Index in the attribute array. Use :py:obj:`None` for + scalar attributes + + Keyword Args: + name (str): Name of the Variable. Default: ``''`` + bounds (tuple[float, float]): Lower and upper bounds of the + variable value. Default: (-inf, inf) + delta (float): Step. Default: 1.0 + """ + # Ensure the uniqueness of elements + if isinstance(elements, Element): + self._elements = {elements} + else: + self._elements = set(elements) + self._getf = getval(attrname, index=index) + self._setf = setval(attrname, index=index) + super().__init__(**kwargs) + + def _setfun(self, value: float, **kwargs): + for elem in self._elements: + self._setf(elem, value) + + def _getfun(self, **kwargs) -> float: + values = np.array([self._getf(elem) for elem in self._elements]) + return np.average(values) + + @property + def elements(self): + """Return the set of elements acted upon by the variable""" + return self._elements diff --git a/pyat/at/lattice/utils.py b/pyat/at/lattice/utils.py index 28037c878..49a979048 100644 --- a/pyat/at/lattice/utils.py +++ b/pyat/at/lattice/utils.py @@ -37,6 +37,7 @@ from typing import Union, Tuple, List, Type from enum import Enum from itertools import compress +from operator import attrgetter from fnmatch import fnmatch from .elements import Element, Dipole @@ -58,7 +59,7 @@ 'set_shift', 'set_tilt', 'set_rotation', 'tilt_elem', 'shift_elem', 'rotate_elem', 'get_value_refpts', 'set_value_refpts', 'Refpts', - 'get_geometry'] + 'get_geometry', 'setval', 'getval'] _axis_def = dict( x=dict(index=0, label="x", unit=" [m]"), @@ -113,6 +114,72 @@ def _type_error(refpts, types): "Invalid refpts type {0}. Allowed types: {1}".format(tp, types)) +# setval and getval return pickleable functions: no inner, nested function +# are allowed. So nested functions are replaced be module-level callable +# class instances +class _AttrItemGetter(object): + __slots__ = ["attrname", "index"] + + def __init__(self, attrname: str, index: int): + self.attrname = attrname + self.index = index + + def __call__(self, elem): + return getattr(elem, self.attrname)[self.index] + + +def getval(attrname: str, index: Optional[int] = None) -> Callable: + """Return a callable object which fetches item *index* of + attribute *attrname* of its operand. Examples: + + - After ``f = getval('Length')``, ``f(elem)`` returns ``elem.Length`` + - After ``f = getval('PolynomB, index=1)``, ``f(elem)`` returns + ``elem.PolynomB[1]`` + + """ + if index is None: + return attrgetter(attrname) + else: + return _AttrItemGetter(attrname, index) + + +class _AttrSetter(object): + __slots__ = ["attrname"] + + def __init__(self, attrname: str): + self.attrname = attrname + + def __call__(self, elem, value): + setattr(elem, self.attrname, value) + + +class _AttrItemSetter(object): + __slots__ = ["attrname", "index"] + + def __init__(self, attrname: str, index: int): + self.attrname = attrname + self.index = index + + def __call__(self, elem, value): + getattr(elem, self.attrname)[self.index] = value + + +def setval(attrname: str, index: Optional[int] = None) -> Callable: + """Return a callable object which sets the value of item *index* of + attribute *attrname* of its 1st argument to it 2nd orgument. + + - After ``f = setval('Length')``, ``f(elem, value)`` is equivalent to + ``elem.Length = value`` + - After ``f = setval('PolynomB, index=1)``, ``f(elem, value)`` is + equivalent to ``elem.PolynomB[1] = value`` + + """ + if index is None: + return _AttrSetter(attrname) + else: + return _AttrItemSetter(attrname, index) + + # noinspection PyIncorrectDocstring def axis_descr(*args, key=None) -> Tuple: r"""axis_descr(axis [ ,axis], key=None) @@ -774,13 +841,7 @@ def get_value_refpts(ring: Sequence[Element], refpts: Refpts, Returns: attrvalues: numpy Array of attribute values. """ - if index is None: - def getf(elem): - return getattr(elem, attrname) - else: - def getf(elem): - return getattr(elem, attrname)[index] - + getf = getval(attrname, index=index) return numpy.array([getf(elem) for elem in refpts_iterator(ring, refpts, regex=regex)]) @@ -817,13 +878,7 @@ def set_value_refpts(ring: Sequence[Element], refpts: Refpts, elements are shared with the original lattice. Any further modification will affect both lattices. """ - if index is None: - def setf(elem, value): - setattr(elem, attrname, value) - else: - def setf(elem, value): - getattr(elem, attrname)[index] = value - + setf = setval(attrname, index=index) if increment: attrvalues += get_value_refpts(ring, refpts, attrname, index=index, @@ -836,8 +891,7 @@ def setf(elem, value): # noinspection PyShadowingNames @make_copy(copy) def apply(ring, refpts, values, regex): - for elm, val in zip(refpts_iterator(ring, refpts, - regex=regex), values): + for elm, val in zip(refpts_iterator(ring, refpts, regex=regex), values): setf(elm, val) return apply(ring, refpts, attrvalues, regex) diff --git a/pyat/at/lattice/variables.py b/pyat/at/lattice/variables.py new file mode 100644 index 000000000..5b90e5f5a --- /dev/null +++ b/pyat/at/lattice/variables.py @@ -0,0 +1,458 @@ +r""" +Definition of :py:class:`Variable <.VariableBase>` objects used in matching and +response matrices. + +See :ref:`example-notebooks` for examples of matching and response matrices. + +Each :py:class:`Variable <.VariableBase>` has a scalar value. + +.. rubric:: Class hierarchy + +:py:class:`VariableBase`\ (name, bounds, delta) + +- :py:class:`~.lattice_variables.ElementVariable`\ (elements, attrname, index, ...) +- :py:class:`~.lattice_variables.RefptsVariable`\ (refpts, attrname, index, ...) +- :py:class:`CustomVariable`\ (setfun, getfun, ...) + +.. rubric:: VariableBase methods + +:py:class:`VariableBase` provides the following methods: + +- :py:meth:`~VariableBase.get` +- :py:meth:`~VariableBase.set` +- :py:meth:`~VariableBase.set_previous` +- :py:meth:`~VariableBase.reset` +- :py:meth:`~VariableBase.increment` +- :py:meth:`~VariableBase.step_up` +- :py:meth:`~VariableBase.step_down` + +.. rubric:: VariableBase properties + +:py:class:`.VariableBase` provides the following properties: + +- :py:attr:`~VariableBase.initial_value` +- :py:attr:`~VariableBase.last_value` +- :py:attr:`~VariableBase.previous_value` +- :py:attr:`~VariableBase.history` + +The :py:class:`VariableBase` abstract class may be used as a base class to define +custom variables (see examples). Typically, this consist in overloading the abstract +methods *_setfun* and *_getfun* + +.. rubric:: Examples + +Write a subclass of :py:class:`VariableBase` which varies two drift lengths so +that their sum is constant: + +.. code-block:: python + + class ElementShifter(at.VariableBase): + '''Varies the length of the elements identified by *ref1* and *ref2* + keeping the sum of their lengths equal to *total_length*. + + If *total_length* is None, it is set to the initial total length + ''' + def __init__(self, drift1, drift2, total_length=None, **kwargs): + # store the 2 variable elements + self.drift1 = drift1 + self.drift2 = drift2 + # store the initial total length + if total_length is None: + total_length = drift1.Length + drift2.Length + self.length = total_length + super().__init__(bounds=(0.0, total_length), **kwargs) + + def _setfun(self, value, **kwargs): + self.drift1.Length = value + self.drift2.Length = self.length - value + + def _getfun(self, **kwargs): + return self.drift1.Length + +And create a variable varying the length of drifts *DR_01* and *DR_01* and +keeping their sum constant: + +.. code-block:: python + + drift1 = hmba_lattice["DR_01"] + drift2 = hmba_lattice["DR_02"] + var2 = ElementShifter(drift1, drift2, name="DR_01") + +""" + +from __future__ import annotations + +__all__ = [ + "VariableBase", + "CustomVariable", + "VariableList", +] + +import abc +from collections import deque +from collections.abc import Iterable, Sequence, Callable +from typing import Union + +import numpy as np + +Number = Union[int, float] + + +def _nop(value): + return value + + +class VariableBase(abc.ABC): + """A Variable abstract base class + + Derived classes must implement the :py:meth:`~VariableBase._getfun` and + :py:meth:`~VariableBase._getfun` methods + """ + + _counter = 0 + _prefix = "var" + + def __init__( + self, + *, + name: str = "", + bounds: tuple[Number, Number] = (-np.inf, np.inf), + delta: Number = 1.0, + history_length: int = None, + ring=None, + ): + """ + Parameters: + name: Name of the Variable + bounds: Lower and upper bounds of the variable value + delta: Initial variation step + history_length: Maximum length of the history buffer. :py:obj:`None` + means infinite + ring: provided to an attempt to get the initial value of the + variable + """ + self.name: str = self._setname(name) #: Variable name + self.bounds: tuple[Number, Number] = bounds #: Variable bounds + self.delta: Number = delta #: Increment step + #: Maximum length of the history buffer. :py:obj:`None` means infinite + self.history_length = history_length + self._initial = np.nan + self._history = deque([], self.history_length) + try: + self.get(ring=ring, initial=True) + except ValueError: + pass + + @classmethod + def _setname(cls, name): + cls._counter += 1 + if name: + return name + else: + return f"{cls._prefix}{cls._counter}" + + # noinspection PyUnusedLocal + def _setfun(self, value: Number, ring=None): + classname = self.__class__.__name__ + raise TypeError(f"{classname!r} is read-only") + + @abc.abstractmethod + def _getfun(self, ring=None) -> Number: ... + + @property + def history(self) -> list[Number]: + """History of the values of the variable""" + return list(self._history) + + @property + def initial_value(self) -> Number: + """Initial value of the variable""" + if not np.isnan(self._initial): + return self._initial + else: + raise IndexError(f"{self.name}: No value has been set yet") + + @property + def last_value(self) -> Number: + """Last value of the variable""" + try: + return self._history[-1] + except IndexError as exc: + exc.args = (f"{self.name}: No value has been set yet",) + raise + + @property + def previous_value(self) -> Number: + """Value before the last one""" + try: + return self._history[-2] + except IndexError as exc: + exc.args = (f"{self.name}: history too short",) + raise + + def set(self, value: Number, ring=None) -> None: + """Set the variable value + + Args: + value: New value to be applied on the variable + ring: Depending on the variable type, a :py:class:`.Lattice` argument + may be necessary to set the variable. + """ + if value < self.bounds[0] or value > self.bounds[1]: + raise ValueError(f"set value must be in {self.bounds}") + self._setfun(value, ring=ring) + if np.isnan(self._initial): + self._initial = value + self._history.append(value) + + def get( + self, ring=None, *, initial: bool = False, check_bounds: bool = False + ) -> Number: + """Get the actual variable value + + Args: + ring: Depending on the variable type, a :py:class:`.Lattice` argument + may be necessary to get the variable value. + initial: If :py:obj:`True`, clear the history and set the variable + initial value + check_bounds: If :py:obj:`True`, raise a ValueError if the value is out + of bounds + + Returns: + value: Value of the variable + """ + value = self._getfun(ring=ring) + if initial or np.isnan(self._initial): + self._initial = value + self._history = deque([value], self.history_length) + if check_bounds: + if value < self.bounds[0] or value > self.bounds[1]: + raise ValueError(f"value out of {self.bounds}") + return value + + value = property(get, set, doc="Actual value") + + @property + def _safe_value(self): + try: + v = self._history[-1] + except IndexError: + v = np.nan + return v + + def set_previous(self, ring=None) -> None: + """Reset to the value before the last one + + Args: + ring: Depending on the variable type, a :py:class:`.Lattice` argument + may be necessary to set the variable. + """ + if len(self._history) >= 2: + self._history.pop() # Remove the last value + value = self._history.pop() # retrieve the previous value + self.set(value, ring=ring) + else: + raise IndexError(f"{self.name}: history too short") + + def reset(self, ring=None) -> None: + """Reset to the initial value and clear the history buffer + + Args: + ring: Depending on the variable type, a :py:class:`.Lattice` argument + may be necessary to reset the variable. + """ + iniv = self._initial + if not np.isnan(iniv): + self._history = deque([], self.history_length) + self.set(iniv, ring=ring) + else: + raise IndexError(f"reset {self.name}: No value has been set yet") + + def increment(self, incr: Number, ring=None) -> None: + """Increment the variable value + + Args: + incr: Increment value + ring: Depending on the variable type, a :py:class:`.Lattice` argument + may be necessary to increment the variable. + """ + if self._initial is None: + self.get(ring=ring, initial=True) + self.set(self.last_value + incr, ring=ring) + + def _step(self, step: Number, ring=None) -> None: + if self._initial is None: + self.get(ring=ring, initial=True) + self.set(self._initial + step, ring=ring) + + def step_up(self, ring=None) -> None: + """Set to initial_value + delta + + Args: + ring: Depending on the variable type, a :py:class:`.Lattice` argument + may be necessary to set the variable. + """ + self._step(self.delta, ring=ring) + + def step_down(self, ring=None) -> None: + """Set to initial_value - delta + + Args: + ring: Depending on the variable type, a :py:class:`.Lattice` argument + may be necessary to set the variable. + """ + self._step(-self.delta, ring=ring) + + @staticmethod + def _header(): + return "\n{:>12s}{:>13s}{:>16s}{:>16s}\n".format( + "Name", "Initial", "Final ", "Variation" + ) + + def _line(self): + vnow = self._safe_value + vini = self._initial + + return "{:>12s}{: 16e}{: 16e}{: 16e}".format( + self.name, vini, vnow, (vnow - vini) + ) + + def status(self): + """Return a string describing the current status of the variable + + Returns: + status: Variable description + """ + return "\n".join((self._header(), self._line())) + + def __float__(self): + return float(self._safe_value) + + def __int__(self): + return int(self._safe_value) + + def __str__(self): + return f"{self.__class__.__name__}({self._safe_value}, name={self.name!r})" + + def __repr__(self): + return repr(self._safe_value) + + +class CustomVariable(VariableBase): + r"""A Variable with user-defined get and set functions + + This is a convenience function allowing user-defined *get* and *set* + functions. But subclassing :py:class:`.Variable` should always be preferred + for clarity and efficiency. + + """ + + def __init__( + self, + setfun: Callable, + getfun: Callable, + *args, + name: str = "", + bounds: tuple[Number, Number] = (-np.inf, np.inf), + delta: Number = 1.0, + history_length: int = None, + ring=None, + **kwargs, + ): + """ + Parameters: + getfun: Function for getting the variable value. Called as + :pycode:`getfun(*args, ring=ring, **kwargs) -> Number` + setfun: Function for setting the variable value. Called as + :pycode:`setfun(value: Number, *args, ring=ring, **kwargs): None` + name: Name of the Variable + bounds: Lower and upper bounds of the variable value + delta: Initial variation step + *args: Variable argument list transmitted to both the *getfun* + and *setfun* functions. Such arguments can always be avoided by + using :py:func:`~functools.partial` or callable class objects. + + Keyword Args: + **kwargs: Keyword arguments transmitted to both the *getfun* + and *setfun* functions. Such arguments can always be avoided by + using :py:func:`~functools.partial` or callable class objects. + """ + self.getfun = getfun + self.setfun = setfun + self.args = args + self.kwargs = kwargs + super().__init__( + name=name, + bounds=bounds, + delta=delta, + history_length=history_length, + ring=ring, + ) + + def _getfun(self, ring=None) -> Number: + return self.getfun(*self.args, ring=ring, **self.kwargs) + + def _setfun(self, value: Number, ring=None): + self.setfun(value, *self.args, ring=ring, **self.kwargs) + + +class VariableList(list): + """Container for Variable objects + + :py:class:`VariableList` supports all :py:class:`list` methods, like + appending, insertion or concatenation with the "+" operator. + """ + + def get(self, ring=None, **kwargs) -> Sequence[float]: + r"""Get the current values of Variables + + Args: + ring: Depending on the variable type, a :py:class:`.Lattice` argument + may be necessary to set the variable. + + Keyword Args: + initial: If :py:obj:`True`, set the Variables' + initial value + check_bounds: If :py:obj:`True`, raise a ValueError if the value is out + of bounds + + Returns: + values: 1D array of values of all variables + """ + return np.array([var.get(ring=ring, **kwargs) for var in self]) + + def set(self, values: Iterable[float], ring=None) -> None: + r"""Set the values of Variables + + Args: + values: Iterable of values + ring: Depending on the variable type, a :py:class:`.Lattice` argument + may be necessary to set the variable. + """ + for var, val in zip(self, values): + var.set(val, ring=ring) + + def increment(self, increment: Iterable[float], ring=None) -> None: + r"""Increment the values of Variables + + Args: + increment: Iterable of values + ring: Depending on the variable type, a :py:class:`.Lattice` argument + may be necessary to increment the variable. + """ + for var, incr in zip(self, increment): + var.increment(incr, ring=ring) + + # noinspection PyProtectedMember + def status(self, **kwargs) -> str: + """String description of the variables""" + values = "\n".join(var._line(**kwargs) for var in self) + return "\n".join((VariableBase._header(), values)) + + def __str__(self) -> str: + return self.status() + + @property + def deltas(self) -> Sequence[Number]: + """delta values of the variables""" + return np.array([var.delta for var in self])