In [1]:
import numpy as np
import pandas as pd
from pathlib import Path
import sys
import logging
from pprint import pprint

from sklearn.model_selection import KFold

sys.path.append(Path('.').absolute().parent.resolve().as_posix())
sys.path.append((Path('.').absolute().parent / 'source').resolve().as_posix())
from data_loader import config_loader, data_preprocessing
from logging_util.logger import get_logger
from models import estimators
from evaluation import evaluation

config = config_loader.load_config()
logger = get_logger(__name__)
logging.getLogger("PIL").setLevel(logging.WARNING)
logging.getLogger('matplotlib').setLevel(logging.WARNING)

2024-01-12 10:04:20,919 - DEBUG - config_loader.py:18 - Loading data configuration
2024-01-12 10:04:20,929 - INFO - __init__.py:19 - Loading HistGradientBoostRegressor
2024-01-12 10:04:21,238 - DEBUG - __init__.py:350 - matplotlib data path: c:\Users\steve\anaconda3\envs\fertilizer_usage\lib\site-packages\matplotlib\mpl-data
2024-01-12 10:04:21,245 - DEBUG - __init__.py:350 - CONFIGDIR=C:\Users\steve\.matplotlib
2024-01-12 10:04:21,249 - DEBUG - __init__.py:1511 - interactive is False
2024-01-12 10:04:21,250 - DEBUG - __init__.py:1512 - platform is win32
2024-01-12 10:04:21,378 - DEBUG - __init__.py:350 - CACHEDIR=C:\Users\steve\.matplotlib
2024-01-12 10:04:21,389 - DEBUG - font_manager.py:1574 - Using fontManager instance from C:\Users\steve\.matplotlib\fontlist-v330.json
2024-01-12 10:04:22,340 - DEBUG - config_loader.py:18 - Loading data configuration
2024-01-12 10:04:22,344 - DEBUG - config_loader.py:18 - Loading data configuration
2024-01-12 10:04:22,348 - INFO - __init__.py:19 - 

In [2]:
estimators

{'HistGradientBoostRegressor': models.HistGradientBoostRegressor.estimator.HistGradientBoostEstimator,
 'XgboostRegressor': models.XgboostRegressor.estimator.XGBoostEstimator}

In [3]:
target = 'K2O_avg_app'
config.target = target

In [4]:
resultspath = Path().resolve().parent.resolve() / 'results' / target
resultspath

WindowsPath('C:/Users/steve/OneDrive - Universiteit Antwerpen/Documenten/00IDLab/projects/crop_database_imputation/results/K2O_avg_app')

In [5]:
# TODO: add type hints

def ohe(input_data, config):
    
    datapath = Path().resolve().parent.resolve() / 'data'
    dtype = data_preprocessing.get_data_types(config)
    all_data = pd.read_csv(f'{datapath}/{config.csv_files["features"]}', dtype=dtype, usecols=dtype.keys()) 

    input_data_ohe = data_preprocessing.one_hot_encoding(input_data)

    return input_data_ohe

def load_input_data(resultspath: Path, it: int = 0, test: bool = False, config=config):
    dtype = data_preprocessing.get_data_types(config)
    X = pd.read_csv(resultspath / f'data/{"test" if test else "train"}_set_{it}.csv', index_col=0, dtype=dtype)
    return X

def make_prediction(estimator, X, config):
    # drop the last column, which is the target
    if config.target in X.columns:
        X = X.drop(config.target, axis=1)
    X = ohe(X, config)
    y_pred = pd.Series(
            estimator.model.predict(X),
            index=X.index,
            name='predicted_' + config.target,
        )
    return y_pred

for estimator_name, estimator_cls in estimators.items():
    print(estimator_name)
    estimator = estimator_cls(outpath=resultspath)
    estimator = estimator.load(path=estimator.output_path, it=0)
    X = load_input_data(resultspath, it=0, test=True, config=config)    
    y_pred = make_prediction(estimator, X, config)

    # check if the predictions match
    print((y_pred != estimator.y_pred).any())



HistGradientBoostRegressor
False
XgboostRegressor


ValueError: feature_names mismatch: ['Year', 'Area_report', 'Fert_perc', 'Area_fert', 'Area_FAOStat', 'Perc_Agr_Area_FAOStat', 'GDP_Capita_WB', 'GDP_perCapita_UN', 'Urea_Price', 'P_rock_Price', 'K2O_Price', 'Crop_Global_Price', 'Country_total_N', 'Country_total_P2O5', 'Country_total_K2O', 'Area_Irrig', 'Edu_Exp', 'pop_pressure', 'Avg_Size_Hold', 'Avg_Size_Hold_stand', 'MAP', 'N_removal', 'P_removal', 'K_removal', 'N_removal_ha', 'P_removal_ha', 'K_removal_ha', 'Avg_sqKm', 'K2O_perc_area', 'P_cost', 'N_cost', 'K_cost', 'soil_clay', 'soil_ph', 'soil_sand', 'soil_silt', 'soil_cec', 'soil_nitrogen', 'soil_ocs', 'AI', 'K2O_ha_cropland', 'P2O5_ha_cropland', 'N_ha_cropland', 'PET', 'TMN', 'Crop_Price_Real', 'Crop_Price_Nominal', 'n_mac_ag', 'Crop_Code_1_1', 'Crop_Code_1_2', 'Crop_Code_1_3', 'Crop_Code_1_4', 'Crop_Code_2_1', 'Crop_Code_2_2', 'Crop_Code_2_3', 'Crop_Code_3_1', 'Crop_Code_3_2', 'Crop_Code_4', 'Crop_Code_5', 'Crop_Code_6', 'Crop_Code_7', 'Region_Name_Australia and New Zealand', 'Region_Name_Caribbean', 'Region_Name_Central America', 'Region_Name_Central Asia', 'Region_Name_Eastern Africa', 'Region_Name_Eastern Asia', 'Region_Name_Eastern Europe', 'Region_Name_Melanesia', 'Region_Name_Middle Africa', 'Region_Name_Northern Africa', 'Region_Name_Northern America', 'Region_Name_Northern Europe', 'Region_Name_South America', 'Region_Name_South-eastern Asia', 'Region_Name_Southern Africa', 'Region_Name_Southern Asia', 'Region_Name_Southern Europe', 'Region_Name_Western Africa', 'Region_Name_Western Asia', 'Region_Name_Western Europe', 'FAOStat_area_code_10', 'FAOStat_area_code_100', 'FAOStat_area_code_101', 'FAOStat_area_code_102', 'FAOStat_area_code_104', 'FAOStat_area_code_105', 'FAOStat_area_code_106', 'FAOStat_area_code_107', 'FAOStat_area_code_11', 'FAOStat_area_code_110', 'FAOStat_area_code_112', 'FAOStat_area_code_114', 'FAOStat_area_code_115', 'FAOStat_area_code_117', 'FAOStat_area_code_118', 'FAOStat_area_code_119', 'FAOStat_area_code_12', 'FAOStat_area_code_120', 'FAOStat_area_code_121', 'FAOStat_area_code_122', 'FAOStat_area_code_126', 'FAOStat_area_code_129', 'FAOStat_area_code_130', 'FAOStat_area_code_131', 'FAOStat_area_code_133', 'FAOStat_area_code_134', 'FAOStat_area_code_137', 'FAOStat_area_code_138', 'FAOStat_area_code_143', 'FAOStat_area_code_146', 'FAOStat_area_code_149', 'FAOStat_area_code_150', 'FAOStat_area_code_156', 'FAOStat_area_code_157', 'FAOStat_area_code_158', 'FAOStat_area_code_159', 'FAOStat_area_code_16', 'FAOStat_area_code_162', 'FAOStat_area_code_165', 'FAOStat_area_code_166', 'FAOStat_area_code_167', 'FAOStat_area_code_169', 'FAOStat_area_code_170', 'FAOStat_area_code_171', 'FAOStat_area_code_173', 'FAOStat_area_code_174', 'FAOStat_area_code_181', 'FAOStat_area_code_183', 'FAOStat_area_code_185', 'FAOStat_area_code_19', 'FAOStat_area_code_194', 'FAOStat_area_code_195', 'FAOStat_area_code_198', 'FAOStat_area_code_199', 'FAOStat_area_code_2', 'FAOStat_area_code_201', 'FAOStat_area_code_202', 'FAOStat_area_code_203', 'FAOStat_area_code_206', 'FAOStat_area_code_208', 'FAOStat_area_code_21', 'FAOStat_area_code_210', 'FAOStat_area_code_211', 'FAOStat_area_code_212', 'FAOStat_area_code_214', 'FAOStat_area_code_215', 'FAOStat_area_code_216', 'FAOStat_area_code_217', 'FAOStat_area_code_221', 'FAOStat_area_code_222', 'FAOStat_area_code_223', 'FAOStat_area_code_225', 'FAOStat_area_code_229', 'FAOStat_area_code_230', 'FAOStat_area_code_231', 'FAOStat_area_code_233', 'FAOStat_area_code_234', 'FAOStat_area_code_235', 'FAOStat_area_code_236', 'FAOStat_area_code_237', 'FAOStat_area_code_238', 'FAOStat_area_code_246', 'FAOStat_area_code_247', 'FAOStat_area_code_249', 'FAOStat_area_code_251', 'FAOStat_area_code_255', 'FAOStat_area_code_27', 'FAOStat_area_code_28', 'FAOStat_area_code_29', 'FAOStat_area_code_3', 'FAOStat_area_code_32', 'FAOStat_area_code_33', 'FAOStat_area_code_37', 'FAOStat_area_code_38', 'FAOStat_area_code_39', 'FAOStat_area_code_4', 'FAOStat_area_code_40', 'FAOStat_area_code_41', 'FAOStat_area_code_44', 'FAOStat_area_code_48', 'FAOStat_area_code_50', 'FAOStat_area_code_52', 'FAOStat_area_code_53', 'FAOStat_area_code_54', 'FAOStat_area_code_56', 'FAOStat_area_code_57', 'FAOStat_area_code_58', 'FAOStat_area_code_59', 'FAOStat_area_code_60', 'FAOStat_area_code_63', 'FAOStat_area_code_66', 'FAOStat_area_code_67', 'FAOStat_area_code_68', 'FAOStat_area_code_7', 'FAOStat_area_code_75', 'FAOStat_area_code_78', 'FAOStat_area_code_79', 'FAOStat_area_code_81', 'FAOStat_area_code_84', 'FAOStat_area_code_89', 'FAOStat_area_code_9', 'FAOStat_area_code_90', 'FAOStat_area_code_91', 'FAOStat_area_code_95', 'FAOStat_area_code_97', 'FAOStat_area_code_98'] ['Year', 'Area_report', 'Fert_perc', 'Area_fert', 'Area_FAOStat', 'Perc_Agr_Area_FAOStat', 'GDP_Capita_WB', 'GDP_perCapita_UN', 'Urea_Price', 'P_rock_Price', 'K2O_Price', 'Crop_Global_Price', 'Country_total_N', 'Country_total_P2O5', 'Country_total_K2O', 'Area_Irrig', 'Edu_Exp', 'pop_pressure', 'Avg_Size_Hold', 'Avg_Size_Hold_stand', 'MAP', 'N_removal', 'P_removal', 'K_removal', 'N_removal_ha', 'P_removal_ha', 'K_removal_ha', 'Avg_sqKm', 'K2O_perc_area', 'P_cost', 'N_cost', 'K_cost', 'soil_clay', 'soil_ph', 'soil_sand', 'soil_silt', 'soil_cec', 'soil_nitrogen', 'soil_ocs', 'AI', 'K2O_ha_cropland', 'P2O5_ha_cropland', 'N_ha_cropland', 'PET', 'TMN', 'Crop_Price_Real', 'Crop_Price_Nominal', 'n_mac_ag', 'Crop_Code_1_1', 'Crop_Code_1_2', 'Crop_Code_1_3', 'Crop_Code_1_4', 'Crop_Code_2_1', 'Crop_Code_2_2', 'Crop_Code_2_3', 'Crop_Code_3_1', 'Crop_Code_3_2', 'Crop_Code_4', 'Crop_Code_5', 'Crop_Code_6', 'Crop_Code_7', 'Region_Name_Australia and New Zealand', 'Region_Name_Caribbean', 'Region_Name_Central America', 'Region_Name_Central Asia', 'Region_Name_Eastern Africa', 'Region_Name_Eastern Asia', 'Region_Name_Eastern Europe', 'Region_Name_Melanesia', 'Region_Name_Micronesia', 'Region_Name_Middle Africa', 'Region_Name_Northern Africa', 'Region_Name_Northern America', 'Region_Name_Northern Europe', 'Region_Name_Polynesia', 'Region_Name_South America', 'Region_Name_South-eastern Asia', 'Region_Name_Southern Africa', 'Region_Name_Southern Asia', 'Region_Name_Southern Europe', 'Region_Name_Western Africa', 'Region_Name_Western Asia', 'Region_Name_Western Europe', 'FAOStat_area_code_1', 'FAOStat_area_code_10', 'FAOStat_area_code_100', 'FAOStat_area_code_101', 'FAOStat_area_code_102', 'FAOStat_area_code_103', 'FAOStat_area_code_104', 'FAOStat_area_code_105', 'FAOStat_area_code_106', 'FAOStat_area_code_107', 'FAOStat_area_code_108', 'FAOStat_area_code_109', 'FAOStat_area_code_11', 'FAOStat_area_code_110', 'FAOStat_area_code_112', 'FAOStat_area_code_113', 'FAOStat_area_code_114', 'FAOStat_area_code_115', 'FAOStat_area_code_116', 'FAOStat_area_code_117', 'FAOStat_area_code_118', 'FAOStat_area_code_119', 'FAOStat_area_code_12', 'FAOStat_area_code_120', 'FAOStat_area_code_121', 'FAOStat_area_code_122', 'FAOStat_area_code_123', 'FAOStat_area_code_124', 'FAOStat_area_code_126', 'FAOStat_area_code_127', 'FAOStat_area_code_128', 'FAOStat_area_code_129', 'FAOStat_area_code_13', 'FAOStat_area_code_130', 'FAOStat_area_code_131', 'FAOStat_area_code_132', 'FAOStat_area_code_133', 'FAOStat_area_code_134', 'FAOStat_area_code_135', 'FAOStat_area_code_136', 'FAOStat_area_code_137', 'FAOStat_area_code_138', 'FAOStat_area_code_14', 'FAOStat_area_code_141', 'FAOStat_area_code_143', 'FAOStat_area_code_144', 'FAOStat_area_code_145', 'FAOStat_area_code_146', 'FAOStat_area_code_147', 'FAOStat_area_code_148', 'FAOStat_area_code_149', 'FAOStat_area_code_150', 'FAOStat_area_code_153', 'FAOStat_area_code_154', 'FAOStat_area_code_155', 'FAOStat_area_code_156', 'FAOStat_area_code_157', 'FAOStat_area_code_158', 'FAOStat_area_code_159', 'FAOStat_area_code_16', 'FAOStat_area_code_160', 'FAOStat_area_code_162', 'FAOStat_area_code_165', 'FAOStat_area_code_166', 'FAOStat_area_code_167', 'FAOStat_area_code_168', 'FAOStat_area_code_169', 'FAOStat_area_code_170', 'FAOStat_area_code_171', 'FAOStat_area_code_173', 'FAOStat_area_code_174', 'FAOStat_area_code_175', 'FAOStat_area_code_176', 'FAOStat_area_code_177', 'FAOStat_area_code_178', 'FAOStat_area_code_179', 'FAOStat_area_code_18', 'FAOStat_area_code_181', 'FAOStat_area_code_182', 'FAOStat_area_code_183', 'FAOStat_area_code_184', 'FAOStat_area_code_185', 'FAOStat_area_code_186', 'FAOStat_area_code_188', 'FAOStat_area_code_189', 'FAOStat_area_code_19', 'FAOStat_area_code_191', 'FAOStat_area_code_193', 'FAOStat_area_code_194', 'FAOStat_area_code_195', 'FAOStat_area_code_196', 'FAOStat_area_code_197', 'FAOStat_area_code_198', 'FAOStat_area_code_199', 'FAOStat_area_code_2', 'FAOStat_area_code_20', 'FAOStat_area_code_200', 'FAOStat_area_code_201', 'FAOStat_area_code_202', 'FAOStat_area_code_203', 'FAOStat_area_code_206', 'FAOStat_area_code_207', 'FAOStat_area_code_208', 'FAOStat_area_code_209', 'FAOStat_area_code_21', 'FAOStat_area_code_210', 'FAOStat_area_code_211', 'FAOStat_area_code_212', 'FAOStat_area_code_213', 'FAOStat_area_code_214', 'FAOStat_area_code_215', 'FAOStat_area_code_216', 'FAOStat_area_code_217', 'FAOStat_area_code_218', 'FAOStat_area_code_219', 'FAOStat_area_code_220', 'FAOStat_area_code_221', 'FAOStat_area_code_222', 'FAOStat_area_code_223', 'FAOStat_area_code_225', 'FAOStat_area_code_226', 'FAOStat_area_code_227', 'FAOStat_area_code_228', 'FAOStat_area_code_229', 'FAOStat_area_code_23', 'FAOStat_area_code_230', 'FAOStat_area_code_231', 'FAOStat_area_code_233', 'FAOStat_area_code_234', 'FAOStat_area_code_235', 'FAOStat_area_code_236', 'FAOStat_area_code_237', 'FAOStat_area_code_238', 'FAOStat_area_code_244', 'FAOStat_area_code_246', 'FAOStat_area_code_247', 'FAOStat_area_code_248', 'FAOStat_area_code_249', 'FAOStat_area_code_25', 'FAOStat_area_code_250', 'FAOStat_area_code_251', 'FAOStat_area_code_255', 'FAOStat_area_code_256', 'FAOStat_area_code_26', 'FAOStat_area_code_27', 'FAOStat_area_code_272', 'FAOStat_area_code_273', 'FAOStat_area_code_276', 'FAOStat_area_code_277', 'FAOStat_area_code_28', 'FAOStat_area_code_29', 'FAOStat_area_code_299', 'FAOStat_area_code_3', 'FAOStat_area_code_32', 'FAOStat_area_code_33', 'FAOStat_area_code_35', 'FAOStat_area_code_351', 'FAOStat_area_code_37', 'FAOStat_area_code_38', 'FAOStat_area_code_39', 'FAOStat_area_code_4', 'FAOStat_area_code_40', 'FAOStat_area_code_41', 'FAOStat_area_code_44', 'FAOStat_area_code_45', 'FAOStat_area_code_46', 'FAOStat_area_code_47', 'FAOStat_area_code_48', 'FAOStat_area_code_49', 'FAOStat_area_code_50', 'FAOStat_area_code_51', 'FAOStat_area_code_52', 'FAOStat_area_code_53', 'FAOStat_area_code_54', 'FAOStat_area_code_55', 'FAOStat_area_code_56', 'FAOStat_area_code_57', 'FAOStat_area_code_58', 'FAOStat_area_code_59', 'FAOStat_area_code_60', 'FAOStat_area_code_61', 'FAOStat_area_code_62', 'FAOStat_area_code_63', 'FAOStat_area_code_64', 'FAOStat_area_code_66', 'FAOStat_area_code_67', 'FAOStat_area_code_68', 'FAOStat_area_code_69', 'FAOStat_area_code_7', 'FAOStat_area_code_70', 'FAOStat_area_code_72', 'FAOStat_area_code_73', 'FAOStat_area_code_74', 'FAOStat_area_code_75', 'FAOStat_area_code_78', 'FAOStat_area_code_79', 'FAOStat_area_code_8', 'FAOStat_area_code_80', 'FAOStat_area_code_81', 'FAOStat_area_code_83', 'FAOStat_area_code_84', 'FAOStat_area_code_86', 'FAOStat_area_code_87', 'FAOStat_area_code_89', 'FAOStat_area_code_9', 'FAOStat_area_code_90', 'FAOStat_area_code_91', 'FAOStat_area_code_93', 'FAOStat_area_code_95', 'FAOStat_area_code_96', 'FAOStat_area_code_97', 'FAOStat_area_code_98', 'FAOStat_area_code_99']
training data did not have the following fields: FAOStat_area_code_45, FAOStat_area_code_189, FAOStat_area_code_70, FAOStat_area_code_55, FAOStat_area_code_256, FAOStat_area_code_46, FAOStat_area_code_273, FAOStat_area_code_228, FAOStat_area_code_73, FAOStat_area_code_124, Region_Name_Polynesia, FAOStat_area_code_250, FAOStat_area_code_272, FAOStat_area_code_147, FAOStat_area_code_96, FAOStat_area_code_196, FAOStat_area_code_93, FAOStat_area_code_186, FAOStat_area_code_153, FAOStat_area_code_1, FAOStat_area_code_51, FAOStat_area_code_26, FAOStat_area_code_227, FAOStat_area_code_244, FAOStat_area_code_35, FAOStat_area_code_47, FAOStat_area_code_14, FAOStat_area_code_86, FAOStat_area_code_193, FAOStat_area_code_103, FAOStat_area_code_219, FAOStat_area_code_154, FAOStat_area_code_148, FAOStat_area_code_218, FAOStat_area_code_109, FAOStat_area_code_213, FAOStat_area_code_176, FAOStat_area_code_226, FAOStat_area_code_99, FAOStat_area_code_209, FAOStat_area_code_49, FAOStat_area_code_175, FAOStat_area_code_178, FAOStat_area_code_277, FAOStat_area_code_61, FAOStat_area_code_207, FAOStat_area_code_25, FAOStat_area_code_69, FAOStat_area_code_13, FAOStat_area_code_168, FAOStat_area_code_276, FAOStat_area_code_135, FAOStat_area_code_20, FAOStat_area_code_80, Region_Name_Micronesia, FAOStat_area_code_200, FAOStat_area_code_64, FAOStat_area_code_116, FAOStat_area_code_127, FAOStat_area_code_83, FAOStat_area_code_299, FAOStat_area_code_72, FAOStat_area_code_136, FAOStat_area_code_18, FAOStat_area_code_108, FAOStat_area_code_184, FAOStat_area_code_62, FAOStat_area_code_23, FAOStat_area_code_145, FAOStat_area_code_182, FAOStat_area_code_351, FAOStat_area_code_113, FAOStat_area_code_132, FAOStat_area_code_160, FAOStat_area_code_188, FAOStat_area_code_87, FAOStat_area_code_74, FAOStat_area_code_141, FAOStat_area_code_144, FAOStat_area_code_155, FAOStat_area_code_177, FAOStat_area_code_123, FAOStat_area_code_8, FAOStat_area_code_128, FAOStat_area_code_197, FAOStat_area_code_220, FAOStat_area_code_191, FAOStat_area_code_179, FAOStat_area_code_248

In [6]:
full_dataset = data_preprocessing.load_all_data(config=config)

for estimator_name, estimator_cls in estimators.items():
    if 'Hist' in estimator_name:
        print(estimator_name)
        estimator = estimator_cls(outpath=resultspath)
        estimator = estimator.load(path=estimator.output_path, it=0)  
        y_pred = make_prediction(estimator, full_dataset, config)

HistGradientBoostRegressor


In [7]:
y_pred

0         15.526538
1          5.672649
2         36.002052
3         22.684568
4          3.375614
            ...    
150018    76.304565
150019    76.973701
150020    73.475768
150021    73.595969
150022    62.826732
Name: predicted_K2O_avg_app, Length: 150023, dtype: float64

In [8]:
full_dataset.shape

(150023, 51)

In [11]:
# add the predictions to the full dataset
# full_dataset = full_dataset.join(y_pred)
it = 0
full_dataset.to_csv(resultspath / f'{target}_{it}_predictions.csv')

In [None]:
# TODO: for the predictions, we need to combine them from the two folds somehow
# we could do this by taking the prediction from folds where samples were test samples
# for those not in the train/test set, we take the average of the two predictions
# or weigh it using the r2/rmse/... of the two predictions