In [1]:
from src.utils import *

In [2]:
# delete 75% of files in the directory
import os
import random

def delete_files(directory, percentage=0.75):
    files = os.listdir(directory)
    files = random.sample(files, int(len(files) * percentage))
    for file in files:
        os.remove(os.path.join(directory, file))

# delete_files('data/train/audio_yes_no/no', 0.5)
# delete_files('data/train/audio_yes_no/yes', 0.5)

In [3]:
import torch
from transformers import AutoFeatureExtractor, WhisperForAudioClassification

model_name = "sanchit-gandhi/whisper-medium-fleurs-lang-id"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

In [4]:
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
import os

class CustomAudioDataset(Dataset):
    def __init__(self, data_dir, transform=None, fixed_length=None):
        self.data_dir = data_dir
        self.file_list, self.labels = self._get_file_list_and_labels()
        self.transform = transform
        self.fixed_length = fixed_length

    def _get_file_list_and_labels(self):
        file_list = []
        labels = []
        for root, dirs, files in os.walk(self.data_dir):
            for file in files:
                if file.endswith(".wav"):  # Adjust file extension if needed
                    file_list.append(root + "/" + file)
                    labels.append(os.path.basename(root))  # Extract label from directory name
        return file_list, labels

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        waveform, sample_rate = torchaudio.load(file_path)
        
        
        
        if self.fixed_length:
            waveform = self._pad_waveform(waveform, self.fixed_length)

        label = self.labels[idx]

        

        if self.transform:
            waveform = self.transform(waveform, sampling_rate=sample_rate).input_features

        return waveform, sample_rate, label

    def _pad_waveform(self, waveform, target_length):
        length_diff = target_length - waveform.size(1)
        if length_diff > 0:
            padding = torch.zeros((1, length_diff))
            waveform = torch.cat([waveform, padding], dim=1)
        return waveform.squeeze(0)

# Example usage
data_dir = "data/train/audio_small/"
# data_dir = "data/train/audio_yes_no/"
transform = feature_extractor  # You can define transformations if needed
fixed_length = 16000  # Assuming you want to fix the length to 16000 samples
sampling_rate = fixed_length  # Assuming you want to fix the sampling rate to 16000 Hz



batch_size = 4

In [5]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def freeze_layers_except_last_n(model, n):
    # Get all parameters
    parameters = list(model.parameters())
    total_layers = len(parameters)

    # Freeze all layers except the last n
    for i, param in enumerate(parameters):
        if i < total_layers - n:
            param.requires_grad = False




In [6]:
num_epochs = 5
perc = 0.05

In [7]:
# set number of output classes
num_classes = 11


In [8]:
import random
import numpy as np

for i in tqdm(np.arange (0, 5, 1), desc='Training loop (5 times)'):
   
    random.seed(int(i))
    torch.manual_seed(i)
    torch.cuda.manual_seed(i)
    torch.cuda.manual_seed_all(i)

    
   
    model = WhisperForAudioClassification.from_pretrained(model_name)

    only_name = model_name.split("/")[-1]   


    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = torch.nn.CrossEntropyLoss()

    device = torch.device('cuda')
    # Assuming your model is named ast_model
    freeze_layers_except_last_n(model.encoder.layers, 4)

    model.classifier.dense = torch.nn.Linear(256, num_classes, bias=True)

    model.to(device)

    train_dataset = CustomAudioDataset(data_dir, fixed_length=16000, transform=feature_extractor)
                                    


    n_train = len(train_dataset)
    n_val = int(perc * n_train)
    n_test = n_val//2
    n_train = n_train - n_val

    train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [n_train, n_val], generator=torch.Generator().manual_seed(42))
    val_dataset, test_dataset = torch.utils.data.random_split(val_dataset, [n_val-n_test, n_test], generator=torch.Generator().manual_seed(42))

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    

    labels = set(train_dataset.dataset.labels)
    label_to_index = dict((label, i) for i, label in enumerate(sorted(labels)))
    name = data_dir.split("/")[-2]
    # print(name)
    log_dir = train(model, train_loader, val_loader, num_epochs, optimizer, criterion, device, label_to_index, only_name, log=True, description=f"test_{name}_{i}")
    test(model, test_loader, criterion, device, label_to_index, only_name, log_dir)



Training loop (5 times):   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1/5:   0%|          | 0/6187 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Epoch 1/5, Train Loss: 0.7314735478685197


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 1/5, Validation Loss: 0.41234261760428964, Validation Accuracy: 0.8617511520737328


Epoch 2/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 2/5, Train Loss: 0.39274853548712824


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 2/5, Validation Loss: 0.2759318984087727, Validation Accuracy: 0.9170506912442397


Epoch 3/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 3/5, Train Loss: 0.29955985410027525


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 3/5, Validation Loss: 0.2792816430711394, Validation Accuracy: 0.9093701996927803


Epoch 4/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 4/5, Train Loss: 0.24352654900701617


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 4/5, Validation Loss: 0.24922679444964238, Validation Accuracy: 0.9170506912442397


Epoch 5/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 5/5, Train Loss: 0.21195026092185848


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 5/5, Validation Loss: 0.24092325689417504, Validation Accuracy: 0.9308755760368663


Testing:   0%|          | 0/163 [00:00<?, ?it/s]

{'val_acc': [0.8617511520737328, 0.9170506912442397, 0.9093701996927803, 0.9170506912442397, 0.9308755760368663], 'val_loss': [0.41234261760428964, 0.2759318984087727, 0.2792816430711394, 0.24922679444964238, 0.24092325689417504], 'train_loss': [0.7314735478685197, 0.39274853548712824, 0.29955985410027525, 0.24352654900701617, 0.21195026092185848], 'test_correct_in_batch': [4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 3.0, 3.0, 4.0, 4.0, 3.0, 3.0, 3.0, 4.0, 3.0, 3.0, 4.0, 3.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 4.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 2.0, 3.0, 3.0, 3.0, 3.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 3.0, 4.0, 3.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 2.0, 4.0, 4.0, 4.0, 4.0, 2.0, 4.0, 4.0, 2.0, 4.0, 4.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 3.0, 4.0, 3.0, 4.0, 4.0, 3.0, 4.0, 2.0, 4.0, 3.0, 3.0

Epoch 1/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 1/5, Train Loss: 0.7360471295512307


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 1/5, Validation Loss: 0.36569705311308237, Validation Accuracy: 0.8955453149001537


Epoch 2/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 2/5, Train Loss: 0.4049851952418189


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 2/5, Validation Loss: 0.39416229297472855, Validation Accuracy: 0.8632872503840245


Epoch 3/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 3/5, Train Loss: 0.2997509396424148


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 3/5, Validation Loss: 0.2696320814943786, Validation Accuracy: 0.9155145929339478


Epoch 4/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 4/5, Train Loss: 0.24855262920615298


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 4/5, Validation Loss: 0.22045003644274558, Validation Accuracy: 0.9385560675883257


Epoch 5/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 5/5, Train Loss: 0.21074882147890717


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 5/5, Validation Loss: 0.2713854363063619, Validation Accuracy: 0.9247311827956989


Testing:   0%|          | 0/163 [00:00<?, ?it/s]

{'val_acc': [0.8955453149001537, 0.8632872503840245, 0.9155145929339478, 0.9385560675883257, 0.9247311827956989], 'val_loss': [0.36569705311308237, 0.39416229297472855, 0.2696320814943786, 0.22045003644274558, 0.2713854363063619], 'train_loss': [0.7360471295512307, 0.4049851952418189, 0.2997509396424148, 0.24855262920615298, 0.21074882147890717], 'test_correct_in_batch': [4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 4.0, 4.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 4.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 3.0, 4.0, 4.0, 3.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 3.0, 3.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 1.0, 4.0, 4.0, 3.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 2.0, 4.0, 3.0, 3.0, 

Epoch 1/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 1/5, Train Loss: 0.7666274698163548


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 1/5, Validation Loss: 0.35125709243584424, Validation Accuracy: 0.9001536098310292


Epoch 2/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 2/5, Train Loss: 0.39867589043179963


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 2/5, Validation Loss: 0.2878647295727462, Validation Accuracy: 0.9047619047619048


Epoch 3/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 3/5, Train Loss: 0.3112487620377029


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 3/5, Validation Loss: 0.31905774711174145, Validation Accuracy: 0.901689708141321


Epoch 4/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 4/5, Train Loss: 0.25725051252448955


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 4/5, Validation Loss: 0.24536551739622456, Validation Accuracy: 0.9155145929339478


Epoch 5/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 5/5, Train Loss: 0.22300258267918788


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 5/5, Validation Loss: 0.23637979775258533, Validation Accuracy: 0.9201228878648233


Testing:   0%|          | 0/163 [00:00<?, ?it/s]

{'val_acc': [0.9001536098310292, 0.9047619047619048, 0.901689708141321, 0.9155145929339478, 0.9201228878648233], 'val_loss': [0.35125709243584424, 0.2878647295727462, 0.31905774711174145, 0.24536551739622456, 0.23637979775258533], 'train_loss': [0.7666274698163548, 0.39867589043179963, 0.3112487620377029, 0.25725051252448955, 0.22300258267918788], 'test_correct_in_batch': [4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 3.0, 4.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 4.0, 1.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 3.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 4.0, 4.0, 3.0, 4.0, 3.0, 4.0, 3.0, 4.0, 3.0, 4.0, 3.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 1.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 2.0, 4.0, 4.0, 3.0, 4.0, 3.0, 4.0, 4.0, 3.0,

Epoch 1/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 1/5, Train Loss: 1.1050467141801672


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 1/5, Validation Loss: 0.5205692339722822, Validation Accuracy: 0.8341013824884793


Epoch 2/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 2/5, Train Loss: 0.48280126454600736


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 2/5, Validation Loss: 0.5204714477984264, Validation Accuracy: 0.8494623655913979


Epoch 3/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 3/5, Train Loss: 0.3785735684734355


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 3/5, Validation Loss: 0.28950207544443557, Validation Accuracy: 0.9109062980030722


Epoch 4/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 4/5, Train Loss: 0.3096301896126029


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 4/5, Validation Loss: 0.287115418625727, Validation Accuracy: 0.9247311827956989


Epoch 5/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 5/5, Train Loss: 0.2614665496264596


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 5/5, Validation Loss: 0.24758548250755014, Validation Accuracy: 0.9324116743471582


Testing:   0%|          | 0/163 [00:00<?, ?it/s]

{'val_acc': [0.8341013824884793, 0.8494623655913979, 0.9109062980030722, 0.9247311827956989, 0.9324116743471582], 'val_loss': [0.5205692339722822, 0.5204714477984264, 0.28950207544443557, 0.287115418625727, 0.24758548250755014], 'train_loss': [1.1050467141801672, 0.48280126454600736, 0.3785735684734355, 0.3096301896126029, 0.2614665496264596], 'test_correct_in_batch': [4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 3.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 3.0, 3.0, 4.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 3.0, 3.0, 4.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 3.0, 4.0, 3.0, 4.0, 4.0, 4.0, 3.0, 4.0, 3.0, 1.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 3.0, 4.0, 3.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 1.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 2.0, 4.0, 4.0, 3.0, 4.0

Epoch 1/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 1/5, Train Loss: 0.7926140035742842


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 1/5, Validation Loss: 0.5045262830452623, Validation Accuracy: 0.8540706605222734


Epoch 2/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 2/5, Train Loss: 0.4125488789480276


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 2/5, Validation Loss: 0.3070395333352001, Validation Accuracy: 0.9109062980030722


Epoch 3/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 3/5, Train Loss: 0.3152773565509881


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 3/5, Validation Loss: 0.33283837619083156, Validation Accuracy: 0.8847926267281107


Epoch 4/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 4/5, Train Loss: 0.25748935558131547


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 4/5, Validation Loss: 0.36088686593886754, Validation Accuracy: 0.890937019969278


Epoch 5/5:   0%|          | 0/6187 [00:00<?, ?it/s]

Epoch 5/5, Train Loss: 0.22312406889814715


Validation:   0%|          | 0/163 [00:00<?, ?it/s]

Epoch 5/5, Validation Loss: 0.25033088452397473, Validation Accuracy: 0.9170506912442397


Testing:   0%|          | 0/163 [00:00<?, ?it/s]

{'val_acc': [0.8540706605222734, 0.9109062980030722, 0.8847926267281107, 0.890937019969278, 0.9170506912442397], 'val_loss': [0.5045262830452623, 0.3070395333352001, 0.33283837619083156, 0.36088686593886754, 0.25033088452397473], 'train_loss': [0.7926140035742842, 0.4125488789480276, 0.3152773565509881, 0.25748935558131547, 0.22312406889814715], 'test_correct_in_batch': [4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 3.0, 3.0, 4.0, 3.0, 4.0, 2.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 3.0, 4.0, 3.0, 4.0, 4.0, 4.0, 3.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 4.0, 2.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 4.0, 4.0, 3.0, 4.0, 4.0, 2.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 1.0, 4.0, 4.0, 4.0, 3.0, 2.0, 4.0, 4.0, 3.0, 4.0, 4.0, 3.0, 4.0, 4.0, 4.0, 4.0, 3.0, 4.0, 4.0, 3.0, 4.0, 4.0, 2.0, 4.0, 2.0, 4.0, 4.0, 3.0, 4.0, 3.0, 4.0, 4.0, 3.0, 4