In [None]:
"""
This code is based on Facebook's HDemucs code: https://github.com/facebookresearch/demucs
"""
import logging
import os
import sys
import torch
from pathlib import Path
import hydra
import wandb
import torchaudio
import IPython.display as ipd


cwd = Path().resolve()
prj_dir = os.path.dirname(os.path.dirname(os.path.abspath(cwd)))
print(f'prj_dir:{prj_dir}')
sys.path.append(prj_dir)

from src.data.datasets import LrHrSet
from src.ddp import distrib
from src.evaluate import evaluate
from src.models import modelFactory
from src.utils import bold
from src.wandb_logger import _init_wandb_run
from hydra import initialize, compose

logger = logging.getLogger(__name__)

SERIALIZE_KEY_MODELS = 'models'
SERIALIZE_KEY_BEST_STATES = 'best_states'
SERIALIZE_KEY_STATE = 'state'


initialize('../../conf') # Assume the configuration file is in the current folder
args = compose(config_name='main_config')

# load model

In [None]:
def _load_model(args):
    model_name = args.experiment.model
    checkpoint_file = Path(args.checkpoint_file)
    model = modelFactory.get_model(args)['generator']
    package = torch.load(checkpoint_file, 'cpu')
    load_best = args.continue_best
    if load_best:
        logger.info(bold(f'Loading model {model_name} from best state.'))
        model.load_state_dict(
            package[SERIALIZE_KEY_BEST_STATES][SERIALIZE_KEY_MODELS]['generator'][SERIALIZE_KEY_STATE])
    else:
        logger.info(bold(f'Loading model {model_name} from last state.'))
        model.load_state_dict(package[SERIALIZE_KEY_MODELS]['generator'][SERIALIZE_KEY_STATE])

    return model.cuda()

#16kHZ-->24kHZ
args.checkpoint_file="/mnt/cephfs/hjh/train_record/super_resolution/aero/train_outputs/16-24/aero-nfft=512-hl=256/checkpoint.th"
#4kHZ-->16kHZ
# args.checkpoint_file="/mnt/cephfs/hjh/train_record/super_resolution/aero/train_outputs/4-16/aero-nfft=512-hl=256/checkpoint.th"

model=_load_model(args)

print("lond model done!")

In [None]:
# x = torch.randn(size=(1, 1, 16000)).float().cuda()

#16kHZ
file="/mnt/cephfs/hjh/train_record/super_resolution/aero/dataset/vctk/rs_wav16k/p314/p314_412.wav"
file="/mnt/cephfs/hjh/common_dataset/tts/english/microsoft/wavs_16k/v0/en-US-AshleyNeural_1624587085207.wav"

#4kHZ
# file="/mnt/cephfs/hjh/train_record/super_resolution/aero/dataset/vctk/rs_wav4k/p347/p347_001.wav"
# file="/tmp/tts.wav"
# file="/mnt/cephfs/hjh/train_record/super_resolution/aero/dataset/vctk/rs_wav4k/p347/p347_001.wav"

x, sr = torchaudio.load(file) #音频sr=16000
print("原始音频:")
ipd.display(ipd.Audio(x, rate=sr))
x=x.unsqueeze(0).cuda()#[1,1,number]
y=model(x)

print("超分后音频:")
print(y.size())
ipd.display(ipd.Audio(y.squeeze().detach().cpu(), rate=16000))

# test inver spec

In [None]:
import numpy as np

import torch
from torch.nn import functional as F

from src.models.spec import spectro, ispectro

hop_length = 256
nfft = 512
scale = 1
win_length = 512


def spec(x):
    if np.mod(x.shape[-1], hop_length):
        x = F.pad(x, (0, hop_length - np.mod(x.shape[-1], hop_length)))
    hl = hop_length

    z = spectro(x, nfft, hl, win_length=win_length)[..., :-1, :]
    return z


def ispec(z):
    hl = int(hop_length * scale)
    z = F.pad(z, (0, 0, 0, 1))
    x = ispectro(z, hl, win_length=win_length)
    return x


def move_complex_to_channels_dim(z):
    B, C, Fr, T = z.shape
    m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
    m = m.reshape(B, C * 2, Fr, T)
    return m


def convert_to_complex(x):
    """

    :param x: signal of shape [Batch, Channels, 2, Freq, TimeFrames]
    :return: complex signal of shape [Batch, Channels, Freq, TimeFrames]
    """
    out = x.permute(0, 1, 3, 4, 2)
    out = torch.view_as_complex(out.contiguous())
    return out


file="/mnt/cephfs/hjh/train_record/super_resolution/aero/dataset/vctk/rs_wav16k/p314/p314_412.wav"
x, sr = torchaudio.load(file) #音频sr=16000
print("原始音频:")
ipd.display(ipd.Audio(x, rate=sr))
x=x.unsqueeze(0)

length = x.shape[-1]

#----------------
# 提取伪复数
#----------------
z = spec(x)
x = move_complex_to_channels_dim(z)
print("伪复数维度:",x.size())

#----------------
# 复数反转
#----------------
x = x.unsqueeze(1)
print("经过encoder-decoder后:",x.size())
x_spec_complex = convert_to_complex(x)
x = ispec(x_spec_complex)
x = x[..., :length].squeeze()

ipd.display(ipd.Audio(x, rate=sr))
