In [None]:
'''
Re-implementation for the paper "HybridSN: Exploring 3-D–2-D CNN Feature Hierarchy for Hyperspectral Image Classification"
The official implementation is in https://github.com/gokriznastic/HybridSN
'''

############ IMPORTS ####################

import sys
sys.path.append("./../")
import os
import numpy as np
import torch.utils.data as dataf
from scipy import io
from sklearn.decomposition import PCA
import tensorflow as tf
import numpy as np
from sklearn.metrics import confusion_matrix, accuracy_score, cohen_kappa_score
import torch
from operator import truediv
import record
import tensorflow.keras as keras
import keras.backend as K
from keras.layers import Conv2D, Conv3D, Flatten, Dense, Reshape, BatchNormalization
from keras.layers import Dropout, Input
from keras.utils import np_utils
from keras.callbacks import ModelCheckpoint
from keras.models import load_model
from keras.losses import categorical_crossentropy
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.utils import to_categorical as keras_to_categorical


############# CONFIGS ##########################

# tf.config.experimental_run_functions_eagerly(True)
# os.environ["CUDA_VISIBLE_DEVICES"]="3"

datasetNames = ["Trento"]
testSizeNumber = 5000
patchsize1 = 11
patchsize2 = 11
batchsize = 64
EPOCH = 200
LR = 0.001

def AA_andEachClassAccuracy(confusion_matrix):
    counter = confusion_matrix.shape[0]
    list_diag = np.diag(confusion_matrix)
    list_raw_sum = np.sum(confusion_matrix, axis=1)
    each_acc = np.nan_to_num(truediv(list_diag, list_raw_sum))
    average_acc = np.mean(each_acc)
    return each_acc*100, average_acc*100

def get_model_compiled(NC,Classes):
    ## input layer
    input_layer = Input((patchsize1, patchsize2, NC, 1))

    ## convolutional layers
    conv_layer1 = Conv3D(filters=8, kernel_size=(3, 3, 7), activation='relu')(input_layer)
    conv_layer2 = Conv3D(filters=16, kernel_size=(3, 3, 5), activation='relu')(conv_layer1)
    conv_layer3 = Conv3D(filters=32, kernel_size=(3, 3, 3), activation='relu')(conv_layer2)

    conv3d_shape = conv_layer3.shape
    conv_layer3 = Reshape((conv3d_shape[1], conv3d_shape[2], conv3d_shape[3]*conv3d_shape[4]))(conv_layer3)
    conv_layer4 = Conv2D(filters=64, kernel_size=(3,3), activation='relu')(conv_layer3)

    flatten_layer = Flatten()(conv_layer4)

    ## fully connected layers
    dense_layer1 = Dense(units=256, activation='relu')(flatten_layer)
    dense_layer1 = Dropout(0.4)(dense_layer1)
    dense_layer2 = Dense(units=128, activation='relu')(dense_layer1)
    dense_layer2 = Dropout(0.4)(dense_layer2)
    output_layer = Dense(units=Classes, activation='softmax')(dense_layer2)

    clf = Model(inputs=input_layer, outputs=output_layer)

    adam = Adam(lr=LR, decay=1e-06)
    clf.compile(loss='categorical_crossentropy', optimizer=adam, metrics=['accuracy'])
    return clf


for datasetName in datasetNames:
        try:
            os.makedirs(datasetName)
        except FileExistsError:
            pass
        
        print("----------------------------------Training for ",datasetName,"---------------------------------------------")

        try:
            os.makedirs(datasetName)
        except FileExistsError:
            pass
        # Train data
        HSI = io.loadmat('./../'+datasetName+'11x11/HSI_Tr.mat')
        TrainPatch = HSI['Data']
        TrainPatch = TrainPatch.astype(np.float32)
        NC = TrainPatch.shape[3] # NC is number of bands

        label = io.loadmat('./../'+datasetName+'11x11/TrLabel.mat')
        TrLabel = label['Data']

        # Test data
        HSI = io.loadmat('./../'+datasetName+'11x11/HSI_Te.mat')
        TestPatch = HSI['Data']
        TestPatch = TestPatch.astype(np.float32)

        label = io.loadmat('./../'+datasetName+'11x11/TeLabel.mat')
        TsLabel = label['Data']


        TrainPatch1 = torch.from_numpy(TrainPatch)
    #         TrainPatch1 = TrainPatch1.permute(0,3,1,2)
        TrainLabel1 = torch.from_numpy(TrLabel)-1
        TrainLabel1 = TrainLabel1.long()


        TestPatch1 = torch.from_numpy(TestPatch)
    #         TestPatch1 = TestPatch1.permute(0,3,1,2)
        TestLabel1 = torch.from_numpy(TsLabel)-1
        TestLabel1 = TestLabel1.long()

        Classes = len(np.unique(TrainLabel1))
        
        print("Train data shape = ", TrainPatch1.shape)
        print("Train label shape = ", TrainLabel1.shape)
        print("Test data shape = ", TestPatch1.shape)
        print("Test label shape = ", TestLabel1.shape)
        print("Num classes = ", Classes)
        
        
        KAPPA = []
        OA = []
        AA = []
        ELEMENT_ACC = np.zeros((3, Classes))
        tf.compat.v1.keras.backend.clear_session()
        config = tf.compat.v1.ConfigProto( device_count = {'GPU': 0} ) 
        config.gpu_options.allow_growth = True
        sess = tf.compat.v1.Session(config=config) 
        tf.compat.v1.keras.backend.set_session(sess)
        g = tf.Graph()
        with g.as_default():
            for iter in range(3):
        #             tf.compat.v1.set_random_seed(43)
        #             np.random.seed(43)

                clf = get_model_compiled(NC,Classes)
                valdata = (TestPatch1.unsqueeze(-1).cpu().detach().numpy(), keras_to_categorical(TestLabel1.reshape(-1).cpu().detach().numpy(), Classes))
                clf.fit(TrainPatch1.unsqueeze(-1).cpu().detach().numpy(), keras_to_categorical(TrainLabel1.reshape(-1).cpu().detach().numpy(), Classes),
                                    batch_size=batchsize,
                                    epochs=EPOCH,
                                    verbose=True,
                                    validation_data=valdata,
                                    callbacks = [ModelCheckpoint(datasetName+"/best_model_HSIOnly.h5", monitor='val_accuracy', verbose=0, save_best_only=True)])



                clf = load_model(datasetName+"/best_model_HSIOnly.h5")
                pred_y = np.argmax(clf.predict(TestPatch1.unsqueeze(-1).cpu().detach().numpy()), axis=1)

                y_test = TestLabel1.reshape(-1).cpu().detach().numpy()
                oa = accuracy_score(y_test, pred_y)*100
                confusion = confusion_matrix(y_test, pred_y)
                each_acc, aa = AA_andEachClassAccuracy(confusion)
                kappa = cohen_kappa_score(y_test, pred_y)*100
                KAPPA.append(kappa)
                OA.append(oa)
                AA.append(aa)
                ELEMENT_ACC[iter, :] = each_acc

        print("--------" + datasetName + " Training Finished-----------")
        record.record_output(OA, AA, KAPPA, ELEMENT_ACC,'./' + datasetName +'/HybridSN_Report_' + datasetName +'.txt')