In [1]:
import argparse
import os
import numpy as np
import math
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import Audio

from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
import torchaudio
import torchaudio.transforms as T

from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

import speechbrain as sb
from speechbrain.pretrained import EncoderDecoderASR

from models import ResnetGenerator, Discriminator, UnetGenerator, ContextEncoder
from models import weights_init_normal
from datasets import SpecDataset
from utils import normalize_spec, denormalize_spec, db_to_power, power_to_db
from utils import plot_waveform, plot_spectrogram, save_spectrogram

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
class OPT:
    cuda_id = 0
    my_id = 99
    generator_type = 'resnet'
    pretrain_dir = './noise_enhance/ckpt_cgan_resnet'
    use_context_embed = False
    n_steps = 10
    batch_size = 32
    lr = 0.0001
    b1 = 0.5
    b2 = 0.999
    n_cpu = 1
    channels = 1
    img_height = 128
    img_width = 128
    sample_interval = 1000
    verbose_interval = 100
    num_residual_blocks = 6
    gamma = 0.1
opt = OPT()

dir = './noise_enhance'
sample_dir = os.path.join(dir, f'cgan_samples_{opt.generator_type}')
ckpt_dir = os.path.join(dir, f'cgan_ckpt_{opt.generator_type}')
if opt.my_id is not None:
    sample_dir = f'{sample_dir}_{opt.my_id}'
    ckpt_dir = f'{ckpt_dir}_{opt.my_id}'
clean_train_dir = os.path.join(dir, 'data/clean_trainset')
clean_test_dir = os.path.join(dir, 'data/clean_testset')
noisy_train_dir = os.path.join(dir, 'data/noisy_trainset')
noisy_test_dir = os.path.join(dir, 'data/noisy_testset')
train_sample = os.path.join(noisy_train_dir, os.listdir(noisy_train_dir)[0])
test_sample = os.path.join(noisy_test_dir, os.listdir(noisy_test_dir)[0])

# os.makedirs(sample_dir, exist_ok=True)
# os.makedirs(ckpt_dir, exist_ok=True)
print(f'{dir}\t{sample_dir}\t{ckpt_dir}')
print(f'{clean_train_dir}\t{clean_test_dir}\t{noisy_train_dir}\t{noisy_test_dir}')
print(f'{train_sample}\t{test_sample}')

img_shape = (opt.channels, opt.img_height, opt.img_width)

cuda = True if torch.cuda.is_available() else False
device = f'cuda:{opt.cuda_id}' if torch.cuda.is_available() else 'cpu'

# Loss functions
adversarial_loss = torch.nn.MSELoss().to(device)

./noise_enhance	./noise_enhance/cgan_samples_resnet_99	./noise_enhance/cgan_ckpt_resnet_99
./noise_enhance/data/clean_trainset	./noise_enhance/data/clean_testset	./noise_enhance/data/noisy_trainset	./noise_enhance/data/noisy_testset
./noise_enhance/data/noisy_trainset/p244_343.wav	./noise_enhance/data/noisy_testset/p254_213.wav


In [10]:
# Initialize generator and discriminator
if opt.use_context_embed:
    encoder = ContextEncoder(device).to(device)
    context_loss = nn.L1Loss().to(device)
if opt.generator_type == 'unet':
    generator = UnetGenerator(input_nc=1, output_nc=1, num_downs=7, use_dropout=True).to(device)
else:
    generator = ResnetGenerator(
        input_nc=13 if opt.use_context_embed else 1,
        output_nc=1,
        ngf=64,
        norm_layer=nn.BatchNorm2d,
        use_dropout=False,
        n_blocks=opt.num_residual_blocks,
        padding_type='reflect'
    ).to(device)
discriminator = Discriminator(img_shape).to(device)

generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)
print('init weight')
if opt.pretrain_dir is not  None:
    # generator.load_state_dict(torch.load(os.path.join(opt.pretrain_dir, 'G_model.pt')))
    # discriminator.load_state_dict(torch.load(os.path.join(opt.pretrain_dir, 'D_model.pt')))
    print('Loaded model WEIGHTS')

# Configure data loader
dataset = SpecDataset(noisy_dir=noisy_train_dir, gt_dir=clean_train_dir, device=device)
dataloader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
milestones = [int(opt.n_steps * 0.2), int(opt.n_steps * 0.4), int(opt.n_steps * 0.6), int(opt.n_steps * 0.8)]
scheduler_G = torch.optim.lr_scheduler.MultiStepLR(optimizer_G, milestones=milestones, gamma=opt.gamma)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

init weight
Loaded model WEIGHTS




In [None]:
def sample_image(steps_done, gt=False):
    with torch.no_grad():
        for s, sample in zip(['test', 'train'], [test_sample, train_sample]):
            input_img, input_wav = dataset.getitem_helper(sample)
            if gt:
                save_spectrogram(input_img[0].detach().cpu().numpy(), os.path.join(sample_dir, os.path.basename(sample).replace('.wav', '.png')))
            else:
                gen_imgs = generator(input_img.unsqueeze(0).to(device))
                gen_imgs = gen_imgs.squeeze(0).data.detach().cpu().numpy()[0] # get 2D
                gen_imgs, _ = normalize_spec(gen_imgs)
                save_spectrogram(gen_imgs, os.path.join(sample_dir, f"{s}_step%d.png" % steps_done))

In [30]:
# from utils import plot_spectrogram, power_to_db
# max_len = 64000 + 64*40
# x = dataset[2]
# spec, w = x['noisy']
# print((203*1024) // (128*512))
# w = w[:max_len]
# print(spec.shape, w.shape)
# emb = encoder(torch.stack([w, w, w], dim=0))
# # o = generator(spec.unsqueeze(0))
# emb.shape

3
torch.Size([1, 128, 128]) torch.Size([65280])


torch.Size([3, 203, 1024])

In [28]:
n_epochs = (opt.n_steps // len(dataloader)) + 1
curr_step = 0
isContinue = True
G_loss = 0.0
D_loss = 0.0
for epoch in range(n_epochs):
    for i, batch in enumerate(dataloader):
        break
    break

In [8]:
# count step
curr_step += 1
isContinue = curr_step < opt.n_steps
# if not isContinue: break
#  train model
(input_img, input_wav), (gt_img, gt_wav) = batch['noisy'], batch['gt']
batch_size = input_img.shape[0]

# Adversarial ground truths
valid = Variable(Tensor(np.ones((batch_size, *discriminator.output_shape))), requires_grad=False).to(device)
fake = Variable(Tensor(np.zeros((batch_size, *discriminator.output_shape))), requires_grad=False).to(device)
            
# Configure input
input_img = Variable(input_img.type(FloatTensor))
gt_img = Variable(gt_img.type(FloatTensor))

In [8]:
# from torchaudioTrans import inverse_spectrogram, InverseSpectrogram, InverseMelScale

# my_inv = InverseMelScale(
#     device=device,
#     sample_rate=dataset.sample_rate,
#     n_stft=dataset.n_stft,
#     n_mels=dataset.n_mels,
#     mel_scale='htk',
#     max_iter=1000)

In [9]:
with torch.no_grad():
    emb = encoder(input_wav) # remove 1 extra dim
    print(emb.shape)
    batch_size, h, w = emb.shape
    xtimes = int((h*w) // (128**2))
    emb = emb.view(batch_size, -1)[..., :xtimes*(128**2)]
    emb = emb.view(emb.shape[0], -1, 128, 128)
    inp = torch.cat([input_img, emb], 1)
    print(inp.shape)
    gen_imgs = generator(inp)
    print(gen_imgs.shape, input_img.shape, emb.shape, input_wav.shape)

torch.Size([32, 203, 1024])
torch.Size([32, 13, 128, 128])
torch.Size([32, 13, 128, 128]) torch.Size([32, 1, 128, 128]) torch.Size([32, 12, 128, 128]) torch.Size([32, 65280])


In [15]:
x = torch.rand(8, 1, 32, 32)
y = torch.rand(8, 1, 32, 32)
torch.cat([x, y], 1).shape

torch.Size([8, 2, 32, 32])

In [8]:
# -----------------
#  Train Generator
# -----------------
optimizer_G.zero_grad()

# Generate a batch of images
gen_imgs = generator(input_img)
print('gen_imgs.shape', gen_imgs.shape)
print('gt_wav.shape', gt_wav.shape)
if opt.use_context_embed:
    # reconstruct wav
    # gen_wavs = dataset.spec_to_audio(gen_imgs)
    spec = gen_imgs
    spec[0, 0, 0, 0] = -1.
    spec = torch.clamp(spec, min=0.0, max=1.0)
    spec = denormalize_spec(spec, dataset.spec_min, dataset.spec_max)
    spec = db_to_power(spec)
    print(spec.device)
    # spec_ = inv_transform(spec)
    # gen_wavs = grifflim_transform(spec_)

    # lim_len = min(gen_wavs.shape[2], gt_wav.shape[2])
    # gen_wavs_embed = encoder(gen_wavs.squeeze(1)[..., :lim_len])
    # gt_wavs_embed = encoder(gt_wav.squeeze(1)[..., :lim_len])
    # print(gen_wavs_embed.shape, gt_wavs_embed.shape)
    # ct_loss = context_loss(gen_wavs_embed.view(-1), gt_wavs_embed.view(-1))

    # # Loss measures generator's ability to fool the discriminator
    # validity = discriminator(gen_imgs)
    # g_loss = adversarial_loss(validity, valid)

gen_imgs.shape torch.Size([32, 1, 128, 128])
gt_wav.shape torch.Size([32, 1, 65280])
cuda:0


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_mm)

In [63]:
ct_loss, g_loss

(tensor(0.1469, device='cuda:0'),
 tensor(2.7578, device='cuda:0', grad_fn=<MseLossBackward0>))

In [None]:
print(input_img.max(), input_img.min(), gen_imgs.max(), gen_imgs.min())
plot_spectrogram(input_img[0][0].detach().clone().cpu())
plot_spectrogram(gen_imgs[0][0].detach().clone().cpu())

## DRAFT FOR CODE TESTING

In [23]:
x, x1 = dataset[0], dataset[1]
(spec_n0, wav_n0), (spec_c0, wav_c0) = x['noisy'], x['gt']
(spec_n1, wav_n1), (spec_c1, wav_c1) = x1['noisy'], x1['gt']

wav_n0_ = dataset.spec_to_audio(torch.cat([spec_n0, spec_n1]))
print(wav_n0_.shape, wav_n0_.max(), wav_n0_.min())
# plot_waveform(wav_n0.cpu(), sr=dataset.sample_rate)
# plot_waveform(wav_n0_.cpu(), sr=dataset.sample_rate)
# Audio(wav_n0_.cpu(), rate=dataset.sample_rate, normalize=True)

torch.Size([2, 65024]) tensor(0.0132, device='cuda:0') tensor(-0.0112, device='cuda:0')


In [17]:
dataset.is_normal_spec = False
min_val, max_val = dataset.get_minmax_spec(True)
dataset.is_normal_spec = True
min_val, max_val

get_minmax_spectrogram:   0%|          | 0/11000 [00:00<?, ?it/s]

get_minmax_spectrogram: 100%|██████████| 11000/11000 [00:11<00:00, 995.01it/s] 


(tensor(-96.6628, device='cuda:0'), tensor(6.0352, device='cuda:0'))

In [None]:
# x_norm = dataset.normalize(x).detach().clone().cpu().numpy()[0]
x1 = dataset[4]
print(x1['noisy'][0].max(), x1['noisy'][0].min())
plot_spectrogram(x1['noisy'][0].detach().clone().cpu().numpy()[0])
plot_spectrogram(x1['gt'][0].detach().clone().cpu().numpy()[0])

In [19]:
for i, batch in enumerate(dataloader):
    break
input_img, gt_img = batch['noisy'][0], batch['gt'][0]
batch_size = input_img.shape[0]
batch_size, input_img.shape, gt_img.shape

(8, torch.Size([8, 1, 128, 128]), torch.Size([8, 1, 128, 128]))

In [16]:
valid = Variable(Tensor(np.ones((batch_size, *discriminator.output_shape))), requires_grad=False)
fake = Variable(Tensor(np.zeros((batch_size, *discriminator.output_shape))), requires_grad=False)
print('valid.shape', valid.shape)
print('fake.shape', fake.shape)

valid.shape torch.Size([8, 1, 8, 8])
fake.shape torch.Size([8, 1, 8, 8])


In [17]:
# Configure input
input_img = Variable(input_img.type(FloatTensor))
gt_img = Variable(gt_img.type(FloatTensor))
print('input_img.shape', input_img.shape)
print('gt_img.shape', gt_img.shape)

input_img.shape torch.Size([8, 1, 128, 128])
gt_img.shape torch.Size([8, 1, 128, 128])


In [24]:
optimizer_G.zero_grad()
# Generate a batch of images
# gen_imgs = generator(input_img)
x = torch.rand(1, 1, 96, 128).to(device)
gen_imgs = generator(x)
print('gen_imgs.shape', gen_imgs.shape)

gen_imgs.shape torch.Size([1, 1, 96, 128])


In [19]:
# Loss measures generator's ability to fool the discriminator
validity = discriminator(gen_imgs)
print('validity.shape', validity.shape)

validity.shape torch.Size([1, 1, 8, 8])


In [20]:
g_loss = adversarial_loss(validity, valid)
g_loss.backward()
optimizer_G.step()

  return F.mse_loss(input, target, reduction=self.reduction)


In [21]:
# ---------------------
#  Train Discriminator
# ---------------------

optimizer_D.zero_grad()

# Loss for real images
validity_real = discriminator(gt_img)
d_real_loss = adversarial_loss(validity_real, valid)
print('validity_real.shape', validity_real.shape)

validity_real.shape torch.Size([8, 1, 8, 8])


In [None]:
# Loss for fake images
validity_fake = discriminator(gen_imgs.detach())
d_fake_loss = adversarial_loss(validity_fake, fake)

In [None]:
# Total discriminator loss
d_loss = (d_real_loss + d_fake_loss) / 2

d_loss.backward()
optimizer_D.step()

In [None]:
def sample_image(steps_done):
    # Sample from test set
    input_img = dataset.getitem_helper(test_sample)
    gen_imgs = generator(input_img.unsqueeze(0).to(device))
    # save_image(gen_imgs.squeeze(0).data, f"./{sample_dir}/test_step%d.png" % steps_done, normalize=True)
    # Sample from train set
    input_img = dataset.getitem_helper(train_sample)
    gen_imgs = generator(input_img.unsqueeze(0).to(device))
    # save_image(gen_imgs.squeeze(0).data, f"./{sample_dir}/train_step%d.png" % steps_done, normalize=True)
sample_image(0)

In [None]:
dataloader_mnist = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(128), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)
for batch in dataloader_mnist:
    break

In [None]:
cuda