In [1]:
import sys
import numpy as np
import pandas as pd
import gc
import re

from catboost import CatBoostClassifier
from sklearn.metrics import f1_score

import warnings
warnings.filterwarnings(action='ignore')

In [2]:
def process_input_df(fpath):
    tdf = pd.read_parquet(fpath)
    tdf['CT'].replace(0, np.nan, inplace=True)
    tdf['CT'].fillna(tdf['OUTPUT'], inplace=True)
    
    tdf = tdf[tdf.isna().sum(axis=1)==0]
    tdf = tdf[tdf['OUTPUT']!=0].reset_index(drop=True)

    for tcol in ['start_date','end_date','valid_date']:
        if tcol in tdf.columns:
            tdf[tcol] = pd.to_datetime(tdf[tcol])

    tdf['is_crop'] = tdf['OUTPUT']==11
    
    tdf.set_index(['location_id','ref_id','pixelids'], inplace=True)

    return tdf 

In [3]:
tpath = '/vitodata/worldcereal/features/features-presto-monthly-nointerp/annual_CIB/'

trn_df = process_input_df('{}CAL/training_df_LC.parquet'.format(tpath))
val_df = process_input_df('{}VAL/training_df_LC.parquet'.format(tpath))

In [4]:
presto_df_trn = pd.read_parquet('{}CAL/training_df_LC_presto-worldcereal.parquet'.format(tpath))
presto_df_val = pd.read_parquet('{}VAL/training_df_LC_presto-worldcereal.parquet'.format(tpath))

presto_df_trn.reset_index(inplace=True)
presto_df_val.reset_index(inplace=True)

presto_df_trn.set_index(['location_id','ref_id','pixelids'], inplace=True)
presto_df_val.set_index(['location_id','ref_id','pixelids'], inplace=True)

presto_emb_colnames = [xx for xx in presto_df_val.columns if 'presto_ft' in xx]

In [5]:
trn_df = pd.concat([trn_df,presto_df_trn[presto_emb_colnames]], join='inner', axis=1)
val_df = pd.concat([val_df,presto_df_val[presto_emb_colnames]], join='inner', axis=1)

In [6]:
del presto_df_trn, presto_df_val
gc.collect()

0

In [7]:
n_months = 12

optical12_feats = [xx for xx in trn_df.columns if re.search(r'OPTICAL.*ts({})-'.format('|'.join(map(str, list(range(n_months))))), xx)]
sar12_feats = [xx for xx in trn_df.columns if re.search(r'SAR.*ts({})-'.format('|'.join(map(str, list(range(n_months))))), xx)]
temp12_feats  = [xx for xx in trn_df.columns if re.search(r'METEO-temp.*ts({})-'.format('|'.join(map(str, list(range(n_months))))), xx)]
prcp12_feats  = [xx for xx in trn_df.columns if re.search(r'METEO-precip.*ts({})-'.format('|'.join(map(str, list(range(n_months))))), xx)]
dem_feats = ['DEM-alt-20m', 'DEM-slo-20m']
latlon_feats = ['lat','lon']
presto_emb_colnames = [xx for xx in trn_df.columns if 'presto' in xx]

In [19]:
label_colname = 'is_crop'

# tfeatures = optical12_feats + sar12_feats + temp12_feats + prcp12_feats + dem_feats + latlon_feats
# tfeatures = presto_emb_colnames
tfeatures = optical12_feats + sar12_feats + temp12_feats + prcp12_feats + dem_feats + latlon_feats + presto_emb_colnames

train_years = [2017,2018,2019,2020]
test_aez = [22190]

X_trn_df = trn_df[
    (trn_df['end_date'].dt.year.isin(train_years)) & 
    (~trn_df['aez_zoneid'].isin(test_aez))][tfeatures]
y_trn_df = trn_df[
    (trn_df['end_date'].dt.year.isin(train_years)) & 
    (~trn_df['aez_zoneid'].isin(test_aez))][label_colname]

X_test = val_df[tfeatures]
X_trn_not_used = trn_df[
    (~trn_df['end_date'].dt.year.isin(train_years)) | 
    (trn_df['aez_zoneid'].isin(test_aez))][tfeatures].reset_index(drop=True)
X_test = pd.concat([X_test,X_trn_not_used], axis=0)

y_test = val_df[label_colname]
y_trn_not_used = trn_df[
    (~trn_df['end_date'].dt.year.isin(train_years)) | 
    (trn_df['aez_zoneid'].isin(test_aez))][label_colname].reset_index(drop=True)
y_test = pd.concat([y_test,y_trn_not_used], axis=0)

In [20]:
model = CatBoostClassifier(
    iterations=500, 
    depth=8,
    eval_metric='F1',
    learning_rate=0.3,
    l2_leaf_reg=100,
    verbose=50,
    random_seed=42,
    )

model.fit(X_trn_df, y_trn_df)
pred = model.predict(X_test).flatten()
pred = np.array([xx=='True' for xx in pred])

0:	learn: 0.6050283	total: 492ms	remaining: 4m 5s
50:	learn: 0.7609113	total: 18.2s	remaining: 2m 40s
100:	learn: 0.7864034	total: 36s	remaining: 2m 22s
150:	learn: 0.8010536	total: 53.9s	remaining: 2m 4s
200:	learn: 0.8127373	total: 1m 12s	remaining: 1m 47s
250:	learn: 0.8217994	total: 1m 30s	remaining: 1m 29s
300:	learn: 0.8298410	total: 1m 48s	remaining: 1m 11s
350:	learn: 0.8364297	total: 2m 6s	remaining: 53.8s
400:	learn: 0.8421372	total: 2m 24s	remaining: 35.7s
450:	learn: 0.8473681	total: 2m 42s	remaining: 17.7s
499:	learn: 0.8524563	total: 3m	remaining: 0us


In [21]:
f1_score(y_test.values, pred)

0.6656457653534164

In [22]:
preds_df = pd.DataFrame([y_test.values, pred]).transpose()
preds_df.columns = ['true','pred']

for attr in ['aez_zoneid','end_date']:
    preds_df[attr] = list(val_df[attr].values) + list(trn_df[
        (~trn_df['end_date'].dt.year.isin(train_years)) | 
        (trn_df['aez_zoneid'].isin(test_aez))][attr].values)

preds_df['year'] = preds_df['end_date'].dt.year

In [None]:
# run with raw features
preds_df.groupby('aez_zoneid').apply(lambda xx: pd.Series({
      'n_pixels': xx['aez_zoneid'].count(),
      'f1': f1_score(xx['true'], xx['pred'])
  })).sort_values(by='n_pixels', ascending=False).iloc[:5]

Unnamed: 0_level_0,n_pixels,f1
aez_zoneid,Unnamed: 1_level_1,Unnamed: 2_level_1
22190,255347.0,0.748207
46172,85541.0,0.779596
43153,42640.0,0.609698
12048,15700.0,0.821351
46173,13921.0,0.733736


In [18]:
# run with presto features
preds_df.groupby('aez_zoneid').apply(lambda xx: pd.Series({
      'n_pixels': xx['aez_zoneid'].count(),
      'f1': f1_score(xx['true'], xx['pred'])
  })).sort_values(by='n_pixels', ascending=False).iloc[:5]

Unnamed: 0_level_0,n_pixels,f1
aez_zoneid,Unnamed: 1_level_1,Unnamed: 2_level_1
22190,255347.0,0.573395
46172,85541.0,0.753528
43153,42640.0,0.591042
12048,15700.0,0.802526
46173,13921.0,0.731675


In [23]:
# run with raw+presto features
preds_df.groupby('aez_zoneid').apply(lambda xx: pd.Series({
      'n_pixels': xx['aez_zoneid'].count(),
      'f1': f1_score(xx['true'], xx['pred'])
  })).sort_values(by='n_pixels', ascending=False).iloc[:5]

Unnamed: 0_level_0,n_pixels,f1
aez_zoneid,Unnamed: 1_level_1,Unnamed: 2_level_1
22190,255347.0,0.679833
46172,85541.0,0.78138
43153,42640.0,0.612898
12048,15700.0,0.834195
46173,13921.0,0.738312
