In [2]:
import logging
import math
import os

logging.basicConfig(level=logging.INFO)
logging.info("library loading")
logging.info("DEBUG")

import torch
import librosa

torch.set_grad_enabled(True)

import cached_conv as cc
import gin
import nn_tilde
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from typing import Union, Optional
from absl import flags, app

import sys, os
try:
    import raveish
except:
    import sys, os 
    sys.path.append(os.path.abspath('.'))
    import raveish

import raveish.core
import raveish.dataset
from raveish.transforms import get_augmentations, add_augmentation
import raveish.blocks
import raveish.resampler
import IPython.display as ipd
import pickle

from raveish.cached_a_weight import Cached_A_Weight as cached_a_weight
from raveish.pitch_enc import PitchEncoderV2

INFO:root:library loading
INFO:root:DEBUG


Using device: mps
  [INFO]: device is not None, use mps
  [INFO]    > call by:torchfcpe.tools.spawn_infer_cf_naive_mel_pe_from_pt
  [WARN] args.model.use_harmonic_emb is None; use default False
  [WARN]    > call by:torchfcpe.tools.spawn_cf_naive_mel_pe


In [3]:
cc.use_cached_conv(True)

run = "pretrained/causal"
ema_weights = False
prior_flag = False
channel_flag = None
sr_flag = 44100
fidelity_flag =.95

logging.info("building UNIFIED_TT")

gin.parse_config_file(os.path.join(run, "config.gin"))
checkpoint = raveish.core.search_for_run(run)
print("loading checkpoint:", checkpoint)

pretrained = raveish.UNIFIED_TT()
if run is not None:
    logging.info('model found : %s'%run)
    checkpoint = torch.load(checkpoint, map_location='cpu')
    if ema_weights and "EMA" in checkpoint["callbacks"]:
        pretrained.load_state_dict(
            checkpoint["callbacks"]["EMA"],
            strict=False,
        )
    else:
        pretrained.load_state_dict(
            checkpoint["state_dict"],
            strict=False,
        )
else:
    logging.error("No checkpoint found")
    exit()
    
pretrained.eval()

pitch_enc = PitchEncoderV2(data_size=6,
                           capacity=16,
                           ratios=[4,4,4,2],
                           latent_size=1440,
                           n_out=1,
                           kernel_size=3,
                           dilations=[[1, 3, 9], [1, 3, 9], [1, 3, 9], [1, 3]])

pitch_enc.load_state_dict(torch.load(f"raveish/utils/caus2048_mb6.pth", weights_only=True))
pretrained.pitch_encoder = pitch_enc

INFO:root:building UNIFIED_TT
  WeightNorm.apply(module, name, dim)
INFO:root:model found : pretrained/causal


loading checkpoint: pretrained/causal/latest.ckpt


In [4]:
print(pretrained.decoder.net[2].cumulative_delay)

1


In [5]:
for m in pretrained.modules():
    if hasattr(m, "weight_g"):
        nn.utils.remove_weight_norm(m)

t = torch.rand(1, 1, 131072)
out, _, _, _ = pretrained(t, ['Violin'])
print(out.shape)

torch.Size([1, 1, 131072])


  return _VF.stft(  # type: ignore[attr-defined]


In [6]:
with open('raveish/utils/loudness_stats.pkl', 'rb') as file:
    loudness_dict = pickle.load(file)

print(loudness_dict)

# Extract the mean and std values
means = [v['mean'] for v in loudness_dict.values()]
stds = [v['std'] for v in loudness_dict.values()]

# Compute the averages
global_mean = (sum(means) / len(means))
global_std = sum(stds) / len(stds)

print("Mean of means:", global_mean)
print("Mean of stds:", global_std)

{'Cello': {'mean': -1.0976669, 'std': 0.99742645, 'mean_f0': 150.07826, 'n_files': 11}, 'Violin': {'mean': -1.1541308, 'std': 0.95938534, 'mean_f0': 506.10043, 'n_files': 34}, 'Trombone': {'mean': -1.228859, 'std': 1.178475, 'mean_f0': 179.20522, 'n_files': 8}, 'Bassoon': {'mean': -1.0357323, 'std': 1.1757089, 'mean_f0': 200.41727, 'n_files': 3}, 'Clarinet': {'mean': -1.1593133, 'std': 1.1082362, 'mean_f0': 356.0085, 'n_files': 10}, 'Oboe': {'mean': -1.2284337, 'std': 0.859931, 'mean_f0': 533.29974, 'n_files': 6}, 'Trumpet': {'mean': -1.4467067, 'std': 1.2905153, 'mean_f0': 384.7168, 'n_files': 22}, 'Saxophone': {'mean': -1.29092, 'std': 1.3220173, 'mean_f0': 343.31354, 'n_files': 11}, 'Horn': {'mean': -1.367948, 'std': 1.3526137, 'mean_f0': 257.88876, 'n_files': 5}, 'Flute': {'mean': -1.1600183, 'std': 1.0587157, 'mean_f0': 631.1199, 'n_files': 18}}
Mean of means: -1.2169728994369506
Mean of stds: 1.1303024888038635


In [7]:
class Scripted(nn_tilde.Module):

    def __init__(self,
                 pretrained: raveish.UNIFIED_TT,
                 channels: Optional[int] = None,
                 fidelity: float = .95,
                 target_sr: bool = None) -> None:

        super().__init__()
        self.pqmf = pretrained.pqmf
        self.sr = pretrained.sr
        self.spectrogram = pretrained.spectrogram
        self.resampler = None
        self.input_mode = pretrained.input_mode
        self.output_mode = pretrained.output_mode
        self.n_channels = pretrained.n_channels
        self.target_channels = channels or self.n_channels
        self.stereo_mode = False

        if target_sr is not None:
            if target_sr != self.sr:
                assert not target_sr % self.sr, "Incompatible target sampling rate"
                self.resampler = raveish.resampler.Resampler(target_sr, self.sr)
                self.sr = target_sr

        self.full_latent_size = pretrained.latent_size

        self.register_attribute("learn_target", False)
        self.register_attribute("reset_target", False)
        self.register_attribute("learn_source", False)
        self.register_attribute("reset_source", False)

        self.register_buffer("latent_pca", pretrained.latent_pca)
        self.register_buffer("latent_mean", pretrained.latent_mean)
        self.register_buffer("fidelity", pretrained.fidelity)

        self.register_buffer("global_mean", torch.tensor(global_mean))
        self.register_buffer("global_std", torch.tensor(global_std))

        self.latent_size = 2

        # have to init cached conv before graphing
        self.decoder = pretrained.decoder
        self.amp_block = pretrained.amp_block
        self.pitch_encoder = pretrained.pitch_encoder
        
        x_len = 2**14
        x = torch.zeros(1, self.n_channels, x_len)
        
        # configure encoder
        if (pretrained.input_mode == "pqmf") or (pretrained.output_mode == "pqmf"):
            # scripting fails if cached conv is not initialized
            self.pqmf(torch.zeros(1, 1, x_len))

        encode_shape = (pretrained.n_channels, 2**14) 

        self.register_method(
            "forward",
            in_channels=1,
            in_ratio=1,
            out_channels=self.target_channels,
            out_ratio=1,
            input_labels=['(signal) Channel %d'%d for d in range(1, 1 + 1)],
            output_labels=['(signal) Channel %d'%d for d in range(1, self.target_channels+1)],
            test_method=False
        )

    def post_process_latent(self, z):
        raise NotImplementedError

    def pre_process_latent(self, z):
        raise NotImplementedError

    def update_adain(self):
        for m in self.modules():
            if isinstance(m, raveish.blocks.AdaptiveInstanceNormalization):
                m.learn_x.zero_()
                m.learn_y.zero_()

                if self.learn_target[0]:
                    m.learn_y.add_(1)
                if self.learn_source[0]:
                    m.learn_x.add_(1)

                if self.reset_target[0]:
                    m.reset_y()
                if self.reset_source[0]:
                    m.reset_x()

        self.reset_source = False,
        self.reset_target = False,

    @torch.jit.export
    def set_stereo_mode(self, stereo):
        self.stereo_mode = bool(stereo)


    @torch.jit.export
    def forward(self, x, emb_x, emb_y, p_mult, n_mult, loudness, loudness_linear):
        batch_size = x.shape[:-2]

        x_m = self.pqmf(x)
        x_m = x_m.reshape(batch_size + (-1, x_m.shape[-1]))
        
        pitch_logits = self.pitch_encoder(x_m[:, :6, :])
        pitch = torch.argmax(pitch_logits, dim=1)
        pitch = raveish.core.bins_to_frequency(pitch).unsqueeze(-1)
        periodicity = raveish.core.entropy(pitch_logits)

        pitch = pitch * p_mult

        emb = torch.zeros(1, 2)
        emb[:, 0] = emb_x
        emb[:, 1] = emb_y

        emb = emb.unsqueeze(-1)
        amplitudes = self.amp_block(emb.transpose(2, 1))
        emb = emb.repeat(1, 1, periodicity.shape[-1])

        y, _, _, = self.decoder(emb,
                                pitch, 
                                amplitudes,
                                loudness.transpose(2,1),
                                loudness_linear.unsqueeze(-1),
                                periodicity.unsqueeze(-1))

        batch_size = emb.shape[:-2]
        if self.pqmf is not None:
            y = y.reshape(y.shape[0] * self.n_channels, -1, y.shape[-1])
            y = self.pqmf.inverse(y)
            y = y.reshape(batch_size+(self.n_channels, -1))

        if self.resampler is not None:
            y = self.resampler.from_model_sampling_rate(y)
                
        return y

    @torch.jit.export
    def get_learn_target(self) -> bool:
        return self.learn_target[0]

    @torch.jit.export
    def set_learn_target(self, learn_target: bool) -> int:
        self.learn_target = (learn_target, )
        return 0

    @torch.jit.export
    def get_learn_source(self) -> bool:
        return self.learn_source[0]

    @torch.jit.export
    def set_learn_source(self, learn_source: bool) -> int:
        self.learn_source = (learn_source, )
        return 0

    @torch.jit.export
    def get_reset_target(self) -> bool:
        return self.reset_target[0]

    @torch.jit.export
    def set_reset_target(self, reset_target: bool) -> int:
        self.reset_target = (reset_target, )
        return 0

    @torch.jit.export
    def get_reset_source(self) -> bool:
        return self.reset_source[0]

    @torch.jit.export
    def set_reset_source(self, reset_source: bool) -> int:
        self.reset_source = (reset_source, )
        return 0

In [8]:
script_class = Scripted

In [9]:
prior_scripted=None

logging.info("script model")
scripted = script_class(
    pretrained=pretrained,
    channels = channel_flag,
    fidelity=fidelity_flag,
    target_sr=sr_flag)

INFO:root:script model
INFO:root:Registering method "forward"


In [10]:
x = torch.zeros(1, pretrained.n_channels, 2**14)
x = torch.zeros(1, 1, 2048)

emb_x = torch.zeros(1)
emb_y = torch.zeros(1)
mult = torch.ones(1)
l = torch.ones(1, 1, 1)
ln = torch.ones(1, 1)

y = scripted.forward(x, emb_x, emb_y, mult, mult, l, ln)
print(y.shape)

logging.info("save model")
output = os.path.dirname(run)
model_name = run.split(os.sep)[-1]

model_name += "_streaming.ts"

output = os.path.abspath(output)
if not os.path.isdir(output):
    os.makedirs(output)
scripted.export_to_ts(os.path.join(output, model_name))
try:
    if pretrained.n_channels <= 2:
        # test stereo mode for VST export
        scripted.set_stereo_mode(True)
        z_vst_input = torch.zeros(2, scripted.full_latent_size, z.shape[-1])
        out = scripted.decode(z_vst_input)
        assert out.shape[1] == 2, "model output is not stereo"
        logging.info(f"this model seems compatible with the RAVE vst.")
except Exception as e:
    logging.warning(f"this model will not work with the RAVE VST. \n Caught error : %s"%e)

logging.info(f"all good ! model exported to {os.path.join(output, model_name)}")

INFO:root:save model


torch.Size([1, 1, 2048])


 Caught error : name 'z' is not defined
INFO:root:all good ! model exported to /Users/andersbargum/Documents/unified-timbre-transfer/pretrained/causal_streaming.ts
