In [1]:
import scipy.io as sio
import numpy as np
import random
import tensorflow as tf

In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')
print(gpus)
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


# 定义一些常量

In [3]:
window_size = 9   # 创建的数据立方体大小
batch_size = 32
train_num = 200
seed = 666
fix_seed = False
savedata = True

name1 = "paviaU.mat"
name1_gt = "paviaU_gt.mat"

name2 = "Salinas_corrected.mat"
name2_gt = "Salinas_gt.mat"

# 定义数据处理所需要的函数

In [4]:
# 数据标准化
def max_min(x):
    return (x-np.min(x))/(np.max(x)-np.min(x))

In [5]:
def loadData():
    data_dict = sio.loadmat(r"E:\Eric_HSI\hyperspectral_datasets\Salinas_corrected.mat")
    data_gt_dict = sio.loadmat(r"E:\Eric_HSI\hyperspectral_datasets\Salinas_gt.mat")
    # startswith 检查字符串是否以 "————" 开头, 取出数据集
    data_name = [t for t in list(data_dict.keys()) if not t.startswith('__')][0]
    data_gt_name = [t for t in list(data_gt_dict.keys()) if not t.startswith('__')][0]
    data = data_dict[data_name]
    data_gt = data_gt_dict[data_gt_name].astype(np.int32)
    # 标准化
    data = max_min(data).astype(np.float32)
    class_num = np.max(data_gt)
    print('DataSet %s shape is %s class_num is %s'%(data_name,data.shape,class_num))
    return data, data_gt

In [6]:
# data, data_gt = loadData()

In [7]:
# 是否用随机种子
if fix_seed:
    random.seed(seed)

In [8]:
# 数据分类,将各类别放到以各类别名为键的字典中,这里只保存的坐标位置，并不是像素值

def split_background(removeZeroLabels=False):
    # 这个是未分类版本
    class_num = np.max(data_gt)
    data_pos = {i: [] for i in range(0, 1)}
    print(data_pos)

    for i in range(data_gt.shape[0]):
        for j in range(data_gt.shape[1]):
            if removeZeroLabels:
                if data_gt[i, j]:
                    data_pos[0].append([i, j])
            else:
                data_pos[0].append([i, j])

    return data_pos

In [9]:
# data_pos = split_background(removeZeroLabels=False)

In [10]:
# data_pos = split_background(removeZeroLabels=True)

In [11]:
# 划分训练集和测试集,将各类别放到以各类别名为键的训练集和测试集字典中，这里只保存的坐标位置，并不是像素值
# 在这里已经把标签为 0 的背景给删除了
def split_train_test(train_num=200):
    class_num = np.max(data_gt)

    data_pos = {i: [] for i in range(1, class_num + 1)}
    print(data_pos)

    train_pos = {i: [] for i in range(1, class_num + 1)}
    test_pos = {i: [] for i in range(1, class_num + 1)}

    for i in range(data_gt.shape[0]):
        for j in range(data_gt.shape[1]):
            for k in range(1, class_num + 1):
                if data_gt[i, j] == k:
                    data_pos[k].append([i, j])

    for k, v in data_pos.items():
        if len(v)<train_num:
            train_seclect = 15
        else:
            train_seclect = train_num
        train_pos[k] = random.sample(v, int(train_seclect))
        test_pos[k] = [i for i in v if i not in train_pos[k]]
    return train_pos, test_pos

In [12]:
# train_pos, test_pos = split_train_test(train_num=200)

In [13]:
# 字典转列表,但是这时依然是字典中的信息
def dict_to_list():
    data_pos_all = list()
    train_pos_all = list()
    test_pos_all = list()

    for k,v in data_pos.items():
        for t in v:
            data_pos_all.append([k,t])

    for k,v in train_pos.items():
        for t in v:
            train_pos_all.append([k,t])

    for k,v in test_pos.items():
        for t in v:
            test_pos_all.append([k,t])
            
    return data_pos_all, train_pos_all, test_pos_all

In [14]:
# data_pos_all, train_pos_all, test_pos_all = dict_to_list()

In [15]:
# 辅助函数，查看各类别信息
def classinfo(train_pos, test_pos, data_pos):
    data_num = 0
    train_num = 0
    test_num = 0
    train_test_num = 0
    for (k1,v1),(k2,v2) in zip(train_pos.items(), test_pos.items()):
        print('traindata-ID %s: %s; testdata-ID %s: %s'%(k1,len(v1),k2,len(v2)))
        train_num += len(v1)
        test_num += len(v2)
    train_test_num = train_num + test_num
    print('total train %s, total test %s, train_test_num %s'%(train_num, test_num, train_test_num))

    for k,v in data_pos.items():
        data_num += len(v)
    print('total data %s'%data_num)

    return data_num, train_num, test_num, train_test_num 

In [16]:
# data_num, train_num, test_num, train_test_num  = classinfo(train_pos, test_pos, data_pos)

# 第一次调用一部分函数：到创建 patch 之前

In [17]:
data, data_gt = loadData()
data_pos = split_background(removeZeroLabels=False)
train_pos, test_pos = split_train_test(train_num=200)
data_pos_all, train_pos_all, test_pos_all = dict_to_list()
data_num, train_num, test_num, train_test_num  = classinfo(train_pos, test_pos, data_pos)

DataSet salinas_corrected shape is (512, 217, 204) class_num is 16
{0: []}
{1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: [], 9: [], 10: [], 11: [], 12: [], 13: [], 14: [], 15: [], 16: []}
traindata-ID 1: 200; testdata-ID 1: 1809
traindata-ID 2: 200; testdata-ID 2: 3526
traindata-ID 3: 200; testdata-ID 3: 1776
traindata-ID 4: 200; testdata-ID 4: 1194
traindata-ID 5: 200; testdata-ID 5: 2478
traindata-ID 6: 200; testdata-ID 6: 3759
traindata-ID 7: 200; testdata-ID 7: 3379
traindata-ID 8: 200; testdata-ID 8: 11071
traindata-ID 9: 200; testdata-ID 9: 6003
traindata-ID 10: 200; testdata-ID 10: 3078
traindata-ID 11: 200; testdata-ID 11: 868
traindata-ID 12: 200; testdata-ID 12: 1727
traindata-ID 13: 200; testdata-ID 13: 716
traindata-ID 14: 200; testdata-ID 14: 870
traindata-ID 15: 200; testdata-ID 15: 7068
traindata-ID 16: 200; testdata-ID 16: 1607
total train 3200, total test 50929, train_test_num 54129
total data 111104


In [18]:
# 创建一个空的ndarray 用于装数据
data_all = np.zeros((data_num, window_size, window_size, data.shape[2])).astype(np.float32)
data_label_all = np.zeros((data_num)).astype(np.int32)

train_all = np.zeros((train_num, window_size, window_size, data.shape[2])).astype(np.float32)
train_label_all = np.zeros((train_num)).astype(np.int32)

test_all = np.zeros((test_num, window_size, window_size, data.shape[2])).astype(np.float32)
test_label_all = np.zeros((test_num)).astype(np.int32)

In [19]:
# data_all.shape, data_label_all.shape, train_all.shape, train_label_all.shape, test_all.shape, test_label_all.shape

In [20]:
# create_patch = neighbor_add
def neighbor_add(row, col, window_size=3):  
    t = window_size // 2
    # 初始化立方体 shape = 1, 1, 204
    cube = np.zeros(shape=[window_size, window_size, data.shape[2]])
    for i in range(-t, t + 1):
        for j in range(-t, t + 1):
            # 如果创建的 cube 在图像之外
            if i + row < 0 or i + row >= data.shape[0] or j + col < 0 or j + col >= data.shape[1]:
                cube[i + t, j + t] = data[row, col]
            else:
                cube[i + t, j + t] = data[i + row, j + col]
    return cube

In [21]:
# cube_t = cube_target
def create_data_all():
    k = 0
    for i in data_pos_all:
        [r,c] = i[1]
        cube_t = neighbor_add(r,c,window_size=window_size).astype(np.float32)
        data_all[k] = cube_t
        # 标签值 - 1
        label_t = np.array(np.array(i[0] - 1).astype(np.int32))
        data_label_all[k] = label_t
        k = k + 1
    return data_all, data_label_all

In [22]:
# cube_t = cube_target
def create_train_all():
    k = 0
    for i in train_pos_all:
        [r,c] = i[1]
        cube_t = neighbor_add(r,c,window_size=window_size).astype(np.float32)
        train_all[k] = cube_t
        # 标签值 - 1
        label_t = np.array(np.array(i[0] - 1).astype(np.int32))
        train_label_all[k] = label_t
        k = k + 1
    return train_all, train_label_all

In [23]:
# cube_t = cube_target
def create_test_all():
    k = 0
    for i in test_pos_all:
        [r,c] = i[1]
        cube_t = neighbor_add(r,c,window_size=window_size).astype(np.float32)
        test_all[k] = cube_t
        # 标签值 - 1
        label_t = np.array(np.array(i[0] - 1).astype(np.int32))
        test_label_all[k] = label_t
        k = k + 1
    return test_all, test_label_all

In [24]:
# data_all, data_label_all = create_data_all()
# train_all, train_label_all = create_train_all()
# test_all, test_label_all = create_test_all()

In [25]:
# data_all.shape, data_label_all.shape, train_all.shape, train_label_all.shape, test_all.shape, test_label_all.shape

In [26]:
def savePreprocessedData(path, data, data_label, train, train_label, test, test_label):
    data_path = os.path.join(os.getcwd(), path)
    print(data_path)

    if savedata:
        with open(os.path.join(data_path, 'data.npy'), 'bw') as outfile:
            np.save(outfile, data_all)
        with open(os.path.join(data_path, 'data_label.npy'), 'bw') as outfile:
            np.save(outfile, data_label_all)  

        with open(os.path.join(data_path, 'train.npy'), 'bw') as outfile:
            np.save(outfile, train_all)
        with open(os.path.join(data_path, 'train_label.npy'), 'bw') as outfile:
            np.save(outfile, train_label_all)
        
        with open(os.path.join(data_path, 'test.npy'), 'bw') as outfile:
            np.save(outfile, test_all)
        with open(os.path.join(data_path, 'test_label.npy'), 'bw') as outfile:
            np.save(outfile, test_label_all)

# 第二次调用这些函数：数据预处理结束并保存数据

In [27]:
data_all, data_label_all = create_data_all()
train_all, train_label_all = create_train_all()
test_all, test_label_all = create_test_all()

In [28]:
data_all.shape, data_label_all.shape, train_all.shape, train_label_all.shape, test_all.shape, test_label_all.shape

((111104, 9, 9, 204),
 (111104,),
 (3200, 9, 9, 204),
 (3200,),
 (50929, 9, 9, 204),
 (50929,))

In [30]:
if not os.path.exists('Salinas_w_size_'+ str(window_size) + '_num_200_for_2D'):
    os.mkdir('Salinas_w_size_'+ str(window_size) + '_num_200_for_2D')
savePreprocessedData('Salinas_w_size_'+ str(window_size) + '_num_200_for_2D', data_all, data_label_all, train_all, train_label_all, test_all, test_label_all)

e:\Eric_HSI\hyper_data_preprocess\Salinas_w_size_9_num_200_for_2D


# 创建datase

In [31]:
# 创建datase
db_train = tf.data.Dataset.from_tensor_slices((train_all, train_label_all))
db_test = tf.data.Dataset.from_tensor_slices((test_all, test_label_all))

# 自定义训练函数不用 repeat
db_train = db_train.shuffle(train_num).batch(batch_size=batch_size)
db_test = db_test.batch(batch_size=batch_size)

In [32]:
db_train, db_test

(<BatchDataset shapes: ((None, 9, 9, 204), (None,)), types: (tf.float32, tf.int32)>,
 <BatchDataset shapes: ((None, 9, 9, 204), (None,)), types: (tf.float32, tf.int32)>)