In [1]:
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from config import Wav2Vec2Config
from model import Wav2Vec2ForPreTraining,Wav2Vec2FeatureEncoder,Wav2Vec2GumbelVectorQuantizer,_compute_mask_indices,Wav2Vec2Encoder,Wav2Vec2FeatureProjection

from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
import random
from tqdm import tqdm
import os

device = 'cuda' if torch.cuda.is_available() else 'cpu'



from dataset import AudioDataset



parent_dir = 'data/mp3_train_files'
file_list = [os.path.join(root, file) 
             for root, _, files in os.walk(parent_dir) 
             for file in files]

random.seed(42)
random.shuffle(file_list)

train_size = int(0.8 * len(file_list))
val_size = int(0.1 * len(file_list))
test_size = len(file_list) - train_size - val_size

train_files = file_list[:train_size]
val_files = file_list[train_size:train_size + val_size]
test_files = file_list[train_size + val_size:]

train_dataset = AudioDataset(train_files)
val_dataset = AudioDataset(val_files)
test_dataset = AudioDataset(test_files)

In [2]:
batch_size = 6
data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=32)

In [3]:
config = Wav2Vec2Config()
model = Wav2Vec2ForPreTraining(config)

model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

from transformers import get_linear_schedule_with_warmup

num_training_steps = 400000  

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.08 * num_training_steps),
    num_training_steps=num_training_steps
)


In [4]:
model.params()

19780352

In [6]:
epochs = 20

output_attentions = True
output_hidden_states = False
return_dict = torch.BoolTensor(1)

for epoch in range(epochs):
    torch.cuda.empty_cache()
    model.train()
    batch_iterator = tqdm(data_loader, desc=f"Processing Epoch {epoch:02d}")
    total_loss = 0.0

    for batch_idx, batch in enumerate(batch_iterator):
        input_values = batch[0].to(device)

        # with torch.no_grad():
        #      extract_features = model.wav2vec2.feature_extractor(input_values).transpose(1,2)
        #      seq_len = extract_features.size(1)  # Sequence length after feature extraction

        seq_len = 249
        attention_mask = torch.ones((input_values.size(0), seq_len), dtype=torch.long, device=device)

        mask_time_indices = _compute_mask_indices(
            shape=(input_values.size(0), seq_len),
            mask_prob=config.mask_time_prob,
            mask_length=config.mask_time_length,
            attention_mask=attention_mask,
            min_masks=config.mask_time_min_masks
        )

        sampled_negative_indices = _sample_negative_indices(
            features_shape=(input_values.size(0), seq_len),
            num_negatives=model.config.num_negatives,
            mask_time_indices=mask_time_indices,
        )

        mask_time_indices = torch.tensor(mask_time_indices, device=device, dtype=torch.bool)
        sampled_negative_indices = torch.tensor(sampled_negative_indices, device=device)

    

        out = model(input_values=input_values,
        attention_mask=attention_mask,
        mask_time_indices=mask_time_indices,
        sampled_negative_indices=sampled_negative_indices,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict)

        

        loss = out.loss
        contrastive_loss = out.contrastive_loss.item()

        optimizer.zero_grad()

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        scheduler.step()

        total_loss += loss.item()

        batch_iterator.set_postfix(loss=total_loss / (batch_idx + 1),contrastive_loss=contrastive_loss)

    print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(data_loader)}")


Processing Epoch 00: 100%|██████████| 2067/2067 [04:30<00:00,  7.63it/s, contrastive_loss=369, loss=543]


Epoch 1/20, Loss: 543.2560056881105


Processing Epoch 01: 100%|██████████| 2067/2067 [04:30<00:00,  7.64it/s, contrastive_loss=346, loss=544]


Epoch 2/20, Loss: 543.8089731158296


Processing Epoch 02: 100%|██████████| 2067/2067 [04:18<00:00,  8.01it/s, contrastive_loss=369, loss=544]


Epoch 3/20, Loss: 543.8050731406099


Processing Epoch 03: 100%|██████████| 2067/2067 [04:14<00:00,  8.11it/s, contrastive_loss=369, loss=544]


Epoch 4/20, Loss: 543.5762296767989


Processing Epoch 04: 100%|██████████| 2067/2067 [04:14<00:00,  8.12it/s, contrastive_loss=369, loss=544]


Epoch 5/20, Loss: 543.6875189129257


Processing Epoch 05: 100%|██████████| 2067/2067 [04:14<00:00,  8.11it/s, contrastive_loss=369, loss=544]


Epoch 6/20, Loss: 543.7317823346364


Processing Epoch 06: 100%|██████████| 2067/2067 [04:15<00:00,  8.11it/s, contrastive_loss=369, loss=543]


Epoch 7/20, Loss: 543.3108848715729


Processing Epoch 07: 100%|██████████| 2067/2067 [04:14<00:00,  8.12it/s, contrastive_loss=369, loss=544]


Epoch 8/20, Loss: 543.7758012761456


Processing Epoch 08: 100%|██████████| 2067/2067 [04:14<00:00,  8.11it/s, contrastive_loss=369, loss=543]


Epoch 9/20, Loss: 543.4434387472787


Processing Epoch 09: 100%|██████████| 2067/2067 [04:15<00:00,  8.10it/s, contrastive_loss=369, loss=544]


Epoch 10/20, Loss: 544.0651899170749


Processing Epoch 10: 100%|██████████| 2067/2067 [04:14<00:00,  8.11it/s, contrastive_loss=369, loss=544]


Epoch 11/20, Loss: 544.3112778352203


Processing Epoch 11: 100%|██████████| 2067/2067 [04:14<00:00,  8.11it/s, contrastive_loss=337, loss=543]


Epoch 12/20, Loss: 543.3707131439785


Processing Epoch 12: 100%|██████████| 2067/2067 [04:14<00:00,  8.11it/s, contrastive_loss=355, loss=543]


Epoch 13/20, Loss: 543.4667995916108


Processing Epoch 13: 100%|██████████| 2067/2067 [04:15<00:00,  8.10it/s, contrastive_loss=369, loss=543]


Epoch 14/20, Loss: 543.3930338807422


Processing Epoch 14: 100%|██████████| 2067/2067 [04:23<00:00,  7.83it/s, contrastive_loss=369, loss=543]


Epoch 15/20, Loss: 543.1307846386766


Processing Epoch 15:   0%|          | 0/2067 [00:00<?, ?it/s]


KeyboardInterrupt: 

In [None]:
#model_out = model.wav2vec2(test_dataset[0][0].to(device).unsqueeze(0)).last_hidden_state

In [7]:
torch.save(model.state_dict(),'weights/pre_train-19.pt')

# example pretraining 
https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py