<a href="https://colab.research.google.com/github/HernanDL/Noise-Cancellation-Using-GenAI/blob/main/Noise_Cancellation_with_Hugging_Face_Wav2Vec2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-Tuning Wav2Vec2 for Noise Cancellation (Waveform Cancellation)

This notebook fine-tunes the **Wav2Vec2** model from Hugging Face for a specialized task: predicting the **inverse** waveform of an input signal, so that when the input and output signals are combined, the result is silence (destructive interference).

## Goal
The goal is to train the model to generate a phase-inverted waveform that, when added to the original input signal, produces a silent (flat) waveform.

In [None]:
# Step 1: Install Required Libraries
!pip install transformers datasets librosa soundfile torch torchaudio

## Step 2: Import Libraries
We'll import necessary libraries like Hugging Face's `transformers`, `datasets`, PyTorch, and Librosa for audio processing.

In [None]:
import torch
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
from datasets import Dataset
import librosa
import numpy as np
import matplotlib.pyplot as plt
from google.colab import files

## Step 3: Load Pre-trained Wav2Vec2 Model
We'll load the pre-trained Wav2Vec2 model from Hugging Face's model hub, which will be fine-tuned for the task of waveform inversion (phase shift).

In [None]:
# Load the Wav2Vec2 model and tokenizer
model_name = 'facebook/wav2vec2-base-960h'
tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)

## Step 4: Upload Noisy Audio (Input Signal)
We will upload a noisy input signal (in `.wav` format), which the model will learn to cancel by generating an inverse waveform.

In [None]:
# Upload your noisy input audio file
uploaded = files.upload()
input_audio_file = list(uploaded.keys())[0]  # Get uploaded file name

# Load the input audio file
input_audio, sr = librosa.load(input_audio_file, sr=16000)  # Resample to 16kHz (Wav2Vec2's input requirement)

# Plot the input waveform
plt.figure(figsize=(10, 4))
plt.plot(input_audio)
plt.title('Input Audio Waveform')
plt.show()

## Step 5: Generate the Target Inverse Waveform
The model will be trained to generate the **inverse** (180-degree phase-shifted) version of the input signal.
This will be the target signal for the model's output.

In [None]:
# Create the inverse (phase-shifted) waveform
inverse_audio = -input_audio  # Simply invert the waveform (180-degree phase shift)

# Plot the inverse waveform
plt.figure(figsize=(10, 4))
plt.plot(inverse_audio)
plt.title('Inverse (Phase-Shifted) Audio Waveform')
plt.show()

# Check that adding input_audio and inverse_audio results in silence
combined_audio = input_audio + inverse_audio
plt.figure(figsize=(10, 4))
plt.plot(combined_audio)
plt.title('Combined Waveform (Input + Inverse)')  # This should be a flat line (silence)
plt.show()

## Step 6: Preprocess Data
We now preprocess the data by tokenizing the input audio (noisy signal) and setting the inverse audio as the model's target for training.

In [None]:
# Tokenize input (noisy) audio
input_values = tokenizer(input_audio, return_tensors='pt', padding='longest').input_values

# Target is the inverse (phase-shifted) audio
labels = torch.tensor([inverse_audio], dtype=torch.float32)

# Dataset construction
dataset = Dataset.from_dict({
    'input_values': input_values.numpy(),
    'labels': labels.numpy()
})

## Step 7: Fine-Tuning Wav2Vec2 for Waveform Cancellation
We define a training loop that fine-tunes the model to predict the inverse waveform.

In [None]:
from transformers import Trainer, TrainingArguments

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    per_device_train_batch_size=2,
    evaluation_strategy='steps',
    num_train_epochs=3,
    save_steps=500,
    save_total_limit=2,
    logging_dir='./logs',
    logging_steps=100,
)

# Define trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=dataset  # In practice, you should use a separate validation set
)

# Train the model
trainer.train()

## Step 8: Inference and Visualization
After fine-tuning, we'll test the model on the same noisy input and plot the resulting inverse waveform.

In [None]:
# Apply the fine-tuned model to generate an inverse waveform
input_values = tokenizer(input_audio, return_tensors='pt', padding='longest').input_values
with torch.no_grad():
    predicted_inverse = model(input_values).logits

# Plot the predicted inverse waveform
plt.figure(figsize=(10, 4))
plt.plot(predicted_inverse[0].cpu().numpy())
plt.title('Predicted Inverse Waveform')
plt.show()

# Combine input_audio and predicted_inverse to check cancellation
combined_audio = input_audio + predicted_inverse[0].cpu().numpy()
plt.figure(figsize=(10, 4))
plt.plot(combined_audio)
plt.title('Combined Waveform (Input + Predicted Inverse)')  # Should approach silence
plt.show()