In [4]:
import torch
import time
from modules.convtasnet_ext_nosb2 import MaskNet, Encoder, Decoder
from llm_cloud import write_prompt_to_file, send_file_to_server, fetch_embeddings_file, read_embeddings_from_pickle
import scipy.io.wavfile as wav
import numpy as np
import os
import pandas as pd

# Initialize the components
encoder = Encoder(kernel_size=16, out_channels=512)  # Specify appropriate parameters
masknet = MaskNet(N=512, B=128, H=512, P=3, X=8, R=3, C=1, norm_type='gLN',
                  causal=False, mask_nonlinear="relu", cond_dim=4096,
                  film_mode='block', film_n_layer=2, film_scale=True,
                  film_where='before1x1')
decoder = Decoder(in_channels = 512,
    out_channels = 1,
    kernel_size=16,
    stride = 8,
    bias = False)  # Specify appropriate parameters



# Dummy function to generate text embeddings
device = 'cuda'
encoder = encoder.to(device)
masknet = masknet.to(device)
decoder = decoder.to(device)

encoder.load_state_dict(torch.load('save/encoder_model_weights.pth', map_location=device))
masknet.load_state_dict(torch.load('save/masknet_model_weights.pth', map_location=device))
decoder.load_state_dict(torch.load('save/decoder_model_weights.pth', map_location=device))

Use FiLM at (every) block.
Initialized a FiLM before1x1.
Initialized a FiLM before1x1.
Initialized a FiLM before1x1.


<All keys matched successfully>

In [3]:
#### time profiling


import torch
import time
import scipy.io.wavfile as wav
import numpy as np
import pandas as pd
from modules.convtasnet_ext_nosb2 import MaskNet, Encoder, Decoder

# Dummy function for prompt reading
def dummy_read_prompt(prompt, device='cuda'):
    B = len(prompt)
    return torch.rand((1, 4096), device=device)

def edit_sound(mix, text_embed):
    with torch.no_grad():
        timings = {}
        
        start_time = time.perf_counter()

        # Encoding speech
        mix_h = encoder(mix)
        timings['encoder'] = time.perf_counter() - start_time

        # MaskNet processing
        start_time = time.perf_counter()
        est_mask = masknet(mix_h, text_embed).squeeze(0)
        timings['masknet'] = time.perf_counter() - start_time

        # Mask application and decoding
        start_time = time.perf_counter()
        est_tar_h = mix_h * est_mask
        est_tar = decoder(est_tar_h)
        timings['decoder'] = time.perf_counter() - start_time

        # Adjust output length
        T_origin = mix.size(1)
        T_ext = est_tar.size(1)
        if T_origin > T_ext:
            est_tar = torch.nn.functional.pad(est_tar, (0, T_origin - T_ext))
        else:
            est_tar = est_tar[:, :T_origin]

        timings['total'] = timings['decoder'] + timings['masknet'] + timings['encoder']
        return est_tar, timings

# Measurements
results = []
for _ in range(100):
    start_time = time.perf_counter()
    sample_rate, mix = wav.read('mix.wav')
    mix = torch.from_numpy(mix.astype('float32')).unsqueeze(0) / 32768.0
    mix = mix.to(device)
    loading_time = time.perf_counter() - start_time

    prompt = "your prompt goes here"
    text_embed = dummy_read_prompt(prompt, device=device)

    _, timings = edit_sound(mix, text_embed)
    timings['loading'] = loading_time
    results.append(timings)

    # Clear CUDA cache
    torch.cuda.empty_cache()

# Calculate mean and standard deviation
df = pd.DataFrame(results)
summary = df.agg(['mean', 'std']).T
summary.columns = ['mean', 'std']

# Save to CSV
summary.to_csv('profiling.csv')

print("Profiling completed. Results saved to 'profiling.csv'.")


Profiling completed. Results saved to 'profiling.csv'.


In [5]:
### gpu memory profiling
import torch
import time
import scipy.io.wavfile as wav
import numpy as np
import pandas as pd
from modules.convtasnet_ext_nosb2 import MaskNet, Encoder, Decoder

def dummy_read_prompt(prompt, device='cuda'):
    B = len(prompt)
    return torch.rand((1, 4096), device=device)


def edit_sound(mix, text_embed):
    with torch.no_grad():
        # Process the audio mix
        mix_h = encoder(mix)
        est_mask = masknet(mix_h, text_embed).squeeze(0)
        est_tar_h = mix_h * est_mask
        est_tar = decoder(est_tar_h)

        # Adjust output length
        T_origin = mix.size(-1)
        T_ext = est_tar.size(-1)
        if T_origin > T_ext:
            est_tar = torch.nn.functional.pad(est_tar, (0, T_origin - T_ext))
        else:
            est_tar = est_tar[:, :T_origin]

        return est_tar

# Measurements
memory_usage_results = []
lengths = range(16000, 320001, 16000)  # Lengths from 16000 to 160000, increasing by 16000

for length in lengths:
    start_mem = torch.cuda.memory_allocated(device)

    # Load audio and process only a segment of specified length
    #sample_rate, mix = wav.read('mix.wav')
    mix = torch.rand((1, length), device=device)  # Normalize the waveform
    #mix = torch.from_numpy(mix).unsqueeze(0).to(device)  # Add batch dimension and send to GPU

    prompt = "your prompt goes here"
    text_embed = dummy_read_prompt(prompt, device=device)

    # Edit sound and record memory usage
    _ = edit_sound(mix, text_embed)
    end_mem = torch.cuda.memory_allocated(device)
    memory_usage = end_mem 

    memory_usage_results.append({
        'length': length,
        'memory_usage_bytes': memory_usage
    })

    # Clear CUDA cache to avoid memory leaks
    torch.cuda.empty_cache()

# Save results to CSV
df_memory = pd.DataFrame(memory_usage_results)
df_memory.to_csv('memory_profiling_by_length.csv', index=False)

print("Memory usage profiling completed. Results saved to 'memory_profiling_by_length.csv'.")

Memory usage profiling completed. Results saved to 'memory_profiling_by_length.csv'.
