# Pretrained XGB Model Demo

## Load in Libaries

In [1]:
%load_ext autoreload
%autoreload 2
%config IPCompleter.greedy=True
%config IPCompleter.use_jedi=False

In [11]:
import os
import numpy as np
import xgboost as xgb
from sklearn import metrics
from src.preprocessing.esa_compress import compress_esa, decompress_esa
from src.preprocessing.load_landsat_esa import return_xy_npa, y_npa_to_xr, return_x_y_da
from src.visualisation.ani import animate_prediction

## Load in Data and set params

In [4]:
cfd = {
    "start_year_i": 0,
    "mid_year_i": 19,
    "end_year_i": 24,
    "take_esa_coords": True,
    "use_ffil": True,
    "use_mfd": False,
}

x_da, y_da = return_x_y_da(
    take_esa_coords=cfd["take_esa_coords"],
    use_ffil=cfd["use_ffil"],
    use_mfd=cfd["use_mfd"]
)  # load preprocessed data from netcdfs

['take_esa_coords_True_use_mfd_False_use_ffil_True_x.nc', 'take_esa_coords_True_use_mfd_False_use_ffil_True_y.nc']
x/y values premade. Reusing them.
'return_x_y_da'  1.05360 s



### Look at X values

In [5]:
x_da

### Look at Y values

In [6]:
y_da

## Load in Model

In [40]:
# import shutil
from src.constants import GWS_DATA_DIR
direc = GWS_DATA_DIR / "xgb-demo"
# if not os.path.exists(direc):
#    os.mkdir(direc)
# original_model_file = "/home/users/sithom/gtc-biodiversity/wandb/run-20210304_012917-1u5o038w/files/lyric-haze-30_xgb.model"
model_file = direc / "demo_xgb.model"
video_name = "test_joint_val.mp4"
# shutil.copy(original_model_file, model_file)

PosixPath('/gws/nopw/j04/ai4er/guided-team-challenge/2021/biodiversity/xgb-demo/demo_xgb.model')

In [8]:
bst = xgb.Booster({'nthread': 4})  # init model
bst.load_model(model_file)  # load data

## Predict labels with model

In [9]:
x_all, y_all = return_xy_npa(
    x_da, y_da, year=range(cfd["start_year_i"], cfd["end_year_i"])
)  # all data as numpy.
xg_all = xgb.DMatrix(
    x_all, label=compress_esa(y_all)
)  # pass all data to xgb data matrix
y_pr_all = decompress_esa(
    bst.predict(xg_all)
)  # predict whole time period using model
y_pr_da = y_npa_to_xr(
    y_pr_all, y_da.isel(year=range(cfd["start_year_i"], cfd["end_year_i"]))
)  # transform full prediction to dataarray.
print("\n Finished model predict")

'return_xy_npa'  61.83221 s

'y_npa_to_xr'  0.01585 s


 Finished model predict


### Animate the results

In [41]:
animate_prediction(
    x_da.isel(year=range(cfd["start_year_i"], cfd["end_year_i"])),
    y_da.isel(year=range(cfd["start_year_i"], cfd["end_year_i"])),
    y_pr_da,
    video_path=str(video_name),
)  # animate prediction vs inputs.


test_joint_val.mp4:   0%|          | 0/24 [00:00<?, ?it/s][A
test_joint_val.mp4:   4%|▍         | 1/24 [00:01<00:39,  1.71s/it][A
test_joint_val.mp4:   8%|▊         | 2/24 [00:03<00:34,  1.56s/it][A
test_joint_val.mp4:  12%|█▎        | 3/24 [00:04<00:31,  1.51s/it][A
test_joint_val.mp4:  17%|█▋        | 4/24 [00:06<00:29,  1.50s/it][A
test_joint_val.mp4:  21%|██        | 5/24 [00:07<00:28,  1.53s/it][A
test_joint_val.mp4:  25%|██▌       | 6/24 [00:09<00:27,  1.50s/it][A
test_joint_val.mp4:  29%|██▉       | 7/24 [00:10<00:25,  1.48s/it][A
test_joint_val.mp4:  33%|███▎      | 8/24 [00:11<00:23,  1.47s/it][A
test_joint_val.mp4:  38%|███▊      | 9/24 [00:13<00:21,  1.46s/it][A
test_joint_val.mp4:  42%|████▏     | 10/24 [00:15<00:20,  1.50s/it][A
test_joint_val.mp4:  46%|████▌     | 11/24 [00:16<00:19,  1.48s/it][A
test_joint_val.mp4:  50%|█████     | 12/24 [00:17<00:17,  1.47s/it][A
test_joint_val.mp4:  54%|█████▍    | 13/24 [00:19<00:16,  1.47s/it][A
test_joint_val.mp4:  58

Video test_joint_val.mp4 made.
'animate_prediction'  37.06147 s



In [42]:
from IPython.display import Video
Video(video_name)

In [43]:
print("Classification accuracy: {}".format(metrics.accuracy_score(y_all, y_pr_all)))

Classification accuracy: 0.6008437155484884


In [44]:
print(metrics.classification_report(y_all, y_pr_all))

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0       0.61      0.01      0.02     32688
          10       0.65      0.88      0.74   6657454
          11       0.55      0.01      0.02   1047529
          30       0.60      0.00      0.00    571329
          40       0.55      0.00      0.01    387972
          60       0.39      0.44      0.41   1325760
          61       0.96      0.02      0.05      5328
          70       0.61      0.80      0.69   4535517
          80       0.00      0.00      0.00        48
          90       0.34      0.20      0.25   1348324
         100       0.45      0.00      0.00    623258
         110       0.99      0.14      0.25      2180
         130       0.57      0.02      0.03    335008
         150       0.81      0.26      0.39      4226
         160       0.48      0.15      0.23    204273
         180       0.77      0.11      0.19     43710
         190       0.57      0.05      0.09    191940
         200       0.88    

  _warn_prf(average, modifier, msg_start, len(result))
