In [1]:
import os
import sys
import torch
from torch.cuda.amp import autocast
import numpy as np

device = 'cuda'
torch.manual_seed(1234)  # Replaces the seed initialization in the YAML
np.random.seed(1234)

# Model paths
encoder_path = 'save/encoder_full_model.pth'
masknet_path = 'save/masknet_full_model.pth'
decoder_path = 'save/decoder_full_model.pth'

# Load models
encoder = torch.load(encoder_path).to(device).eval()
masknet = torch.load(masknet_path).to(device).eval()
decoder = torch.load(decoder_path).to(device).eval()

def edit_sound(mix, text_embed):
    # Encoding speech
    mix_h = encoder(mix)

    # Extraction
    est_mask = masknet(mix_h, text_embed).squeeze(0)
    est_tar_h = mix_h * est_mask  # (B, F, T)

    # Decoding
    est_tar = decoder(est_tar_h)

    # T changed after conv1d in encoder, fix it here
    T_origin = mix.size(1)
    T_ext = est_tar.size(1)

    if T_origin > T_ext:
        est_tar = torch.nn.functional.pad(est_tar, (0, T_origin - T_ext))
    else:
        est_tar = est_tar[:, :T_origin]

    return est_tar

def dummy_read_prompt(prompt, device='cuda'):
    B = len(prompt)  # batch size
    return torch.rand((B, 4096), device=device)  # Assuming txt_emb_dim is 4096

# Test processing
with torch.no_grad():
    for _ in range(10):
        mix = torch.rand(1, 80000).to(device)
        prompt = ('This is a placeholder.',)
        text_embed = dummy_read_prompt(prompt)
        est_tar = edit_sound(mix, text_embed)
        assert est_tar.shape == (1, 80000), "Output shape mismatch"
        print("Test successful, output shape:", est_tar.shape)


  from .autonotebook import tqdm as notebook_tqdm


Test successful, output shape: torch.Size([1, 80000])
Test successful, output shape: torch.Size([1, 80000])
Test successful, output shape: torch.Size([1, 80000])
Test successful, output shape: torch.Size([1, 80000])
Test successful, output shape: torch.Size([1, 80000])
Test successful, output shape: torch.Size([1, 80000])
Test successful, output shape: torch.Size([1, 80000])
Test successful, output shape: torch.Size([1, 80000])
Test successful, output shape: torch.Size([1, 80000])
Test successful, output shape: torch.Size([1, 80000])


In [1]:
import torch
from modules.convtasnet_ext_nosb import MaskNet

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = MaskNet(N=512, B=128, H=512, P=1, X=8, R=3, C=4)

In [3]:
model

MaskNet(
  (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (bottleneck_conv1x1): Conv1d(512, 128, kernel_size=(1,), stride=(1,))
  (temporal_conv_net): FilmTemporalBlocksSequential(
    (blocks): ModuleDict(
      (filmtemporalblock_0_0): FilmTemporalBlock(
        (layers): ModuleDict(
          (conv): Conv1d(128, 512, kernel_size=(1,), stride=(1,), bias=False)
          (act): PReLU(num_parameters=1)
          (norm): GlobalLayerNorm()
          (DSconv): DepthwiseSeparableConv(
            (depthwise): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,), groups=512, bias=False)
            (pointwise): Conv1d(512, 512, kernel_size=(1,), stride=(1,), bias=False)
            (activation): PReLU(num_parameters=1)
            (norm): GlobalLayerNorm()
          )
        )
      )
      (filmtemporalblock_0_1): FilmTemporalBlock(
        (layers): ModuleDict(
          (conv): Conv1d(128, 512, kernel_size=(1,), stride=(1,), bias=False)
          (act): PR

In [1]:
import torch
from modules.convtasnet_ext_nosb2 import MaskNet

In [2]:
model = MaskNet(N=512, B=128, H=512, P=3, X=8, R=3, C=1,norm_type='gLN',
        causal=False,
        mask_nonlinear="relu",
        cond_dim=4096,
        film_mode='block',
        film_n_layer=2,
        film_scale=True,
        film_where='before1x1')

Use FiLM at (every) block.
Initialized a FiLM before1x1.
Initialized a FiLM before1x1.
Initialized a FiLM before1x1.


In [3]:
model

MaskNet(
  (layer_norm): ChannelwiseLayerNorm()
  (bottleneck_conv1x1): Conv1d(
    (conv): Conv1d(512, 128, kernel_size=(1,), stride=(1,), bias=False)
  )
  (temporal_conv_net): FilmTemporalBlocksSequential(
    (filmtemporalblock_0_0): FilmTemporalBlock(
      (layers): Sequential(
        (conv): Conv1d(
          (conv): Conv1d(128, 512, kernel_size=(1,), stride=(1,), bias=False)
        )
        (act): PReLU(num_parameters=1)
        (norm): GlobalLayerNorm()
        (DSconv): DepthwiseSeparableConv(
          (conv_0): Conv1d(
            (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), groups=512, bias=False)
          )
          (act): PReLU(num_parameters=1)
          (act_0): GlobalLayerNorm()
          (conv_1): Conv1d(
            (conv): Conv1d(512, 128, kernel_size=(1,), stride=(1,), bias=False)
          )
        )
      )
      (film): FiLM(
        (scaler): Sequential(
          (0): Linear(in_features=4096, out_features=128, bias=True)
          (1): ReLU(i

In [6]:
weights_path = 'save/masknet_model_weights.pth'

# Load the weights from the file
state_dict = torch.load(weights_path, map_location='cuda')  # or map_location='cuda' if using GPU

# Load the weights into the model
model.load_state_dict(state_dict)

# If you are ready to use the model for inference, switch to evaluation mode
model.eval()

MaskNet(
  (layer_norm): ChannelwiseLayerNorm()
  (bottleneck_conv1x1): Conv1d(
    (conv): Conv1d(512, 128, kernel_size=(1,), stride=(1,), bias=False)
  )
  (temporal_conv_net): FilmTemporalBlocksSequential(
    (filmtemporalblock_0_0): FilmTemporalBlock(
      (layers): Sequential(
        (conv): Conv1d(
          (conv): Conv1d(128, 512, kernel_size=(1,), stride=(1,), bias=False)
        )
        (act): PReLU(num_parameters=1)
        (norm): GlobalLayerNorm()
        (DSconv): DepthwiseSeparableConv(
          (conv_0): Conv1d(
            (conv): Conv1d(512, 512, kernel_size=(3,), stride=(1,), groups=512, bias=False)
          )
          (act): PReLU(num_parameters=1)
          (act_0): GlobalLayerNorm()
          (conv_1): Conv1d(
            (conv): Conv1d(512, 128, kernel_size=(1,), stride=(1,), bias=False)
          )
        )
      )
      (film): FiLM(
        (scaler): Sequential(
          (0): Linear(in_features=4096, out_features=128, bias=True)
          (1): ReLU(i

In [4]:
import torch
from modules.convtasnet_ext_nosb2 import MaskNet, Encoder, Decoder

# Initialize the components
encoder = Encoder(kernel_size=16, out_channels=512)  # Specify appropriate parameters
masknet = MaskNet(N=512, B=128, H=512, P=3, X=8, R=3, C=1, norm_type='gLN',
                  causal=False, mask_nonlinear="relu", cond_dim=4096,
                  film_mode='block', film_n_layer=2, film_scale=True,
                  film_where='before1x1')
decoder = Decoder(in_channels = 512,
    out_channels = 1,
    kernel_size=16,
    stride = 8,
    bias = False)  # Specify appropriate parameters



# Dummy function to generate text embeddings
device = 'cuda'
encoder = encoder.to(device)
masknet = masknet.to(device)
decoder = decoder.to(device)

encoder.load_state_dict(torch.load('save/encoder_model_weights.pth', map_location=device))
masknet.load_state_dict(torch.load('save/masknet_model_weights.pth', map_location=device))
decoder.load_state_dict(torch.load('save/decoder_model_weights.pth', map_location=device))

# Dummy function to generate text embeddings
def dummy_read_prompt(prompt, device='cuda'):
    B = len(prompt)  # batch size
    return torch.rand((B, 4096), device=device)  # assuming 'txt_emb_dim' is 4096

def edit_sound(mix, text_embed):
    with torch.no_grad():
        # Ensure mix is on the correct device
        mix = mix.to(device)

        # Encoding speech
        mix_h = encoder(mix)

        # Extraction
        est_mask = masknet(mix_h, text_embed).squeeze(0)
        est_tar_h = mix_h * est_mask  # (B, F, T)

        # Decoding
        est_tar = decoder(est_tar_h)

        # T changed after conv1d in encoder, fix it here
        T_origin = mix.size(1)
        T_ext = est_tar.size(1)

        if T_origin > T_ext:
            est_tar = torch.nn.functional.pad(est_tar, (0, T_origin - T_ext))
        else:
            est_tar = est_tar[:, :T_origin]

        return est_tar

# Testing the function
if __name__ == "__main__":
    mix = torch.rand(1, 80000).to(device)  # Simulating a single audio sample
    prompt = ('This is a placeholder.',)
    #text_embed = dummy_read_prompt(prompt, device=device)
    
    
    est_tar = edit_sound(mix, text_embed)
    print(f"Output shape: {est_tar.shape}")  # Should print: Output shape: torch.Size([1, 80000])
    assert est_tar.shape == (1, 80000)

Use FiLM at (every) block.
Initialized a FiLM before1x1.
Initialized a FiLM before1x1.
Initialized a FiLM before1x1.
Output shape: torch.Size([1, 80000])
