In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import glob
import librosa
import numpy as np

In [2]:
class AudioDataset(Dataset):
    def __init__(self, path, transform=None, sample_rate=16000):
        self.path = path
        self.data_list = glob.glob(self.path + '/*.wav')
        
        self.transform = transform
        self.sr = sample_rate
        self.frame_length = 0.025 # win_length, 자연어 처리 분야에서 25ms 크기를 기본으로 하고 있음 (16000Hz -> 400)
        self.frame_stride = 0.0126 # hop_length, 일반적으로 10ms의 크기를 기본으로 하고 있음 (16000Hz -> 160)
        
    def __len__(self):
        return len(self.data_list)
    
    def __getitem__(self, idx):
        data_path = self.data_list[idx]
        data = self.normalize(self.Mel_S(data_path))
        data = np.expand_dims(data, axis=0)
        
        if self.transform is not None:
            data = self.transform(data)
        
        return data # (1, 40, 80)
    
    def Mel_S(self, wav_file):
        y, sr = librosa.load(wav_file, sr=self.sr)
    
        input_nfft = int(round(sr*self.frame_length))
        input_stride = int(round(sr*self.frame_stride))

        s = librosa.feature.melspectrogram(y=y, n_mels=40, n_fft=input_nfft, hop_length=input_stride)
#         print(f"Wav lenght : {len(y)/sr}, Mes_S shape : {np.shape(S)}")
        
        return s
    
    def normalize(self, s):
        s = (s - s.mean()) / s.std() # standardization
        s = (s - s.min()) / (s.max() - s.min()) # min max normalization
        
        return s

In [3]:
dataset = AudioDataset(path='data/원천데이터/after')

In [4]:
dataset.__len__()
dataset_indices = list(range(dataset.__len__()))
np.random.shuffle(dataset_indices)
val_split_index = int(np.floor(0.2 * dataset.__len__()))
train_idx, val_idx = dataset_indices[val_split_index:], dataset_indices[:val_split_index]

train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)

In [5]:
train_loader = DataLoader(dataset=dataset, batch_size=100, sampler=train_sampler, shuffle=False)
val_loader = DataLoader(dataset=dataset, batch_size=100, sampler=val_sampler, shuffle=False)

In [6]:
class VAE(nn.Module):
    def __init__(self, input_channel=1, h_dim=128*5*10, z_dim=512):
        super(VAE, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten()
        )
        
        self.fc1 = nn.Linear(h_dim, z_dim)
        self.fc2 = nn.Linear(h_dim, z_dim)
        self.fc3 = nn.Linear(z_dim, h_dim)
        
        self.decoder = nn.Sequential(
            nn.Unflatten(1, (128, 5, 10)),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()
        )
        
    def reparameterization(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return mu + std*eps # return z sample
    
    def bottleneck(self, h):
        mu, log_var = self.fc1(h), self.fc2(h)
        z = self.reparameterization(mu, log_var)
        return z, mu, log_var
    
    def encode(self, x):
        h = self.encoder(x)
        z, mu, log_var = self.bottleneck(h)
        return z, mu, log_var
    
    def decode(self, z):
        z = self.fc3(z)
        return self.decoder(z)
    
    def forward(self, x):
        z, mu, log_var = self.encode(x)
        recon = self.decode(z)
        return recon, mu, log_var

vae = VAE()
if torch.cuda.is_available():
    vae.cuda()

In [7]:
vae

VAE(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (9): Flatten(start_dim=1, end_dim=-1)
  )
  (fc1): Linear(in_features=6400, out_features=512, bias=True)
  (fc2): Linear(in_features=6400, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=6400, bias=True)
  (decoder): Sequential(
    (0): Unflatten(dim=1, unflattened_size=(128, 5, 10))
    (1): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (2): ReLU()
    (3): ConvTranspose2d(64

In [8]:
optimizer = optim.Adam(vae.parameters())

def loss_function(recon_x, x, mu, log_var):
    KLD = -0.5*torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
#     RMSE = torch.sqrt(F.mse_loss(recon_x, x))
    return KLD + BCE

In [9]:
def train(epoch):
    vae.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        data = data.cuda()
        optimizer.zero_grad()

        recon_batch, mu, log_var = vae(data)
        loss = loss_function(recon_batch, data, mu, log_var)

        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        if batch_idx % 10 == 0:
            print('Train Epoch : {} [{}/{} ({:.0f}%)]\tLoss : {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item() / len(data)
            ))
    print('====> Epoch : {} Average loss : {:.4f}'.format(epoch, train_loss / len(train_loader.dataset)))

In [10]:
def val():
    vae.eval()
    val_loss = 0
    with torch.no_grad():
        for data in val_loader:
            data = data.cuda()
            recon, mu, log_var = vae(data)

            val_loss += loss_function(recon, data, mu, log_var).item()

    val_loss /= len(val_loader.dataset)
    print('====> Validation set loss : {:.4f}'.format(val_loss))

In [11]:
for epoch in range(1, 101):
    train(epoch)
    val()

====> Epoch : 1 Average loss : 657.9008
====> Validation set loss : 76.6534
====> Epoch : 2 Average loss : 281.8653
====> Validation set loss : 66.0030
====> Epoch : 3 Average loss : 259.3653
====> Validation set loss : 62.9675
====> Epoch : 4 Average loss : 251.4497
====> Validation set loss : 61.8803
====> Epoch : 5 Average loss : 245.9747
====> Validation set loss : 60.6391
====> Epoch : 6 Average loss : 243.6785
====> Validation set loss : 60.4416
====> Epoch : 7 Average loss : 241.1527
====> Validation set loss : 59.6593
====> Epoch : 8 Average loss : 238.8188
====> Validation set loss : 59.1496
====> Epoch : 9 Average loss : 236.8401
====> Validation set loss : 58.4667
====> Epoch : 10 Average loss : 234.4619
====> Validation set loss : 58.0401
====> Epoch : 11 Average loss : 232.6297
====> Validation set loss : 57.7525
====> Epoch : 12 Average loss : 230.8287
====> Validation set loss : 57.2666
====> Epoch : 13 Average loss : 228.8740
====> Validation set loss : 56.7122
====> Ep

====> Epoch : 36 Average loss : 213.1486
====> Validation set loss : 53.0986
====> Epoch : 37 Average loss : 212.5991
====> Validation set loss : 53.0336
====> Epoch : 38 Average loss : 212.4889
====> Validation set loss : 53.0268
====> Epoch : 39 Average loss : 212.1045
====> Validation set loss : 52.9475
====> Epoch : 40 Average loss : 211.9955
====> Validation set loss : 53.0098
====> Epoch : 41 Average loss : 211.7148
====> Validation set loss : 52.8033
====> Epoch : 42 Average loss : 211.8331
====> Validation set loss : 52.8782
====> Epoch : 43 Average loss : 211.4848
====> Validation set loss : 52.7659
====> Epoch : 44 Average loss : 211.3855
====> Validation set loss : 52.7497
====> Epoch : 45 Average loss : 211.0453
====> Validation set loss : 52.6394
====> Epoch : 46 Average loss : 211.0344
====> Validation set loss : 52.7297
====> Epoch : 47 Average loss : 210.9249
====> Validation set loss : 52.6491
====> Epoch : 48 Average loss : 210.7166
====> Validation set loss : 52.6554

====> Epoch : 72 Average loss : 208.6407
====> Validation set loss : 52.1840
====> Epoch : 73 Average loss : 208.6014
====> Validation set loss : 52.0833
====> Epoch : 74 Average loss : 208.4875
====> Validation set loss : 52.0858
====> Epoch : 75 Average loss : 208.4603
====> Validation set loss : 52.1224
====> Epoch : 76 Average loss : 208.5676
====> Validation set loss : 52.1942
====> Epoch : 77 Average loss : 208.4197
====> Validation set loss : 52.1395
====> Epoch : 78 Average loss : 208.2690
====> Validation set loss : 52.1288
====> Epoch : 79 Average loss : 208.5185
====> Validation set loss : 52.0219
====> Epoch : 80 Average loss : 208.2191
====> Validation set loss : 52.0385
====> Epoch : 81 Average loss : 208.0399
====> Validation set loss : 52.0869
====> Epoch : 82 Average loss : 208.0150
====> Validation set loss : 52.0288
====> Epoch : 83 Average loss : 207.9973
====> Validation set loss : 52.0621
====> Epoch : 84 Average loss : 208.1080
====> Validation set loss : 51.9373

In [12]:
for epoch in range(101, 201):
    train(epoch)
    val()

====> Epoch : 101 Average loss : 207.2832
====> Validation set loss : 51.9598
====> Epoch : 102 Average loss : 207.2567
====> Validation set loss : 51.8334
====> Epoch : 103 Average loss : 207.0853
====> Validation set loss : 51.8414
====> Epoch : 104 Average loss : 207.1214
====> Validation set loss : 51.8359
====> Epoch : 105 Average loss : 207.0699
====> Validation set loss : 51.8523
====> Epoch : 106 Average loss : 207.0765
====> Validation set loss : 51.8397
====> Epoch : 107 Average loss : 206.9614
====> Validation set loss : 51.7600
====> Epoch : 108 Average loss : 206.8117
====> Validation set loss : 51.7884
====> Epoch : 109 Average loss : 206.9434
====> Validation set loss : 51.7278
====> Epoch : 110 Average loss : 206.8136
====> Validation set loss : 51.7945
====> Epoch : 111 Average loss : 207.0770
====> Validation set loss : 51.7780
====> Epoch : 112 Average loss : 206.7364
====> Validation set loss : 51.8057
====> Epoch : 113 Average loss : 206.8839
====> Validation set l

====> Validation set loss : 51.6470
====> Epoch : 136 Average loss : 206.0463
====> Validation set loss : 51.6463
====> Epoch : 137 Average loss : 206.1201
====> Validation set loss : 51.6518
====> Epoch : 138 Average loss : 206.0579
====> Validation set loss : 51.6352
====> Epoch : 139 Average loss : 205.9351
====> Validation set loss : 51.6017
====> Epoch : 140 Average loss : 206.0922
====> Validation set loss : 51.6449
====> Epoch : 141 Average loss : 206.0331
====> Validation set loss : 51.5580
====> Epoch : 142 Average loss : 206.1020
====> Validation set loss : 51.5717
====> Epoch : 143 Average loss : 205.9491
====> Validation set loss : 51.6576
====> Epoch : 144 Average loss : 205.9642
====> Validation set loss : 51.6442
====> Epoch : 145 Average loss : 205.8519
====> Validation set loss : 51.6575
====> Epoch : 146 Average loss : 205.9078
====> Validation set loss : 51.6063
====> Epoch : 147 Average loss : 205.8508
====> Validation set loss : 51.5926
====> Epoch : 148 Average lo

====> Epoch : 170 Average loss : 205.5862
====> Validation set loss : 51.6117
====> Epoch : 171 Average loss : 205.6623
====> Validation set loss : 51.5970
====> Epoch : 172 Average loss : 205.3506
====> Validation set loss : 51.6094
====> Epoch : 173 Average loss : 205.3952
====> Validation set loss : 51.5280
====> Epoch : 174 Average loss : 205.4828
====> Validation set loss : 51.4914
====> Epoch : 175 Average loss : 205.3013
====> Validation set loss : 51.4275
====> Epoch : 176 Average loss : 205.3334
====> Validation set loss : 51.4124
====> Epoch : 177 Average loss : 205.2587
====> Validation set loss : 51.5940
====> Epoch : 178 Average loss : 205.3461
====> Validation set loss : 51.5103
====> Epoch : 179 Average loss : 205.2438
====> Validation set loss : 51.5028
====> Epoch : 180 Average loss : 205.2045
====> Validation set loss : 51.4648
====> Epoch : 181 Average loss : 205.2384
====> Validation set loss : 51.4714
====> Epoch : 182 Average loss : 205.1534
====> Validation set l

In [13]:
torch.save(vae.state_dict(), 'model/vae_mel_200.pt')

In [14]:
torch.save(vae, 'model/all_vae_mel_200.pt')

In [15]:
test_model = torch.load('model/all_vae_mel_200.pt')

In [19]:
test_model2 = VAE()

In [20]:
test_model2.load_state_dict(torch.load('model/vae_mel_200.pt'))

<All keys matched successfully>