## Imports

In [37]:

# number stuff
import pandas as pd
import numpy as np

# visualization
import matplotlib.pyplot as plt
import seaborn as sns

# TF

from tensorflow.keras import layers, Sequential

In [38]:
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer

In [39]:
# Baseline Model

def get_baseline_model(encoded_y_df):
    '''
    Takes a one-hot-encoded y,
    predicts the probability of each species as a proportion of total species
    '''
    new_df = encoded_y_df.copy()
    for column in new_df.columns:
        new_df[column] = encoded_y_df[column].sum()/encoded_y_df.sum().sum()
    return new_df

In [40]:
# Metrics

def compute_average(y_true, y_pred, t):
    """Returns the average number of species observed correctly predicted given a threshold value t"""
    assert t <= 1
    assert t >= 0
    N, C = y_pred.shape
    temp = y_pred[y_true == 1].applymap(lambda x: 1 if x >= t else 0)
    average = temp.values.sum()/N
    return average

def find_t_min(y_true, y_pred, K, rate, t):
    """Returns the minimum threshold t and corresponding average satisfying the condition average <= K. The minimum t is found iteratively, with tuning parameter rate [0-1]"""
    assert rate <= 1
    assert rate >= 0
    assert K > 0
    average = compute_average(y_true, y_pred, t)
    while average <= K:
        t = rate*t
        average = compute_average(y_true, y_pred, t)
    t_min = t/rate
    average = compute_average(y_true, y_pred, t_min)
    return t_min, average

def compute_accuracy(y_true, y_pred, t_min):
    N, C = y_pred.shape
    temp = y_pred[y_true == 1].applymap(lambda x: 1 if x >= t_min else 0)
    return temp.values.sum()/(N*C)

def custom_metric(y_true, y_pred, K, rate, t):
    t_min, average = find_t_min(y_true, y_pred, K, rate, t)
    accuracy = compute_accuracy(y_true, y_pred, t_min)
    return t_min, average, accuracy


# First Modelling

## Import Data

In [41]:
X = pd.read_csv('../raw_data/Experiment_data/coordinates_1000_features.csv')

In [42]:
y = pd.read_csv('../raw_data/Experiment_data/occurences_1000_encoded.csv').dropna()

In [43]:
y['coords'] = y['latitude'].astype(str) + '-' + y['longitude'].astype(str)

In [44]:
X['coords'] = X['latitude'].astype(str) + '-' + X['longitude'].astype(str)

In [107]:
X

Unnamed: 0,latitude,longitude,bio_1,bio_2,bio_3,bio_4,bio_5,bio_6,bio_7,bio_8,...,silt_30-60cm,silt_5-15cm,silt_60-100cm,soc_0-5cm,soc_100-200cm,soc_15-30cm,soc_30-60cm,soc_5-15cm,soc_60-100cm,coords
0,49.800575,7.749696,9.270833e+00,8.125000e+00,3.276210e+01,6.376821e+02,2.330000e+01,-1.500000e+00,2.480000e+01,1.568333e+01,...,447,469,435,623,68,200,76,337,76,49.800575-7.749696
1,50.412086,10.038344,7.483334e+00,7.150000e+00,2.894737e+01,6.752452e+02,2.110000e+01,-3.600000e+00,2.470000e+01,3.500001e-01,...,394,426,392,630,73,216,122,421,77,50.412086-10.038344
2,48.066643,8.995314,6.520834e+00,8.475000e+00,3.284884e+01,6.565386e+02,2.070000e+01,-5.100000e+00,2.580000e+01,1.303333e+01,...,483,493,491,555,77,179,134,382,90,48.066643-8.995314
3,51.443483,7.792170,9.241667e+00,8.366667e+00,3.471646e+01,5.867469e+02,2.270000e+01,-1.400000e+00,2.410000e+01,1.665000e+01,...,0,0,0,0,0,0,0,0,0,51.443483-7.79217
4,54.465248,12.530766,-3.400000e+38,-3.400000e+38,-3.400000e+38,-3.400000e+38,-3.400000e+38,-3.400000e+38,-3.400000e+38,-3.400000e+38,...,220,220,241,1039,553,716,691,857,686,54.465248-12.530766
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1037,47.765477,11.615015,7.983333e+00,9.216667e+00,3.413580e+01,6.828261e+02,2.220000e+01,-4.800000e+00,2.700000e+01,1.626667e+01,...,453,481,450,1129,198,700,263,517,217,47.765477-11.615015
1038,53.798893,7.291886,-3.400000e+38,-3.400000e+38,-3.400000e+38,-3.400000e+38,-3.400000e+38,-3.400000e+38,-3.400000e+38,-3.400000e+38,...,0,0,0,0,0,0,0,0,0,53.798893-7.291886
1039,52.035636,13.745804,8.970834e+00,8.441667e+00,3.197601e+01,6.904559e+02,2.370000e+01,-2.700000e+00,2.640000e+01,1.758333e+01,...,126,131,141,516,47,165,111,193,90,52.035636-13.745804
1040,47.913877,8.077230,5.908333e+00,9.400000e+00,3.700787e+01,5.984633e+02,1.990000e+01,-5.500000e+00,2.540000e+01,-6.666660e-02,...,369,372,377,1089,233,547,343,745,286,47.913877-8.07723


In [46]:
y

Unnamed: 0,latitude,longitude,10071055,11071158,2650625,2650999,2672680,2673408,2679707,2681972,...,9177060,9182154,9206251,9220780,9349855,9458333,9485490,9490132,9557223,coords
0,49.800575,7.749696,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,49.800575-7.749696
1,50.412086,10.038344,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,50.412086-10.038344
2,48.066643,8.995314,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,48.066643-8.995314
3,51.443483,7.792170,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,51.443483-7.79217
4,54.465248,12.530766,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,54.465248-12.530766
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
988,48.197379,11.535155,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,48.197379-11.535155
989,49.294091,12.855801,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,49.294091-12.855801
990,50.615097,6.442054,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,50.615097-6.442054
991,49.768185,8.340547,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,49.768185-8.340547


In [108]:
merged_df = pd.merge(left = y, right = X, on = 'coords', how = 'inner')

In [109]:
merged_df

Unnamed: 0,latitude_x,longitude_x,10071055,11071158,2650625,2650999,2672680,2673408,2679707,2681972,...,silt_15-30cm,silt_30-60cm,silt_5-15cm,silt_60-100cm,soc_0-5cm,soc_100-200cm,soc_15-30cm,soc_30-60cm,soc_5-15cm,soc_60-100cm
0,49.800575,7.749696,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,478,447,469,435,623,68,200,76,337,76
1,50.412086,10.038344,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,409,394,426,392,630,73,216,122,421,77
2,48.066643,8.995314,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,492,483,493,491,555,77,179,134,382,90
3,51.443483,7.792170,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
4,54.465248,12.530766,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,214,220,220,241,1039,553,716,691,857,686
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1086,48.197379,11.535155,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0,0,0,0,0,0,0,0,0,0
1087,49.294091,12.855801,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,388,367,390,350,392,31,127,46,209,28
1088,50.615097,6.442054,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,498,475,540,470,1029,115,290,164,779,130
1089,49.768185,8.340547,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,511,495,512,511,397,43,142,89,202,55


In [110]:
y_used = merged_df[y.columns[2:-1]]

In [111]:
X.columns[:-1]

Index(['latitude', 'longitude', 'bio_1', 'bio_2', 'bio_3', 'bio_4', 'bio_5',
       'bio_6', 'bio_7', 'bio_8', 'bio_9', 'bio_10', 'bio_11', 'bio_12',
       'bio_13', 'bio_14', 'bio_15', 'bio_16', 'bio_17', 'bio_18', 'bio_19',
       'elevation', 'slope', 'bdod_0-5cm', 'bdod_100-200cm', 'bdod_15-30cm',
       'bdod_30-60cm', 'bdod_5-15cm', 'bdod_60-100cm', 'cec_0-5cm',
       'cec_100-200cm', 'cec_15-30cm', 'cec_30-60cm', 'cec_5-15cm',
       'cec_60-100cm', 'cfvo_0-5cm', 'cfvo_100-200cm', 'cfvo_15-30cm',
       'cfvo_30-60cm', 'cfvo_5-15cm', 'cfvo_60-100cm', 'clay_0-5cm',
       'clay_100-200cm', 'clay_15-30cm', 'clay_30-60cm', 'clay_5-15cm',
       'clay_60-100cm', 'nitrogen_0-5cm', 'nitrogen_100-200cm',
       'nitrogen_15-30cm', 'nitrogen_30-60cm', 'nitrogen_5-15cm',
       'nitrogen_60-100cm', 'ocd_0-5cm', 'ocd_100-200cm', 'ocd_15-30cm',
       'ocd_30-60cm', 'ocd_5-15cm', 'ocd_60-100cm', 'ocs_0-30cm',
       'phh2o_0-5cm', 'phh2o_100-200cm', 'phh2o_15-30cm', 'phh2o_30-60cm',
    

In [114]:
X_used = merged_df[['latitude_x', 'longitude_x', 'bio_1', 'bio_2', 'bio_3', 'bio_4', 'bio_5',
       'bio_6', 'bio_7', 'bio_8', 'bio_9', 'bio_10', 'bio_11', 'bio_12',
       'bio_13', 'bio_14', 'bio_15', 'bio_16', 'bio_17', 'bio_18', 'bio_19',
       'elevation', 'slope', 'bdod_0-5cm', 'bdod_100-200cm', 'bdod_15-30cm',
       'bdod_30-60cm', 'bdod_5-15cm', 'bdod_60-100cm', 'cec_0-5cm',
       'cec_100-200cm', 'cec_15-30cm', 'cec_30-60cm', 'cec_5-15cm',
       'cec_60-100cm', 'cfvo_0-5cm', 'cfvo_100-200cm', 'cfvo_15-30cm',
       'cfvo_30-60cm', 'cfvo_5-15cm', 'cfvo_60-100cm', 'clay_0-5cm',
       'clay_100-200cm', 'clay_15-30cm', 'clay_30-60cm', 'clay_5-15cm',
       'clay_60-100cm', 'nitrogen_0-5cm', 'nitrogen_100-200cm',
       'nitrogen_15-30cm', 'nitrogen_30-60cm', 'nitrogen_5-15cm',
       'nitrogen_60-100cm', 'ocd_0-5cm', 'ocd_100-200cm', 'ocd_15-30cm',
       'ocd_30-60cm', 'ocd_5-15cm', 'ocd_60-100cm', 'ocs_0-30cm',
       'phh2o_0-5cm', 'phh2o_100-200cm', 'phh2o_15-30cm', 'phh2o_30-60cm',
       'phh2o_5-15cm', 'phh2o_60-100cm', 'sand_0-5cm', 'sand_100-200cm',
       'sand_15-30cm', 'sand_30-60cm', 'sand_5-15cm', 'sand_60-100cm',
       'silt_0-5cm', 'silt_100-200cm', 'silt_15-30cm', 'silt_30-60cm',
       'silt_5-15cm', 'silt_60-100cm', 'soc_0-5cm', 'soc_100-200cm',
       'soc_15-30cm', 'soc_30-60cm', 'soc_5-15cm', 'soc_60-100cm']]

## Baseline Model

### Creation

In [115]:
y_pred_baseline = get_baseline_model(y_used)

In [116]:
y_pred_baseline

Unnamed: 0,10071055,11071158,2650625,2650999,2672680,2673408,2679707,2681972,2683866,2685484,...,9172281,9177060,9182154,9206251,9220780,9349855,9458333,9485490,9490132,9557223
0,0.000911,0.001821,0.008197,0.000911,0.000911,0.000911,0.000911,0.001821,0.000911,0.000911,...,0.000911,0.000911,0.001821,0.000911,0.004554,0.000911,0.000911,0.001821,0.000911,0.000911
1,0.000911,0.001821,0.008197,0.000911,0.000911,0.000911,0.000911,0.001821,0.000911,0.000911,...,0.000911,0.000911,0.001821,0.000911,0.004554,0.000911,0.000911,0.001821,0.000911,0.000911
2,0.000911,0.001821,0.008197,0.000911,0.000911,0.000911,0.000911,0.001821,0.000911,0.000911,...,0.000911,0.000911,0.001821,0.000911,0.004554,0.000911,0.000911,0.001821,0.000911,0.000911
3,0.000911,0.001821,0.008197,0.000911,0.000911,0.000911,0.000911,0.001821,0.000911,0.000911,...,0.000911,0.000911,0.001821,0.000911,0.004554,0.000911,0.000911,0.001821,0.000911,0.000911
4,0.000911,0.001821,0.008197,0.000911,0.000911,0.000911,0.000911,0.001821,0.000911,0.000911,...,0.000911,0.000911,0.001821,0.000911,0.004554,0.000911,0.000911,0.001821,0.000911,0.000911
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1086,0.000911,0.001821,0.008197,0.000911,0.000911,0.000911,0.000911,0.001821,0.000911,0.000911,...,0.000911,0.000911,0.001821,0.000911,0.004554,0.000911,0.000911,0.001821,0.000911,0.000911
1087,0.000911,0.001821,0.008197,0.000911,0.000911,0.000911,0.000911,0.001821,0.000911,0.000911,...,0.000911,0.000911,0.001821,0.000911,0.004554,0.000911,0.000911,0.001821,0.000911,0.000911
1088,0.000911,0.001821,0.008197,0.000911,0.000911,0.000911,0.000911,0.001821,0.000911,0.000911,...,0.000911,0.000911,0.001821,0.000911,0.004554,0.000911,0.000911,0.001821,0.000911,0.000911
1089,0.000911,0.001821,0.008197,0.000911,0.000911,0.000911,0.000911,0.001821,0.000911,0.000911,...,0.000911,0.000911,0.001821,0.000911,0.004554,0.000911,0.000911,0.001821,0.000911,0.000911


In [117]:
def compute_k(y_only_relevant_columns):
    
    '''
    only include columns with animal species, no lang/lat etc. Perhaps use:
    y[y.columns[2:-1]]
    '''
    return y_only_relevant_columns.sum(axis=1).mean().round().astype(int)

In [118]:
def get_K_most_probable(y_pred_proba, k):
    
    only_highest = y_pred_baseline.stack().groupby(level=0).nlargest(k).unstack().reset_index(level=1, drop=True).reindex(columns=y_pred_baseline.columns)
    
    fillna_highest =  only_highest.fillna(0)
    
    return fillna_highest.astype(bool).astype(int)
    
    
    

In [119]:
k_most = get_K_most_probable(y_pred_baseline, compute_k(y_used))

### Evaluation

In [120]:
baseline_acc = accuracy_score(y_used, k_most)
baseline_acc

0.026581118240146653

In [None]:
custom_metric(y_used, k_most, compute_k(y_used), 0.98, 0.05)

## Data Cleaning

### Standardizer

In [133]:
ll = X_used[X_used.columns[:2]]

In [134]:
wc = X_used[X.columns[2:21]]

In [135]:
gee = X_used[X.columns[21:23]]

In [136]:
sg = X_used[X.columns[23:-1]]

In [137]:
ll_transformer = Pipeline([
    ('scaler', MinMaxScaler())])

wc_transformer = Pipeline([
    ('scaler', RobustScaler())])

gee_transformer = Pipeline([
    ('scaler', StandardScaler())])

sg_transformer = Pipeline([
    ('scaler', StandardScaler())])


In [138]:
preprocessor = ColumnTransformer([
    ('ll', ll_transformer, ll.columns),
    ('wc', wc_transformer, wc.columns),
    ('gee', gee_transformer, gee.columns),
    ('sg', sg_transformer, sg.columns)])

preprocessor

ColumnTransformer(transformers=[('ll',
                                 Pipeline(steps=[('scaler', MinMaxScaler())]),
                                 Index(['latitude_x', 'longitude_x'], dtype='object')),
                                ('wc',
                                 Pipeline(steps=[('scaler', RobustScaler())]),
                                 Index(['bio_1', 'bio_2', 'bio_3', 'bio_4', 'bio_5', 'bio_6', 'bio_7', 'bio_8',
       'bio_9', 'bio_10', 'bio_11', 'bio_12', 'bio_13', 'bio_14', 'bio_15',
       'bio_16', 'bio_17', 'bio_18', 'bio_1...
       'ocd_60-100cm', 'ocs_0-30cm', 'phh2o_0-5cm', 'phh2o_100-200cm',
       'phh2o_15-30cm', 'phh2o_30-60cm', 'phh2o_5-15cm', 'phh2o_60-100cm',
       'sand_0-5cm', 'sand_100-200cm', 'sand_15-30cm', 'sand_30-60cm',
       'sand_5-15cm', 'sand_60-100cm', 'silt_0-5cm', 'silt_100-200cm',
       'silt_15-30cm', 'silt_30-60cm', 'silt_5-15cm', 'silt_60-100cm',
       'soc_0-5cm', 'soc_100-200cm', 'soc_15-30cm', 'soc_30-60cm',
       'soc_5-1

In [139]:
preprocessor.fit(X_used)



ColumnTransformer(transformers=[('ll',
                                 Pipeline(steps=[('scaler', MinMaxScaler())]),
                                 Index(['latitude_x', 'longitude_x'], dtype='object')),
                                ('wc',
                                 Pipeline(steps=[('scaler', RobustScaler())]),
                                 Index(['bio_1', 'bio_2', 'bio_3', 'bio_4', 'bio_5', 'bio_6', 'bio_7', 'bio_8',
       'bio_9', 'bio_10', 'bio_11', 'bio_12', 'bio_13', 'bio_14', 'bio_15',
       'bio_16', 'bio_17', 'bio_18', 'bio_1...
       'ocd_60-100cm', 'ocs_0-30cm', 'phh2o_0-5cm', 'phh2o_100-200cm',
       'phh2o_15-30cm', 'phh2o_30-60cm', 'phh2o_5-15cm', 'phh2o_60-100cm',
       'sand_0-5cm', 'sand_100-200cm', 'sand_15-30cm', 'sand_30-60cm',
       'sand_5-15cm', 'sand_60-100cm', 'silt_0-5cm', 'silt_100-200cm',
       'silt_15-30cm', 'silt_30-60cm', 'silt_5-15cm', 'silt_60-100cm',
       'soc_0-5cm', 'soc_100-200cm', 'soc_15-30cm', 'soc_30-60cm',
       'soc_5-1

In [140]:
transformed_X = pd.DataFrame(preprocessor.transform(X_used))

In [150]:
transformed_X

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,74,75,76,77,78,79,80,81,82,83
0,0.325664,0.186997,1.256545e-01,4.102644e-02,5.182379e-02,-8.750195e-02,1.764706e-01,6.250000e-02,-8.333333e-02,-1.630438e-01,...,0.969546,0.853344,0.854345,0.825814,0.422566,-0.061279,0.170641,-0.215458,0.447578,-0.045743
1,0.409919,0.447096,-1.371728e+00,-1.158973e+00,-2.050674e+00,4.414463e-01,-1.117647e+00,-1.250000e+00,-1.250000e-01,-5.163043e+00,...,0.654063,0.608424,0.664991,0.624272,0.440931,-0.019306,0.264334,0.116249,0.833052,-0.037578
2,0.086761,0.328559,-2.178011e+00,4.717955e-01,9.963299e-02,1.780283e-01,-1.352941e+00,-2.187500e+00,3.333338e-01,-1.027174e+00,...,1.033557,1.019705,0.960031,1.088287,0.244166,0.014273,0.047668,0.202781,0.654082,0.068570
3,0.552027,0.191824,1.012222e-01,3.384626e-01,1.128977e+00,-8.047489e-01,-1.764706e-01,1.250000e-01,-3.750000e-01,1.521737e-01,...,-1.215977,-1.212307,-1.210939,-1.213041,-1.211892,-0.632123,-1.000529,-0.763496,-1.098905,-0.666304
4,0.968370,0.730353,-2.848169e+38,-4.184614e+38,-1.873919e+38,-4.787737e+36,-2.000000e+38,-2.125000e+38,-1.416667e+38,-1.108696e+38,...,-0.237521,-0.195655,-0.242149,-0.083469,1.513953,4.010179,3.192258,4.219322,2.833843,4.935071
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1086,0.104774,0.617205,-2.478183e-01,6.256403e-01,-1.858719e-01,5.801639e-01,5.882353e-02,-1.187500e+00,6.666667e-01,3.478258e-01,...,-1.215977,-1.212307,-1.210939,-1.213041,-1.211892,-0.632123,-1.000529,-0.763496,-1.098905,-0.666304
1087,0.255880,0.767293,-1.347295e+00,9.435910e-01,1.075189e-02,7.732457e-01,-2.941176e-01,-1.875000e+00,8.750000e-01,5.434565e-03,...,0.558046,0.483653,0.506461,0.427417,-0.183469,-0.371886,-0.256836,-0.431789,-0.139810,-0.437676
1088,0.437890,0.038387,-8.062831e-01,-1.979486e+00,-1.613780e+00,-1.033454e+00,-1.588235e+00,6.250000e-02,-1.333333e+00,-4.804348e+00,...,1.060991,0.982736,1.167000,0.989860,1.487718,0.333274,0.697667,0.419112,2.475904,0.395181
1089,0.321202,0.254146,1.071555e+00,6.051284e-01,1.900110e-01,3.531491e-01,1.235294e+00,4.375000e-01,4.166667e-01,3.369564e-01,...,1.120429,1.075159,1.043699,1.182028,-0.170351,-0.271148,-0.168998,-0.121715,-0.171933,-0.217214


In [173]:
y_used

Unnamed: 0,10071055,11071158,2650625,2650999,2672680,2673408,2679707,2681972,2683866,2685484,...,9172281,9177060,9182154,9206251,9220780,9349855,9458333,9485490,9490132,9557223
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1086,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1087,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1088,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1089,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


## Create Model

In [164]:
y_used.shape[1]

492

In [170]:
def init_model(input_dim, output_dim, metrics):
    model = Sequential([
        layers.Dense(10, input_dim=input_dim, activation='relu'),
        layers.Dense(10, activation='relu'),
        layers.Dense(10, activation='relu'),
        layers.Dense(output_dim, activation='softmax'),
    ])

    model.compile(
        loss='categorical_crossentropy',
        optimizer='adam',
        metrics=[metrics])

    return model

In [171]:
model = init_model(transformed_X.shape[1], y_used.shape[1], 'accuracy')

In [172]:
model.fit(transformed_X, y_used, batch_size=16, epochs=100)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100
Epoch 98/100
Epoch 99/100
Epoch 100/100


<tensorflow.python.keras.callbacks.History at 0x190efb820>

In [159]:
    model = Sequential()
    
    model.add(layers.Dense(100, input_dim=transformed_X.shape[1], activation='relu'))  # /!\ Must specify input size
    model.add(layers.Dense(10, activation='relu'))
    model.add(layers.Dense(10, activation='relu'))
    model.add(layers.Dense(y_used.shape[1], activation='softmax')) # /!\ Must correspond to the task at hand

    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics = ['accuracy'])

    model.fit(transformed_X, y_used, batch_size=2, epochs=100)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

Epoch 80/100
Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100
Epoch 98/100
Epoch 99/100
Epoch 100/100


<tensorflow.python.keras.callbacks.History at 0x190dd7a30>

In [151]:
model.summary()

Model: "sequential_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_44 (Dense)             (None, 100)               8500      
_________________________________________________________________
dense_45 (Dense)             (None, 10)                1010      
_________________________________________________________________
dense_46 (Dense)             (None, 10)                110       
_________________________________________________________________
dense_47 (Dense)             (None, 492)               5412      
Total params: 15,032
Trainable params: 15,032
Non-trainable params: 0
_________________________________________________________________


In [155]:
transformed_X.head(1)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,74,75,76,77,78,79,80,81,82,83
0,0.325664,0.186997,0.125655,0.041026,0.051824,-0.087502,0.176471,0.0625,-0.083333,-0.163044,...,0.969546,0.853344,0.854345,0.825814,0.422566,-0.061279,0.170641,-0.215458,0.447578,-0.045743


In [157]:
pd.DataFrame(model.predict(transformed_X.head(1)))

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,482,483,484,485,486,487,488,489,490,491
0,,,,,,,,,,,...,,,,,,,,,,
