In [None]:
# Viewing model's base architecture
from transformers import Wav2Vec2Model

# Load the model
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

# Print architecture
print(model)

In [None]:
for name, layer in model.named_children():
    print(name, "->", layer)

In [None]:
from torchinfo import summary

summary(model, input_size=(1, 1, 16000))  # Assuming 1 second of audio at 16kHz

In [None]:
import torch
import torch.nn as nn
from transformers import Wav2Vec2PreTrainedModel, Wav2Vec2Model, Wav2Vec2Processor
import torchaudio
from transformers.modeling_outputs import SequenceClassifierOutput
from datasets import Dataset
import numpy as np
from transformers import Trainer, TrainingArguments

# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class Wav2Vec2ForAudioClassification(Wav2Vec2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        
        # Base wav2vec2 model
        self.wav2vec2 = Wav2Vec2Model(config)
        self.dropout = nn.Dropout(config.final_dropout)
        
        # Classification head: Linear layer for binary output (0 or 1)
        hidden_size = config.hidden_size
        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, 128),  # Reduce dimensionality
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 1)  # Single output for binary classification
        )
        
        # Initialize weights
        self.post_init()
    
    def freeze_feature_encoder(self):
        """Freeze the feature encoder to prevent updates during training."""
        self.wav2vec2.feature_extractor._freeze_parameters()
    
    def freeze_base_model(self):
        """Freeze the base model, only train the classifier."""
        for param in self.wav2vec2.parameters():
            param.requires_grad = False
    
    def forward(
        self,
        input_values: torch.Tensor,
        attention_mask: torch.Tensor = None,
        labels: torch.Tensor = None,  # Binary labels (0 or 1)
        output_attentions: bool = None,
        output_hidden_states: bool = None,
        return_dict: bool = None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        # Extract features from wav2vec2
        outputs = self.wav2vec2(
            input_values,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        # Use the last hidden state (CLS-like pooling: mean over time)
        hidden_states = outputs[0]  # Shape: (batch_size, sequence_length, hidden_size)
        pooled_output = hidden_states.mean(dim=1)  # Mean pooling: (batch_size, hidden_size)
        pooled_output = self.dropout(pooled_output)
        
        # Classification logits
        logits = self.classifier(pooled_output)  # Shape: (batch_size, 1)
        
        # Loss computation
        loss = None
        if labels is not None:
            # Ensure labels are float for BCEWithLogitsLoss
            labels = labels.view(-1, 1).float()
            loss_fn = nn.BCEWithLogitsLoss()
            loss = loss_fn(logits, labels)
        
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output
        
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [None]:
def preprocess_audio(audio_paths, processor, model=None, sampling_rate=16000, perturbation_steps=3, epsilon=0.01):
    """Preprocess audio files with F-SAT perturbations for robustness training."""
    input_values = []
    for audio_path in audio_paths:
        # Load and resample audio
        waveform, sr = torchaudio.load(audio_path)
        if sr != sampling_rate:
            resampler = torchaudio.transforms.Resample(sr, sampling_rate)
            waveform = resampler(waveform)
        if waveform.shape[0] > 1:  # Convert to mono
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        # Convert to tensor and ensure 1D
        audio = waveform.squeeze().to(device).float()
        
        # Apply F-SAT perturbations if model is provided (for training)
        if model is not None and model.training:
            audio = audio.unsqueeze(0)  # Add batch dimension
            
            # STFT to get magnitude and phase
            window_size = 512
            hop_length = 128
            stft = torch.stft(audio, n_fft=window_size, hop_length=hop_length, return_complex=True)
            magnitude, phase = torch.abs(stft), torch.angle(stft)
            
            # Initialize perturbation
            perturbation = torch.zeros_like(magnitude, device=device, requires_grad=True)
            
            # Iterative perturbation (simplified F-SAT)
            for _ in range(perturbation_steps):
                # Reconstruct perturbed audio
                perturbed_magnitude = magnitude + perturbation
                perturbed_stft = perturbed_magnitude * torch.exp(1j * phase)
                perturbed_audio = torch.istft(perturbed_stft, n_fft=window_size, hop_length=hop_length, length=audio.shape[-1])
                
                # Process through model to get loss
                processed = processor(perturbed_audio.squeeze().cpu().numpy(), 
                                    sampling_rate=sampling_rate, 
                                    return_tensors="pt", padding=True)
                processed = {k: v.to(device) for k, v in processed.items()}
                
                outputs = model(processed["input_values"], attention_mask=processed["attention_mask"])
                loss = outputs.loss
                loss.backward()
                
                # Update perturbation (gradient ascent to maximize loss)
                with torch.no_grad():
                    perturbation += epsilon * perturbation.grad / (perturbation.grad.norm() + 1e-8)
                    perturbation.clamp_(-epsilon, epsilon)  # Constrain perturbation magnitude
                    perturbation.grad.zero_()
            
            # Apply final perturbation and reconstruct
            perturbed_magnitude = magnitude + perturbation
            perturbed_stft = perturbed_magnitude * torch.exp(1j * phase)
            perturbed_audio = torch.istft(perturbed_stft, n_fft=window_size, hop_length=hop_length, length=audio.shape[-1])
            audio = perturbed_audio.squeeze()
        
        # Process with wav2vec2 processor
        audio_array = audio.cpu().numpy()
        inputs = processor(audio_array, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
        input_values.append(inputs["input_values"].squeeze(0))
    
    # Pad sequences to the same length
    input_values = torch.nn.utils.rnn.pad_sequence(input_values, batch_first=True)
    attention_mask = (input_values != 0).long()  # Mask for padded regions
    return {"input_values": input_values, "attention_mask": attention_mask}

In [None]:
# Example data (replace with your actual paths and labels)
data = {
    "audio": ["path/to/audio1.wav", "path/to/audio2.wav", "path/to/audio3.wav"],
    "labels": [0, 1, 0]  # Binary labels
}
dataset = Dataset.from_dict(data)

# Load processor and model
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForAudioClassification.from_pretrained("facebook/wav2vec2-base-960h")
model.to(device)
model.freeze_base_model()  # Optional: Freeze base model

# Preprocess dataset with perturbations
def preprocess_batch(examples):
    processed = preprocess_audio(examples["audio"], processor, model=model)
    processed["labels"] = examples["labels"]
    return processed

processed_dataset = dataset.map(preprocess_batch, batched=True, batch_size=2)

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir="./wav2vec2_classification",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
)

# Compute metrics function
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = (torch.sigmoid(torch.tensor(logits)) > 0.5).int()
    accuracy = (predictions == torch.tensor(labels)).float().mean().item()
    return {"accuracy": accuracy}

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset,
    eval_dataset=processed_dataset,  # Use a separate validation set in practice
    compute_metrics=compute_metrics,
)

In [None]:
# Train the model
trainer.train()

In [None]:
# Example inference
test_audio = ["path/to/test_audio.wav"]
processed_test = preprocess_audio(test_audio, processor, model=None)  # No perturbation during inference
input_values = processed_test["input_values"].to(device)
attention_mask = processed_test["attention_mask"].to(device)

model.eval()
with torch.no_grad():
    outputs = model(input_values, attention_mask=attention_mask)
    logits = outputs.logits
    prediction = (torch.sigmoid(logits) > 0.5).int().item()
    print(f"Predicted class: {prediction}")