In [18]:
import scipy.io as sio
import numpy as np
import os
from Hyperspectral.patchsize import patch_size
from random import shuffle

**1. 加载数据**

In [19]:
data_path = os.path.join("E:\\pythonProject\\DeepLearning", r'resources\data')
input_data = sio.loadmat(os.path.join(data_path, 'Indian_pines.mat'))['indian_pines']
out_data = sio.loadmat(os.path.join(data_path, 'Indian_pines_gt.mat'))['indian_pines_gt']
print("data:", input_data.shape, out_data.shape, np.min(out_data), np.max(out_data))

data: (145, 145, 220) (145, 145) 0 16


In [20]:
height = input_data.shape[0]
width = input_data.shape[1]
band = input_data.shape[2]
classes = 16

In [21]:
patch_size

11

**2. 图像归一化**

In [22]:

input_data = input_data.astype(float)
input_data -= np.min(input_data)
input_data /= np.max(input_data)

print("normalization data: [",np.min(input_data),'-',np.max(input_data),']')

normalization data: [ 0.0 - 1.0 ]


**3.图像去均值**

In [23]:
mean_array = np.ndarray(shape=(band,),dtype=float)
for i in range(band):
    mean_array[i]=np.mean(input_data[:,:,i])

print("mean_value:", len(mean_array))

# 将数据的均值归0化，进行图像分块

mean_value: 220


In [24]:
def patch(i,j):
    trans_img=np.transpose(input_data,(2,0,1))
    patch=trans_img[:,i:i+patch_size,j:j+patch_size]
    mean_norm_patch=[]
    for i in range(band):
        mean_norm_patch.append(patch[i]-mean_array[i])

    return np.array(mean_norm_patch)


**4. 返回每个类的patch**

In [25]:

ClASSES=[]
for i in range(classes):
    ClASSES.append([])
for i in range(height-patch_size+1):
    for j in range(width-patch_size+1):
        patch_data=patch(i,j)
        curr_out=out_data[i+int((patch_size-1)/2),j+int((patch_size-1)/2)]
        if curr_out!=0:
            ClASSES[curr_out-1].append(patch_data)

In [26]:
for c in ClASSES:
    print(len(c))

46
1428
685
221
423
730
28
478
20
924
2350
561
205
1265
265
93


**5. test 数据和train_数据的划分**


In [27]:
test_split_size=0.25
train_patch=[]
test_patch=[]
test_label=[]

for c in range(classes):
    test_frac=int(len(ClASSES[c])*test_split_size)
    train_patch.append(ClASSES[c][:-test_frac])
    test_patch.extend(ClASSES[c][-test_frac:])
    test_label.extend(np.full(test_frac,c,dtype=int))

for c in range(classes):
    print(len(train_patch[c]))

35
1071
514
166
318
548
21
359
15
693
1763
421
154
949
199
70


In [28]:
print(len(test_patch),len(test_label))

2426 2426


**6. oversamle the train_data to balance the train_data**

In [29]:
COUNT=200
for c in range(classes):
    if len(train_patch[c])<COUNT:
        temp=train_patch[c]
        for i in range(int(COUNT/len(temp))):
            shuffle(train_patch[c])
            train_patch[c]=train_patch[c]+temp
    shuffle(train_patch[c])
    train_patch[c]=train_patch[c][:COUNT]


train_patch=np.reshape(np.asarray(train_patch),(-1,220,patch_size,patch_size))

**7. 建立train_label**

In [30]:
train_labels=np.array([])
for c in range(classes):
    train_labels=np.append(train_labels,np.full(COUNT,c,dtype=int))

print("train_data:",train_patch.shape,"train_label：",train_labels.shape)

train_data: (3200, 220, 11, 11) train_label： (3200,)


**8. 保存训练数据和test数据**


In [31]:
for i in range(int(train_patch.shape[0]/(COUNT*2))):
    data_dict={}
    start=i*COUNT*2
    end=(i+1)*COUNT*2
    filename="train_"+str(patch_size)+'_'+str(i+1)+".mat"
    data_dict['train_data']=train_patch[start:end]
    data_dict['train_label']=train_labels[start:end]
    sio.savemat(os.path.join(data_path,filename),data_dict)

for i in range(int(len(test_patch)/(COUNT*2))):
    data_dict={}
    start=i*COUNT*2
    end=(i+1)*COUNT*2
    filename="test_"+str(patch_size)+'_'+str(i+1)+".mat"
    data_dict['test_data']=train_patch[start:end]
    data_dict['test_label']=train_labels[start:end]
    sio.savemat(os.path.join(data_path,filename),data_dict)

**9. 试加载训练数据**

In [32]:
train_data=sio.loadmat(os.path.join(data_path,'train_11_1.mat'))
test_data=sio.loadmat(os.path.join(data_path,'test_11_1.mat'))
print(train_data['train_data'].shape,train_data['train_label'].shape)
print(test_data['test_data'].shape)


(400, 220, 11, 11) (1, 400)
(400, 220, 11, 11)
