In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
torch.manual_seed(42)
import matplotlib.pyplot as plt
from PIL import Image
import skimage
np.random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7fbfa8e50b10>

In [None]:
class relu_mlp(nn.Module):
  def __init__(self, neurons, h_layers):
    super(relu_mlp, self).__init__()
    self.neurons = neurons
    self.h_layers = h_layers

    self.layers = nn.ModuleList()
    self.layers.append(nn.Linear(2,neurons))
    for i in range (h_layers-1):
      self.layers.append(nn.Linear(neurons,neurons))
    self.layers.append(nn.Linear(neurons, 3))
  
  def forward (self,x):
    for layer in self.layers[:-1]:
      x = F.relu(layer(x))
    x = self.layers[-1](x)
    return x

In [None]:
class siren_mlp(nn.Module):
  def __init__(self, neurons, h_layers, omega = 30, first_omega = 30):
    super(siren_mlp, self).__init__()
    self.neurons = neurons
    self.h_layers = h_layers
    self.omega = omega
    self.first_omega = first_omega

    self.layers = nn.ModuleList()
    self.layers.append(nn.Linear(1, neurons))
    for i in range(h_layers):
      self.layers.append(nn.Linear(neurons, neurons))
    self.layers.append(nn.Linear(neurons,1))

    nn.init.uniform_(self.layers[0].weight,-1/1,1/1)
    self.layers[0].weight.data *= first_omega

    for layer in self.layers[1:]:
      nn.init.uniform_(layer.weight,-np.sqrt(6/neurons)/omega, np.sqrt(6/neurons)/omega)
      layer.weight.data *= omega
    
  
  def forward (self, x):
    for layer in self.layers[:-1]:
      x = torch.sin(layer(x))
    x = self.layers[-1](x)
    return x

In [None]:
import scipy.io.wavfile as wavfile
import io
from IPython.display import Audio

In [None]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/MyDrive

Mounted at /content/drive
/content/drive/MyDrive


In [None]:
data = wavfile.read('gt_bach.wav')
Audio(data[1], rate = data[0])

In [None]:
data[0]

44100

In [None]:
data = data[1].astype(np.float32)

In [None]:
data

array([0.        , 0.        , 0.        , ..., 0.04731218, 0.02968716,
       0.01679885], dtype=float32)

In [None]:
np.min(data)

-0.868308

In [None]:
len(data)

308207

In [None]:
X = np.arange(1,len(data)+1)
y = data

In [None]:
X

array([     1,      2,      3, ..., 308205, 308206, 308207])

In [None]:
y = y

In [None]:
np.min(y)

-0.868308

In [None]:
device = 'cuda'
torch.cuda.is_available()

True

In [None]:
siren_model = siren_mlp(256,3,30, 3000).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(siren_model.parameters(), lr=5*10**(-5))

In [None]:
siren_psnr = []

In [None]:
from tqdm import tqdm

In [None]:
X = X.reshape(-1,1)

In [None]:
num_epochs = 5000
for e in tqdm(range(num_epochs)):
  y_pred_train = siren_model(torch.from_numpy(X.astype('float32')).to(device)).squeeze()
  loss = criterion(y_pred_train, torch.from_numpy(y.astype('float32')).to(device)) # loss is a tensor which stores the current value of train loss
  siren_psnr.append(20 * np.log10(1.0 / np.sqrt(loss.item())))

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

  if e % 100 == 0:
    print(f"Epoch [{e}/{num_epochs}], Loss: {loss.item():.4f}") # loss.item() gives the numerical value stored in loss
  if e == num_epochs-1:
    print(f"Epoch [{e}/{num_epochs}], Loss: {loss.item():.4f}") # loss.item() gives the numerical value stored in loss
     

  0%|          | 1/5000 [00:00<16:19,  5.10it/s]

Epoch [0/5000], Loss: 0.0107


  2%|▏         | 101/5000 [00:19<19:15,  4.24it/s]

Epoch [100/5000], Loss: 0.0107


  4%|▍         | 201/5000 [00:38<18:12,  4.39it/s]

Epoch [200/5000], Loss: 0.0100


  6%|▌         | 301/5000 [00:57<17:05,  4.58it/s]

Epoch [300/5000], Loss: 0.0095


  8%|▊         | 401/5000 [01:15<16:36,  4.62it/s]

Epoch [400/5000], Loss: 0.0092


 10%|█         | 501/5000 [01:33<16:33,  4.53it/s]

Epoch [500/5000], Loss: 0.0094


 12%|█▏        | 601/5000 [01:52<16:29,  4.44it/s]

Epoch [600/5000], Loss: 0.0103


 14%|█▍        | 701/5000 [02:10<16:04,  4.46it/s]

Epoch [700/5000], Loss: 0.0103


 16%|█▌        | 801/5000 [02:29<15:26,  4.53it/s]

Epoch [800/5000], Loss: 0.0096


 18%|█▊        | 901/5000 [02:47<14:59,  4.56it/s]

Epoch [900/5000], Loss: 0.0083


 20%|██        | 1001/5000 [03:06<14:41,  4.54it/s]

Epoch [1000/5000], Loss: 0.0084


 22%|██▏       | 1101/5000 [03:24<14:33,  4.47it/s]

Epoch [1100/5000], Loss: 0.0084


 24%|██▍       | 1201/5000 [03:43<14:09,  4.47it/s]

Epoch [1200/5000], Loss: 0.0093


 26%|██▌       | 1301/5000 [04:01<13:33,  4.55it/s]

Epoch [1300/5000], Loss: 0.0080


 28%|██▊       | 1401/5000 [04:20<13:05,  4.58it/s]

Epoch [1400/5000], Loss: 0.0080


 30%|███       | 1501/5000 [04:38<12:49,  4.54it/s]

Epoch [1500/5000], Loss: 0.0077


 32%|███▏      | 1601/5000 [04:57<12:33,  4.51it/s]

Epoch [1600/5000], Loss: 0.0072


 34%|███▍      | 1701/5000 [05:15<12:17,  4.47it/s]

Epoch [1700/5000], Loss: 0.0075


 36%|███▌      | 1801/5000 [05:34<11:54,  4.48it/s]

Epoch [1800/5000], Loss: 0.0068


 38%|███▊      | 1901/5000 [05:52<11:25,  4.52it/s]

Epoch [1900/5000], Loss: 0.0066


 40%|████      | 2001/5000 [06:11<11:01,  4.53it/s]

Epoch [2000/5000], Loss: 0.0068


 42%|████▏     | 2101/5000 [06:29<10:45,  4.49it/s]

Epoch [2100/5000], Loss: 0.0067


 44%|████▍     | 2201/5000 [06:48<10:26,  4.47it/s]

Epoch [2200/5000], Loss: 0.0119


 46%|████▌     | 2301/5000 [07:07<09:55,  4.53it/s]

Epoch [2300/5000], Loss: 0.0098


 48%|████▊     | 2401/5000 [07:25<09:32,  4.54it/s]

Epoch [2400/5000], Loss: 0.0074


 50%|█████     | 2501/5000 [07:43<09:06,  4.58it/s]

Epoch [2500/5000], Loss: 0.0074


 52%|█████▏    | 2601/5000 [08:02<08:50,  4.52it/s]

Epoch [2600/5000], Loss: 0.0068


 54%|█████▍    | 2701/5000 [08:20<08:29,  4.51it/s]

Epoch [2700/5000], Loss: 0.0074


 56%|█████▌    | 2801/5000 [08:39<08:08,  4.50it/s]

Epoch [2800/5000], Loss: 0.0065


 58%|█████▊    | 2901/5000 [08:57<07:45,  4.51it/s]

Epoch [2900/5000], Loss: 0.0072


 60%|██████    | 3001/5000 [09:16<07:20,  4.54it/s]

Epoch [3000/5000], Loss: 0.0061


 62%|██████▏   | 3101/5000 [09:34<07:02,  4.49it/s]

Epoch [3100/5000], Loss: 0.0063


 64%|██████▍   | 3201/5000 [09:53<06:38,  4.51it/s]

Epoch [3200/5000], Loss: 0.0059


 66%|██████▌   | 3301/5000 [10:12<06:18,  4.49it/s]

Epoch [3300/5000], Loss: 0.0077


 68%|██████▊   | 3401/5000 [10:30<05:54,  4.50it/s]

Epoch [3400/5000], Loss: 0.0062


 70%|███████   | 3501/5000 [10:49<05:33,  4.49it/s]

Epoch [3500/5000], Loss: 0.0058


 72%|███████▏  | 3601/5000 [11:07<05:11,  4.49it/s]

Epoch [3600/5000], Loss: 0.0075


 74%|███████▍  | 3701/5000 [11:26<04:48,  4.50it/s]

Epoch [3700/5000], Loss: 0.0057


 76%|███████▌  | 3801/5000 [11:44<04:27,  4.49it/s]

Epoch [3800/5000], Loss: 0.0056


 78%|███████▊  | 3901/5000 [12:03<04:02,  4.53it/s]

Epoch [3900/5000], Loss: 0.0056


 80%|████████  | 4001/5000 [12:21<03:41,  4.50it/s]

Epoch [4000/5000], Loss: 0.0063


 82%|████████▏ | 4101/5000 [12:40<03:19,  4.52it/s]

Epoch [4100/5000], Loss: 0.0055


 84%|████████▍ | 4201/5000 [12:58<02:57,  4.50it/s]

Epoch [4200/5000], Loss: 0.0060


 86%|████████▌ | 4301/5000 [13:17<02:34,  4.52it/s]

Epoch [4300/5000], Loss: 0.0051


 88%|████████▊ | 4401/5000 [13:35<02:12,  4.52it/s]

Epoch [4400/5000], Loss: 0.0052


 90%|█████████ | 4501/5000 [13:54<01:50,  4.53it/s]

Epoch [4500/5000], Loss: 0.0052


 92%|█████████▏| 4601/5000 [14:12<01:28,  4.52it/s]

Epoch [4600/5000], Loss: 0.0051


 94%|█████████▍| 4701/5000 [14:31<01:06,  4.52it/s]

Epoch [4700/5000], Loss: 0.0056


 96%|█████████▌| 4801/5000 [14:49<00:43,  4.56it/s]

Epoch [4800/5000], Loss: 0.0053


 98%|█████████▊| 4901/5000 [15:08<00:21,  4.53it/s]

Epoch [4900/5000], Loss: 0.0052


100%|██████████| 5000/5000 [15:26<00:00,  5.40it/s]

Epoch [4999/5000], Loss: 0.0052





In [None]:
audio_siren = siren_model(torch.from_numpy(X.astype('float32')).to(device))

In [None]:
audio_siren = audio_siren.cpu().detach().numpy().squeeze()

In [None]:
audio_siren

array([ 0.0397495 , -0.02673191, -0.06978832, ...,  0.08356562,
        0.02265902, -0.06054626], dtype=float32)

In [None]:
Audio(audio_siren, rate = 44100)

In [None]:
y

array([0.        , 0.        , 0.        , ..., 0.04731218, 0.02968716,
       0.01679885], dtype=float32)

In [None]:
y_pred_train = y_pred_train.cpu()

In [None]:
y_pred_train = y_pred_train.detach().numpy()

In [None]:
y_pred_train

array([ 0.04028275, -0.02608917, -0.0689289 , ...,  0.12037916,
        0.02758162, -0.10478705], dtype=float32)

In [None]:
Audio(y_pred_train, rate = 44100)

In [None]:
from scipy.io.wavfile import write
write('siren_audio.wav', 44100, y_pred_train)

In [None]:
X_extrapolated = X + 308207

In [None]:
audio_siren_extrapolated = siren_model(torch.from_numpy(X_extrapolated.astype('float32')).to(device))

In [None]:
audio_siren_extrapolated = audio_siren_extrapolated.cpu().detach().numpy().squeeze()

In [None]:
audio_siren_extrapolated

array([-0.07883598, -0.04245985, -0.07184346, ...,  0.1567958 ,
        0.23646842,  0.01340024], dtype=float32)

In [None]:
Audio(audio_siren_extrapolated, rate = 44100)

In [None]:
write('siren_audio_extrapolated', 44100, audio_siren_extrapolated)

In [None]:
torch.save(siren_model.state_dict(), 'MyDrive')