In [2]:
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from config import Wav2Vec2Config
from model import Wav2Vec2ForPreTraining,Wav2Vec2FeatureEncoder,Wav2Vec2GumbelVectorQuantizer,_compute_mask_indices,Wav2Vec2Encoder,Wav2Vec2FeatureProjection

from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices

from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'

class AudioDataset(Dataset):
    def __init__(self, file_list, target_sample_rate=16000):
        self.file_list = file_list
        self.target_sample_rate = target_sample_rate

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        audio, _ = self.resample_audio_torchaudio(file_path)
        return audio

    def resample_audio_torchaudio(self, file_path, original_sample_rate=44100):
        waveform, sample_rate = torchaudio.load(file_path)
        if sample_rate != original_sample_rate:
            raise ValueError(f"Expected sample rate to be {original_sample_rate}, but got {sample_rate}")
        
        resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=self.target_sample_rate)
        waveform = resampler(waveform)
        
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        
        return waveform.squeeze(), self.target_sample_rate
    
    def __repr__(self) -> str:
        if len(self) == 0:
            return "AudioDataset with 0 samples"
        else:
            audio_length = len(self[0])
            duration = audio_length / self.target_sample_rate
            return f"AudioDataset with {len(self)} samples of {duration:.2f} seconds each"



file_list = [f'data/mp3_train_files/Gould/Gould - WTC_clip_{i}.mp3' for i in range(1, 501)]
dataset = AudioDataset(file_list)

In [3]:
batch_size = 4  
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=32)

In [4]:
config = Wav2Vec2Config()
model = Wav2Vec2ForPreTraining(config)

model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [5]:
model.params()

95044608

In [9]:
epochs = 10

output_attentions = True
output_hidden_states = False
return_dict = torch.BoolTensor(1)

for epoch in range(epochs):
    torch.cuda.empty_cache()  # Clear CUDA cache
    model.train()
    batch_iterator = tqdm(data_loader, desc=f"Processing Epoch {epoch:02d}")
    total_loss = 0.0

    for batch_idx, batch in enumerate(batch_iterator):
        input_values = batch.to(device)

        with torch.no_grad():
            extract_features = model.wav2vec2.feature_extractor(input_values).transpose(1,2)
            seq_len = extract_features.size(1)  # Sequence length after feature extraction

        attention_mask = torch.ones((input_values.size(0), seq_len), dtype=torch.long, device=device)

        mask_time_indices = _compute_mask_indices(
            shape=(input_values.size(0), seq_len),
            mask_prob=config.mask_time_prob,
            mask_length=config.mask_time_length,
            attention_mask=attention_mask,
            min_masks=config.mask_time_min_masks
        )

        sampled_negative_indices = _sample_negative_indices(
            features_shape=(input_values.size(0), seq_len),
            num_negatives=model.config.num_negatives,
            mask_time_indices=mask_time_indices,
        )

        mask_time_indices = torch.tensor(mask_time_indices, device=device, dtype=torch.bool)
        sampled_negative_indices = torch.tensor(sampled_negative_indices, device=device)

    

        out = model(input_values=input_values,
        attention_mask=attention_mask,
        mask_time_indices=mask_time_indices,
        sampled_negative_indices=sampled_negative_indices,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict)

        loss = out.loss
        contrastive_loss = out.contrastive_loss.item()

        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        total_loss += loss.item()

        batch_iterator.set_postfix(loss=total_loss / (batch_idx + 1),contrastive_loss=contrastive_loss)

    print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(data_loader)}")


Processing Epoch 00:  35%|███▌      | 44/125 [00:08<00:15,  5.16it/s, contrastive_loss=369, loss=363]


KeyboardInterrupt: 

# example pretraining 
https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-pretraining/run_wav2vec2_pretraining_no_trainer.py