In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import sklearn
import xgboost as xgb
from xgboost import XGBClassifier
from sklearn.metrics import accuracy_score
from sklearn import metrics
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score
from sklearn.metrics import precision_score
import warnings
warnings.filterwarnings('ignore')

import seaborn as sns

from osgeo import gdal

from sklearn.svm import SVC
import joblib

ModuleNotFoundError: No module named 'xgboost'

In [None]:
def read_tif(input_path):
    dataset = gdal.Open(input_path)
    width = dataset.RasterXSize
    height = dataset.RasterYSize
    gdal_array = dataset.ReadAsArray(0, 0, width, height)
    band = dataset.RasterCount
    proj = dataset.GetProjection()
    geotrans = dataset.GetGeoTransform()
    
    return gdal_array, width, height, band, proj, geotrans, dataset


def re_label(data_y):
    data_y[data_y<0] = 0
    data_y[data_y==11] = 1
    data_y[data_y==21] = 2
    data_y[data_y==22] = 3
    data_y[data_y==23] = 4
    data_y[data_y==24] = 5
    data_y[data_y==31] = 6
    data_y[data_y==41] = 7
    data_y[data_y==42] = 8
    data_y[data_y==43] = 9
    data_y[data_y==52] = 10
    data_y[data_y==71] = 11
    data_y[data_y==81] = 12
    data_y[data_y==82] = 13
    data_y[data_y==90] = 14
    data_y[data_y==95] = 15
    
    return data_y


def write_geotiff(filename, arr, projection, geotransform):
    if arr.dtype == np.float32:
        arr_type = gdal.GDT_Float32
    else:
        arr_type = gdal.GDT_Int32

    driver = gdal.GetDriverByName("GTiff")
    out_ds = driver.Create(filename, arr.shape[1], arr.shape[0], 1, arr_type)
    out_ds.SetProjection(projection)
    out_ds.SetGeoTransform(geotransform)
    band = out_ds.GetRasterBand(1)
    band.WriteArray(arr)
    band.FlushCache()
    band.ComputeStatistics(False)
    
    
def ohe(cluster_label:np.ndarray, k:int):
    m, n = cluster_label.shape
    res = np.zeros((m,n,k))
    for i in range(m):
        for j in range(n):
            res[i][j][cluster_label[i][j]] = 1
    return res


def NDVI_calculation(gdal_array):
    red_band = gdal_array[:, :, 2]
    nir_band = gdal_array[:, :, 3]
    
    (height, weight) = red_band.shape
    
    ndvi_band = (nir_band - red_band) / (nir_band + red_band)
    ndvi_band = ndvi_band.reshape(height, weight, 1)
    
    gdal_array_process = np.concatenate((gdal_array, ndvi_band), axis=2)
    return gdal_array_process


def run_with_clustering(num_clusters=None):
    
    if (num_clusters == 5) or (num_clusters == 7) or (num_clusters == 15):
        cluster_labels = joblib.load('./data/k_cluster_labels.save')
        clusters = ohe(cluster_labels[num_clusters], num_clusters)
    
        gdal_array, width, height, band, proj, geotrans, dataset = read_tif('./data/landsat_extract.tif')
        image = np.rollaxis(gdal_array, 0, 3)
        
        image = NDVI_calculation(image)
        
        image_seg = np.concatenate((image, clusters), axis=2)
        data_x = np.reshape(image_seg, (width * height, band+num_clusters+1))
    
        return data_x, dataset
    else:
        gdal_array, width, height, band, proj, geotrans, dataset = read_tif('./data/landsat_extract.tif')
        image = np.rollaxis(gdal_array, 0, 3)
        image = NDVI_calculation(image)
        data_x = np.reshape(image, (width * height, band+1))
        
        return data_x, dataset

    
def get_scores(y_label, ypred):
    scores = dict()
    scores['accuracy'] = accuracy_score(y_label, ypred)
    scores['f1_micro'] = f1_score(y_label, ypred, average='micro')
    scores['f1_macro'] = f1_score(y_label, ypred, average='macro')
    scores['precision_micro'] = precision_score(y_label, ypred, average='micro') 
    scores['precision_macro'] = precision_score(y_label, ypred, average='macro')

    return scores


def export_tiff(data_x, model, height, width, num_clusters):
    
    dtest_whole = xgb.DMatrix(data_x)
    ypred_whole = model.predict(dtest_whole)
    ypred_image = np.reshape(ypred_whole, (height, width))

    write_geotiff('./output/predict_k_'+str(num_clusters)+'.tiff', ypred_image, proj, geotrans)
    
    print('finished.')

In [None]:
data_x, dataset = run_with_clustering(7)
data_x.shape

## Label

In [None]:
labelset = gdal.Open(r'./data/nlcd_extract.tif')
width = dataset.RasterXSize
height = dataset.RasterYSize
band = dataset.RasterCount
proj = dataset.GetProjection()
geotrans = dataset.GetGeoTransform()
label_array = labelset.ReadAsArray(0, 0, width, height)  # 获取数据

data_y = np.reshape(label_array, (width * height))

data_y = re_label(data_y)
data_y.shape

## Model

In [None]:
X_train, X_test, y_train, y_test = train_test_split(data_x, data_y, test_size=0.9, stratify=data_y)

In [None]:
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.fit_transform(X_test)

del X_train
del X_test

In [None]:
dtrain = xgb.DMatrix(X_train_scaled, label=y_train)
dtest = xgb.DMatrix(X_test_scaled, label=y_test)

In [None]:
del X_train_scaled
del X_test_scaled

In [None]:
params = {'booster': 'gbtree',
          'objective': 'multi:softmax', 
          'num_class': 16,
          'eval_metric': 'auc',  
          'max_depth': 7,
          'lambda': 15,
          'subsample': 0.75,
          'colsample_bytree': 0.75,
          'min_child_weight': 1,
          'eta': 0.025,
          'seed': 0,
          'nthread': 8,
          'verbosity': 1,
          'gamma': 0.15,
          'learning_rate': 0.01}

watchlist = [(dtrain, 'train')]

In [None]:
model = xgb.train(params, dtrain, num_boost_round=30, evals=watchlist)

In [None]:
model.save_model('xgboost.model')

clf = xgb.XGBClassifier()
booster = xgb.Booster()
booster.load_model('xgboost.model')
clf._Booster = booster
ypred = model.predict(dtest)

In [None]:
score = get_scores(y_test, ypred)
score

In [None]:
metric = metrics.confusion_matrix(y_test, ypred)
sns.heatmap(metric[1:, 1:], cmap='Blues')

In [None]:
xgb.plot_importance(model)

## Export Tiff

In [None]:
export_tiff(data_x, model, height, width, 7)