In [20]:
import numpy as np
import scipy.io as sio
import random
import os

# 一、导入数据加数据处理阶段

In [21]:
batch_size = 32
train_num = 200
seed = 666
fix_seed = False
savedata = True

In [22]:
# 是否用随机种子
if fix_seed:
    random.seed(seed)

In [23]:
# 数据标准化
def max_min(x):
    return (x-np.min(x))/(np.max(x)-np.min(x))

In [24]:
# 导入数据
def loadData():
    data_dict = sio.loadmat(r"E:\Eric_HSI\hyperspectral_datasets\paviaU.mat")
    data_gt_dict = sio.loadmat(r"E:\Eric_HSI\hyperspectral_datasets\paviaU_gt.mat")
    # startswith 检查字符串是否以 "————" 开头, 取出数据集
    data_name = [t for t in list(data_dict.keys()) if not t.startswith('__')][0]
    data_gt_name = [t for t in list(data_gt_dict.keys()) if not t.startswith('__')][0]
    data = data_dict[data_name]
    data_gt = data_gt_dict[data_gt_name].astype(np.int32)
    # 标准化
    data = max_min(data).astype(np.float32)
    class_num = np.max(data_gt)
    print('DataSet %s shape is %s class_num is %s'%(data_name,data.shape,class_num))
    return data, data_gt

In [25]:
# data, data_gt = loadData()

In [26]:
# 数据分类,将各类别放到以各类别名为键的字典中,这里只保存的坐标位置，并不是像素值

def split_background(removeZeroLabels=False):
    # 这个是未分类版本
    class_num = np.max(data_gt)
    data_pos = {i: [] for i in range(0, 1)}
    print(data_pos)

    for i in range(data_gt.shape[0]):
        for j in range(data_gt.shape[1]):
            if removeZeroLabels:
                if data_gt[i, j]:
                    data_pos[0].append([i, j])
            else:
                data_pos[0].append([i, j])

    return data_pos

In [27]:
# data_pos = split_background(removeZeroLabels=False)

In [28]:
# 取出测试集和训练集,此过程后训练集和测试集为字典，和上面构造data_pos类似

def split_train_test(train_num=200):
    class_num = np.max(data_gt)

    data_pos = {i: [] for i in range(1, class_num + 1)}
    print(data_pos)

    train_pos = {i: [] for i in range(1, class_num + 1)}
    test_pos = {i: [] for i in range(1, class_num + 1)}

    for i in range(data_gt.shape[0]):
        for j in range(data_gt.shape[1]):
            for k in range(1, class_num + 1):
                if data_gt[i, j] == k:
                    data_pos[k].append([i, j])

    for k, v in data_pos.items():
        if len(v)<train_num:
            train_seclect = 15
        else:
            train_seclect = train_num
        train_pos[k] = random.sample(v, int(train_seclect))
        test_pos[k] = [i for i in v if i not in train_pos[k]]
    return train_pos, test_pos

In [29]:
# train_pos, test_pos = split_train_test(train_num=train_num)

In [30]:
# 辅助函数，查看各类别信息
def classinfo(data_pos, train_pos, test_pos):
    data_num = 0
    train_num = 0
    test_num = 0
    train_test_num = 0
    for (k1,v1),(k2,v2) in zip(train_pos.items(), test_pos.items()):
        print('traindata-ID %s: %s; testdata-ID %s: %s'%(k1,len(v1),k2,len(v2)))
        train_num += len(v1)
        test_num += len(v2)
    train_test_num = train_num + test_num
    print('total train %s, total test %s, train_test_num %s'%(train_num, test_num, train_test_num))

    for k,v in data_pos.items():
        data_num += len(v)
    print('total data %s'%data_num)

    return data_num, train_num, test_num, train_test_num

In [31]:
# data_num, train_num, test_num, train_test_num = classinfo(data_pos, train_pos, test_pos)

In [32]:
# 将位置信息转化光谱值信息，即将坐标值转化为其对应的光谱值，此过程得到一个为列表
# 准确的说是列表内嵌套 ndarray 的结构，ndarray有 1 维，个数103
def dict_to_list():
    data_all = []
    data_label_all = []
    train = []
    train_label = []
    test = []
    test_label = []
    for i in range(len(data_pos)):
        for j in range(len(data_pos[i])):
            row,col = data_pos[i][j]
            data_all.append(data[row,col])
            data_label_all.append(i)


    for i in range(1,len(train_pos)+1):   # 9个类
        for j in range(len(train_pos[i])):   # 200个样本
            row,col = train_pos[i][j]
            train.append(data[row,col])   #### 一下子传入103维 ####
            train_label.append(i)
            
    for i in range(1,len(test_pos)+1):
        for j in range(len(test_pos[i])):
            row,col = test_pos[i][j]
            test.append(data[row,col])
            test_label.append(i)
    return data_all, data_label_all, train, train_label, test, test_label

In [33]:
# data_all, data_label_all, train, train_label, test, test_label = dict_to_list()

In [34]:
def savePreprocessedData(path, data, data_label, train, train_label, test, test_label):
    data_path = os.path.join(os.getcwd(), path)
    print(data_path)

    if savedata:
        with open(os.path.join(data_path, 'data.npy'), 'bw') as outfile:
            np.save(outfile, data_all)
        with open(os.path.join(data_path, 'data_label.npy'), 'bw') as outfile:
            np.save(outfile, data_label_all)  

        with open(os.path.join(data_path, 'train.npy'), 'bw') as outfile:
            np.save(outfile, train)
        with open(os.path.join(data_path, 'train_label.npy'), 'bw') as outfile:
            np.save(outfile, train_label)
        
        with open(os.path.join(data_path, 'test.npy'), 'bw') as outfile:
            np.save(outfile, test)
        with open(os.path.join(data_path, 'test_label.npy'), 'bw') as outfile:
            np.save(outfile, test_label)

# 调用以上函数

In [35]:
data, data_gt = loadData()
data_pos = split_background(removeZeroLabels=False)
train_pos, test_pos = split_train_test(train_num=train_num)
data_all, data_label_all, train, train_label, test, test_label = dict_to_list()
# data_num, train_num, test_num, train_test_num = classinfo(data_pos, train_pos, test_pos)

DataSet data shape is (610, 340, 103) class_num is 9
{0: []}
{1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: [], 9: []}


### np.array & np.asarray
- array和asarray都可以将结构数据转化为ndarray，
- 但是主要区别就是当数据源是ndarray时，array仍然会copy出一个副本，占用新的内存，但asarray不会。

In [36]:
# last step: 化为ndarray!
data_all = np.asarray(data_all)
data_label_all = np.asarray(data_label_all)
train = np.asarray(train)
train_label = np.asarray(train_label)
test = np.asarray(test)
test_label = np.asarray(test_label)

In [37]:
data_all.shape, data_label_all.shape, train.shape, test.shape, train_label.shape, test_label.shape

((207400, 103), (207400,), (1800, 103), (40976, 103), (1800,), (40976,))

# 保存数据

In [38]:
if not os.path.exists('paviaU_num_'+ str(train_num) +'_for_squence'):
    os.mkdir('paviaU_num_'+ str(train_num) +'_for_squence')
savePreprocessedData('paviaU_num_'+ str(train_num) +'_for_squence', data_all, data_label_all, train, train_label, test, test_label)

e:\Eric_HSI\hyper_data_preprocess\paviaU_num_200_for_squence
