In [8]:
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from config import Wav2Vec2Config
from model import Wav2Vec2ForPreTraining,Wav2Vec2FeatureEncoder,Wav2Vec2GumbelVectorQuantizer,_compute_mask_indices,Wav2Vec2Encoder,Wav2Vec2FeatureProjection

from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
import random
from tqdm import tqdm
import os

device = 'cuda' if torch.cuda.is_available() else 'cpu'



from dataset import AudioDataset



parent_dir = 'data/mp3_train_files'
file_list = [os.path.join(root, file) 
             for root, _, files in os.walk(parent_dir) 
             for file in files]

random.seed(42)
random.shuffle(file_list)

train_size = int(0.8 * len(file_list))
val_size = int(0.1 * len(file_list))
test_size = len(file_list) - train_size - val_size

train_files = file_list[:train_size]
val_files = file_list[train_size:train_size + val_size]
test_files = file_list[train_size + val_size:]

train_dataset = AudioDataset(train_files)
val_dataset = AudioDataset(val_files)
test_dataset = AudioDataset(test_files)

In [2]:
batch_size = 6
data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=32)

In [3]:
config = Wav2Vec2Config()
model = Wav2Vec2ForPreTraining(config)

model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [4]:
model.params()

95044608

In [5]:
epochs = 10

output_attentions = True
output_hidden_states = False
return_dict = torch.BoolTensor(1)

for epoch in range(epochs):
    torch.cuda.empty_cache()
    model.train()
    batch_iterator = tqdm(data_loader, desc=f"Processing Epoch {epoch:02d}")
    total_loss = 0.0

    for batch_idx, batch in enumerate(batch_iterator):
        input_values = batch[0].to(device)

        with torch.no_grad():
            extract_features = model.wav2vec2.feature_extractor(input_values).transpose(1,2)
            seq_len = extract_features.size(1)  # Sequence length after feature extraction

        attention_mask = torch.ones((input_values.size(0), seq_len), dtype=torch.long, device=device)

        mask_time_indices = _compute_mask_indices(
            shape=(input_values.size(0), seq_len),
            mask_prob=config.mask_time_prob,
            mask_length=config.mask_time_length,
            attention_mask=attention_mask,
            min_masks=config.mask_time_min_masks
        )

        sampled_negative_indices = _sample_negative_indices(
            features_shape=(input_values.size(0), seq_len),
            num_negatives=model.config.num_negatives,
            mask_time_indices=mask_time_indices,
        )

        mask_time_indices = torch.tensor(mask_time_indices, device=device, dtype=torch.bool)
        sampled_negative_indices = torch.tensor(sampled_negative_indices, device=device)

    

        out = model(input_values=input_values,
        attention_mask=attention_mask,
        mask_time_indices=mask_time_indices,
        sampled_negative_indices=sampled_negative_indices,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict)

        loss = out.loss
        contrastive_loss = out.contrastive_loss.item()

        optimizer.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        total_loss += loss.item()

        batch_iterator.set_postfix(loss=total_loss / (batch_idx + 1),contrastive_loss=contrastive_loss)

    print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(data_loader)}")


Processing Epoch 00: 100%|██████████| 2584/2584 [11:17<00:00,  3.81it/s, contrastive_loss=277, loss=544]


Epoch 1/10, Loss: 543.8335500309711


Processing Epoch 01:   0%|          | 2/2584 [00:01<29:08,  1.48it/s, contrastive_loss=554, loss=549]


KeyboardInterrupt: 

In [25]:
model_out = model.wav2vec2(test_dataset[0][0].to(device).unsqueeze(0)).last_hidden_state

In [42]:
torch.save(model.state_dict(),'weights/pre_train-01.pt')

# example pretraining 
https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py