In [None]:
import numpy as np
import scipy.io as sio
import random
import os

In [None]:
class HSIDataset:
    def __init__(self, data, label, test_ratio=0.5, seed=2333):
        self.data = data
        self.label = label
        self.test_ratio = test_ratio
        self.seed = seed
        self.X_train, self.y_train, self.X_test, self.y_test = [], [], [], []
        
        if data.dtype != np.float32:
            data = data.astype(np.float32)
        if label.dtype != int:
            label = label.astype(int)
        
        self.m, self.n, self.bands = data.shape
        self.nc = np.max(label)
        self.data = self.normalize(data)
        
        self.get_data_corr()
        self.trainTestSplit()
        
        for c in range(1, self.nc+1):
            for k in range(len(self.train_pos[c])):
                i, j = self.train_pos[c][k]
                self.X_train.append(self.data[i, j])
                self.y_train.append(c)

        for c in range(1, self.nc+1):
            for k in range(len(self.test_pos[c])):
                i, j = self.test_pos[c][k]
                self.X_test.append(self.data[i, j])
                self.y_test.append(c)
    
    def normalize(self, data):
        data -= np.min(data)
        data /= np.max(data)
        return data
    
    def get_data_corr(self):
        self.data_pos = {}
        for c in range(1, self.nc+1):
            self.data_pos[c] = []
            for i in range(self.m):
                for j in range(self.n):
                    if self.label[i, j] == c:
                        self.data_pos[c].append((i,j))
    
    def trainTestSplit(self):
        random.seed(self.seed)
        self.train_pos = {}
        self.test_pos = {}
        for c in range(1, self.nc+1):
            self.train_pos[c], self.test_pos[c] = [], []
            total_num = len(self.data_pos[c])
            test_num = int(total_num * self.test_ratio)
            test_ind = random.sample(range(total_num), test_num)
            for k in range(total_num):
                if k not in test_ind:
                    self.train_pos[c].append(self.data_pos[c][k])
                else:
                    self.test_pos[c].append(self.data_pos[c][k])

In [None]:
def get_confusion_matrix(clf, data, label, nc):
    cm = np.zeros((nc, nc))
    for n in range(len(data)):
        sample = data[n].reshape((-1, len(data[n])))
        r = clf.predict(sample)
        cm[r-1, label[n]-1] += 1
    return cm

In [None]:
def kappa(cm):
    kk = 0
    for i in range(cm.shape[0]):
        kk += np.sum(cm[i]) * np.sum(cm[:, i])
    pe = kk / (np.sum(cm))**2
    pa = np.trace(cm) / np.sum(cm)
    kappa = (pa - pe) / (1 - pe)
    return kappa

In [None]:
def report(cm, save_path):
    acc_list = []
    total_right = np.trace(cm)
    total_test = np.sum(cm)
    overal_acc = total_right / total_test
    kap = kappa(cm)
    for i in range(len(cm)):
        acc = cm[i, i] / np.sum(cm[:, i])
        acc_list.append(acc)
        print('class', i+1, ':(', cm[i, i], '/', np.sum(cm[:, i]), ')', acc)
    ave_acc = np.mean(acc_list)
    print('confusion matrix:')
    print(np.int_(cm))
    print('total right num:')
    print(total_right)
    print('total test num:')
    print(total_test)
    print('overal acc:')
    print(overal_acc)
    print('average acc:')
    print(ave_acc)
    print('kappa:')
    print(kap)
    sio.savemat(os.path.join('result', 'result.mat'), {'oa':overal_acc, 
                                                       'aa':ave_acc, 'kappa':kappa, 'acc_list':acc_list})

In [None]:
def get_pred_map(clf, data, label):
    pred_map = np.zeros_like(label)
    m, n = pred_map.shape
    for i in range(m):
        for j in range(n):
            if label[i][j] != 0:
                sample = data[i][j]
                pred_map[i][j] = clf.predict(sample.reshape(-1, len(sample)))
    fig, _ = plt.subplots()
    height, width = label.shape
    fig.set_size_inches(width/100.0, height/100.0)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.subplots_adjust(top=1,bottom=0,left=0,right=1,hspace=0,wspace=0)
    plt.axis('off')
    plt.axis('equal')
    plt.pcolor(pred_map, cmap='jet')
    plt.savefig(os.path.join('result', 'pred_map.png'),format='png',dpi=600)#bbox_inches='tight',pad_inches=0)
    plt.close()
    print('decode map get finished')

In [None]:
def get_gt_map(label):
    fig, _ = plt.subplots()
    height, width = label.shape
    fig.set_size_inches(width/100.0, height/100.0)
    plt.gca().xaxis.set_major_locator(plt.NullLocator())
    plt.gca().yaxis.set_major_locator(plt.NullLocator())
    plt.subplots_adjust(top=1,bottom=0,left=0,right=1,hspace=0,wspace=0)
    plt.axis('off')
    plt.axis('equal')
    plt.pcolor(label, cmap='jet')
    plt.savefig(os.path.join('result', 'gt_map.png'),format='png',dpi=600)#bbox_inches='tight',pad_inches=0)
    plt.close()
    print('gt map get finished')