In [1]:
import torch
from transformers import AutoFeatureExtractor, Wav2Vec2ForPreTraining,get_scheduler
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices, _sample_negative_indices
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.optim import AdamW

from model import Wav2Vec2ForPreTraining,Wav2Vec2Config
from dataset import AudioDatasetSplits




device = 'cuda'

In [2]:
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-base")
config = Wav2Vec2Config()
model = Wav2Vec2ForPreTraining(config)
ds_train = AudioDatasetSplits().train
ds_val = AudioDatasetSplits().val



In [3]:
ips = feature_extractor(ds_train[0]['audio'],return_tensors='pt',sampling_rate=ds_train[0]['sample_rate']).input_values
batch_size, raw_sequence_length = ips.shape
sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()


In [4]:
mask_time_indices = _compute_mask_indices(
    shape=(batch_size, sequence_length), mask_prob=model.config.mask_time_prob, mask_length=model.config.mask_time_length
)

sampled_negative_indices = _sample_negative_indices(
    features_shape=(batch_size, sequence_length),
    num_negatives=model.config.num_negatives,
    mask_time_indices=mask_time_indices,
)

mask_time_indices = torch.tensor(data=mask_time_indices, device=ips.device, dtype=torch.long)
sampled_negative_indices = torch.tensor(
    data=sampled_negative_indices, device=ips.device, dtype=torch.long
)


In [7]:
model = model.train()
loss = model(
    ips, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices
)

In [8]:
loss

Wav2Vec2ForPreTrainingOutput(loss=tensor(nan, grad_fn=<AddBackward0>), projected_states=tensor([[[-0.2438,  0.5366, -0.6080,  ..., -0.7227, -0.4334,  0.8144],
         [-0.5922, -0.4866, -0.0443,  ..., -0.0500, -0.0972,  0.4619],
         [ 0.0490,  0.1360,  0.2218,  ..., -0.3461,  0.3827,  0.0350],
         ...,
         [-0.7289, -0.2255,  0.3768,  ..., -0.6308, -0.2128,  0.6229],
         [ 0.5319,  0.2309, -0.1708,  ..., -0.1450,  0.0904, -0.5288],
         [ 0.4541, -0.2181,  0.1276,  ...,  0.1025,  0.1758,  0.0971]]],
       grad_fn=<ViewBackward0>), projected_quantized_states=tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]]], grad_fn=<ViewBackward0>), codevector_perplexity=tensor(613.7838, grad_fn=<SumBackward0>), hidden_states=None, at

here we go

In [10]:
dl_train = DataLoader(ds_train,batch_size=8) # can put feature extractor in ds
dl_val = DataLoader(ds_val,batch_size=8)

In [11]:
# optimizer = AdamW(
#         list(model.parameters()),
#         lr=5e-4,
#         betas=[args.adam_beta1, args.adam_beta2],
#         eps=args.adam_epsilon,
#     )


# lr_scheduler = get_scheduler(
#         name=args.lr_scheduler_type,
#         optimizer=optimizer,
#         num_warmup_steps=args.num_warmup_steps,
#         num_training_steps=args.max_train_steps,
#     )

In [12]:
start_epochs = 0
total_epochs = 1

for epoch in range(start_epochs,total_epochs):
    model.train()
    batch_iterator = tqdm(dl_train,desc=f"Processing Epoch {epoch:02d}")
    for batch_idx,batch in enumerate(batch_iterator):

        #batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}


        
        ips = feature_extractor(batch['audio'].numpy(),sampling_rate=batch['sample_rate'][0],return_tensors='pt').input_values

        batch_size,raw_sequence_length = ips.shape

        sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item()


        mask_time_indices = _compute_mask_indices(
            shape=(batch_size, sequence_length), mask_prob=model.config.mask_time_prob, mask_length=model.config.mask_time_length
        )

        sampled_negative_indices = _sample_negative_indices(
            features_shape=(batch_size, sequence_length),
            num_negatives=model.config.num_negatives,
            mask_time_indices=mask_time_indices,
        )

        mask_time_indices = torch.tensor(data=mask_time_indices, device=ips.device, dtype=torch.long)
        sampled_negative_indices = torch.tensor(
            data=sampled_negative_indices, device=ips.device, dtype=torch.long
        )

        print(ips)
        print(mask_time_indices)
        print(sampled_negative_indices)

        loss = model(ips,mask_time_indices=mask_time_indices,sampled_negative_indices=sampled_negative_indices).loss

        print('loss',loss)

        break



                
            

















        batch_iterator.set_postfix(loss=running_loss /( 1 if batch_idx == 0 else batch_idx))

Processing Epoch 00:   0%|          | 0/1550 [00:00<?, ?it/s]

tensor([[ 2.2682e-04,  2.3256e-04,  2.2961e-04,  ...,  1.5293e+00,
          1.4560e+00,  1.5053e+00],
        [ 1.5130e-01,  2.7549e-01,  2.4681e-01,  ...,  3.8436e-01,
          5.6157e-01,  7.9311e-01],
        [-3.5496e-04, -3.5496e-04, -3.5496e-04,  ..., -5.8199e-01,
         -6.8625e-01, -7.9274e-01],
        ...,
        [ 9.8245e-04,  9.8096e-04,  9.8080e-04,  ..., -7.8956e-01,
         -1.1840e+00, -1.6731e+00],
        [-5.1873e-04, -5.1843e-04, -5.1778e-04,  ...,  1.8853e+00,
          1.5824e+00,  1.3594e+00],
        [ 2.7939e-04,  2.7939e-04,  2.7939e-04,  ...,  1.0347e-01,
          1.0546e-01,  1.1822e-01]])
tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 1, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])
tensor([[[   0,    0,    0,  ...,    0,    0,    0],
         [   0,    0,    0,  ...,    0,    0,    0],
         [   0,    0,    0,  ...,    0,    0

Processing Epoch 00:   0%|          | 0/1550 [00:00<?, ?it/s]

loss tensor(764.9269, grad_fn=<AddBackward0>)



