In [1]:
import os

import torch
import wespeaker


def load_model_from_path(model, checkpoint_path):
    """
    Load model from a specific checkpoint path
    
    Args:
        model: The model to load weights into
        checkpoint_path: Path to the checkpoint file
    
    Returns:
        dict: Checkpoint information including model state and metadata
    """
    if os.path.exists(checkpoint_path):
        print(f"Loading model from: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path)
        
        model.load_state_dict(checkpoint['model_state_dict'])
        
        checkpoint_info = {
            'epoch': checkpoint.get('epoch', 0),
            'min_eer': checkpoint.get('min_eer', float('inf')),
            'model_state': checkpoint.get('model_state', {})
        }
        
        print(f"Model loaded successfully from epoch {checkpoint_info['epoch']}")
        return checkpoint_info
    else:
        raise FileNotFoundError(f"No checkpoint found at {checkpoint_path}")
    

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ecapa_tdnn = wespeaker.load_model_local("models/voxceleb_ECAPA1024")
model = ecapa_tdnn.model.to(device)
model_loaded = load_model_from_path(model, 'checkpoints/ecapa_tdnn_musan/latest_checkpoint.pth')

  from .autonotebook import tqdm as notebook_tqdm


Loading model from: checkpoints/ecapa_tdnn_musan/latest_checkpoint.pth
Model loaded successfully from epoch 50


In [2]:
model

ECAPA_TDNN(
  (layer1): Conv1dReluBn(
    (conv): Conv1d(80, 1024, kernel_size=(5,), stride=(1,), padding=(2,))
    (bn): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (layer2): SE_Res2Block(
    (se_res2block): Sequential(
      (0): Conv1dReluBn(
        (conv): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,))
        (bn): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): Res2Conv1dReluBn(
        (convs): ModuleList(
          (0-6): 7 x Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
        )
        (bns): ModuleList(
          (0-6): 7 x BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (2): Conv1dReluBn(
        (conv): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,))
        (bn): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (3): SE_Connect(
        

In [3]:
# Freeze lower layers
for param in model.layer1.parameters():
    param.requires_grad = False

for param in model.layer2.parameters():
    param.requires_grad = False

for param in model.layer3.parameters():
    param.requires_grad = False

# Partially unfreeze layer4 - only make SE_Connect trainable
for name, param in model.layer4.named_parameters():
    if "se_res2block.3" in name:  # SE_Connect module
        param.requires_grad = True
    else:
        param.requires_grad = False

# Ensure upper layers are trainable (they are by default, but being explicit)
for param in model.conv.parameters():
    param.requires_grad = True

for param in model.pool.parameters():
    param.requires_grad = True

for param in model.bn.parameters():
    param.requires_grad = True

for param in model.linear.parameters():
    param.requires_grad = True

# Verify which layers are trainable
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(
    f"Trainable parameters: {trainable_params:,} / {total_params:,} ({trainable_params/total_params:.2%})"
)

Trainable parameters: 6,367,680 / 14,657,088 (43.44%)
