In [1]:
import os

import torch


from typing import Optional

from model.audio.sive.sive import SpeakerInvariantVoiceEncoder


def _load_model(
    model_cls,
    config_name: str,
    checkpoint_path: Optional[str] = None,
    device: str = "cuda",
    overrides: dict = {},
    strict: bool = False
):
    """
    Load a model from a checkpoint.

    Args:
        checkpoint_path: Path to checkpoint directory (containing model.safetensors or pytorch_model.bin)
        config_name: Config name from the model class (e.g., "small", "medium")
        device: Device to load the model on

    Returns:
        model in eval mode
    """
    # Create model with same config
    model = model_cls.from_config(config_name, **overrides)
    model = model.to(device)

    if checkpoint_path is None:
        return model

    # Try to load from safetensors first, then pytorch_model.bin
    safetensors_path = os.path.join(checkpoint_path, "model.safetensors")
    pytorch_path = os.path.join(checkpoint_path, "pytorch_model.bin")
    if os.path.exists(safetensors_path):
        from safetensors.torch import load_file
        state_dict = load_file(safetensors_path)
        model.load_state_dict(state_dict, strict=strict)
        print(f"Loaded model from {safetensors_path}")
    elif os.path.exists(pytorch_path):
        print(f"loading pytorch bin")
        state_dict = torch.load(pytorch_path, map_location=device, weights_only=True)
        model.load_state_dict(state_dict, strict=strict)
        print(f"Loaded model from {pytorch_path}")
    else:
        raise FileNotFoundError(
            f"No model checkpoint found at {checkpoint_path}. "
            f"Expected model.safetensors or pytorch_model.bin"
        )

    return model


def load_model(checkpoint_path, config):
    if checkpoint_path is None:
        return

    if not os.path.exists(checkpoint_path):
        print(f"Model checkpoint not found at {checkpoint_path}")
        return

    try:
        model  = _load_model(SpeakerInvariantVoiceEncoder, config, checkpoint_path=checkpoint_path, strict=False)
        model .eval()
    except Exception as e:
        print(f"Failed to load model : {e}")
        raise e

    print(f"Loaded model from {checkpoint_path}")
    print(f"model  parameters: {sum(p.numel() for p in model .parameters()):,}")
    print(f"model  structure: {model }")
    return model 
    
model  = load_model("../../megatransformer/runs/gubert/tiny_deep_0_4/checkpoint-60000/", "tiny_deep")

loading pytorch bin
Loaded model from ../../megatransformer/runs/gubert/tiny_deep_0_4/checkpoint-60000/pytorch_model.bin
Loaded model from ../../megatransformer/runs/gubert/tiny_deep_0_4/checkpoint-60000/
model  parameters: 7,037,822
model  structure: SpeakerInvariantVoiceEncoder(
  (conv_subsample): ConvSubsampling(
    (conv): Sequential(
      (0): Conv1d(80, 128, kernel_size=(7,), stride=(2,), padding=(3,))
      (1): GroupNorm(32, 128, eps=1e-05, affine=True)
      (2): GELU(approximate='none')
      (3): Dropout1d(p=0.1, inplace=False)
      (4): Conv1d(128, 128, kernel_size=(3,), stride=(2,), padding=(1,))
      (5): GroupNorm(32, 128, eps=1e-05, affine=True)
      (6): GELU(approximate='none')
      (7): Dropout1d(p=0.1, inplace=False)
      (8): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
      (9): GroupNorm(32, 128, eps=1e-05, affine=True)
      (10): GELU(approximate='none')
      (11): Dropout1d(p=0.1, inplace=False)
    )
  )
  (encoder_blocks): ModuleLi

In [4]:
model.encoder_blocks[-5].ff1[-2].weight

Parameter containing:
tensor([[-0.0372,  0.1138,  0.1117,  ...,  0.0039,  0.0294,  0.1461],
        [ 0.0781, -0.0722, -0.0191,  ...,  0.1481,  0.0462, -0.1144],
        [-0.0114,  0.0291, -0.0140,  ...,  0.0737, -0.1048,  0.1055],
        ...,
        [-0.1931, -0.0239, -0.1384,  ..., -0.0805, -0.0057, -0.1333],
        [-0.0216,  0.0785,  0.1074,  ...,  0.0863, -0.0153,  0.0059],
        [-0.0015,  0.0775, -0.0013,  ..., -0.0610,  0.0254,  0.0690]],
       device='cuda:0', requires_grad=True)