# Training

## Training diagram

````{div} full-width
```{mermaid}
sequenceDiagram
    autonumber
    participant Agent
    participant RL Method
        Note left of RL Method: SVR, Actor-Critic...
    participant Environment

    loop Episode
        Agent-->>+RL Method: Start training (Data, Initial State)
        loop Step
            RL Method-->>+Environment: Select an action following its exploration strategy
            Environment-->>-RL Method: Return next state, action, reward and done flag
            RL Method->>RL Method: Store transition to memory
        end
        RL Method->>RL Method: Update model
        RL Method-->>-Agent: Returns episode reward
    end
```
````

## Example

Training an Agent powered by SVR model on 6000 datasets split between binary classification, linear and poisson regression problems.

In [1]:
from docs.workflows.utils.generate_training_datasets import generate_training_datasets
datasets = generate_training_datasets(6000)

Agent training:

In [2]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import warnings
warnings.filterwarnings('ignore')

from ostatslib.agents import Agent
from ostatslib.reinforcement_learning_methods import SupportVectorRegression

agent = Agent(rl_method=SupportVectorRegression())
for index, (dataset_type, dataset) in enumerate(datasets):
    agent.train(dataset)

agent.rl_method.fit()

Checking Agent analysis.

Getting first dataset for each trained dataset type:

In [3]:
from datacooker.recipes import LogitRecipe, PoissonRecipe, Recipe

logistic_regression_dataset = [dataset for dataset_type, dataset in datasets if dataset_type == LogitRecipe][0]
linear_regression_dataset = [dataset for dataset_type, dataset in datasets if dataset_type == Recipe][0]
poisson_regression_dataset = [dataset for (dataset_type, dataset) in datasets if dataset_type == PoissonRecipe][0]

- Binary classification:

In [4]:
analysis = agent.analyze(logistic_regression_dataset)
print(analysis.summary())


Analysis executed at 2023-01-17 23:10:10.852220
Final status is Complete
Initial State known features:

Steps:
  Order  Step                             Reward  State Change
-------  -----------------------------  --------  --------------------------
      1  is_response_dichotomous_check  0.5       is_response_dichotomous  1
      2  LogisticRegressionCV(cv=5)     0.778261  score  0.678261


- Regression:

In [5]:
analysis = agent.analyze(linear_regression_dataset)
print(analysis.summary())


Analysis executed at 2023-01-17 23:10:11.148657
Final status is Complete
Initial State known features:

Steps:
  Order  Step                             Reward  State Change
-------  -----------------------------  --------  ---------------------------
      1  is_response_dichotomous_check       0.5  is_response_dichotomous  -1
      2  is_response_quantitative            0.5  is_response_quantitative  1
      3  is_response_discrete_check          0.5  is_response_discrete  -1
      4  DecisionTreeRegressor()             1    score  0.978191


- Poisson Regression

In [6]:
analysis = agent.analyze(poisson_regression_dataset)
print(analysis.summary())


Analysis executed at 2023-01-17 23:10:11.381655
Final status is Complete
Initial State known features:

Steps:
  Order  Step                                                                                        Reward  State Change
-------  ----------------------------------------------------------------------------------------  --------  -----------------------------------
      1  is_response_dichotomous_check                                                             0.5       is_response_dichotomous  -1
      2  is_response_quantitative                                                                  0.5       is_response_quantitative  1
      3  is_response_discrete_check                                                                0.5       is_response_discrete  1
      4  time_convertable_variable                                                                 0.5       time_convertable_variable
      5  is_response_positive_values_only_check                                