In [1]:
from collections import defaultdict
import pathlib
import pickle

import numpy as np
import pandas as pd
import random
import seaborn as sns
import torch

from train import train
from utils.training import make_stats_dataframe, TrainConfig

  return torch._C._cuda_getDeviceCount() > 0


In [2]:
SEED = 42
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

In [3]:
TIMESTAMPS_PER_MONTH = 10

In [4]:
dumps_path = pathlib.Path("dumps")
if not dumps_path.exists():
    dumps_path.mkdir()

# Plan

## Data

1. Years: 2018–2022.
2. Each year: N fields.
3. Each field:
  - features: time series of 10 Sentinel-2 bands (median value for each field), number of timestamps is different;
  - target: crop class label (13 classes).

## Models

1. Classical ML:
  - Random Forest;
  - Catboost;
  - LightGBM.
2. Deep learning:
  - Transformer;
  - TempCNN;
  - *EarlyRNN*.
 
## Training&evaluation workflow

1. For all models besides EarlyRNN: make datasets of reduced size (1–6 months, where 6 months is full-length time series), train models separately on each of them.
2. For EarlyRNN: train using full-length time series.
3. Compare accuracy, precision, recall, f1-score, kappa.

# Part 1. Training

## Classical Machine Learning

### Random Forest

In [7]:
rf_hyperparameters = {
    "n_estimators": range(10, 510, 50),
    "max_depth": range(3, 10, 3),
}

rf_results = defaultdict(dict)

for n_months in range(1, 7):
    sequencelength = n_months * TIMESTAMPS_PER_MONTH
    rf_train_params = TrainConfig(model="rf", year=2018,
                       n_months=n_months,
                       sequencelength=sequencelength,
                       hyperparameters=rf_hyperparameters)
    best_model, stats = train(rf_train_params)
    rf_results[n_months]["best_model"] = best_model
    rf_results[n_months]["stats"] = stats

Cache is activated and will be used if possible
Data: train, year: 2018
Trying to use cache
Loading X and y from cache
Russia dataset for 2018 year (train part) is loaded. It contains 7367 fields
Cache is activated and will be used if possible
Data: test, year: 2018
Trying to use cache
Loading X and y from cache
Russia dataset for 2018 year (test part) is loaded. It contains 1566 fields
X shape: (7367, 100) y shape: (7367,)
Fitting 3 folds for each of 30 candidates, totalling 90 fits
[CV] END .......................max_depth=3, n_estimators=10; total time=   0.2s
[CV] END .......................max_depth=3, n_estimators=10; total time=   0.2s
[CV] END .......................max_depth=3, n_estimators=10; total time=   0.2s
[CV] END .......................max_depth=3, n_estimators=60; total time=   1.2s
[CV] END .......................max_depth=3, n_estimators=60; total time=   1.1s
[CV] END .......................max_depth=3, n_estimators=60; total time=   1.2s
[CV] END ................

[CV] END .......................max_depth=3, n_estimators=10; total time=   0.3s
[CV] END .......................max_depth=3, n_estimators=10; total time=   0.3s
[CV] END .......................max_depth=3, n_estimators=10; total time=   0.3s
[CV] END .......................max_depth=3, n_estimators=60; total time=   1.6s
[CV] END .......................max_depth=3, n_estimators=60; total time=   1.5s
[CV] END .......................max_depth=3, n_estimators=60; total time=   1.5s
[CV] END ......................max_depth=3, n_estimators=110; total time=   2.7s
[CV] END ......................max_depth=3, n_estimators=110; total time=   2.7s
[CV] END ......................max_depth=3, n_estimators=110; total time=   2.7s
[CV] END ......................max_depth=3, n_estimators=160; total time=   3.9s
[CV] END ......................max_depth=3, n_estimators=160; total time=   3.9s
[CV] END ......................max_depth=3, n_estimators=160; total time=   3.9s
[CV] END ...................

[CV] END ......................max_depth=3, n_estimators=110; total time=   3.2s
[CV] END ......................max_depth=3, n_estimators=110; total time=   3.2s
[CV] END ......................max_depth=3, n_estimators=110; total time=   3.2s
[CV] END ......................max_depth=3, n_estimators=160; total time=   4.7s
[CV] END ......................max_depth=3, n_estimators=160; total time=   4.6s
[CV] END ......................max_depth=3, n_estimators=160; total time=   4.6s
[CV] END ......................max_depth=3, n_estimators=210; total time=   6.0s
[CV] END ......................max_depth=3, n_estimators=210; total time=   6.0s
[CV] END ......................max_depth=3, n_estimators=210; total time=   6.0s
[CV] END ......................max_depth=3, n_estimators=260; total time=   7.4s
[CV] END ......................max_depth=3, n_estimators=260; total time=   7.4s
[CV] END ......................max_depth=3, n_estimators=260; total time=   7.7s
[CV] END ...................

[CV] END ......................max_depth=3, n_estimators=210; total time=   5.5s
[CV] END ......................max_depth=3, n_estimators=210; total time=   5.6s
[CV] END ......................max_depth=3, n_estimators=210; total time=   5.7s
[CV] END ......................max_depth=3, n_estimators=260; total time=   7.1s
[CV] END ......................max_depth=3, n_estimators=260; total time=   7.0s
[CV] END ......................max_depth=3, n_estimators=260; total time=   6.9s
[CV] END ......................max_depth=3, n_estimators=310; total time=   9.7s
[CV] END ......................max_depth=3, n_estimators=310; total time=   8.6s
[CV] END ......................max_depth=3, n_estimators=310; total time=   9.0s
[CV] END ......................max_depth=3, n_estimators=360; total time=   9.6s
[CV] END ......................max_depth=3, n_estimators=360; total time=   9.7s
[CV] END ......................max_depth=3, n_estimators=360; total time=   9.6s
[CV] END ...................

[CV] END ......................max_depth=3, n_estimators=310; total time=   6.8s
[CV] END ......................max_depth=3, n_estimators=310; total time=   6.8s
[CV] END ......................max_depth=3, n_estimators=310; total time=   6.9s
[CV] END ......................max_depth=3, n_estimators=360; total time=   8.1s
[CV] END ......................max_depth=3, n_estimators=360; total time=   8.2s
[CV] END ......................max_depth=3, n_estimators=360; total time=   8.1s
[CV] END ......................max_depth=3, n_estimators=410; total time=   9.0s
[CV] END ......................max_depth=3, n_estimators=410; total time=   9.2s
[CV] END ......................max_depth=3, n_estimators=410; total time=   9.0s
[CV] END ......................max_depth=3, n_estimators=460; total time=  10.6s
[CV] END ......................max_depth=3, n_estimators=460; total time=  10.5s
[CV] END ......................max_depth=3, n_estimators=460; total time=  14.3s
[CV] END ...................

[CV] END ......................max_depth=3, n_estimators=410; total time=   8.4s
[CV] END ......................max_depth=3, n_estimators=410; total time=   8.4s
[CV] END ......................max_depth=3, n_estimators=410; total time=   8.5s
[CV] END ......................max_depth=3, n_estimators=460; total time=   9.4s
[CV] END ......................max_depth=3, n_estimators=460; total time=   9.5s
[CV] END ......................max_depth=3, n_estimators=460; total time=   9.7s
[CV] END .......................max_depth=6, n_estimators=10; total time=   0.4s
[CV] END .......................max_depth=6, n_estimators=10; total time=   0.4s
[CV] END .......................max_depth=6, n_estimators=10; total time=   0.4s
[CV] END .......................max_depth=6, n_estimators=60; total time=   2.3s
[CV] END .......................max_depth=6, n_estimators=60; total time=   2.2s
[CV] END .......................max_depth=6, n_estimators=60; total time=   2.2s
[CV] END ...................

In [8]:
with open(dumps_path / "rf_results.dump", "wb") as f:
    pickle.dump(rf_results, f)

### LightGBM

In [None]:
lgbm_hyperparameters = {
    "boosting_type": ("dart",),
    "n_estimators": range(10, 310, 50),
    "max_depth": range(3, 10, 3),
    "learning_rate": [0.001, 0.01, 0.1, 1],
}

lgbm_results = defaultdict(dict)

for n_months in range(1, 7):
    sequencelength = n_months * TIMESTAMPS_PER_MONTH
    lgbm_train_params = TrainConfig(model="lightgbm", year=2018,
                       n_months=n_months,
                       sequencelength=sequencelength,
                       hyperparameters=lgbm_hyperparameters)
    best_model, stats = train(lgbm_train_params)
    lgbm_results[n_months]["best_model"] = best_model
    lgbm_results[n_months]["stats"] = stats

Cache is activated and will be used if possible
Data: train, year: 2018
Trying to use cache
Loading X and y from cache
Russia dataset for 2018 year (train part) is loaded. It contains 7367 fields
Cache is activated and will be used if possible
Data: test, year: 2018
Trying to use cache
Loading X and y from cache
Russia dataset for 2018 year (test part) is loaded. It contains 1566 fields
X shape: (7367, 100) y shape: (7367,)
Fitting 3 folds for each of 72 candidates, totalling 216 fits
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=3, n_estimators=10; total time=   0.4s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=3, n_estimators=10; total time=   0.4s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=3, n_estimators=10; total time=   0.4s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=3, n_estimators=60; total time=   2.1s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=3, n_estimators=60; total time=   2.0s
[CV] END boosti

[CV] END boosting_type=dart, learning_rate=0.01, max_depth=6, n_estimators=110; total time=  19.1s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=6, n_estimators=110; total time=  22.7s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=6, n_estimators=110; total time=  21.9s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=6, n_estimators=160; total time=  31.0s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=6, n_estimators=160; total time=  18.6s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=6, n_estimators=160; total time=  18.8s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=6, n_estimators=210; total time=  23.9s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=6, n_estimators=210; total time=  23.2s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=6, n_estimators=210; total time=  25.1s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=6, n_estimators=260; total time=  31.6s
[CV] END b

[CV] END boosting_type=dart, learning_rate=1, max_depth=3, n_estimators=10; total time=   0.5s
[CV] END boosting_type=dart, learning_rate=1, max_depth=3, n_estimators=10; total time=   0.4s
[CV] END boosting_type=dart, learning_rate=1, max_depth=3, n_estimators=10; total time=   0.4s
[CV] END boosting_type=dart, learning_rate=1, max_depth=3, n_estimators=60; total time=   1.1s
[CV] END boosting_type=dart, learning_rate=1, max_depth=3, n_estimators=60; total time=   1.2s
[CV] END boosting_type=dart, learning_rate=1, max_depth=3, n_estimators=60; total time=   1.3s
[CV] END boosting_type=dart, learning_rate=1, max_depth=3, n_estimators=110; total time=   1.8s
[CV] END boosting_type=dart, learning_rate=1, max_depth=3, n_estimators=110; total time=   2.2s
[CV] END boosting_type=dart, learning_rate=1, max_depth=3, n_estimators=110; total time=   1.9s
[CV] END boosting_type=dart, learning_rate=1, max_depth=3, n_estimators=160; total time=   2.5s
[CV] END boosting_type=dart, learning_rate=1, 

[CV] END boosting_type=dart, learning_rate=0.001, max_depth=6, n_estimators=110; total time=  23.3s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=6, n_estimators=160; total time=  37.0s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=6, n_estimators=160; total time=  34.6s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=6, n_estimators=160; total time=  36.0s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=6, n_estimators=210; total time=  49.0s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=6, n_estimators=210; total time=  48.2s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=6, n_estimators=210; total time= 1.1min
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=6, n_estimators=260; total time= 1.3min
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=6, n_estimators=260; total time= 1.1min
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=6, n_estimators=260; total time=  55.0s


[CV] END boosting_type=dart, learning_rate=0.1, max_depth=3, n_estimators=10; total time=   1.0s
[CV] END boosting_type=dart, learning_rate=0.1, max_depth=3, n_estimators=10; total time=   0.9s
[CV] END boosting_type=dart, learning_rate=0.1, max_depth=3, n_estimators=60; total time=   5.7s
[CV] END boosting_type=dart, learning_rate=0.1, max_depth=3, n_estimators=60; total time=   4.6s
[CV] END boosting_type=dart, learning_rate=0.1, max_depth=3, n_estimators=60; total time=   4.7s
[CV] END boosting_type=dart, learning_rate=0.1, max_depth=3, n_estimators=110; total time=   9.5s
[CV] END boosting_type=dart, learning_rate=0.1, max_depth=3, n_estimators=110; total time=  10.1s
[CV] END boosting_type=dart, learning_rate=0.1, max_depth=3, n_estimators=110; total time=   8.5s
[CV] END boosting_type=dart, learning_rate=0.1, max_depth=3, n_estimators=160; total time=  17.3s
[CV] END boosting_type=dart, learning_rate=0.1, max_depth=3, n_estimators=160; total time=  15.0s
[CV] END boosting_type=da

[CV] END boosting_type=dart, learning_rate=1, max_depth=6, n_estimators=210; total time=   5.1s
[CV] END boosting_type=dart, learning_rate=1, max_depth=6, n_estimators=260; total time=   6.5s
[CV] END boosting_type=dart, learning_rate=1, max_depth=6, n_estimators=260; total time=   6.9s
[CV] END boosting_type=dart, learning_rate=1, max_depth=6, n_estimators=260; total time=   5.9s
[CV] END boosting_type=dart, learning_rate=1, max_depth=9, n_estimators=10; total time=   1.9s
[CV] END boosting_type=dart, learning_rate=1, max_depth=9, n_estimators=10; total time=   1.9s
[CV] END boosting_type=dart, learning_rate=1, max_depth=9, n_estimators=10; total time=   1.8s
[CV] END boosting_type=dart, learning_rate=1, max_depth=9, n_estimators=60; total time=   3.1s
[CV] END boosting_type=dart, learning_rate=1, max_depth=9, n_estimators=60; total time=   3.3s
[CV] END boosting_type=dart, learning_rate=1, max_depth=9, n_estimators=60; total time=   2.8s
[CV] END boosting_type=dart, learning_rate=1, 

[CV] END boosting_type=dart, learning_rate=0.01, max_depth=3, n_estimators=60; total time=   5.6s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=3, n_estimators=60; total time=   5.5s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=3, n_estimators=60; total time=   5.6s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=3, n_estimators=110; total time=  10.4s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=3, n_estimators=110; total time=  10.3s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=3, n_estimators=110; total time=  10.3s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=3, n_estimators=160; total time=  15.5s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=3, n_estimators=160; total time=  15.2s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=3, n_estimators=160; total time=  15.3s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=3, n_estimators=210; total time=  20.8s
[CV] END boos

[CV] END boosting_type=dart, learning_rate=0.1, max_depth=6, n_estimators=260; total time= 1.2min
[CV] END boosting_type=dart, learning_rate=0.1, max_depth=6, n_estimators=260; total time= 1.2min
[CV] END boosting_type=dart, learning_rate=0.1, max_depth=6, n_estimators=260; total time= 1.2min
[CV] END boosting_type=dart, learning_rate=0.1, max_depth=9, n_estimators=10; total time=   3.4s
[CV] END boosting_type=dart, learning_rate=0.1, max_depth=9, n_estimators=10; total time=   3.4s
[CV] END boosting_type=dart, learning_rate=0.1, max_depth=9, n_estimators=10; total time=   3.3s
[CV] END boosting_type=dart, learning_rate=0.1, max_depth=9, n_estimators=60; total time=  19.0s
[CV] END boosting_type=dart, learning_rate=0.1, max_depth=9, n_estimators=60; total time=  18.8s
[CV] END boosting_type=dart, learning_rate=0.1, max_depth=9, n_estimators=60; total time=  18.4s
[CV] END boosting_type=dart, learning_rate=0.1, max_depth=9, n_estimators=110; total time=  35.3s
[CV] END boosting_type=dar

[CV] END boosting_type=dart, learning_rate=0.001, max_depth=3, n_estimators=60; total time=   6.4s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=3, n_estimators=110; total time=  10.9s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=3, n_estimators=110; total time=  10.9s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=3, n_estimators=110; total time=  11.1s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=3, n_estimators=160; total time=  15.9s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=3, n_estimators=160; total time=  19.7s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=3, n_estimators=160; total time=  22.5s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=3, n_estimators=210; total time=  26.8s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=3, n_estimators=210; total time=  31.5s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=3, n_estimators=210; total time=  24.5s
[

[CV] END boosting_type=dart, learning_rate=0.01, max_depth=6, n_estimators=260; total time= 1.3min
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=6, n_estimators=260; total time= 1.4min
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=9, n_estimators=10; total time=   3.6s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=9, n_estimators=10; total time=   3.6s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=9, n_estimators=10; total time=   3.4s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=9, n_estimators=60; total time=  20.3s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=9, n_estimators=60; total time=  20.1s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=9, n_estimators=60; total time=  20.7s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=9, n_estimators=110; total time=  38.1s
[CV] END boosting_type=dart, learning_rate=0.01, max_depth=9, n_estimators=110; total time=  37.0s
[CV] END boostin

[CV] END boosting_type=dart, learning_rate=1, max_depth=3, n_estimators=160; total time=   5.8s
[CV] END boosting_type=dart, learning_rate=1, max_depth=3, n_estimators=160; total time=   6.4s
[CV] END boosting_type=dart, learning_rate=1, max_depth=3, n_estimators=210; total time=   7.8s
[CV] END boosting_type=dart, learning_rate=1, max_depth=3, n_estimators=210; total time=   7.3s
[CV] END boosting_type=dart, learning_rate=1, max_depth=3, n_estimators=210; total time=   7.7s
[CV] END boosting_type=dart, learning_rate=1, max_depth=3, n_estimators=260; total time=   9.4s
[CV] END boosting_type=dart, learning_rate=1, max_depth=3, n_estimators=260; total time=   9.1s
[CV] END boosting_type=dart, learning_rate=1, max_depth=3, n_estimators=260; total time=   9.4s
[CV] END boosting_type=dart, learning_rate=1, max_depth=6, n_estimators=10; total time=   2.1s
[CV] END boosting_type=dart, learning_rate=1, max_depth=6, n_estimators=10; total time=   2.2s
[CV] END boosting_type=dart, learning_rate

[CV] END boosting_type=dart, learning_rate=0.001, max_depth=9, n_estimators=10; total time=   4.1s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=9, n_estimators=10; total time=   4.5s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=9, n_estimators=10; total time=   6.0s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=9, n_estimators=60; total time=  21.7s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=9, n_estimators=60; total time=  26.2s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=9, n_estimators=60; total time=  21.2s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=9, n_estimators=110; total time=  41.1s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=9, n_estimators=110; total time=  45.1s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=9, n_estimators=110; total time=  46.1s
[CV] END boosting_type=dart, learning_rate=0.001, max_depth=9, n_estimators=160; total time= 1.2min
[CV] E

In [None]:
with open(dumps_path / "lgbm_results.dump", "wb") as f:
    pickle.dump(lgbm_results, f)

### Catboost

In [None]:
catboost_hyperparameters = {
    "iterations": range(10, 310, 50),
    "depth": range(3, 10, 3),
    "learning_rate": [0.001, 0.01, 0.1, 1],
}

catboost_results = defaultdict(dict)

for n_months in range(1, 7):
    sequencelength = n_months * TIMESTAMPS_PER_MONTH
    catboost_train_params = TrainConfig(model="catboost", year=2018,
                       n_months=n_months,
                       sequencelength=sequencelength,
                       hyperparameters=catboost_hyperparameters)
    best_model, stats = train(catboost_train_params)
    catbpost_results[n_months]["best_model"] = best_model
    catboost_results[n_months]["stats"] = stats

In [None]:
with open(dumps_path / "catboost_results.dump", "wb") as f:
    pickle.dump(catboost_results, f)

## Deep learning approaches

### Transformer

Code for the model adapted from [BreizhCrops paper](https://arxiv.org/abs/1905.11893).

In [None]:
transformer_results = defaultdict(dict)

for n_months in range(1, 7):
    sequencelength = n_months * TIMESTAMPS_PER_MONTH
    transformer_train_params = TrainConfig(
        epochs=100,
        model="transformer",
        year=2018,
        n_months=n_months,
        sequencelength=sequencelength
    )
    best_model, stats = train(transformer_train_params)
    transformer_results[n_months]["best_model"] = best_model
    transformer_results[n_months]["stats"] = stats

In [None]:
with open(dumps_path / "transformer_results.dump", "wb") as f:
    pickle.dump(transformer_results, f)

### TempCNN

Code adapted from [BreizhCrops paper](https://arxiv.org/abs/1905.11893). Originally the model was introduced in [paper about temporal convolutional neural networks for satellite time series classification](https://arxiv.org/abs/1811.10166).

In [None]:
tempcnn_results = defaultdict(dict)

for n_months in range(1, 7):
    sequencelength = n_months * TIMESTAMPS_PER_MONTH
    tempcnn_train_params = TrainConfig(
        epochs=100,
        model="tempcnn",
        year=2018,
        n_months=n_months,
        sequencelength=sequencelength
    )
    best_model, stats = train(tempcnn_train_params)
    tempcnn_results[n_months]["best_model"] = best_model
    tempcnn_results[n_months]["stats"] = stats

In [None]:
with open(dumps_path / "tempcnn_results.dump", "wb") as f:
    pickle.dump(tempcnn_results, f)

### EarlyRNN

Code for model adapted from the [paper on early classification of time series for crop type mapping](https://arxiv.org/pdf/1901.10681.pdf).

In [None]:
N_MONTHS = 6

earlyrnn_results = defaultdict(dict)

earlyrnn_args = TrainConfig(epochs=100,
                   model="earlyrnn",
                   n_months=N_MONTHS,
                   sequencelength=N_MONTHS * TIMESTAMPS_PER_MONTH)
best_model, stats = train(earlyrnn_args)
earlyrnn_results[N_MONTHS]["best_model"] = best_model
earlyrnn_results[N_MONTHS]["stats"] = stats

In [None]:
with open(dumps_path / "earlyrnn_results.dump", "wb") as f:
    pickle.dump(earlyrnn_results, f)

# Part 2. Models Comparison