In [7]:
%load_ext autoreload
%autoreload 2
import sys
from pathlib import Path
sys.path.insert(1, str(Path.cwd().parent))
str(Path.cwd().parent)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


'/home/ximo/Documents/GitHub/skforecast'

In [8]:
import pandas as pd
import numpy as np
from typing import Union, Tuple

from skforecast.utils import preprocess_exog

In [9]:
def fix_exog_dtypes(
    exog: Union[pd.Series, pd.DataFrame],
    exog_dtypes: dict,
) -> Union[pd.Series, pd.DataFrame]:
    """
    Cast `exog` to a specified types.
    If `exog` is a pandas Series, `exog_dtypes` must be a dict with a single value.

    Parameters
    ----------        
    exog : pandas Series, pandas DataFrame
        Exogenous variables.
    exog_dtypes: dict
        Dictionary with name and type of the series or data frame columns.

    Returns 
    -------
    exog

    """
    
    if isinstance(exog, pd.Series) and exog.dtypes != list(exog_dtypes.values())[0]:
            exog = exog.astype(list(exog_dtypes.values())[0])
    elif isinstance(exog, pd.DataFrame):
        for col, initial_dtype in exog_dtypes.items():
            if exog[col].dtypes != initial_dtype:
                if initial_dtype == "category" and exog[col].dtypes==float:
                    exog[col] = exog[col].astype(int).astype("category")
                else:
                    exog[col] = exog[col].astype(initial_dtype)

    return exog

In [10]:
exog = pd.DataFrame({
    'col1': [2,3,5],
    'col2': [4,5,6]
})

In [11]:
exog = pd.DataFrame({
    'col1': [2,3,5],
    'col2': [4,5,6]
})
_, _, exog_dtypes = preprocess_exog(exog)
print(exog.dtypes)
print("")


exog = exog.astype(float)
print(exog.dtypes)
print("")

exog = fix_exog_dtypes(exog, exog_dtypes)
print(exog.dtypes)
print("")

exog_dtypes['col1'] = str
exog = fix_exog_dtypes(exog, exog_dtypes)
print(exog.dtypes)


col1    int64
col2    int64
dtype: object

col1    float64
col2    float64
dtype: object

col1    int64
col2    int64
dtype: object

col1    object
col2     int64
dtype: object


In [12]:
exog = pd.Series([1,2,3], dtype=int)
_, _, exog_dtypes = preprocess_exog(exog)
print(exog.dtypes)
print("")


exog = exog.astype(float)
print(exog.dtypes)
print("")

exog = fix_exog_dtypes(exog, exog_dtypes)
print(exog.dtypes)
print("")

exog_dtypes = {None: str}
exog = fix_exog_dtypes(exog, exog_dtypes)
print(exog.dtypes)

int64

float64

int64

object
