In [3]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import wavfile as wav
from scipy.fftpack import fft, fftfreq
from pydub import AudioSegment
import matplotlib
from pydub.playback import play

In [4]:
audio = AudioSegment.from_wav('Sound_Files/All_C_1/1_7.wav')

In [5]:
class Analyzer:

    def __init__(self, file_path,source='Sound_Files/All_C_1/'):
        self.audio = AudioSegment.from_wav(source+file_path)
        self.file_name = file_path
        self.SAMPLE_RATE, self.data = wav.read(source+file_path)
        self.DURATION = len(self.audio)/1000
        self.xf, self.stft_out = None, None

    def play_file(self):
        return self.audio

    def split_audio(self, k=1):
        s = np.round(len(self.audio)/k)
        return [self.audio[i*s:(i+1)*s] for i in range(k)]
        # return np.split(self.audio,k)

    def get_data(self):
        return self.data

    def split_data(self, k=1):
        return np.split(self.data,k)

    def stft(self, k=1):
        arr = self.split_data(k)

        duration = self.DURATION / k
        # N = self.SAMPLE_RATE * self.DURATION
        N = int(self.SAMPLE_RATE * duration)

        self.xf = fftfreq(N, 1 / self.SAMPLE_RATE)
        return self.xf,np.array([fft(arr[i]) for i in range(k)])

    def plot(self, l, r,save_path,k=1,lim=1,rows=1,my_top=0.9):
        cols = int(k/rows)
        fig, axs = plt.subplots(rows,cols,figsize=(20,3*rows))
        cur_data, cur_out = self.stft(k)
        fig.suptitle('File ' + self.file_name)
        fig.subplots_adjust(top=my_top)
        if lim > len(cur_out):
            lim = len(cur_out)
        cur_out = cur_out[:lim]
        for i in range(lim):
            x = int(i/cols)
            y = i%cols
            if k == 1:
                cur_ax = axs
            else:
                if rows == 1:
                    cur_ax = axs[y]
                else:
                    cur_ax = axs[x,y]
            cur_ax.title.set_text('Split ' + str(i+1))
            cur_ax.plot(cur_data,np.abs(cur_out[i]))

            cur_ax.set_ylim(0,2e7)
            cur_ax.set_xlim(l,r)
        if k==1:
            axs.set(xlabel='Frequencies', ylabel='Amplitudes')
        else:
            for ax in axs.flat:
                ax.set(xlabel='Frequencies', ylabel='Amplitudes')

            # Hide x labels and tick labels for top plots and y ticks for right plots.
            for ax in axs.flat:
                ax.label_outer()

        fig.subplots_adjust(hspace=.2)
        plt.savefig('STFT_Graphs/' + save_path)


        plt.show()

In [6]:
test = Analyzer('1_7.wav')
# Gets an stft of 10 splits
x, stft_arr = test.stft(10)