In [1]:
"""
This code is based on Facebook's HDemucs code: https://github.com/facebookresearch/demucs
"""
import logging
import os
import sys
import torch
from pathlib import Path
import hydra
import wandb
import torchaudio
import IPython.display as ipd


cwd = Path().resolve()
prj_dir = os.path.dirname(os.path.dirname(os.path.abspath(cwd)))
print(f'prj_dir:{prj_dir}')
sys.path.append(prj_dir)

from src.data.datasets import LrHrSet
from src.ddp import distrib
from src.evaluate import evaluate
from src.models import modelFactory
from src.utils import bold
from src.wandb_logger import _init_wandb_run
from hydra import initialize, compose

logger = logging.getLogger(__name__)

SERIALIZE_KEY_MODELS = 'models'
SERIALIZE_KEY_BEST_STATES = 'best_states'
SERIALIZE_KEY_STATE = 'state'


initialize('../../conf') # Assume the configuration file is in the current folder
args = compose(config_name='main_config')

prj_dir:/data1/hjh/pycharm_projects/othergithub_prjs/super_resolution/aero




# load model

In [2]:
def _load_model(args):
    model_name = args.experiment.model
    checkpoint_file = Path(args.checkpoint_file)
    model = modelFactory.get_model(args)['generator']
    package = torch.load(checkpoint_file, 'cpu')
    load_best = args.continue_best
    if load_best:
        logger.info(bold(f'Loading model {model_name} from best state.'))
        model.load_state_dict(
            package[SERIALIZE_KEY_BEST_STATES][SERIALIZE_KEY_MODELS]['generator'][SERIALIZE_KEY_STATE])
    else:
        logger.info(bold(f'Loading model {model_name} from last state.'))
        model.load_state_dict(package[SERIALIZE_KEY_MODELS]['generator'][SERIALIZE_KEY_STATE])

    return model.cuda()

args.checkpoint_file="/mnt/cephfs/hjh/train_record/super_resolution/aero/train_outputs/16-24/aero-nfft=512-hl=256/checkpoint.th"

model=_load_model(args)

print("lond model done!")

lond model done!


In [3]:
# x = torch.randn(size=(1, 1, 16000)).float().cuda()
file="/mnt/cephfs/hjh/train_record/super_resolution/aero/dataset/vctk/rs_wav16k/p314/p314_412.wav"
x, sr = torchaudio.load(file) #音频sr=16000
print("原始音频:")
ipd.display(ipd.Audio(x, rate=sr))
x=x.unsqueeze(0).cuda()#[1,1,number]
y=model(x)

print("超分后音频:")
print(y.size())
ipd.display(ipd.Audio(y.squeeze().detach().cpu(), rate=24000))

原始音频:


------encoder_x:torch.Size([1, 2, 256, 267])
------decoder_x:torch.Size([1, 1, 2, 256, 267])
超分后音频:
torch.Size([1, 1, 67666])


# test inver spec

In [10]:
import numpy as np

import torch
from torch.nn import functional as F

from src.models.spec import spectro, ispectro

hop_length = 256
nfft = 512
scale = 1
win_length = 512


def spec(x):
    if np.mod(x.shape[-1], hop_length):
        x = F.pad(x, (0, hop_length - np.mod(x.shape[-1], hop_length)))
    hl = hop_length

    z = spectro(x, nfft, hl, win_length=win_length)[..., :-1, :]
    return z


def ispec(z):
    hl = int(hop_length * scale)
    z = F.pad(z, (0, 0, 0, 1))
    x = ispectro(z, hl, win_length=win_length)
    return x


def move_complex_to_channels_dim(z):
    B, C, Fr, T = z.shape
    m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
    m = m.reshape(B, C * 2, Fr, T)
    return m


def convert_to_complex(x):
    """

    :param x: signal of shape [Batch, Channels, 2, Freq, TimeFrames]
    :return: complex signal of shape [Batch, Channels, Freq, TimeFrames]
    """
    out = x.permute(0, 1, 3, 4, 2)
    out = torch.view_as_complex(out.contiguous())
    return out


file="/mnt/cephfs/hjh/train_record/super_resolution/aero/dataset/vctk/rs_wav16k/p314/p314_412.wav"
x, sr = torchaudio.load(file) #音频sr=16000
print("原始音频:")
ipd.display(ipd.Audio(x, rate=sr))
x=x.unsqueeze(0)

length = x.shape[-1]

#----------------
# 提取伪复数
#----------------
z = spec(x)
x = move_complex_to_channels_dim(z)
print("伪复数维度:",x.size())

#----------------
# 复数反转
#----------------
x = x.unsqueeze(1)
print("经过encoder-decoder后:",x.size())
x_spec_complex = convert_to_complex(x)
x = ispec(x_spec_complex)
x = x[..., :length].squeeze()

ipd.display(ipd.Audio(x, rate=sr))


原始音频:


tensor([[[[ 5.0072e-02+0.0000e+00j,  1.3812e-02+0.0000e+00j,
            1.8839e-03+0.0000e+00j,  ...,
            6.5798e-03+0.0000e+00j, -3.7140e-02+0.0000e+00j,
           -8.4395e-04+0.0000e+00j],
          [-2.2688e-02-2.6537e-09j, -2.4543e-03-1.5308e-02j,
           -2.4720e-03+6.8833e-03j,  ...,
           -7.8291e-03-2.3686e-02j,  2.4458e-02+1.3579e-02j,
           -7.3232e-04-2.8316e-05j],
          [-1.0894e-02+6.5002e-10j, -6.8549e-03+1.8804e-03j,
            2.0978e-03-6.1653e-03j,  ...,
            7.6426e-03+7.9455e-03j, -7.4757e-03-6.7603e-03j,
           -4.3279e-04-4.4258e-05j],
          ...,
          [ 8.7262e-05-1.3069e-10j, -6.1203e-06+4.8922e-06j,
            4.4965e-06-5.7286e-06j,  ...,
           -4.3473e-06+3.1215e-07j, -4.9355e-05+8.6515e-05j,
            1.2060e-05+2.9224e-07j],
          [-8.2787e-05+1.8666e-10j, -2.8908e-06-6.9438e-06j,
           -1.5537e-06-4.9435e-07j,  ...,
            1.3389e-06+1.6072e-06j, -1.1492e-05-9.6468e-05j,
           -4.609