In [10]:
import os

import torch


from typing import Optional

from model.audio.vocoder.vocoder import Vocoder
from utils.audio_utils import SharedWindowBuffer


def load_model(
    model_cls,
    config_name: str,
    checkpoint_path: Optional[str] = None,
    device: str = "cuda",
    overrides: dict = {},
    strict: bool = False
):
    """
    Load a model from a checkpoint.

    Args:
        checkpoint_path: Path to checkpoint directory (containing model.safetensors or pytorch_model.bin)
        config_name: Config name from the model class (e.g., "small", "medium")
        device: Device to load the model on

    Returns:
        model in eval mode
    """
    # Create model with same config
    model = model_cls.from_config(config_name, **overrides)
    model = model.to(device)

    if checkpoint_path is None:
        return model

    # Try to load from safetensors first, then pytorch_model.bin
    safetensors_path = os.path.join(checkpoint_path, "model.safetensors")
    pytorch_path = os.path.join(checkpoint_path, "pytorch_model.bin")
    if os.path.exists(safetensors_path):
        from safetensors.torch import load_file
        state_dict = load_file(safetensors_path)
        model.load_state_dict(state_dict, strict=strict)
        print(f"Loaded model from {safetensors_path}")
    elif os.path.exists(pytorch_path):
        print(f"loading pytorch bin")
        state_dict = torch.load(pytorch_path, map_location=device, weights_only=True)
        model.load_state_dict(state_dict, strict=strict)
        print(f"Loaded model from {pytorch_path}")
    else:
        raise FileNotFoundError(
            f"No model checkpoint found at {checkpoint_path}. "
            f"Expected model.safetensors or pytorch_model.bin"
        )

    return model


def load_vocoder(shared_window_buffer, vocoder_checkpoint_path, vocoder_config, is_wrapped: bool = False):
    """Lazily load vocoder on first use."""
    if vocoder_checkpoint_path is None:
        return

    if not os.path.exists(vocoder_checkpoint_path):
        print(f"Vocoder checkpoint not found at {vocoder_checkpoint_path}")
        return

    try:
        if is_wrapped:
            class VocoderWrapper(torch.nn.Module):
                def __init__(self):
                    super().__init__()
                    self.vocoder: Optional[Vocoder] = None

                @classmethod
                def from_config(cls, config_name: str, shared_window_buffer: Optional[SharedWindowBuffer], **overrides) -> "VocoderWrapper":
                    wrapper = cls()
                    wrapper.vocoder = Vocoder.from_config(config_name, shared_window_buffer=shared_window_buffer, **overrides)
                    return wrapper

            vocoder = load_model(VocoderWrapper, vocoder_config, checkpoint_path=vocoder_checkpoint_path, overrides={"shared_window_buffer": shared_window_buffer}, strict=False).vocoder
            vocoder.eval()
        else:
            vocoder = load_model(Vocoder, vocoder_config, checkpoint_path=vocoder_checkpoint_path, overrides={"shared_window_buffer": shared_window_buffer}, strict=False)
            vocoder.eval()
    except Exception as e:
        print(f"Failed to load vocoder: {e}")
        raise e

    print(f"Loaded vocoder from {vocoder_checkpoint_path}")
    print(f"Vocoder parameters: {sum(p.numel() for p in vocoder.parameters()):,}")
    print(f"Vocoder structure: {vocoder}")
    return vocoder
    
shared_window_buffer = SharedWindowBuffer()
vocoder = load_vocoder(shared_window_buffer, "../../megatransformer/runs/vocoder/test_0_0/checkpoint-78000/", "tiny", is_wrapped=True)

loading pytorch bin
Loaded model from ../../megatransformer/runs/vocoder/test_0_0/checkpoint-78000/pytorch_model.bin
Loaded vocoder from ../../megatransformer/runs/vocoder/test_0_0/checkpoint-78000/
Vocoder parameters: 2,283,266
Vocoder structure: Vocoder(
  (input_proj): Conv1d(80, 128, kernel_size=(7,), stride=(1,), padding=(3,))
  (backbone): ModuleList(
    (0-1): 2 x ConvNeXtBlock(
      (dwconv): Conv1d(128, 128, kernel_size=(7,), stride=(1,), padding=(3,), groups=128)
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (pwconv1): Linear(in_features=128, out_features=1024, bias=True)
      (act): GELU(approximate='none')
      (pwconv2): Linear(in_features=1024, out_features=128, bias=True)
    )
    (2): FrequencyAttentionBlock(
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (qkv): Linear(in_features=128, out_features=384, bias=True)
      (attn_dropout): Dropout(p=0.0, inplace=False)
      (proj): Linear(in_features=128, out_feat

In [11]:
vocoder.mag_head_low.weight

Parameter containing:
tensor([[[-1.8921e-02, -2.0605e-01,  4.4141e-01,  ..., -4.5312e-01,
           6.6406e-01, -2.3828e-01],
         [-2.6562e-01, -9.0942e-03, -1.8359e-01,  ..., -2.9883e-01,
          -5.7373e-02,  4.4189e-02],
         [-9.5215e-02, -1.1475e-02, -1.9238e-01,  ..., -1.2695e-01,
          -1.2451e-01, -1.4551e-01],
         ...,
         [ 3.4961e-01,  1.1670e-01, -1.5234e-01,  ..., -5.5664e-02,
           1.4941e-01, -1.2500e-01],
         [ 3.1836e-01,  6.1646e-03,  3.8605e-03,  ...,  9.9609e-02,
          -5.2979e-02,  3.6377e-02],
         [ 2.4902e-01,  2.1582e-01,  1.3672e-01,  ...,  1.2891e-01,
           2.2168e-01,  1.1426e-01]],

        [[ 5.6641e-01,  5.1953e-01,  7.3047e-01,  ..., -1.5234e-01,
          -4.2773e-01, -1.7090e-01],
         [ 1.6992e-01,  9.3384e-03,  1.0059e-01,  ..., -2.2754e-01,
          -1.6797e-01,  3.2959e-02],
         [-3.9258e-01, -4.6484e-01, -4.2578e-01,  ..., -4.0820e-01,
          -2.6953e-01, -3.3594e-01],
         ...,
   