In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import torchaudio.transforms as T
from torch.utils.data import Dataset, DataLoader
from encodec import EncodecModel  # Install with: pip install encodec   

In [2]:
# Device setup (Check for GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Directory where clean audio files are stored
DATASET_DIR = "data/100_all/preprocessed_audio"
TARGET_SAMPLE_RATE = 24000  # Encodec model operates at 24kHz
BATCH_SIZE = 64  # Adjust based on GPU memory
NUM_EPOCHS = 5
LR = 1e-4  # Learning rate

Using device: cuda


In [3]:
class AudioDataset(Dataset):
    def __init__(self, dataset_dir, sample_rate=TARGET_SAMPLE_RATE):
        self.files = [os.path.join(dataset_dir, f) for f in os.listdir(dataset_dir) if f.endswith(".wav")]
        self.sample_rate = sample_rate

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = self.files[idx]
        waveform, sr = torchaudio.load(file_path)

        # Resample if necessary
        if sr != self.sample_rate:
            transform = T.Resample(orig_freq=sr, new_freq=self.sample_rate)
            waveform = transform(waveform)

        return waveform  # Keep in CPU RAM

# Create DataLoader for batch processing
dataset = AudioDataset(DATASET_DIR)
data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)

print(f"Dataset size: {len(dataset)} audio files loaded.")

Dataset size: 360573 audio files loaded.


In [4]:
encodec_model = EncodecModel.encodec_model_24khz().to(device)
encoder = encodec_model.encoder  # Extract encoder part
encoder.requires_grad_(True)  # Enable fine-tuning

# Define Custom Decoder
class CustomDecoder(nn.Module):
    def __init__(self):
        super(CustomDecoder, self).__init__()
        self.fc1 = nn.Linear(75, 512)  # Adjust based on encoder output size
        self.fc2 = nn.Linear(512, 1024)
        self.fc3 = nn.Linear(1024, TARGET_SAMPLE_RATE)  # Output back to waveform

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.tanh(self.fc3(x))  # Tanh keeps output within [-1,1]
        return x

decoder = CustomDecoder().to(device)

  WeightNorm.apply(module, name, dim)


In [5]:
loss_fn = nn.MSELoss().to(device)  # Mean Squared Error loss
optimizer = optim.AdamW(
    list(encoder.parameters()) + list(decoder.parameters()),  
    lr=LR,  
    betas=(0.9, 0.999),  
    weight_decay=1e-5  
)

# Automatic Mixed Precision (AMP) for GPU training
scaler = torch.cuda.amp.GradScaler(enabled=(torch.cuda.is_available()))

  scaler = torch.cuda.amp.GradScaler(enabled=(torch.cuda.is_available()))


In [6]:
from tqdm import tqdm  # Import tqdm for progress bar

for epoch in range(NUM_EPOCHS):
    encoder.train()
    decoder.train()
    
    epoch_loss = 0.0

    progress_bar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=True)

    for batch in progress_bar:
        batch = batch.to(device, non_blocking=True)

        optimizer.zero_grad()

        with torch.cuda.amp.autocast():  
            latent = encoder(batch)  
            reconstructed = decoder(latent)  
            loss = loss_fn(reconstructed, batch)

        # Backpropagation
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()

        # Update progress bar with loss
        progress_bar.set_postfix(loss=loss.item())

        # Free memory
        del batch, latent, reconstructed, loss
        torch.cuda.empty_cache()

    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] - Average Loss: {epoch_loss / len(data_loader):.6f}")

print("Training Complete! 🎉")

  with torch.cuda.amp.autocast():
  return F.mse_loss(input, target, reduction=self.reduction)
Epoch 1/5:   6%|▋         | 356/5634 [02:07<31:30,  2.79it/s, loss=0.00609]


KeyboardInterrupt: 