In [1]:
import numpy  as np
import torch

In [28]:
class Masked_dataset():

    """
    生成 masked dataset,
    我们输入的原始光谱是 (number_data, seq_len), 其中seq len就是光谱向量的长度
    我们目前考虑在完整的长度为3321个sampling points组成的2950~3150的光谱范围内
    随机选择两个端点， 截取端点内的光谱，作为训练数据，
    
    """

    def __init__(self, if_fixed_window_length):

        self.if_fixed_window_length = if_fixed_window_length

    def generate_mask(self, selected_window_length, total_sequence_length):

        assert selected_window_length < total_sequence_length, "the selected window length should be lower than the total sequence length"
        mask = np.zeros(total_sequence_length)
        
        if self.if_fixed_window_length:
            
            start_point = np.random.randint(0, total_sequence_length-selected_window_length)
            end_point = start_point + selected_window_length

            mask[: selected_window_length] = 1
            """
            这里的mask是一个长度为seq len的向量，因为我们目前是截取之后从头放在完整的光谱内，所以从0到截止位置，都是真实的光谱，
            截止位置往后开始，是padding补的0，mask从这里开始全部是0，因此将来在使用mask时候，不会因为光谱本身的0而被分配一个-inf的
            数值在计算attention的时候
            """
        else:
            start_point = 0
            end_point = 0
            # 使用循环确保起点和终点的差大于等于 200
            # 这里200是一个保险，太小了光谱范围太短了，而且通常也不会这么小
            while abs(start_point - end_point) < 200:
                start_point = np.random.randint(0, total_sequence_length)
                end_point = np.random.randint(0, total_sequence_length)
                
            mask[: end_point-start_point] = 1
            
        return mask, start_point, end_point

    def apply_mask(self, dataset, selected_window_length):

        number_data = dataset.shape[0]
        total_data_length = dataset.shape[1]
        checkpoints = np.zeros((number_data, 2))
        mask_list = np.zeros((number_data, total_data_length))
        masked_dataset = np.zeros((number_data, total_data_length))
        
        for i in range(number_data):
            mask, start, end = self.generate_mask(selected_window_length, total_data_length)
            mask_list[i] = mask
            masked_dataset[i, :end-start] = dataset[i, start: end]
            checkpoints[i, 0] = start
            checkpoints[i, 1] = end

        """
        mask list就是所有的mask
        chekcpoint记录了在原施光谱中截取的端点的索引，有了这个，将来可以通过checkpoints对应到Nu上了
        self.dataset 就是保存截取之后，并且被padding 0 了的新的数据集，每一行是一个截取的光谱，并且label与原施数据集的label对应

        """
        return mask_list, checkpoints, masked_dataset

In [29]:
a = np.random.randn(10, 10)
print(a.shape)

(10, 10)


In [30]:
mask_dataset = Masked_dataset(False)
mask_list, checkpoints, dataset = mask_dataset.apply_mask(a, 5)

In [31]:
print(a)

[[-0.10131752 -1.80570545 -0.94228085 -1.89822749 -0.03223352 -0.92105166
  -0.18840933  0.01064409  0.43751709  0.94069841]
 [ 0.55009984 -0.56001683  0.60601304 -1.05222666 -0.1087534  -1.75418838
   0.54089968 -0.63365475  0.07446843  0.06098322]
 [ 0.16445849 -0.62098471 -0.86967252 -0.60899879 -0.01606063 -0.35421297
  -1.23207985  1.67941567 -2.35239399  1.77483578]
 [ 0.15136193 -0.13105176 -1.13044838 -0.41982815 -1.4110316   1.69085568
  -1.11750465  0.04438985  0.55317002 -0.34173056]
 [-0.37253958 -1.69588698 -0.40981865 -1.56280347 -0.83977588 -0.99886077
  -0.14713719  3.06756589  0.51910039  0.20905589]
 [-0.63015401 -1.38024705 -1.45493027 -1.27426766  0.23388429  1.00399153
   0.40752747 -0.03053307 -1.82687426  0.15567778]
 [ 0.12074594 -2.59121186  0.47140079 -1.10862341 -1.76158511 -2.16326841
  -0.47158371  0.39573743  0.76062327 -0.01834627]
 [ 2.17181355 -0.14243467 -0.2308977   1.95159368 -0.38842306 -0.84359483
   0.45483994  0.22661657 -0.01153107  0.76791286]


In [32]:
print(checkpoints)

[[8. 8.]
 [2. 7.]
 [0. 7.]
 [4. 5.]
 [0. 6.]
 [1. 2.]
 [1. 9.]
 [3. 6.]
 [0. 1.]
 [3. 6.]]


In [33]:
print(mask_list)

[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1. 0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1. 1. 1. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1. 1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1. 1. 1. 1. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0. 0. 0. 0. 0.]]
