In [2]:
# dataset
import torch
from torch.utils.data import Dataset
from torch.utils.data import random_split

from torchaudio import datasets
import torchaudio.transforms

from torch.utils.data import DataLoader

from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F

# audio processing
import numpy as np
import librosa
import librosa.display
import matplotlib.pyplot as plt

# neural network
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.transforms as transforms
import torchvision.models as models

#set device to GPU
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [49]:
# convert audio from VCTK to STFT and stack real/imaginary values to 2 len array
def stft_vctk(example):
    #512 is the recommended length for windowed signal (speech)
    sample_rate = example[1]
    audio = np.squeeze(example[0].numpy())
    audio_stft = librosa.stft(audio, n_fft = 512)
    
    # processing
    real = np.real(audio_stft)
    imag = np.imag(audio_stft)
    audio_stft_sep = np.stack([real, imag])
    
    return audio_stft_sep, sample_rate

# subclass that implements STFT conversion when getting items
class STFT_Dataset(datasets.VCTK_092):
    def __init__(self, path):
        super().__init__(root = path)
        
    def _load_audio(self, file_path):
        s, sr = stft_vctk(torchaudio.load(file_path))
        return (s, sr)
    
    def __len__(self):
        return super().__len__()
    def __getitem__(self, idx):
        # return wav file, label of speaker
        return super().__getitem__(idx)[0], super().__getitem__(idx)[3]

In [50]:
import pickle

# initialise dataset
stft_data = STFT_Dataset("VCTK")

# get means
feature_mean_sums = np.zeros((stft_data.__getitem__(0)[0].shape[0], stft_data.__getitem__(0)[0].shape[1]))
print(feature_mean_sums.shape)

# collect mean sums from dataset
for x in range(stft_data.__len__()):
    if x % 1000 == 0:
        print(x)
    example = stft_data.__getitem__(x)[0]
    feature_mean_sums = np.add(feature_mean_sums, np.mean(example, axis = 2))

(2, 257)
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000


In [51]:
#get means, save to file
dataset_feature_mean = feature_mean_sums / stft_data.__len__()
print(dataset_feature_mean.shape)
np.save("mean_comp", dataset_feature_mean)

(2, 257)


In [60]:
# sanity check to ensure calculated means are correct
print(dataset_feature_mean[1][0])
print(dataset_feature_mean[0][0])

ex_mean_sum_real = 0
ex_mean_sum_imag = 0

for x in range(stft_data.__len__()):
    if x % 1000 == 0:
        print(x)
    ex_mean_sum_imag += np.mean(stft_data.__getitem__(x)[0][1][0])
    ex_mean_sum_real += np.mean(stft_data.__getitem__(x)[0][0][0])

ex_mean_sum_imag /= stft_data.__len__()
ex_mean_sum_real /= stft_data.__len__()

print(ex_mean_sum_imag)
print(ex_mean_sum_real)

0.0
-0.00036830184406802334
0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
0.0
-0.0003683018691618449


In [56]:
print(dataset_feature_mean[1][0])

0.0


In [61]:
# get variances
feature_variance_sums = np.zeros((stft_data.__getitem__(0)[0].shape[0], stft_data.__getitem__(0)[0].shape[1]))
print(feature_variance_sums.shape)

(2, 257)


In [92]:
# collect squared sums of (data - mean) across dataset
for x in range(stft_data.__len__()):
    if x % 1000 == 0:
        print(x)
    example = stft_data.__getitem__(x)[0]
    example = example - dataset_feature_mean[:, :, np.newaxis]
    sq_sum = np.mean(example ** 2, axis = 2)
    feature_variance_sums = np.add(feature_variance_sums, sq_sum)

0
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000


In [93]:
print(feature_variance_sums)
# get standard deviation, save to file
std = np.sqrt(feature_variance_sums / stft_data.__len__())

print(std.shape)
np.save("std_comp", std)

[[1.71428253e+05 1.68792491e+05 3.02366085e+05 2.85537368e+05
  2.97780171e+05 3.85076112e+05 2.94286523e+05 1.96082732e+05
  1.26277176e+05 7.26899237e+04 4.52176849e+04 3.13940961e+04
  2.26275774e+04 1.74638557e+04 1.49826227e+04 1.32858360e+04
  1.20802230e+04 1.09859464e+04 9.51494389e+03 7.68106787e+03
  6.07377546e+03 4.91634840e+03 3.95374400e+03 3.21929108e+03
  2.89182087e+03 2.85727930e+03 2.97896485e+03 3.02768163e+03
  2.89718858e+03 2.58469976e+03 2.22404497e+03 1.91389792e+03
  1.68100323e+03 1.53088541e+03 1.42357608e+03 1.32607743e+03
  1.20562471e+03 1.08949959e+03 9.86957372e+02 9.11297924e+02
  8.60258223e+02 8.12605890e+02 7.68884913e+02 7.37536779e+02
  7.06344297e+02 6.76524544e+02 6.46211403e+02 6.12511638e+02
  5.82742354e+02 5.60240969e+02 5.43957312e+02 5.26836035e+02
  5.13388391e+02 4.95168791e+02 4.72133878e+02 4.51139806e+02
  4.29672351e+02 4.10024140e+02 3.94095025e+02 3.78160871e+02
  3.65302270e+02 3.55281393e+02 3.46362722e+02 3.36058584e+02
  3.2783

In [109]:
# normalise then use pickle to save to folder
def normalise(stft, mean, std):
    std = np.where(std == 0, 1, std)
    stft = (stft - mean[:, :, np.newaxis]) / std[:, :, np.newaxis]
    return stft

# write to folder
for x in range(stft_data.__len__()):
    item = stft_data.__getitem__(x)
    item_norm_stft = normalise(item[0], dataset_feature_mean, std)
    item_label = item[1]
    
    with open("STFTcomp/stftc" + str(x), "wb") as file:
        pickle.dump([item_norm_stft, item_label], file)
    

In [107]:
# sanity check
with open("STFTcomp/stftc0", "rb") as file:
    test = pickle.load(file)
    
print(len(test))
print(test[0].shape)
print(test[1])

2
(2, 257, 770)
p225


In [110]:
# sanity check 2
ind = stft_data.__len__() - 1
with open("STFTcomp/stftc" + str(ind), "rb") as file:
    test = pickle.load(file)

print(test[0].shape)
print(test[1])

test2 = stft_data.__getitem__(ind)
print(test2[0].shape)
print(test[1])

(2, 257, 1838)
s5
(2, 257, 1838)
s5
