In [1]:
from datasets import get_dataset
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import time
from configs.get_configs import get_config


class DotDict(dict):
    def __getattr__(self, key):
        return self[key]
    def __setattr__(self, key, val):
        if key in self.__dict__:
            self.__dict__[key] = val
        else:
            self[key] = val

In [2]:
import torch
from evaluation import evaluate
torch.cuda.device_count()

1

In [3]:
import torch
import gc
torch.cuda.empty_cache()
gc.collect()

0

In [4]:
%%capture
args = DotDict()
args.conf = "128_deep"
args.test = False
args.DDP = False
args.workdir = "pfgm_128_deep_v2"
args.eval_folder = "eval"
args.sampling = True
config = get_config(args)
args.config = config
#args.config.sampling.ode_solver = 'torchdiffeq'
args.config.sampling.ode_solver = 'rk45' # rk45, improved_euler
args.config.sampling.ckpt_number = 500000
#args.config.sampling.N = 100
args.config.sampling.z_max = 200
args.config.eval.num_samples = 32
args.config.eval.batch_size = 32

## Libs for Sampling

In [5]:
import datasets
from models import utils as mutils
from models.ema import ExponentialMovingAverage
from models.ema import ExponentialMovingAverage
import losses

from configs.default_audio_configs import get_mels_64, get_mels_128

from evaluation import sampling
from evaluation.utils.mel_to_wav import convert
import datasets
import methods
from utils.checkpoint import save_checkpoint, restore_checkpoint
import os, logging

In [6]:
config = args.config
workdir = args.workdir
eval_folder = args.eval_folder
eval_dir = os.path.join(workdir, eval_folder)
os.makedirs(eval_dir, exist_ok=True)

# setup logger
gfile_stream = open(os.path.join(args.workdir, 'stdout_eval.txt'), 'w')
handler = logging.StreamHandler(gfile_stream)
formatter = logging.Formatter('%(levelname)s - %(filename)s - %(asctime)s - %(message)s')
handler.setFormatter(formatter)
logger = logging.getLogger()
logger.addHandler(handler)
logger.setLevel('INFO')


# Create data normalizer and its inverse
scaler = datasets.get_data_scaler(config)
inverse_scaler = datasets.get_data_inverse_scaler(config)

# Initialize model
net = mutils.create_model(args)
print("Created Model")
optimizer, scheduler = losses.get_optimizer(config, net.parameters())
ema = ExponentialMovingAverage(net.parameters(), decay=config.model.ema_rate)
state = dict(optimizer=optimizer, model=net, ema=ema, scheduler=scheduler, step=0)

checkpoint_dir = os.path.join(workdir, "checkpoints")

torch.cuda.empty_cache()
gc.collect()

# Setup methods
if config.training.sde.lower() == 'poisson':
    sde = methods.Poisson(args=args)
    sampling_eps = config.sampling.z_min
    print("--- sampling eps:", sampling_eps)
else:
    raise NotImplementedError(f"Method {config.training.sde} unknown.")


# Wait if the target checkpoint doesn't exist yet
torch.manual_seed(config.seed)
np.random.seed(config.seed)

if torch.cuda.is_available():
    torch.cuda.manual_seed_all(config.seed)

if config.training.sde == 'poisson':
    if config.sampling.ckpt_number > 0:
        ckpt_filename = os.path.join(checkpoint_dir, "checkpoint_{}.pth".format(config.sampling.ckpt_number))
        ckpt_path = os.path.join(checkpoint_dir, f'checkpoint_{config.sampling.ckpt_number}.pth')
    else:
        raise ValueError("Please provide a ckpt_number!")

if not os.path.exists(ckpt_filename):
    print(f"{ckpt_filename} does not exist! Loading from meta-checkpoint")
    ckpt_filename = os.path.join(checkpoint_dir, os.pardir, 'checkpoints-meta', 'checkpoint.pth')
    if not os.path.exists(ckpt_filename):
        print("No checkpoints-meta")

# Wait for 2 additional mins in case the file exists but is not ready for reading
print("Loading from ", ckpt_path)
try:
    state = restore_checkpoint(ckpt_path, state, map_location=config.device)
    print("State Loaded")
except Exception as e:
    print("Loading Failed!")
    print(e)
    time.sleep(60)
    try:
        state = restore_checkpoint(ckpt_path, state, map_location=config.device)
    except Exception as e:
        time.sleep(120)
        state = restore_checkpoint(ckpt_path, state, map_location=config.device)

        
ckpt = config.sampling.ckpt_number
ema.copy_to(net.parameters())

# Build the sampling function when sampling is enabled
if config.eval.enable_sampling:
    sampling_shape = (config.eval.batch_size,
                      config.data.num_channels,
                      config.data.image_height, config.data.image_width)
    sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps, net)
    
print("DONE!")

BUILDING MODEL...
MODEL BUILT!
Created Model
--- sampling eps: 0.001
Loading from  pfgm_128_deep_v2/checkpoints/checkpoint_500000.pth
State Loaded
DONE!


In [7]:
import gc
torch.cuda.empty_cache()
gc.collect()

12

In [8]:
num_sampling_rounds = config.eval.num_samples // config.eval.batch_size + 1
# Directory to save samples. Different for each host to avoid writing conflicts
this_sample_dir = os.path.join(eval_dir, f"ckpt_{ckpt}")
audio_dir = os.path.join(this_sample_dir,f"audio_{ckpt}")
os.makedirs(this_sample_dir, exist_ok=True)
os.makedirs(audio_dir, exist_ok=True)

torch.backends.cudnn.benchmark=True
net.eval()

print(f"Sampling for {num_sampling_rounds} rounds...")
start = time.time()
total_samples = 0
for r in range(num_sampling_rounds):
    samples, n = sampling_fn(net)
    total_samples += samples.shape[0]
    print(f"Round {r} nfe={n}")
    break
stop = time.time()

Sampling for 2 rounds...
 lsoda--  rwork length needed, lenrw (=i1), exceeds lrw (=i2)
      in above message,  i1 =  16777236   i2 =   9437206
Round 0 nfe=0




In [14]:
total = stop - start
per_sample = total / total_samples
print("Took: ", total, " seconds for ", total_samples, "\nAverage of ", per_sample)

Took:  140.98982501029968  seconds for  32 
Average of  4.405932031571865


In [15]:
from librosa.feature.inverse import db_to_power
from librosa.feature.inverse import mel_to_audio

In [16]:
from configs.default_audio_configs import get_mels_128

In [17]:
spec_conf = get_mels_128()
sample_rate = spec_conf.sample_rate
hop_length = spec_conf.hop_length
nfft = spec_conf.nfft

In [18]:
mel_dat = samples
mel_data = mel_dat.squeeze().cpu().numpy()


mel_data /= mel_data.max()
mel_data *= 80
mel_data -= 80

mel_data = db_to_power(mel_data)
audio = mel_to_audio(
    M=mel_data,
    sr=sample_rate,
    n_fft=nfft,
    hop_length=hop_length,
    win_length=hop_length * 4,
    center=True,
    power=1,
    n_iter=32,
    fmin=20,
    fmax=sample_rate / 2.0,
    pad_mode="reflect",
    norm='slaney',
    htk=True
)
audio /= max(audio.max(), -audio.min())

  return f(*args, **kwargs)


In [20]:
import IPython.display as ipd
ipd.Audio(audio[0], rate=sample_rate) # load a local WAV file

## Full evaluation pipeline

In [None]:
sr = 16000
mel_args = {
      'sample_rate': sr,
      'win_length': 256 * 4,
      'hop_length': 256,
      'n_fft': 1024,
      'f_min': 20.0,
      'f_max': sr / 2.0,
      'n_mels': 80,
      'power': 1.0,
      'normalized': True,
  }
spectrogram = torchaudio.transforms.MelSpectrogram(**mel_args)
audio, sr = torchaudio.load('SpeechCommands/speech_commands_v0.02/one/0a2b400e_nohash_0.wav')
mel = spectrogram(audio)

In [None]:
start_diff = time.time()
for _ in range(32):
    audio, sample_rate = diffwave_predict(spectrogram, model_dir, fast_sampling=True)
stop_diff = time.time()

In [None]:
print(stop_diff-start_diff)

In [None]:
import IPython.display as ipd
ipd.Audio(audio[0].cpu().numpy(), rate=sample_rate) # load a local WAV file

In [None]:
start = time.time()
evaluate(args)
stop = time.time()

In [None]:
print(stop-start)

In [None]:
for i, item in enumerate(trainds):
    if i>1: break
    print(item.max(), item.min())

In [None]:
for i, (item, path) in enumerate(trainds):
    if i>25: break
    spec = item[0][0].numpy()
    print(spec.min(), spec.max(), spec.mean())
    print(path[0])
    plt.figure()
    plt.imshow(spec)
    plt.show()
    plt.figure()
    plt.imshow(spec_demean)
    plt.show()

In [None]:
# Calculate norms on CIFAR10 for comparison
import torchvision
cf10 = torchvision.datasets.CIFAR10(root='.', download=True)

norms_2_cf = []
norms_cf = []
for item in tqdm(cf10):
    pic = np.array(item[0]) / 255.0
    norm_2 = np.sqrt(pic.ravel()**2)
    norms_2_cf.append(norm_2)
    norm = pic.ravel()**2
    norms_cf.append(norm)

In [None]:
norms_cf = np.array(norms_cf)
norms_2_cf = np.array(norms_2_cf)

print(norms_cf.sum())

In [None]:
norms_2 = []
norms = []
for batch in tqdm(trainds):
    for item in batch:
        #scale to 0-1
        item += 1
        item /= 2.0
        norm = item.numpy().ravel()**2
        norms.append(norm)
        norm_2 = np.sqrt(item.numpy().ravel()**2)
        norms_2.append(norm_2)

In [None]:
gc.collect()

In [None]:
norms = np.array(norms)

In [None]:
cfm = norms_cf.mean()
nm = norms.mean()

In [None]:
nm/cfm * 32

In [None]:
norms.sum()

In [None]:
156607710 / 44138412

In [None]:
print(np.array(norms).mean(), np.array(norms_2).mean())

In [None]:
0.007212365 / .2873594542857112 * 32

In [None]:
.032030698 / .4733630004850902