# Evaluation Functions

## Abstract Class

ReactEA Evaluation Functions Abstract Class.

Child Classes must implement all abstract methods (`get_fitness_single` and `method_str`).

In [3]:
from joblib import Parallel, delayed
from rdkit.Chem import Mol
from typing import Union, List
from abc import ABC, abstractmethod


class ChemicalEvaluationFunction(ABC):
    """
    Base class for chemical evaluation functions.
    Child classes must implement the get_fitness and method_str methods.
    """

    def __init__(self, maximize: bool = True, worst_fitness: float = 0.0):
        """
        Initializes the Chemical Evaluation Function class.

        Parameters
        ----------
        maximize: bool
            If it is a maximization problem.
        worst_fitness: float
            The worst fitness that can given to a solution.
        """
        self.maximize = maximize
        self.worst_fitness = worst_fitness

    def get_fitness(self, candidates: Union[Mol, List[Mol]]):
        """
        Evaluates the fitness of the candidate(s).

        Parameters
        ----------
        candidates: Union[Mol, List[Mol]]
            The candidate(s) to evaluate.

        Returns
        -------
        List[float]
            The fitness(es) of the candidate(s).
        """
        if isinstance(candidates, Mol):
            candidates = [candidates]
        return Parallel(n_jobs=-1, backend="multiprocessing")(delayed(self.get_fitness_single)(candidate)
                                                              for candidate in candidates)

    @abstractmethod
    def get_fitness_single(self, candidate: Mol):
        """
        Get fitness of a single solution.

        Parameters
        ----------
        candidate: Mol
            Mol object to get fitness from.
        Returns
        -------
        float
            Fitness of the Mol object
        """
        raise NotImplementedError

    @abstractmethod
    def method_str(self):
        """
        Get name of the evaluation function.

        Returns
        -------
        str:
            name of the evaluation function.
        """
        raise NotImplementedError

    def __str__(self):
        return self.method_str()

    def __call__(self, candidate: Union[Mol, List[Mol]]):
        return self.get_fitness(candidate)

## Example of how to implement you own evaluation functions

ReactEA already has some [default evaluation functions](https://github.com/BioSystemsUM/ReactEA/blob/main/src/reactea/optimization/evaluation.py), however it is easy to implement your own.

In ReactEA evaluation functions act on RDKit Mol objects.

### Using the evaluation function wrapper:

You need to provide:
   - a callable that returns the score of a single Mol object;
  - whether the objective is to maximize or minimize the score;
  - the worst fitness when an invalid Mol is generated or when the evaluation function cannot be calculated;
  - the name of the evaluation function.

In [4]:
from rdkit.Chem.QED import qed
from reactea import evaluation_functions_wrapper

# dummy evaluation function (always returns 1)
def dummy_eval_f(mol):
    return 1

f1 = evaluation_functions_wrapper(dummy_eval_f, maximize=True, worst_fitness=0, name='dummy_eval_f')

# evaluation function returning the drug-likeliness score (QED) of a molecule
def qed_score(mol):
    return qed(mol)

f2 = evaluation_functions_wrapper(qed_score, maximize=True, worst_fitness=0.0, name='qed')

[10:58:35] Initializing Normalizer


## Creating your own class:

The class needs to inherit from the `ChemicalEvaluationFunction` class.
It needs to implement its abstract methods (`get_fitness_single` and `method_str`).

In [5]:
import numpy as np
from rdkit.Chem.Descriptors import MolWt

# evaluation class to maximize QED
class QED(ChemicalEvaluationFunction):

    def __init__(self, maximize: bool = True, worst_fitness: float = 0.0):
        super(QED, self).__init__(maximize, worst_fitness)

    def get_fitness_single(self, candidate: Mol):
        try:
            return qed(candidate)
        except:
            return self.worst_fitness

    def method_str(self):
        return "QED"

# evaluation class to optimize molecules to have molecular weight between range
class MolecularWeight(ChemicalEvaluationFunction):

    def __init__(self,
                 min_weight: float = 300.0,
                 max_weight: float = 900,
                 maximize: bool = True,
                 worst_fitness: float = 0.0):
        super(MolecularWeight, self).__init__(maximize, worst_fitness)
        self.min_weight = min_weight
        self.max_weight = max_weight

    def _mol_weight(self, mol: Mol):
        try:
            mw = MolWt(mol)
            if mw < self.min_weight:
                # increasingly penalize molecules with molecular weight lower than the defined minimum
                return np.cos((mw - self.min_weight+200) / 320)
            elif mw < self.max_weight:
                return 1.0
            else:
                # increasingly penalize molecules with molecular weight greater than the defined maximum
                return 1.0 / np.log(mw / 250.0)
        except Exception:
            return self.worst_fitness

    def get_fitness_single(self, candidate: Mol):
        return self._mol_weight(candidate)

    def method_str(self):
        return "MolecularWeight"