In [61]:
%%capture
!pip install soundfile torchaudio transformers datasets accelerate

In [62]:
import torch.nn as nn
import torch
import torch.nn.functional as F
import torchaudio
import os
import random
from datasets import Dataset as HFDataset
from transformers import Trainer, TrainingArguments
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from safetensors.torch import load_file

In [63]:
def split_into_chunks(waveform: torch.Tensor, chunk_size: int = 16000):
    """
    Splits 1D waveform into list of fixed-size chunks (with zero-padding if needed).
    """
    total_len = waveform.shape[-1]
    pad_len = (chunk_size - total_len % chunk_size) % chunk_size
    padded = F.pad(waveform, (0, pad_len))

    chunks = padded.unfold(dimension=-1, size=chunk_size, step=chunk_size)
    return chunks 



class WaveformAutoencoder(nn.Module):
    def __init__(self, size='small'):
        super().__init__()
        if size == 'small':
            self.encoder = nn.Sequential(
                nn.Conv1d(1, 16, 4, stride=2, padding=1),  # 8000
                nn.LeakyReLU(),
                nn.Conv1d(16, 32, 4, stride=2, padding=1), # 4000
                nn.LeakyReLU(),
                nn.Conv1d(32, 64, 4, stride=2, padding=1), # 2000
            )
            self.decoder = nn.Sequential(
                nn.ConvTranspose1d(64, 32, 4, stride=2, padding=1),  # 4000
                nn.LeakyReLU(),
                nn.ConvTranspose1d(32, 16, 4, stride=2, padding=1),  # 8000
                nn.LeakyReLU(),
                nn.ConvTranspose1d(16, 1, 4, stride=2, padding=1),   # 16000
                nn.Tanh()
            )
        elif size == 'medium':
            self.encoder = nn.Sequential(
                nn.Conv1d(1, 32, 4, stride=2, padding=1),  # 8000
                nn.LeakyReLU(),
                nn.Conv1d(32, 64, 4, stride=2, padding=1), # 4000
                nn.LeakyReLU(),
                nn.Conv1d(64, 128, 4, stride=2, padding=1),# 2000
                nn.LeakyReLU(),
                nn.Conv1d(128, 256, 4, stride=2, padding=1), # 1000
            )
            self.decoder = nn.Sequential(
                nn.ConvTranspose1d(256, 128, 4, stride=2, padding=1), # 2000
                nn.LeakyReLU(),
                nn.ConvTranspose1d(128, 64, 4, stride=2, padding=1),  # 4000
                nn.LeakyReLU(),
                nn.ConvTranspose1d(64, 32, 4, stride=2, padding=1),   # 8000
                nn.LeakyReLU(),
                nn.ConvTranspose1d(32, 1, 4, stride=2, padding=1),    # 16000
                nn.Tanh()
            )

        elif size == 'large':
            self.encoder = nn.Sequential(
                nn.Conv1d(1, 64, 4, stride=2, padding=1),  # 8000
                nn.LeakyReLU(),
                nn.Conv1d(64, 128, 4, stride=2, padding=1), # 4000
                nn.LeakyReLU(),
                nn.Conv1d(128, 256, 4, stride=2, padding=1),# 2000
                nn.LeakyReLU(),
                nn.Conv1d(256, 512, 4, stride=2, padding=1), # 1000
            )
            self.decoder = nn.Sequential(
                nn.ConvTranspose1d(512, 256, 4, stride=2, padding=1), # 2000
                nn.LeakyReLU(),
                nn.ConvTranspose1d(256, 128, 4, stride=2, padding=1), # 4000
                nn.LeakyReLU(),
                nn.ConvTranspose1d(128, 64, 4, stride=2, padding=1),   # 8000
                nn.LeakyReLU(),
                nn.ConvTranspose1d(64, 1, 4, stride=2, padding=1),     # 16000
                nn.Tanh()
            )
        else:
            raise ValueError("Invalid size. Choose from 'small', 'medium', or 'large'.")
    

    def forward(self, **kwargs):
        x = kwargs.get("input_values")
        x = self.encoder(x)
        x = self.decoder(x)
        return {
            "logits": x,
            "masks": kwargs.get("masks")
        }
    
    @torch.no_grad()
    def recover(self, waveform, type="noise", mask=None, chunk_size=16000, device='cuda'):
        if type not in ["noise", "mask"]:
            raise ValueError("type must be either 'noise' or 'mask'")
        if type == "mask" and mask is None:
            raise ValueError("mask must be provided for type 'mask'")
        
        self.eval()
        self.to(device)
        waveform = torch.tensor(waveform)
        print(f"Waveform shape: {waveform.shape}")

        # Split into chunks
        chunks = split_into_chunks(waveform, chunk_size=chunk_size)
        print(f"Chunks shape: {chunks.shape}")
        chunks = chunks.unsqueeze(1).to(device)  # add channel dimension: (N, 1, chunk_size)
        print(f"Chunks shape after unsqueeze: {chunks.shape}")
        outputs = self(input_values=chunks)["logits"]
        print(f"Outputs shape: {outputs.shape}")
        # Reconstruct:
        reconstructed = outputs.squeeze(1).reshape(-1)[:waveform.shape[-1]]
        print(f"Reconstructed shape: {reconstructed.shape}")
        
        if type == "noise":
            return reconstructed.cpu()
        else:
            mask = mask[-1].to(device)
            # reconstucted if mask == 1 else waveform
            wave_clone = waveform.clone().to(device)
            print("Returning masked_reconstructed.shape", reconstructed.shape, mask.shape, wave_clone.shape)
            reconstructed = torch.where(mask == 1, reconstructed, wave_clone)
            return reconstructed.cpu()
    

    def model_size_params(self):
        """
        Returns the number of parameters in the model.
        """
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

In [64]:
import torch
import torchaudio
import os
import random

def add_noise(audio, noise_prob_range, noise_var=None):
    """
    Add noise to the audio tensor by randomly injecting impulses (spikes or dips).
    
    Args:
        audio (torch.Tensor): The input audio tensor.
        noise_prob_range (tuple): A tuple specifying the range of noise probabilities.
        noise_var (float, optional): The variance of the noise to be added. Default is None.
    Returns:
        torch.Tensor: The noisy audio tensor."""
    noisy = audio.clone()
    noise_prob = random.uniform(*noise_prob_range)  # Randomly choose noise probability
    mask = torch.rand_like(audio) < noise_prob # binary mask: where to inject impulses

    # Randomly choose 0, 1, or -1
    impulses = torch.randint(0, 3, audio.shape).float()
    impulses[impulses == 1] = 1.0   # spike
    impulses[impulses == 2] = -1.0  # dip
    impulses[impulses == 0] = 0.0   # zero-out

    if noise_var:
        noise_var = torch.randn_like(audio) * noise_var
        noisy[mask] = impulses[mask] + noise_var[mask]
    else:
        noisy[mask] = impulses[mask]
    return noisy

def mask_audio(audio, mask_prob_range, mask_value=-1):
    """
    Apply a mask to the audio tensor.
    
    Args:
        audio (torch.Tensor): The input audio tensor.
        mask_prob (float): The probability of masking each element.
        mask_value (int, optional): The value to use for masking. Default is -1
        
    Returns:
        torch.Tensor: The masked audio tensor.
    """
    mask_prob = random.uniform(*mask_prob_range)  # Randomly choose mask probability
    mask = torch.rand_like(audio) < mask_prob
    masked_audio = audio.clone()
    masked_audio[mask] = mask_value
    return masked_audio, mask
    
class WaveformDataset(torch.utils.data.Dataset):
    def __init__(
            self, 
            folder, 
            chunk_length=16000, 
            sample_rate=16000, 
            type="noise",
            noise_prob_range=(0.05, 0.1), 
            noise_var=0.1,
            mask_prob_range=(0.1, 0.2),    
            mask_value=-100   
        ):
        if type not in ["noise", "mask"]:
            raise ValueError("type must be either 'noise' or 'mask'")
        
        self.type = type
        self.chunk_length = chunk_length
        self.sample_rate = sample_rate
        self.noise_prob_range = noise_prob_range
        self.noise_var = noise_var
        self.audio_chunks = []
        self.mask_prob_range = mask_prob_range
        self.mask_value = mask_value
        self._preload_chunks([
            os.path.join(folder, f) for f in os.listdir(folder) if f.endswith('.wav')
        ])

    def _preload_chunks(self, file_paths):
        for path in file_paths:
            waveform, sr = torchaudio.load(path)
            if sr != self.sample_rate:
                resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=self.sample_rate)
                waveform = resampler(waveform)

            # Convert to mono
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)

            # Chunking
            total_len = waveform.shape[1]
            for i in range(0, total_len - self.chunk_length + 1, self.chunk_length):
                chunk = waveform[:, i:i + self.chunk_length]
                self.audio_chunks.append(chunk)


    def __len__(self):
        return len(self.audio_chunks)

    def __getitem__(self, idx):
        clean = self.audio_chunks[idx]
        if self.type == "noise":
            noisy = add_noise(clean, self.noise_prob_range, self.noise_var)
            return {
                "input_values": noisy,   # model input: noised audio
                "labels": clean.clone()  # ground truth: clean audio
            }
        elif self.type == "mask":
            masked, mask = mask_audio(clean, self.mask_prob_range, self.mask_value)
            return {
                "input_values": masked,   # model input: masked audio
                "labels": clean.clone(),  # ground truth: clean audio
                "masks": mask              # binary mask
            }

In [65]:
def train(
    model,
    data_dir="audio/",
    chunk_length=16000,
    sample_rate=16000,
    type="noise",  # "noise" or "mask"
    noise_prob_range=(0.05, 0.25),
    noise_var=0.1,
    mask_prob_range=(0.05, 0.25),
    batch_size=16,
    num_epochs=1,
    learning_rate=1e-3,
    logging_steps=100,
    output_dir="./ae_ckpt",
):
    if type not in ["noise", "mask"]:
        raise ValueError("type must be either 'noise' or 'mask'")
    
    print("Loading train dataset...")
    audio_ds = WaveformDataset(
        folder=data_dir, chunk_length=chunk_length, sample_rate=sample_rate, 
        noise_prob_range=noise_prob_range, noise_var=noise_var,
        mask_prob_range=mask_prob_range, type=type
    )
    print(f"Train dataset size: {len(audio_ds)}")


    args = TrainingArguments(
        report_to="tensorboard",
        output_dir=output_dir,
        per_device_train_batch_size=batch_size,
        learning_rate=learning_rate,
        num_train_epochs=num_epochs,
        logging_steps=logging_steps,
        save_strategy="epoch",
        save_total_limit=10,
        eval_strategy="no",
        remove_unused_columns=False,
    )

    def data_collator_for_noise(batch):
        input_values = torch.stack([item["input_values"] for item in batch])
        labels = torch.stack([item["labels"] for item in batch])
        return {
            "input_values": input_values,
            "labels": labels
        }

    def compute_loss_for_noise(outputs, labels, num_items_in_batch=batch_size, return_outputs=False):
        loss = F.mse_loss(outputs["logits"],labels)
        return (loss, outputs) if return_outputs else loss
    
    def data_collator_for_mask(batch):
        input_values = torch.stack([item["input_values"] for item in batch])
        labels = torch.stack([item["labels"] for item in batch])
        masks = torch.stack([item["masks"] for item in batch])
        return {
            "input_values": input_values,
            "labels": labels,
            "masks": masks
        }
    
    def compute_loss_for_mask(outputs, labels, num_items_in_batch=batch_size, return_outputs=False):
        msk = outputs["masks"]
        logits = outputs["logits"]
        loss = F.mse_loss(logits * msk, labels * msk)
        return (loss, outputs) if return_outputs else loss
    

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=audio_ds,
        data_collator=data_collator_for_noise if type == "noise" else data_collator_for_mask,
        compute_loss_func=compute_loss_for_noise if type == "noise" else compute_loss_for_mask
    )

    print("Starting training...")
    trainer.train(resume_from_checkpoint='/kaggle/working/autoencoder2/checkpoint-1725')

In [66]:
def compute_snr(clean_signal, reconstructed_signal):
    signal_power = np.sum(clean_signal ** 2)
    noise_power = np.sum((clean_signal - reconstructed_signal) ** 2)
    
    if noise_power == 0:
        return np.inf  # Perfect reconstruction
    snr = 10 * np.log10(signal_power / noise_power)
    return snr


def evaluate(
    model,
    data_dir="audio/",
    chunk_length=16000,
    sample_rate=16000,
    type="noise",  # "noise" or "mask"
    noise_prob_range=(0.05, 0.25),
    noise_var=0.1,
    mask_prob_range=(0.05, 0.5),
    batch_size=16,
    device="cuda"
):
    """
    Evaluate the model on the evaluation dataset.
    
    Args:
        model: The model to evaluate.
        data_dir (str): Directory containing the evaluation dataset.
        chunk_length (int): Length of each audio chunk.
        sample_rate (int): Sample rate of the audio.
        type (str): Type of evaluation ("noise" or "mask").
        noise_prob_range (tuple): Range of noise probabilities.
        noise_var (float): Variance of the noise.
        mask_prob (float): Probability of masking.
        batch_size (int): Batch size for evaluation.
        device (str): Device to use for evaluation ("cuda" or "cpu").
    
    Returns:
        A dictionary containing the evaluation results.
    """
    if type not in ["noise", "mask"]:
        raise ValueError("type must be either 'noise' or 'mask'")
    
    print("Loading evaluation dataset...")
    model.to(device)
    model.eval()

    eval_dataset = WaveformDataset(
        data_dir, chunk_length=chunk_length, sample_rate=sample_rate, 
        noise_prob_range=noise_prob_range, noise_var=noise_var,
        mask_prob_range=mask_prob_range, type=type)
    
    print("Loading evaluation dataset...")

    def data_collator(batch):
        input_values = torch.stack([item["input_values"] for item in batch])
        labels = torch.stack([item["labels"] for item in batch])
        if type == "mask":
            masks = torch.stack([item["masks"] for item in batch])
        else:
            masks = None
        return {
            "input_values": input_values,
            "labels": labels,
            "masks": masks
        }
    
        
    dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, 
                            collate_fn=data_collator, num_workers=4)

    all_clean = []
    all_noisy = []
    all_masks = []

    with torch.no_grad():
        for batch in tqdm(dataloader):
            input_values = batch['input_values'].to(device)
            labels = batch['labels']
            masks = batch['masks']

            outputs = model(input_values=input_values)['logits']

            all_clean.append(labels.cpu())
            all_noisy.append(outputs.cpu())
            if masks is not None:
                all_masks.append(masks.cpu())

    if type == "mask":
        all_masks = torch.cat(all_masks, dim=0)

    all_clean = torch.cat(all_clean, dim=0)
    all_noisy = torch.cat(all_noisy, dim=0)

    print(all_clean.shape)
    print(all_noisy.shape)
    print(all_masks.shape if type == "mask" else "No masks")

    # Calculate metrics
    mse_loss = nn.MSELoss()(all_clean * all_masks, all_noisy * all_masks).item()

    print(mse_loss)

In [67]:
def recover_test(type="noise",resample=False, new_sr=16000):
    if type not in ["noise", "mask"]:
        raise ValueError("type must be either 'noise' or 'mask'")
        
    torch.cuda.empty_cache()
    model = WaveformAutoencoder(size='medium')
    state_dict = load_file("/kaggle/working/autoencoder/checkpoint-1725/model.safetensors")
    model.load_state_dict(state_dict)
    
    waveform, sr = torchaudio.load('/kaggle/input/audio-data/test/test_4.wav')
    print('Raw waveform shape:', waveform.shape)
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    if resample:
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=new_sr)
        waveform = resampler(waveform)
        
    print(waveform)
    if type == "noise":
        waveform = add_noise(waveform, (0.05, 0.3))
        torchaudio.save('noised6.wav', waveform, new_sr if resample else sr)
        waveform = waveform.squeeze(0)
        rc = model.recover(waveform)
        rc = rc.unsqueeze(0).to('cpu')
        
        print('Done')
    else:
        waveform, mask = mask_audio(waveform, (0.01, 0.5))
        waveform = waveform.squeeze(0)
        rc = model.recover(waveform, type="mask", mask=mask)
        rc = rc.unsqueeze(0).to('cpu')
        
    print("Saving recovered, shape is", rc.shape)
    torchaudio.save('rcv-2.wav', rc, new_sr if resample else sr)


In [68]:
def cal_mse():
    clean, sr = torchaudio.load('/kaggle/input/audio-data/test/test_4.wav')
    clean = torch.mean(clean, dim=0, keepdim=True)
    noised1, _ = torchaudio.load('/kaggle/working/rcv-2.wav')
    noised2, _ = torchaudio.load('/kaggle/working/rcv-1.wav')

    print(nn.MSELoss()(clean, noised1).item())
    print(nn.MSELoss()(clean, noised2).item())

In [69]:
def main():
    model = WaveformAutoencoder(size='medium')
    train(
        model=model,
        data_dir="/kaggle/input/audio-data/train",
        batch_size=256,
        logging_steps=100,
        type="mask",
        num_epochs=20,
        mask_prob_range=(0.01, 0.3),
        output_dir="./autoencoder3"
    )

    evaluate(
        model=model,
        data_dir="/kaggle/input/audio-data/test",
        type="mask"
    )

In [70]:
main()
# recover_test(type="mask")
# cal_mse()

Loading train dataset...
Train dataset size: 29374
Starting training...


Step,Training Loss
1800,0.0021
1900,0.0013
2000,0.0013
2100,0.0013
2200,0.0013
2300,0.0012


Loading evaluation dataset...
Loading evaluation dataset...


100%|██████████| 25/25 [00:00<00:00, 35.39it/s]


torch.Size([390, 1, 16000])
torch.Size([390, 1, 16000])
torch.Size([390, 1, 16000])
0.0052595557644963264
