In [1]:
import pprint
pp = pprint.PrettyPrinter(indent=4)
from pytorch_lightning import LightningDataModule
from torchaudio.datasets.librispeech import LIBRISPEECH

from examples.self_supervised_learning.data_modules._utils import BucketizeBatchSampler, CommonTransform
from examples.self_supervised_learning.data_modules._wav2vec2_datamodule import *
from torchaudio.prototype.models._conformer_wav2vec2 import * 
import torch.nn.functional as F
import pickle
from pathlib import Path
from tqdm import tqdm


librispeech_cls = LIBRISPEECH
dataset = librispeech_cls(root="..", url="train-clean-100", download=False) # setup root properly, download=True only if you need to download the dataset 
dataset = TransformDataset(dataset, self.train_transform)
pp.pprint(dataset[0][0].shape)


torch.Size([1, 225360])


In [2]:
len_list_path = Path('len_list.obj')
len_list = []
if len_list_path.is_file():
    with open('len_list.obj', 'rb') as fp:
        len_list = pickle.load(fp)
else:
    len_list = [d[0].size(1) for d in tqdm(dataset)] # takes a couple of minutes
    with open('len_list.obj', 'wb') as fp:
        pickle.dump(len_list, fp)
    


sampler = BucketizeBatchSampler(
    len_list,
    num_buckets=10000,
    max_token_count=30 * 16000,
    min_len=32000,
    max_len=250000,
    shuffle=True,
)

dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=sampler,
        )


for X, y in dataloader:
    print(X[0][0])
    break

tensor([[0.0016, 0.0019, 0.0025,  ..., 0.0007, 0.0010, 0.0001],
        [0.0103, 0.0117, 0.0148,  ..., 0.0006, 0.0007, 0.0002],
        [0.0002, 0.0013, 0.0042,  ..., 0.0003, 0.0007, 0.0003],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]])


In [3]:
# model init
model = conformer_wav2vec2_pretrain_base()

In [4]:
# shapes of model output
for X, y in dataloader:
    #print(X[0].shape)
    x, lengths, mask_indices, targets, negatives, neg_idxs = model(*X)
    print('X =', X[0].shape)
    print('x =', x.shape)
    print('lengths =', lengths.shape)
    print('mask_indices =', mask_indices.shape)
    print('targets =', targets.shape)
    print('negatives =', negatives.shape)
    print('neg_idxs =', neg_idxs.shape)
    break

X = torch.Size([14, 205, 64])
x = torch.Size([14, 51, 256])
lengths = torch.Size([14])
mask_indices = torch.Size([14, 51])
targets = torch.Size([14, 9, 256])
negatives = torch.Size([100, 14, 9, 256])
neg_idxs = torch.Size([14, 900])


In [8]:
# THIS PART DOESN'T WORK 
#source https://www.internalfb.com/code/fbsource/[64489570a965cb67396c309615f727f62a9462cd]/fbcode/deeplearning/projects/pyspeech/pyspeech/criterions/wav2vec2_loss.py?lines=24%2C39

def compute_contrastive_loss(x, mask_indices, targets, neg_is_pos, reduce, logit_temp = 0.1):
    x = (
        x[mask_indices]
        .view(x.size(0), -1, x.size(-1))
        .unsqueeze(0)
        .expand(targets.shape)
    )
    logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).float()
    logits /= logit_temp
    if neg_is_pos.any():
        logits[1:][neg_is_pos] = float("-inf")
    target = logits.new_zeros(logits.size(1) * logits.size(2), dtype=torch.long)
    logits = logits.transpose(0, 2)
    logits = logits.reshape(-1, logits.size(-1))
    loss = F.cross_entropy(
        logits,
        target,
        reduction="sum" if reduce else "none",
    )
    sample_size = target.numel()
    return loss, sample_size, logits

def wav2vec_loss(net_output, reduce=True):
    
    x, lengths, mask_indices, y, negatives, neg_idxs = net_output

    # 1. get x (note that, right_context is removed in net_output if there is right_context  n src_tokens)
    #x, lengths, state, extra_output = net_output

    # 2. get negative samples
    #negatives = extra_output[f"extra_output_{self.masking_layer_key}"]

    # 3. get y and mask_indices
    #y = self.get_registered_buffer(model, "y_before_sampling")
    #mask_indices = self.get_registered_buffer(model, "mask_indices")
    assert y is not None
    assert mask_indices is not None
    if model.training:
        assert y.requires_grad
    assert mask_indices.sum() == y.shape[0] * y.shape[1]

    # 4. compute targets
    neg_is_pos = (y == negatives).all(-1)
    y = y.unsqueeze(0)
    targets = torch.cat([y, negatives], dim=0)

    # 5. compute losses
    loss, sample_size, _ = compute_contrastive_loss(
        x, mask_indices, targets, neg_is_pos, reduce
    )
    loss = loss.float()

    return loss, sample_size


epoches = 10
opt = torch.optim.SGD(model.parameters(), lr= 0.01)
for epoch in range(epoches):
    acc_loss = 0.0
    ds_size = 0
    for X, _ in tqdm(dataloader):
        ds_size += X[0].size(0)
        opt.zero_grad()
        y_h = model(*X)
        loss = wav2vec_loss(y_h)
        print(loss[1])
        loss[0].backward()
        opt.step()
        acc_loss += loss[0]
        if ds_size > 100:
            break
    print(f'ep = {epoch}, acc_loss={acc_loss}')
       

0it [00:00, ?it/s]

154


1it [2:45:24, 9924.92s/it]

154


2it [2:45:28, 4088.55s/it]

168


3it [2:45:31, 2223.14s/it]

169


4it [2:45:35, 1347.31s/it]

182


5it [2:45:38, 862.47s/it] 

169


6it [2:45:41, 570.07s/it]

143


7it [2:45:43, 384.53s/it]

169


7it [2:45:46, 1420.92s/it]


ep = 0, acc_loss=6093.26220703125


0it [00:00, ?it/s]

168


1it [00:02,  2.67s/it]

154


2it [00:05,  3.02s/it]

182


3it [00:08,  3.03s/it]

143


4it [00:11,  2.90s/it]

143


5it [00:14,  2.71s/it]

169


6it [00:16,  2.75s/it]

156


7it [00:19,  2.84s/it]

169


7it [00:22,  3.18s/it]


ep = 1, acc_loss=5770.744140625


0it [00:00, ?it/s]

140


1it [00:02,  2.28s/it]

182


2it [00:05,  2.58s/it]

154


3it [00:08,  3.06s/it]

143


4it [00:10,  2.73s/it]

143


5it [00:13,  2.54s/it]

143


6it [00:15,  2.45s/it]

156


7it [00:17,  2.46s/it]

169


7it [00:20,  2.90s/it]


ep = 2, acc_loss=5463.7158203125


0it [00:00, ?it/s]

140


1it [00:02,  2.44s/it]

168


2it [00:04,  2.34s/it]

168


3it [00:07,  2.57s/it]

156


4it [00:10,  2.78s/it]

169


5it [00:13,  2.70s/it]

143


6it [00:15,  2.57s/it]

169


7it [00:17,  2.45s/it]

156


7it [00:20,  2.89s/it]


ep = 3, acc_loss=5690.49365234375


0it [00:00, ?it/s]

154


1it [00:03,  3.70s/it]

168


2it [00:06,  3.12s/it]

168


3it [00:08,  2.86s/it]

169


4it [00:11,  2.65s/it]

156


5it [00:13,  2.41s/it]

169


6it [00:16,  2.81s/it]

182


7it [00:19,  2.82s/it]

143


7it [00:21,  3.10s/it]


ep = 4, acc_loss=5862.97119140625


0it [00:00, ?it/s]

154


1it [00:02,  2.45s/it]

182


2it [00:04,  2.27s/it]

168


3it [00:08,  2.81s/it]

169


4it [00:11,  2.92s/it]

169


5it [00:13,  2.87s/it]

156


6it [00:16,  2.74s/it]

156


7it [00:18,  2.55s/it]

169


7it [00:20,  3.00s/it]


ep = 5, acc_loss=5985.4873046875


0it [00:00, ?it/s]

154


1it [00:03,  3.50s/it]

168


2it [00:06,  2.94s/it]

154


3it [00:08,  2.60s/it]

169


4it [00:11,  2.88s/it]

169


5it [00:14,  2.93s/it]

156


6it [00:17,  3.04s/it]

156


7it [00:22,  3.58s/it]

169


7it [00:27,  3.95s/it]


ep = 6, acc_loss=5737.80908203125


0it [00:00, ?it/s]

154


1it [00:04,  4.39s/it]

154


2it [00:08,  4.03s/it]

154


3it [00:10,  3.24s/it]

130


4it [00:13,  3.07s/it]

169


5it [00:16,  3.01s/it]

169


6it [00:19,  3.08s/it]

182


7it [00:22,  2.95s/it]

169


7it [00:24,  3.48s/it]


ep = 7, acc_loss=5715.21337890625


0it [00:00, ?it/s]

154


1it [00:02,  2.62s/it]

154


2it [00:05,  2.70s/it]

168


3it [00:11,  4.22s/it]

169


4it [00:18,  5.47s/it]

143


5it [00:21,  4.39s/it]

143


6it [00:24,  4.04s/it]

169


7it [00:27,  3.56s/it]

156


7it [00:30,  4.29s/it]


ep = 8, acc_loss=5575.4873046875


0it [00:00, ?it/s]

140


1it [00:03,  3.18s/it]

154


2it [00:06,  2.99s/it]

168


3it [00:08,  2.80s/it]

156


4it [00:11,  2.70s/it]

143


5it [00:13,  2.64s/it]

169


6it [00:16,  2.86s/it]

169


7it [00:19,  2.81s/it]

117


7it [00:22,  3.23s/it]

ep = 9, acc_loss=5420.0048828125



