# Mapping distribution of fungal richness and proportion under climate change scenarios by python
# e.g., fungal saprotrophs

In [None]:
#import python packages for mapping fungal saprotrophs in current
from pyhdf.SD import SD, SDC # hdf4 
import glob
from sklearn.neighbors import KDTree
from datetime import datetime
from multiprocessing import Pool
import xarray as xr
import numpy as np
import pandas as pd
import os,glob
import matplotlib.pyplot as plt
import pprint
import os
import re
import pyproj

from sklearn.ensemble import RandomForestRegressor
# from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
from statsmodels.formula.api import ols
from sklearn.preprocessing import LabelEncoder

from osgeo import gdal
import warnings
warnings.filterwarnings("ignore")
%matplotlib inline


In [None]:
df=pd.read_table('F:/projection_data.csv',sep=',')
# Import environmental variables
x=df[['band1','band2','band3','band4','band10','band11','band12','band15','band18','band19','slope','gpp']]
# Import richness of fungal saprotrophs
saprotrophs=df.iloc[:,[7]]

In [None]:
# Import layers of environmental variables
datapath="F:/current/"
band1=xr.open_rasterio(datapath+'clim/band1.tif')
band2=xr.open_rasterio(datapath+'clim/band2.tif')
band3=xr.open_rasterio(datapath+'clim/band3.tif')
band4=xr.open_rasterio(datapath+'clim/band4.tif')
band10=xr.open_rasterio(datapath+'clim/band10.tif')
band11=xr.open_rasterio(datapath+'clim/band11.tif')
band12=xr.open_rasterio(datapath+'clim/band12.tif')
band15=xr.open_rasterio(datapath+'clim/band15.tif')
band18=xr.open_rasterio(datapath+'clim/band18.tif')
band19=xr.open_rasterio(datapath+'clim/band19.tif')
slope=xr.open_rasterio(datapath+'topo/slope.tif')
gpp=xr.open_rasterio(datapath+'vege/gpp.tif')

In [None]:
band1=band1.values
band2=band2.values
band3=band3.values
band4=band4.values
band10=band10.values
band11=band11.values
band12=band12.values
band15=band15.values
band18=band18.values
band19=band19.values
slope=slope.values
gpp=gpp.values

In [None]:
modeldata=np.dstack([band1.ravel(),
                     band2.ravel(),
                     band3.ravel(),
                     band4.ravel(),
                     band10.ravel(),
                     band11.ravel(),
                     band12.ravel(),
                     band15.ravel(),
                     band18.ravel(),
                     band19.ravel(),
                     slope.ravel(),
                     gpp.ravel()
                     ]).squeeze()

In [None]:
# Define training dataset, test dataset, and tree numbers of random forest regression
x_train,x_test,y_train,y_test=train_test_split(
                                            x,saprotrophs,
                                            test_size=0.2,
                                            random_state=42)

model=RandomForestRegressor(n_estimators=500,n_jobs=-1)
model.fit(x_train,y_train)

In [None]:
# Assess the accuracy (R2) of random forest regression
r2_score(y_test,model.predict(x_test))
r2_score(y_train,model.predict(x_train))
rf=model.predict(modeldata)

In [None]:
# Mapping the gini importance of environmental variables
features=np.array(x_train.columns)
imps_gini=model.feature_importances_
std_gini=np.std([tree.feature_importances_ for tree in model.estimators_],axis=0)
indices_gini=np.argsort(imps_gini)
plt.title('Feature Importance')
plt.barh(range(len(indices_gini)),imps_gini[indices_gini],yerr=std_gini[indices_gini],color='c', align='center')
plt.yticks(range(len(indices_gini)), features[indices_gini])
plt.xlabel('Gini Importance')
plt.savefig('F:/ gini_saprotrohps.pdf',dpi=300)
plt.show()

In [None]:
# Define the export information
def writeTiff(im_data,im_geotrans,im_proj,path):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
        im_bands, im_height, im_width = im_data.shape
    
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
    if(dataset!= None):
        dataset.SetGeoTransform(im_geotrans)
        dataset.SetProjection(im_proj)
    for i in range(im_bands):
        dataset.GetRasterBand(i+1).WriteArray(im_data[i])
    del dataset

In [None]:
# Export the result of fungal saprotrophs richness under current scenarios
band2=r'F:/band2.tif'
in_ds=gdal.Open(band2)

tif_width=in_ds.RasterXSize
tif_height=in_ds.RasterYSize
tif_geotrans=in_ds.GetGeoTransform()
tif_proj=in_ds.GetProjection()
output_data=in_ds.ReadAsArray(0,0,tif_width,tif_height)
savepath='F:/saprotrophs/current/' 
rf=rf.reshape(output_data.shape)
writeTiff(rf,tif_geotrans,tif_proj,savepath+"current"+".tif")