In [1]:
from pydub import AudioSegment
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.utils import Sequence
import os
from tqdm.auto import tqdm

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]=""
LENGTH=22100
OUT_LENGTH=100

In [3]:
model=Sequential()
model.add(LSTM(300, batch_input_shape=(None,LENGTH, 2)))
model.add(Dense(2*OUT_LENGTH, activation="linear"))

In [4]:
model.compile(loss="mean_squared_error", optimizer="rmsprop")

In [5]:
class DataLoader(Sequence):
    def __init__(self, filename, length,out_length,batch_size):
        sound = AudioSegment.from_mp3(filename)
        sound = np.array(sound.get_array_of_samples())
        left=sound[::2]
        right=sound[1::2]
        sound = np.concatenate([left,right])
        sound = (sound+32768).astype(np.float32)/65535
        self.sound = sound.reshape((2,-1)).T
        self.batch_size = batch_size
        self.length=length
        self.out_length=out_length
    
    def __len__(self):
        return int(np.ceil((self.sound.shape[0] - self.length)/self.batch_size))
    
    def __getitem__(self, idx):
        X=[]
        y=[]
        batch_size = self.batch_size
        length=self.length
        out_length=self.out_length
        for i in range(batch_size):
            X.append(self.sound[idx*batch_size+i:idx*batch_size+length+i])
            y.append(self.sound[idx*batch_size+length+i+1:idx*batch_size+length+out_length+i+1].reshape(-1,))
        X= np.asarray(X)
        y=np.asarray(y)
        if X.shape[1:]==(length,2) and y.shape[1]==out_length*2:
            return X, y
        else:
            print("error loading")
            return np.zeros((1,length,2)), np.zeros((1,out_length*2))


In [6]:
data_gen=DataLoader("ear.mp3", LENGTH,OUT_LENGTH, 32)

In [None]:
model.fit(data_gen, shuffle=True)

  ...
    to  
  ['...']
Train for 3320622 steps
      9/3320622 [..............................] - ETA: 65130:38:09 - loss: 0.1061

In [None]:
now=np.zeros((1,LENGTH,2))+0.5
for i in tqdm(range(1000)):
    pred=model.predict(now[:,-1*LENGTH:,:])
    now=np.concatenate([now,pred.reshape(1,1,2)],axis=1)

In [None]:
plt.plot([i for i in range(now.shape[1])],now[0,:,0])
plt.plot([i for i in range(now.shape[1])],now[0,:,1])