## A simple tutorial for basic uses of classes
(object-oriented programming)

In [1]:
import numpy as np

### Class as containers (with `Enum`)

You can also define classes just as containers (with `Enum`); in this case, you don't have `self` because you are not going to define multiple class instances, just a single one

In [2]:
from enum import Enum
from typing import Type, Optional

`Regularization` acts a "default dictionary" let' say, a dictionary which can include two keys: `force_field_reg` and `forward_model_reg`, to specify the regularizations to the force-field correction and the forward model, respectively;

- the first key is either a string (among `plain l2`, `constraint 1`, `constraint 2`, `KL divergence`) or a user-defined function which takes as input `pars_ff` and returns the regularization term to be multiplied by the hyperparameter `beta`;
- the second key is a user-defined function which takes as input `pars_fm` and `forward_coeffs_0` (current and refined forward-model coefficients) and returns the regularization term to be multiplied by the hyperparameter `gamma`.

#### 1. define `Force_field_reg` and `Forward_model_reg` with `Enum`

In [None]:
class Force_field_reg(Enum):
    PLAIN_L2 = 'plain l2'
    CONSTRAINT_1 = 'constraint 1'
    CONSTRAINT_2 = 'contraint 2'
    KL_DIVERGENCE = 'KL divergence'
    CUSTOM = lambda x : np.linalg.norm(x)**2  # example of custom function

In [4]:
Force_field_reg.PLAIN_L2 = 7

# correctly, you cannot assign to Force_field_reg an arbitrary attribute value!

AttributeError: Cannot reassign members.

In [5]:
ff_reg = Force_field_reg.PLAIN_L2

if ff_reg.value == 'plain l2': print('yes')

yes


In [6]:
class Forward_model_reg(Enum):
    PLAIN_L2 = 'plain l2'
    CUSTOM = lambda x : np.linalg.norm(x)**2

# class Regularization:
#     force_field_reg = Force_field_reg
#     forward_model_reg = Forward_model_reg
    
#     assert type(force_field_reg)

#### 2. define `Regularization`

you can also define the `Regularization` class with `Type`

optimal way to define it:

In [8]:
class Regularization:
    def __init__(self, force_field_reg : Optional[Force_field_reg] = None,
        forward_model_reg: Optional[Forward_model_reg] = None):
        
        if not (isinstance(force_field_reg, Force_field_reg) or (force_field_reg is None)):
            raise TypeError("force_field_reg must be an instance of Force_field_reg")
        self.force_field_reg = force_field_reg

        if not (isinstance(forward_model_reg, Forward_model_reg) or (forward_model_reg is None)):
            raise TypeError("forward_model_reg must be an instance of Forward_model_reg")
        self.forward_model_reg = forward_model_reg


#### 3. usage

In [12]:
reg = Regularization(force_field_reg=Force_field_reg.KL_DIVERGENCE)

print(reg.force_field_reg.value)
print(reg.forward_model_reg)

KL divergence
None


In [13]:
reg = Regularization(forward_model_reg=Forward_model_reg.PLAIN_L2)

print(reg.forward_model_reg.value)

plain l2


In [14]:
reg1 = Regularization(Force_field_reg.KL_DIVERGENCE)
reg2 = Regularization(None)  # now allowed
print(reg1.force_field_reg)
print(reg2.force_field_reg)

Force_field_reg.KL_DIVERGENCE
None


In [15]:
reg = Regularization(None)

print(reg.force_field_reg)

None


In [17]:
try:
    reg.force_field_reg = Force_field_reg.PLAIN_L2 = my_fun
    # print(reg.force_field_reg)
except:
    print('correctly, it returns error')

correctly, it returns error


In [21]:
print(reg.force_field_reg)

None
