In [1]:
import torch
import time
from modules.convtasnet_ext_nosb2 import MaskNet, Encoder, Decoder
from llm_cloud import write_prompt_to_file, send_file_to_server, fetch_embeddings_file, read_embeddings_from_pickle
import scipy.io.wavfile as wav
import numpy as np


# Initialize the components
encoder = Encoder(kernel_size=16, out_channels=512)  # Specify appropriate parameters
masknet = MaskNet(N=512, B=128, H=512, P=3, X=8, R=3, C=1, norm_type='gLN',
                  causal=False, mask_nonlinear="relu", cond_dim=4096,
                  film_mode='block', film_n_layer=2, film_scale=True,
                  film_where='before1x1')
decoder = Decoder(in_channels = 512,
    out_channels = 1,
    kernel_size=16,
    stride = 8,
    bias = False)  # Specify appropriate parameters



# Dummy function to generate text embeddings
device = 'cuda'
encoder = encoder.to(device)
masknet = masknet.to(device)
decoder = decoder.to(device)

encoder.load_state_dict(torch.load('save/encoder_model_weights.pth', map_location=device))
masknet.load_state_dict(torch.load('save/masknet_model_weights.pth', map_location=device))
decoder.load_state_dict(torch.load('save/decoder_model_weights.pth', map_location=device))

Use FiLM at (every) block.
Initialized a FiLM before1x1.
Initialized a FiLM before1x1.
Initialized a FiLM before1x1.


<All keys matched successfully>

In [2]:
def edit_sound(mix, text_embed):
    with torch.no_grad():
        # Ensure mix is on the correct device
        mix = mix.to(device)

        # Encoding speech
        mix_h = encoder(mix)

        # Extraction
        est_mask = masknet(mix_h, text_embed).squeeze(0)
        est_tar_h = mix_h * est_mask  # (B, F, T)

        # Decoding
        est_tar = decoder(est_tar_h)

        # T changed after conv1d in encoder, fix it here
        T_origin = mix.size(1)
        T_ext = est_tar.size(1)

        if T_origin > T_ext:
            est_tar = torch.nn.functional.pad(est_tar, (0, T_origin - T_ext))
        else:
            est_tar = est_tar[:, :T_origin]

        return est_tar

# Testing the function
if __name__ == "__main__":
    # Simulating a single audio sample
    sample_rate, mix = wav.read('mix.wav')
    mix = torch.from_numpy(mix.astype('float32')).unsqueeze(0) / 32768.0  # Normalize the waveform
    mix = mix.to(device)
    
    # Prompt handling
    prompt = "Remove all people talking."
    local_input_path = write_prompt_to_file(prompt)
    server = "axon.rc.zi.columbia.edu"
    remote_input_path = "/home/ss6928/LCE_inference/input_file.txt"
    send_file_to_server(local_input_path, remote_input_path, server)
    print("Prompt has been sent to the server.")
    
    print("Waiting for the server to process the prompt...")
    time.sleep(15)  

    remote_output_path = "/home/ss6928/LCE_inference/embedding.pkl"
    local_output_path = "embedding.pkl"
    fetch_embeddings_file(server, remote_output_path, local_output_path)
    embeddings = read_embeddings_from_pickle(local_output_path)
    
    # Sound editing using the embeddings
    if embeddings:
        est_tar = edit_sound(mix, embeddings[0])
        est_tar = est_tar.squeeze().cpu().numpy() * 32768  # Rescale to int16 range
        est_tar = est_tar.astype('int16')  # Convert to int16
        wav.write('edited_mix.wav', sample_rate, est_tar)  
        print("Edited audio saved as 'edited_mix.wav'.")
    else:
        print("No embeddings found.")

Prompt has been sent to the server.
Waiting for the server to process the prompt...
Edited audio saved as 'edited_mix.wav'.
