# Based on Publication:

"LIGHTWEIGHTFEATUREENCODERFORWAKE-UPWORDDETECTION BASEDONSELF-SUPERVISEDSPEECHREPRESENTATION"

In [95]:
import torch
import torch.nn as nn
import torchaudio
import torch.optim as optim

from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F


import tqdm as notebook_tqdm

torch.random.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [88]:
# LiteFEW Architecture
class LiteFEW(nn.Module):
    def __init__(self, alpha=0.5):
        super(LiteFEW, self).__init__()
        channels = [int(512 * alpha)] * 7
        strides = [5, 2, 2, 2, 2, 2, 2]
        kernel_widths = [10, 3, 3, 3, 3, 2, 2]
        
        # Adjusting the number of input channels for each layer
        in_channels = 1  # Initial number of channels is 1 (mono audio)
        for c, s, k in zip(channels, strides, kernel_widths):
            self.layers.append(nn.Conv1d(in_channels, c, k, stride=s))
            in_channels = c  # Update in_channels for the next iteration
            self.layers.append(nn.ReLU())
            
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


In [89]:
# Distillation Training
class AutoEncoder(nn.Module):
    def __init__(self, input_dim, bottleneck_dim):
        super(AutoEncoder, self).__init__()
        self.encoder = nn.Linear(input_dim, bottleneck_dim)
        self.decoder = nn.Linear(bottleneck_dim, input_dim)
        
    def forward(self, x):
        return self.decoder(self.encoder(x))

def distillation_loss(z_s, z_t, autoencoder, lambda_value=0.5):
    z_r = autoencoder(z_t)
    l_recon = nn.MSELoss()(z_t, z_r)
    l_distill = nn.MSELoss()(z_s, z_r)
    return lambda_value * l_recon + (1 - lambda_value) * l_distill


In [90]:
# Fine-tuning
def focal_loss(input, target, gamma=2):
    pt = torch.sigmoid(input)
    pt = pt if target == 1 else 1 - pt
    return - (1 - pt) ** gamma * torch.log(pt)

class WWD(nn.Module):
    def __init__(self, litefew, num_classes=2):
        super(WWD, self).__init__()
        self.litefew = litefew
        self.fc = nn.Linear(litefew.layers[-2].out_channels, num_classes)
        
    def forward(self, x):
        x = self.litefew(x)
        x = torch.flatten(x, 1)
        return self.fc(x)


# Preparing Dataset

In [91]:
import os

SAMPLE_RATE = 16000 
TARGET_LENGTH = 2 * SAMPLE_RATE  # 2 seconds at 16kHz

class AudioDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform

        # Collect all audio paths and their corresponding labels
        self.audio_paths = []
        self.labels = []

        # Assuming two folders "other" and "Hey_FOOBY" representing classes 0 and 1 respectively
        for class_label, class_name in enumerate(["other", "Hey_FOOBY"]):
            class_dir = os.path.join(root_dir, class_name)
            for filename in os.listdir(class_dir):
                if filename.endswith(".wav"):
                    self.audio_paths.append(os.path.join(class_dir, filename))
                    self.labels.append(class_label)

    def __len__(self):
        return len(self.audio_paths)

    def __getitem__(self, index):
        waveform, _ = torchaudio.load(self.audio_paths[index])
        
        if waveform.size(1) < TARGET_LENGTH:
            # Zero padding for shorter clips
            padding_size = TARGET_LENGTH - waveform.size(1)
            waveform = torch.nn.functional.pad(waveform, (0, padding_size)).squeeze(0)
        elif waveform.size(1) > TARGET_LENGTH:
            # Trimming for longer clips
            waveform = waveform[:, :TARGET_LENGTH].squeeze(0)

        # if self.transform:
        #     waveform = self.transform(waveform)
        
        return waveform, self.labels[index]
        
    # def collate_fn(batch):
    #     data, targets = zip(*batch)
    #     min_length = min([waveform.size(2) for waveform in data])
    #     data = [waveform[..., :min_length] for waveform in data]
    #     return torch.stack(data), torch.tensor(targets)

# Placeholder transform
# transform = torchaudio.transforms.MelSpectrogram()

# Initialize dataset and loader using the root directory containing the "other" and "Hey_FOOBY" folders
root_dir = "/Users/ruben/Projects/ba-thesis-voicetrigger-in-mobileapps/data-wakeup-LSTM" 
dataset = AudioDataset(root_dir, transform=None)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True) # collate_fn=AudioDataset.collate_fn


## Preparing Teaching

In [92]:
#print(wav2vec2)
train_loader


<torch.utils.data.dataloader.DataLoader at 0x2b7ca9e50>

In [93]:
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
wav2vec2 = bundle.get_model() # teacher model

litefew = LiteFEW(alpha=0.5)
autoencoder = AutoEncoder(input_dim=wav2vec2.encoder.feature_projection.projection.in_features, 
                          bottleneck_dim=litefew.layers[-2].out_channels)


optimizer = optim.Adam(list(litefew.parameters()) + list(autoencoder.parameters()), lr=0.001)

# If you have a CUDA device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
litefew = litefew.to(device)
wav2vec2 = wav2vec2.to(device)
autoencoder = autoencoder.to(device)


AttributeError: 'LiteFEW' object has no attribute 'layers'

In [94]:
def train_distillation(epoch, teacher, student, autoencoder, optimizer, train_loader, lambda_value=0.5):
    teacher.eval()  # Set teacher to evaluation mode
    student.train() # Set student to training mode
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.to(device)
        
        optimizer.zero_grad()
        
        # Compute lengths (assuming your data has shape [batch, time])
        lengths = torch.full((data.size(0),), data.size(1), dtype=torch.long, device=device)
        
        with torch.no_grad():
            z_t = teacher.feature_extractor(data, lengths)[0].detach()  # Extract features using teacher
        
        z_s = student(data)
        
        loss = distillation_loss(z_s, z_t, autoencoder, lambda_value)
        
        loss.backward()
        optimizer.step()
        
        if batch_idx % 10 == 0:
            print(f"Train Epoch: {epoch} [{batch_idx*len(data)}/{len(train_loader.dataset)}] "
                  f"Loss: {loss.item():.6f}")

# Training Loop
num_epochs = 10
for epoch in range(1, num_epochs+1):
    train_distillation(epoch, wav2vec2, litefew, autoencoder, optimizer, train_loader)


RuntimeError: Given groups=1, weight of size [256, 1, 10], expected input[1, 32, 32000] to have 1 channels, but got 32 channels instead