In [None]:
import os

import pywt
import numpy as np
import scipy.signal.spectral as sss
import tensorflow as tf
from PIL import Image
# import seaborn as sns
import matplotlib.pyplot as plt


def batch2image(signal_batch, sampling_rate, pooling_size, process_type='wavelet'):
    """
    批数据转换成图片
    :signal_batch:
    :sampling_rate:
    :pooling_size:
    :process_type:
    :return:
    """
    num, dim = np.shape(signal_batch)
    batch_set = []
    if process_type == 'wavelet':
        for i in range(num):
            image, _ = wavelet2image(signal_batch[i, :], sampling_rate)
            batch_set.append(image)
    elif process_type == 'stft':
        for i in range(num):
            image, _ = stft2image(signal_batch[i, :], sampling_rate)
            batch_set.append(image)
    else:
        raise KeyError("process_type must be wavelet of stft!")
    batch_set = np.array(batch_set)
    batch_set = image_downsampling(batch_set, pooling_size, form='avg_pooling')

    return batch_set

def wavelet2image(signal, sampling_rate, freq_dim_scale=256, wavelet_name='morl'):

    """
    小波图像
    :param signal: 1D temporal sequence
    :param sampling_rate: sampling rate for the sequence  定义了每秒从连续信号中提取并组成离散信号的采样个数
    :param freq_dim_scale: frequency resolution  目的是避免信号混淆保证高频信号不被歪曲成低频信号
    :param wavelet_name: wavelet name for CWT, here we have 'morl', 'gaus', 'cmor',...
    :return: time-freq image and its reciprocal frequencies  时频图像及其倒数频率
    """

    freq_centre = pywt.central_frequency(wavelet_name)            # 所选小波的中心频率
    cparam = 2 * freq_centre * freq_dim_scale
    scales = cparam / np.arange(1, freq_dim_scale + 1, 1)         # 获取小波基函数的尺度参数 a 的倒数
    [cwt_matrix, frequencies] = pywt.cwt(signal, scales, wavelet_name, 1.0 / sampling_rate)

    return abs(cwt_matrix), frequencies

def stft2image(signal, sampling_rate, freq_dim_scale=256, window_name=('gaussian', 3.0)):

    """
    :param signal: signal input for stft
    :param sampling_rate:
    :param window_name: (gaussian,3), hann, hamming, etc.

    Notes
    -----
    Window types:

        `boxcar`, `triang`, `blackman`, `hamming`, `hann`, `bartlett`,
        `flattop`, `parzen`, `bohman`, `blackmanharris`, `nuttall`,
        `barthann`, `kaiser` (needs beta), `gaussian` (needs standard
        deviation), `general_gaussian` (needs power, width), `slepian`
        (needs width), `dpss` (needs normalized half-bandwidth),
        `chebwin` (needs attenuation), `exponential` (needs decay scale),
        `tukey` (needs taper fraction)

    :return: time-freq image and its frequencies
    """

    f, t, Zxx = sss.stft(signal, fs=sampling_rate, window=window_name, nperseg=freq_dim_scale)

    return Zxx, f

def image_downsampling(image_set, pooling_size=2, form='max_pooling', axis=None):

    """
    :param image_set: input image with large size
    :param pooling_size: down-sampling rate
    :param form: 'max_pooling' or 'avg_pooling'
    :param axis: if axis is not None, it means that the image will be down-sampled
                 just within it row(axis=0) or column(axis=1).
    :return: image has been down-sampled
    """

    num, time_dim, freq_dim = np.shape(image_set)[0], np.shape(image_set)[1], np.shape(image_set)[2]
    image_set = image_set.reshape(num, time_dim, freq_dim, 1)
    im_input = tf.placeholder(dtype=tf.float32, shape=[num, time_dim, freq_dim, 1])
    kernel_size = [pooling_size, 2*pooling_size]
    if axis == 0:
        kernel_size = [pooling_size, 1]
    elif axis == 1:
        kernel_size = [1, pooling_size]

    with tf.device('/cpu:0'):
        pooling_max = tf.contrib.slim.max_pool2d(im_input, kernel_size=kernel_size, stride=kernel_size)
        pooling_avg = tf.contrib.slim.avg_pool2d(im_input, kernel_size=kernel_size, stride=kernel_size)

    with tf.Session() as sess:
        down_sampling_im = sess.run(fetches=pooling_max, feed_dict={im_input: image_set})
        if form == 'avg_pooling':
            down_sampling_im = sess.run(fetches=pooling_avg, feed_dict={im_input: image_set})

    return down_sampling_im

def get_batch(filename, window_size=512, batch_size=1000, stride=180):
    data = np.loadtxt(filename)
    print(data.shape)
    start = 0
    cnt = 0
    batch_data = []
    while start + window_size < data.shape[0] and cnt < batch_size:
        batch_data.append(data[start: start + window_size])
        start = start + stride + 1
        cnt += 1
    batch_data = np.array(batch_data)
    return batch_data

for root, dirs, files in os.walk("./datas/origin"):
    print(files)

output_path = './datas/image'
if not os.path.exists(output_path):
    os.makedirs(output_path)

    

for i, file in enumerate(files):
    #print(1)
    print("processing %s" % file)
    c = file.split('_')[0] if "normal" in file else file.split('_')[3]
    #label = "{}_{}".format(c, i)  # 师兄写的
    label = c.split('.')[0]  # 我写的
    print(label)
    file_path = os.path.join(root, file)
    print(file_path)
    signal = get_batch(filename=file_path, batch_size=2000)
    batch_image = batch2image(signal, sampling_rate=1, pooling_size=4)
    print("saving %s/%s.npy, shape %s" % (output_path, label, batch_image.shape))
    np.save("%s/%s.npy" % (output_path, label), batch_image)

In [None]:
import numpy as np
import keras
import matplotlib.pyplot as plt
%matplotlib inline
from keras.utils import plot_model
from keras.utils import np_utils

In [None]:
def creat_dataset1(select):
    x_all = []
    y_all = []
    i=0
    for elem in select:
        print(elem[0])
#        sig = df[elem[0]]
      #  print(len(sig))
        if elem[1] == 0:
            sig =np.load("./datas/image/IF.npy")
        elif elem[1] == 1:
            sig =np.load("./datas/image/OF.npy")
        else:
            sig =np.load("./datas/image/BF.npy")
      #  i = i+1   
        label = elem[1]
        x = sig
        print(x.shape)
      #  print(x.shape[0])
        x_all.append(x)
        y = [[label] for _ in range(x.shape[0])]
        y_all.append(y)
    x_merge = np.vstack(x_all)  # 在竖直方向上堆叠
    y_merge = np.vstack(y_all)
    return x_merge, y_merge
def creat_dataset2(select):
    x_all = []
    y_all = []
    i=0
    for elem in select:
        print(elem[0])
#        sig = df[elem[0]]
      #  print(len(sig))
        if elem[1] == 0:
            sig =np.load("./datas/image/IO.npy")
        elif elem[1] == 1:
            sig =np.load("./datas/image/IB.npy")
        
        elif elem[1] == 2:
            sig =np.load("./datas/image/OB.npy")
        else:
            sig =np.load("./datas/image/IOB.npy")
      #  i = i+1   
        label = elem[1]
        x = sig
        print(x.shape)
      #  print(x.shape[0])
        x_all.append(x)
        y = [[label] for _ in range(x.shape[0])]
        y_all.append(y)
    x_merge = np.vstack(x_all)
    y_merge = np.vstack(y_all)
    return x_merge, y_merge

In [None]:
data_train = [('Inner',0), ('Outter', 1), ('Ball', 2)]
data_test  = [('IO',0), ('IB', 1), ('OB', 2),('IOB', 3)]

x_train, y_train = creat_dataset1(select=data_train)
x_test, y_test = creat_dataset2(select=data_test)

np.save('./data/train/x_trian', x_train)
np.save('./data/train/y_train', y_train)
np.save('./data/test/x_test', x_test)
np.save('./data/test/y_test', y_test)

In [None]:
idx1 = np.random.randint(0, 2000,1)
idx2 = np.random.randint(2000, 4000,1)
idx3 = np.random.randint(4000, 6000,1)
idx4 = np.random.randint(6000, 8000,1)

x_train =np.load("./datas/train/x_train.npy")
x_test =np.load("./datas/test/x_test.npy")

IF = x_train[idx1]
OF = x_train[idx2]
BF = x_train[idx3]

IO = x_test[idx1]
IB = x_test[idx2]
OB = x_test[idx3]
IOB = x_test[idx4]
#print(IO.shape)
IF = IF.reshape(64,64)  
OF = OF.reshape(64,64) 
BF = BF.reshape(64,64) 

IO = IO.reshape(64,64)   
IB = IB.reshape(64,64)  
OB = OB.reshape(64,64)  
IOB = IOB.reshape(64,64) 

In [None]:
import matplotlib.pyplot as pyplot 
#import cv2 as cv
# cv.imshow('GrayImage', IO)

# GDATA = rgb2gray(IO)
# pyplot.imshow(GDATA)
plt.imshow(IF, cmap = plt.get_cmap('gray'))

In [None]:
import matplotlib.pyplot as pyplot 
#import cv2 as cv
# cv.imshow('GrayImage', IO)

# GDATA = rgb2gray(IO)
# pyplot.imshow(GDATA)
plt.imshow(OF, cmap = plt.get_cmap('gray'))

In [None]:
plt.imshow(BF, cmap = plt.get_cmap('gray'))

In [None]:
plt.imshow(IO, cmap = plt.get_cmap('gray'))

In [None]:
plt.imshow(IB, cmap = plt.get_cmap('gray'))

In [None]:
plt.imshow(OB, cmap = plt.get_cmap('gray'))

In [None]:
plt.imshow(IOB, cmap = plt.get_cmap('gray'))