In [None]:
import torch
from torch import nn

In [None]:
from CoreAudioML.dataset import DataSet

dataset = DataSet(data_dir='./')
dataset.create_subset('train', frame_len=44100)
dataset.create_subset('test')
dataset.load_file('diodeclip', set_names=['train', 'test'], splits=[0.79, 0.21])

In [None]:
# segment_length segments_count=batch_count features_count(1 sample)
print(dataset.subsets['train'].data['input'][0].shape)
print(dataset.subsets['train'].data['target'][0].shape)
print(dataset.subsets['test'].data['input'][0].shape)
print(dataset.subsets['test'].data['target'][0].shape)

In [None]:
class StateTrajectoryNetworkFF(nn.Module):
    def __init__(self):
        super().__init__()
        self.densely_connected_layers = nn.Sequential(nn.Linear(2, 8, bias=False), nn.Tanh(), nn.Linear(8, 8, bias=False), nn.Tanh(), nn.Linear(8, 1, bias=False))

    def forward(self, x):
        dense_output = self.densely_connected_layers(x)
        output = dense_output[..., 0] + x[..., 1]
        return output.unsqueeze(-1)

# stn = StateTrajectoryNetworkFF()

In [None]:
# No teacher forcing as for now
class StateTrajectoryNetwork(nn.Module):
    def __init__(self, is_trained=False):
        super().__init__()
        self.rnn = torch.nn.RNN(input_size=1, hidden_size=1, num_layers=2, nonlinearity='tanh', bias=False)
        self.hidden = None

    def forward(self, x):
        out, self.hidden = self.rnn(x, self.hidden)
        print(x.shape)
        print(out.shape)
        print(self.hidden.shape)
        return out + x

    def initialize_state(self, batch_size, state_size):
        self.state = torch.zeros((batch_size, state_size))

stn = StateTrajectoryNetwork()

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Using {device} for training.')

stn.to(device)

In [None]:
import torch.optim as optim

def normalized_mse_loss(output, target):
    minimum_value = 1e-5 * torch.ones_like(target)
    loss = torch.mean(torch.div((target - output) ** 2, torch.maximum(target ** 2, minimum_value)))
    return loss

optimizer = optim.Adam(stn.parameters(), lr=0.001)
criterion = normalized_mse_loss
# criterion = nn.MSELoss()

In [None]:
print(stn)
for param in stn.parameters():
    print(type(param), param.size())

In [None]:
import numpy as np

# Training
epochs = 100
print_loss_every = 200
segments_in_a_batch = 40

loss_history = torch.zeros((epochs,), device=device)
gradient_norm_history = torch.zeros((epochs,), device=device)

input_data = dataset.subsets['train'].data['input'][0].to(device)
target_data = dataset.subsets['train'].data['target'][0].to(device)

segments_count = input_data.shape[1]
batch_count = int(np.ceil(segments_count / segments_in_a_batch))

segments_order = torch.randperm(batch_count)

for epoch in range(epochs):
    
    running_loss = 0.0

    for i in range(batch_count):
        minibatch_segment_indices = segments_order[i*segments_in_a_batch:(i+1)*segments_in_a_batch]
        input_minibatch = input_data[:, minibatch_segment_indices, :]
        target_minibatch = target_data[:, minibatch_segment_indices, :]
        
        optimizer.zero_grad()

        output_minibatch = stn(input_minibatch)

        loss = criterion(output_minibatch, target_minibatch)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % print_loss_every == print_loss_every - 1:
            print('[%d, %5d] loss: %.5f; Running loss: %.5f' % (epoch + 1, i + 1, loss.item(), running_loss/print_loss_every))
            running_loss = 0.
        
    loss_history[epoch] = loss.item()
    gradient = torch.cat([param.grad.flatten() for param in stn.parameters()])
    gradient_norm_history[epoch] = torch.linalg.norm(gradient)

print('Finished training.')

In [None]:
import matplotlib.pyplot as plt

plt.figure()
plt.plot(loss_history.cpu())
plt.xlabel('Epochs')
plt.ylabel('Loss (Normalized MSE)')

plt.figure()
plt.plot(gradient_norm_history.cpu())
plt.xlabel('Epochs')
plt.ylabel('Gradient L2 norm')

In [None]:
PATH = './diode_clipper_2x8tanhRNN.pth'

In [None]:
torch.save(stn.state_dict(), PATH)

In [None]:
stn.load_state_dict(torch.load(PATH))

In [None]:
from tqdm import tqdm

# 1 batch, 1-element sequence, 2 variables (input and state, i.e., previous output)
stn = stn.cpu()
input_vector = torch.zeros((1, 1, 2), dtype=torch.float)
output_vector = torch.zeros((1, 1, 1), dtype=torch.float)
test_output = torch.zeros_like(test_input_waveform.to('cpu'))

print('Processing test data...')

with torch.no_grad():
    for i, sample in tqdm(enumerate(test_input_waveform), total=test_input_waveform.shape[0]):
        input_vector[0, 0, 0] = sample
        input_vector[0, 0, 1] = output_vector[0, 0, 0]

        output_vector = stn(input_vector)

        test_output[i] = output_vector[0, 0, 0]

    test_loss = criterion(test_output, test_target_waveform)
    print(f'Test loss: {test_loss:.5f}')

In [None]:
print_stats(test_output.unsqueeze(0))
test_output = torch.clamp(test_output, -1., 1.)
print_stats(test_output.unsqueeze(0))
torchaudio.save('./test_output.wav', test_output.unsqueeze(0), sample_rate)