In [None]:
import torch
import pytorch_lightning as pl
from omegaconf import OmegaConf
import sys

sys.path.append(r"../../")
from circe.models.LightningClassifier import LightningClassifier

cfg_model = OmegaConf.load('../training/conf/model/hf-gpt.yaml')

ckpt_path = "../../../models-hfGPT-specvqgan/lightning_logs/version_28/checkpoints/epoch=94-step=2375.ckpt"
model = LightningClassifier(cfg=cfg_model)
model.configure_sharded_model()
model.load_state_dict(torch.load(ckpt_path)["state_dict"])
model.eval()

In [None]:
out_dir = 'out' # ignored if init_from is not 'resume'
num_samples = 10 # number of samples to draw
max_new_tokens = 500 # number of tokens generated in each sample
temperature = 1.0 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions
top_k = 40 # retain only the top_k most likely tokens, clamp others to have 0 probability
seed = 1337
device = "cuda"
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast

In [None]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

In [None]:
from contextlib import nullcontext

start_ids = []
start_ids = start_ids[:(len(start_ids) // 5) * 5]
x = (torch.tensor(start_ids, dtype=torch.long, device=next(model.parameters()).device)[None, ...])
# run generation
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=torch.float16)
with torch.no_grad():
    with ctx:
        for k in range(num_samples):
            y = model.generate(x, max_new_tokens, temperature=0.9, top_k=50)
            print(y[0].tolist())
            print('---------------')


## SpecVQGAN inference

In [None]:
import os
import sys
from pathlib import Path
import soundfile
import torch
import IPython
import matplotlib.pyplot as plt
from einops import rearrange
from importlib import reload

sys.path.append(r"../../")
from circe.specvqgan.feature_extraction.demo_utils import (calculate_codebook_bitrate,
                                           extract_melspectrogram,
                                           get_audio_file_bitrate,
                                           get_duration,
                                           load_neural_audio_codec)
from circe.specvqgan.sample_visualization import tensor_to_plt
from torch.utils.data.dataloader import default_collate

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = "cpu"

In [None]:
model_name = '2021-05-19T22-16-54_vggsound_codebook'
log_dir = '../../../Circe/vggsound'
# loading the models might take a few minutes
config, model, vocoder = load_neural_audio_codec(model_name, log_dir, device)

In [None]:
info = []
info = torch.tensor(info).to(device)
info = rearrange(info, "(t f) -> (f t) 1", f=5)
orig_audio = start_ids
orig_audio = torch.tensor(orig_audio).to(device)
orig_audio = rearrange(orig_audio, "(t f) -> (f t) 1", f=5)
with torch.no_grad():
    xrec = model.decode_code(info.squeeze(), shape=(info.shape[1], 5, -1, 256))
    x = model.decode_code(orig_audio.squeeze(), shape=(orig_audio.shape[1], 5, -1, 256))

print('Compressed representation (it is all you need to recover the audio):')
print(info.reshape(5, -1).shape)

In [None]:
# Save and Display
x = x.squeeze(0)
xrec = xrec.squeeze(0)
# specs are in [-1, 1], making them in [0, 1]
wav_x = vocoder((x + 1) / 2).squeeze().detach().cpu().numpy()
wav_xrec = vocoder((xrec + 1) / 2).squeeze().detach().cpu().numpy()
print(wav_xrec.shape)
# Creating a temp folder which will hold the results
tmp_dir = '/tmp/neural_audio_codec'
os.makedirs(tmp_dir, exist_ok=True)
# Save paths
x_save_path = Path(tmp_dir) / f'specvqgan_cond.wav'
xrec_save_path = Path(tmp_dir) / f'specvqgan_generated.wav'
# Save
model_sr = config.data.params.sample_rate
soundfile.write(x_save_path, wav_x, model_sr, 'PCM_16')
soundfile.write(xrec_save_path, wav_xrec, model_sr, 'PCM_16')
# Display
print(f'Conditioning audio generated:')
IPython.display.display(IPython.display.Audio(str(x_save_path)))
plt.close()
print('Conditioning Spectrogram:')
IPython.display.display(tensor_to_plt(x, flip_dims=(2,)))
plt.close()
print(f'Reconstructed audio generated:')
IPython.display.display(IPython.display.Audio(str(xrec_save_path)))
plt.close()
print('Reconstructed Spectrogram:')
IPython.display.display(tensor_to_plt(xrec, flip_dims=(2,)))
plt.close()