In [1]:
import torch
from torch import nn

In [2]:
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T

print(torch.__version__)
print(torchaudio.__version__)

1.8.1+cu111
0.8.1


In [3]:
input_filepath = './diodeclip-input.wav'
target_filepath = './diodeclip-target.wav'
input_waveform, sample_rate = torchaudio.load(input_filepath)
target_waveform, sample_rate = torchaudio.load(target_filepath)

In [4]:
from audio_utils import play_audio, plot_waveform, print_stats, print_metadata, plot_specgram

for filepath in [input_filepath, target_filepath]:
    metadata = torchaudio.info(filepath)
    print_metadata(metadata, src=filepath)
    waveform, _ = torchaudio.load(filepath)
    print_stats(waveform)

----------
Source: ./diodeclip-input.wav
----------
 - sample_rate: 44100
 - num_channels: 1
 - num_frames: 21123903
 - bits_per_sample: 16
 - encoding: PCM_S

Shape: (1, 21123903)
Dtype: torch.float32
 - Max:      0.563
 - Min:     -0.566
 - Mean:    -0.000
 - Std Dev:  0.074

tensor([[-3.0518e-05, -3.0518e-05, -3.0518e-05,  ..., -1.2207e-04,
          9.1553e-05, -1.2207e-04]])

----------
Source: ./diodeclip-target.wav
----------
 - sample_rate: 44100
 - num_channels: 1
 - num_frames: 21123903
 - bits_per_sample: 16
 - encoding: PCM_S

Shape: (1, 21123903)
Dtype: torch.float32
 - Max:      1.000
 - Min:     -1.000
 - Mean:    -0.007
 - Std Dev:  0.719

tensor([[-0.0013, -0.0014, -0.0013,  ..., -0.0008, -0.0002, -0.0011]])



In [5]:
import numpy as np

assert input_waveform.shape == target_waveform.shape
frames_count = input_waveform.shape[1]
train_frames_count = int(0.8 * frames_count)
train_input_waveform = input_waveform[0, :train_frames_count]
test_input_waveform = input_waveform[0, train_frames_count:]
train_target_waveform = target_waveform[0, :train_frames_count]
test_target_waveform = target_waveform[0, train_frames_count:]
torchaudio.save('./test_target.wav', test_target_waveform.unsqueeze(0), sample_rate)
torchaudio.save('./test_input.wav', test_input_waveform.unsqueeze(0), sample_rate)

In [6]:
class StateTrajectoryNetworkFF(nn.Module):
    def __init__(self, is_trained=False):
        super().__init__()
        self.densely_connected_layers = nn.Sequential(nn.Linear(2, 8, bias=False), nn.Tanh(), nn.Linear(8, 8, bias=False), nn.Tanh(), nn.Linear(8, 1, bias=False))

    def forward(self, x):
        dense_output = self.densely_connected_layers(x)
        output = dense_output[..., 0] + x[..., 1]
        return output.unsqueeze(-1)

stn = StateTrajectoryNetworkFF(is_trained=True)

In [7]:
class StateTrajectoryNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.rnn1 = nn.RNN(input_size=2, hidden_size=8, num_layers=1, nonlinearity='tanh', batch_first=True)
        self.linear = nn.Linear(8, 8)
        self.rnn2 = nn.RNN(input_size=8, hidden_size=1, num_layers=1, nonlinearity='tanh', batch_first=True)

    def forward(self, x, hidden):
        output1, hidden1 = self.rnn1(x)
        output_linear = torch.tanh(self.linear(output1))
        output2, hidden2 = self.rnn2(output_linear)
        output3 = output2[..., 0] + x[..., 1]
        output3.unsqueeze_(-1)
        return output3

# stn = StateTrajectoryNetwork()

In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Using {device} for training.')

stn.to(device)

Using cuda:0 for training.


StateTrajectoryNetworkFF(
  (densely_connected_layers): Sequential(
    (0): Linear(in_features=2, out_features=8, bias=False)
    (1): Tanh()
    (2): Linear(in_features=8, out_features=8, bias=False)
    (3): Tanh()
    (4): Linear(in_features=8, out_features=1, bias=False)
  )
)

In [9]:
import torch.optim as optim

def normalized_mse_loss(output, target):
    minimum_value = 1e-5 * torch.ones_like(target)
    loss = torch.mean(torch.div((target - output) ** 2, torch.maximum(target ** 2, minimum_value)))
    return loss

optimizer = optim.Adam(stn.parameters(), lr=0.001)
criterion = normalized_mse_loss
# criterion = nn.MSELoss()

In [10]:
print(stn)
for param in stn.parameters():
    print(type(param), param.size())

StateTrajectoryNetworkFF(
  (densely_connected_layers): Sequential(
    (0): Linear(in_features=2, out_features=8, bias=False)
    (1): Tanh()
    (2): Linear(in_features=8, out_features=8, bias=False)
    (3): Tanh()
    (4): Linear(in_features=8, out_features=1, bias=False)
  )
)
<class 'torch.nn.parameter.Parameter'> torch.Size([8, 2])
<class 'torch.nn.parameter.Parameter'> torch.Size([8, 8])
<class 'torch.nn.parameter.Parameter'> torch.Size([1, 8])


In [11]:
# Data pre-processing
sequence_length = 2048
segments_count = train_frames_count // sequence_length
input_batch = np.zeros((segments_count, sequence_length, 2))
target_batch = np.zeros((segments_count, sequence_length, 1))
for i in range(segments_count):
    start_id = i * sequence_length
    end_id = (i + 1) * sequence_length
    input_batch[i, :, 0] = train_input_waveform[start_id:end_id]
    input_batch[i, 1:, 1] = train_target_waveform[start_id:end_id-1]
    target_batch[i, :, 0] = train_target_waveform[start_id:end_id]

print(f'1 input minibatch shape: {input_batch.shape}')
print(f'1 target minibatch shape: {target_batch.shape}')

input_batch = torch.tensor(input_batch, dtype=torch.float, device=device)
target_batch = torch.tensor(target_batch, dtype=torch.float, device=device)

1 input minibatch shape: (8251, 2048, 2)
1 target minibatch shape: (8251, 2048, 1)


In [12]:
# Training
epochs = 1000
print_loss_every = 200
segments_in_a_batch = 40
batch_count = segments_count // segments_in_a_batch

for epoch in range(epochs):
    
    running_loss = 0.0

    for i in range(batch_count):
        input_minibatch = input_batch[i*segments_in_a_batch:(i+1)*segments_in_a_batch, :, :]
        target_minibatch = target_batch[i*segments_in_a_batch:(i+1)*segments_in_a_batch, :, :]
        
        optimizer.zero_grad()

        output_minibatch = stn(input_minibatch)

        loss = criterion(output_minibatch, target_minibatch)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % print_loss_every == print_loss_every - 1:
            print('[%d, %5d] loss: %.5f' % (epoch + 1, i + 1, running_loss/print_loss_every))
            running_loss = 0.

print('Finished training.')

[261,   200] loss: 0.02270
[262,   200] loss: 0.02269
[263,   200] loss: 0.02268
[264,   200] loss: 0.02267
[265,   200] loss: 0.02266
[266,   200] loss: 0.02264
[267,   200] loss: 0.02263
[268,   200] loss: 0.02262
[269,   200] loss: 0.02261
[270,   200] loss: 0.02260
[271,   200] loss: 0.02259
[272,   200] loss: 0.02257
[273,   200] loss: 0.02256
[274,   200] loss: 0.02255
[275,   200] loss: 0.02254
[276,   200] loss: 0.02253
[277,   200] loss: 0.02252
[278,   200] loss: 0.02250
[279,   200] loss: 0.02249
[280,   200] loss: 0.02248
[281,   200] loss: 0.02247
[282,   200] loss: 0.02246
[283,   200] loss: 0.02245
[284,   200] loss: 0.02243
[285,   200] loss: 0.02242
[286,   200] loss: 0.02241
[287,   200] loss: 0.02240
[288,   200] loss: 0.02239
[289,   200] loss: 0.02237
[290,   200] loss: 0.02236
[291,   200] loss: 0.02235
[292,   200] loss: 0.02234
[293,   200] loss: 0.02233
[294,   200] loss: 0.02231
[295,   200] loss: 0.02230
[296,   200] loss: 0.02229
[297,   200] loss: 0.02228
[

In [13]:
PATH = './diode_clipper_2x8tanhRNN.pth'

In [14]:
torch.save(stn.state_dict(), PATH)

In [15]:
stn.load_state_dict(torch.load(PATH))

<All keys matched successfully>

In [16]:
from tqdm import tqdm

# 1 batch, 1-element sequence, 2 variables (input and state, i.e., previous output)
stn = stn.cpu()
input_vector = torch.zeros((1, 1, 2), dtype=torch.float)
output_vector = torch.zeros((1, 1, 1), dtype=torch.float)
test_output = torch.zeros_like(test_input_waveform.to('cpu'))

print('Processing test data...')

with torch.no_grad():
    for i, sample in tqdm(enumerate(test_input_waveform), total=test_input_waveform.shape[0]):
        input_vector[0, 0, 0] = sample
        input_vector[0, 0, 1] = output_vector[0, 0, 0]

        output_vector = stn(input_vector)

        test_output[i] = output_vector[0, 0, 0]

    test_loss = criterion(test_output, test_target_waveform)
    print(f'Test loss: {test_loss:.5f}')

Processing test data...
100%|██████████| 4224781/4224781 [19:18<00:00, 3647.78it/s]Test loss: 0.13651



In [17]:
print_stats(test_output.unsqueeze(0))
test_output /= torch.amax(torch.abs(test_output))
print_stats(test_output.unsqueeze(0))
torchaudio.save('./test_output.wav', test_output.unsqueeze(0), sample_rate)

Shape: (1, 4224781)
Dtype: torch.float32
 - Max:      1.026
 - Min:     -1.026
 - Mean:     0.020
 - Std Dev:  0.674

tensor([[ 0.5250,  0.7531,  0.8579,  ..., -0.0019,  0.0009, -0.0019]])

Shape: (1, 4224781)
Dtype: torch.float32
 - Max:      1.000
 - Min:     -1.000
 - Mean:     0.019
 - Std Dev:  0.657

tensor([[ 0.5118,  0.7342,  0.8363,  ..., -0.0018,  0.0009, -0.0018]])

