In [41]:
import pickle
import os, glob
import rasterio

import xgboost as xgb
import numpy as np
import pandas as pd
from tqdm import tqdm
from rasterio.windows import Window

In [42]:
def predict_window(df_data, model):
    df_data = xgb.DMatrix(df_data)
    X_predict = model.predict(df_data)
    return X_predict.astype('uint16')

In [43]:
def convert_img_to_input_model(array, list_number_band):
    i = 0
    dfObj = pd.DataFrame()
    for band in array:
        band = band.flatten()
        name_band = f"band {list_number_band[i]}"
        dfObj[name_band] = band
        i+=1
    return dfObj

In [44]:
def predict_one_window(array_window, model, list_number_band):
    shape_predict = array_window.shape[1:]
    df_window = convert_img_to_input_model(array_window, list_number_band)
    x_predict = predict_window(df_window, model)
    return np.array([x_predict.reshape(shape_predict)])

In [45]:
def write_window_many_chanel(output_ds, arr_c, s_h, e_h ,s_w, e_w, sw_w, sw_h, size_w_crop, size_h_crop):
    for c, arr in enumerate(arr_c):
        output_ds.write(arr[s_h:e_h,s_w:e_w],window = Window(sw_w, sw_h, size_w_crop, size_h_crop), indexes= c + 1)

In [46]:
def predict_big(out_fp_predict, fp_img_large, input_size, crop_size, model, list_number_band):
    with rasterio.open(fp_img_large) as src:
        h,w = src.height,src.width
        source_crs = src.crs
        source_transform = src.transform
        dtype_or = src.dtypes
        num_band = src.count
        
    with rasterio.open(out_fp_predict, 'w', driver='GTiff',
                                height = h, width = w,
                                count=1, dtype=dtype_or[0],
                                crs=source_crs,
                                transform=source_transform,
                                nodata=0,
                                compress='lzw') as output_ds:
        output_ds = np.empty((1,h,w))
        
        
        
    padding = int((input_size - crop_size)/2)
    list_weight = list(range(0, w, crop_size))
    list_hight = list(range(0, h, crop_size))


    src = rasterio.open(fp_img_large)

    with rasterio.open(out_fp_predict,"r+") as output_ds:
        with tqdm(total=len(list_hight)*len(list_weight)) as pbar:
            for start_h_org in list_hight:
                for start_w_org in list_weight:
                    # print('join')
                    # vi tri bat dau
                    h_crop_start = start_h_org - padding
                    w_crop_start = start_w_org - padding
                    # kich thuoc
                    tmp_img_size_model = np.zeros((num_band, input_size,input_size))
                    # truong hop 0 0
                    if h_crop_start < 0 and w_crop_start < 0:
                        # continue
                        h_crop_start = 0
                        w_crop_start = 0
                        size_h_crop = crop_size + padding
                        size_w_crop = crop_size + padding
                        img_window_crop  = src.read(window=Window(w_crop_start, h_crop_start, size_w_crop, size_h_crop))
                        tmp_img_size_model[:, padding:, padding:] = img_window_crop
                        img_predict = predict_one_window(tmp_img_size_model, model, list_number_band) + 1
                        write_window_many_chanel(output_ds, img_predict, padding, crop_size + padding, padding, crop_size + padding, 
                                                                        start_w_org, start_h_org, crop_size, crop_size)
                    
                    # truong hop h = 0 va w != 0
                    elif h_crop_start < 0:
                        h_crop_start = 0
                        size_h_crop = crop_size + padding
                        size_w_crop = min(crop_size + 2*padding, w - start_w_org + padding)
                        img_window_crop  = src.read(window=Window(w_crop_start, h_crop_start, size_w_crop, size_h_crop))
                        
                        if size_w_crop == w - start_w_org + padding:
                            end_c_index_w =  size_w_crop
                            tmp_img_size_model[:,padding:,:end_c_index_w] = img_window_crop
                        else:
                            end_c_index_w = crop_size + padding
                            tmp_img_size_model[:, padding:,:] = img_window_crop
                        img_predict = predict_one_window(tmp_img_size_model, model, list_number_band) + 1
                        write_window_many_chanel(output_ds, img_predict, padding, crop_size + padding ,padding, end_c_index_w, 
                                                    start_w_org, start_h_org,  min(crop_size, w - start_w_org), crop_size)
                    
                    # Truong hop w = 0, h!=0 
                    elif w_crop_start < 0:
                        w_crop_start = 0
                        size_w_crop = crop_size + padding
                        size_h_crop = min(crop_size + 2*padding, h - start_h_org + padding)
                        img_window_crop  = src.read(window=Window(w_crop_start, h_crop_start, size_w_crop, size_h_crop))
                        
                        if size_h_crop == h - start_h_org + padding:
                            end_c_index_h =  size_h_crop
                            tmp_img_size_model[:,:end_c_index_h,padding:] = img_window_crop
                        else:
                            end_c_index_h = crop_size + padding
                            tmp_img_size_model[:,:, padding:] = img_window_crop
                        img_predict = predict_one_window(tmp_img_size_model, model, list_number_band) + 1
                        write_window_many_chanel(output_ds, img_predict, padding, end_c_index_h, padding, crop_size + padding, 
                                                    start_w_org, start_h_org, crop_size, min(crop_size, h - start_h_org))
                        
                    # Truong hop ca 2 deu khac khong
                    else:
                        size_w_crop = min(crop_size +2*padding, w - start_w_org + padding)
                        size_h_crop = min(crop_size +2*padding, h - start_h_org + padding)
                        img_window_crop  = src.read(window=Window(w_crop_start, h_crop_start, size_w_crop, size_h_crop))
                        # print(img_window_crop.shape, size_w_crop, size_h_crop)
                        if size_w_crop < (crop_size + 2*padding) and size_h_crop < (crop_size + 2*padding):
                            print(img_window_crop.shape, size_w_crop, size_h_crop)
                            end_c_index_h = size_h_crop
                            end_c_index_w = size_w_crop
                            tmp_img_size_model[:,:end_c_index_h,:   end_c_index_w] = img_window_crop
                        elif size_w_crop < (crop_size + 2*padding):
                            end_c_index_h = crop_size + padding
                            end_c_index_w = size_w_crop
                            tmp_img_size_model[:,:,:end_c_index_w] = img_window_crop
                        elif size_h_crop < (crop_size + 2*padding):
                            end_c_index_w = crop_size + padding
                            end_c_index_h = size_h_crop
                            tmp_img_size_model[:,:end_c_index_h,:] = img_window_crop
                        else:
                            end_c_index_w = crop_size + padding
                            end_c_index_h = crop_size + padding
                            tmp_img_size_model[:,:,:] = img_window_crop
                        img_predict = predict_one_window(tmp_img_size_model, model, list_number_band) + 1 
                        write_window_many_chanel(output_ds, img_predict, padding, end_c_index_h, padding, end_c_index_w, 
                                                    start_w_org, start_h_org, min(crop_size, w - start_w_org), min(crop_size, h - start_h_org))
                    pbar.update()
        output_ds.write_colormap(
                    1, {
                        1: (255, 0, 0, 255),
                        2: (255,255,0, 255),
                        3:(128,0,0,255),
                        4:(0,255,0,255),
                        5:(0,128,0,255),
                        6:(0,0,255,255) })

## Predict Mongolia

In [50]:
in_img_dir = r"E:\WORK\Mongodia\Data\Img"
model_path = r"E:\WORK\Mongodia\pixel_base\model_5000_v2_num_round_100_max_depth7_7Band.model"

out_dir_predict = os.path.join(in_img_dir + 'predict_xgboost' + os.path.basename(model_path)[:-6])
if not os.path.exists(out_dir_predict):
    os.makedirs(out_dir_predict)
print(out_dir_predict)

crop_size = 900
input_size = 1000
list_number_band = [1,2,3,4,5,6,7]

model = xgb.Booster({'nthread': 20})
model.load_model(model_path)

list_fp_img = glob.glob(os.path.join(in_img_dir, '*.tif'))
for fp_img_large in list_fp_img:
    out_fp_predict = os.path.join(out_dir_predict, os.path.basename(fp_img_large))
    predict_big(out_fp_predict, fp_img_large, input_size, crop_size, model, list_number_band)
    


E:\WORK\Mongodia\Data\Imgpredict_xgboostmodel_5000_v2_num_round_100_max_depth7_7Band


 99%|█████████▉| 80/81 [01:46<00:01,  1.22s/it]

(7, 861, 751) 751 861


100%|██████████| 81/81 [01:47<00:00,  1.33s/it]
 99%|█████████▉| 80/81 [01:49<00:01,  1.29s/it]

(7, 861, 751) 751 861


100%|██████████| 81/81 [01:50<00:00,  1.37s/it]
 99%|█████████▉| 80/81 [01:49<00:01,  1.29s/it]

(7, 861, 751) 751 861


100%|██████████| 81/81 [01:50<00:00,  1.36s/it]
 99%|█████████▉| 80/81 [01:54<00:01,  1.30s/it]

(7, 861, 761) 761 861


100%|██████████| 81/81 [01:55<00:00,  1.42s/it]
 99%|█████████▉| 80/81 [01:53<00:01,  1.30s/it]

(7, 871, 761) 761 871


100%|██████████| 81/81 [01:54<00:00,  1.42s/it]
 99%|█████████▉| 80/81 [01:50<00:01,  1.27s/it]

(7, 871, 761) 761 871


100%|██████████| 81/81 [01:52<00:00,  1.38s/it]
 99%|█████████▉| 80/81 [01:46<00:01,  1.25s/it]

(7, 861, 761) 761 861


100%|██████████| 81/81 [01:47<00:00,  1.33s/it]
 99%|█████████▉| 80/81 [01:46<00:01,  1.28s/it]

(7, 861, 751) 751 861


100%|██████████| 81/81 [01:48<00:00,  1.33s/it]
 99%|█████████▉| 80/81 [01:46<00:01,  1.22s/it]

(7, 821, 661) 661 821


100%|██████████| 81/81 [01:47<00:00,  1.32s/it]
 99%|█████████▉| 80/81 [01:46<00:01,  1.25s/it]

(7, 851, 751) 751 851


100%|██████████| 81/81 [01:47<00:00,  1.33s/it]
 99%|█████████▉| 80/81 [01:45<00:01,  1.27s/it]

(7, 851, 761) 761 851


100%|██████████| 81/81 [01:46<00:00,  1.32s/it]
 99%|█████████▉| 80/81 [01:46<00:01,  1.25s/it]

(7, 851, 761) 761 851


100%|██████████| 81/81 [01:47<00:00,  1.33s/it]
 99%|█████████▉| 80/81 [01:46<00:01,  1.25s/it]

(7, 851, 761) 761 851


100%|██████████| 81/81 [01:47<00:00,  1.33s/it]
