In [21]:
from asteroid.models import DPRNNTasNet
import warnings
import torchaudio
import torch
import torchaudio.functional as F
import numpy as np
import matplotlib.pyplot as plt
import librosa
import librosa.display
import IPython.display as play
warnings.filterwarnings("ignore")

In [2]:
model = DPRNNTasNet.from_pretrained("mpariente/DPRNNTasNet-ks2_WHAM_sepclean")

In [3]:
model.eval()
print("Model Loaded Successfully")

Model Loaded Successfully


In [22]:
waveform,sr = torchaudio.load("MixedSpeech_3.wav")
# if the sample rate is not 8000 then resample it
if sr != 8000:
    resampled = F.resample(waveform,sr,8000)
    torchaudio.save("MixedSpeech_8k.wav",resampled,8000)
    waveform = resampled
    sr = 8000
    print("Input Audio Successfully resampled to 8000 Hz.")

Input Audio Successfully resampled to 8000 Hz.


In [23]:
with torch.no_grad():
    waveform = waveform.unsqueeze(0) # Shape : (1,number of channels, number of samples)
    est_sources = model.separate(waveform) # Separating the sources
    est_sources = est_sources.squeeze(0) # Remove the batch dimension
print("Shape of the separated sources is:",est_sources.shape)

Shape of the separated sources is: torch.Size([2, 89088])


In [20]:
for i in range(est_sources.shape[0]):
    torchaudio.save(f"Source_{i+1}.wav",est_sources[i].unsqueeze(0),sample_rate=sr)
    print(f"Separated Source file: Source{i+1}")

Separated Source file: Source1
Separated Source file: Source2


In [13]:
play.Audio("MixedSpeech_3.wav")

In [12]:
play.Audio("Source_1.wav")

In [14]:
play.Audio("Source_2.wav")

In [15]:
import torch.quantization as quantization

# Apply dynamic quantization
quantized_model = quantization.quantize_dynamic(
    model,  # Original model
    {torch.nn.Linear, torch.nn.Conv1d},  # Layers to quantize
    dtype=torch.qint8  # Use 8-bit integers
)

# Save the quantized model
torch.save(quantized_model.state_dict(), "quantized_dprnn.pth")
print("Model Quantized Successfully.")

Model Quantized Successfully.


In [16]:
import os

# Path to the saved model file
model_path = "quantized_dprnn.pth"

# Get the size of the file in bytes
file_size_bytes = os.path.getsize(model_path)

# Convert to KB and MB
file_size_kb = file_size_bytes / 1024
file_size_mb = file_size_kb / 1024

print(f"Size of the model file: {file_size_bytes} bytes")
print(f"Size of the model file: {file_size_kb:.2f} KB")
print(f"Size of the model file: {file_size_mb:.2f} MB")

Size of the model file: 13487958 bytes
Size of the model file: 13171.83 KB
Size of the model file: 12.86 MB
