In [1]:
import os
%load_ext autoreload
%autoreload 2

In [2]:
from torch.utils.data import Dataset
import torch


class MockData(Dataset):
    def __len__(self):
        return 100

    def __getitem__(self, idx):
        return {'input_features': torch.randn(80, 3000), 'speaker_inventory': torch.randn(8, 384),
                'target_asr_ids': torch.randint(0, 60, size=(256,)),
                'target_diar_ids': torch.randint(0, 8, size=(256,))}


data = MockData()
loader = torch.utils.data.DataLoader(data, batch_size=4)

In [3]:
batch = next(iter(loader))

In [2]:
from datasets import load_dataset
from mango.utils.tokenization import retain_cyrillic

cv13 = load_dataset('mozilla-foundation/common_voice_13_0', 'uk', trust_remote_code=True)
urban = load_dataset('danavery/urbansound8K')

urban = urban['train'].rename_column('class', 'label')
cv13 = cv13['train'].rename_columns({'sentence': 'transcription', 'client_id': 'speaker_id'})
cv13 = retain_cyrillic(cv13, 'transcription')

cv13.set_format('pt')
urban.set_format('pt')

In [3]:
from mango.training.SpeakerAttributedMixer import SpeakerAttributeExample, SpeakerAttributedMixer, DatasetMixerConfig

data = SpeakerAttributedMixer(
    DatasetMixerConfig(max_speakers=3, utterances_count=250, beta=5, min_repetitions=3, max_repetitions=6), cv13,
    urban)
example = data.generate()

In [4]:
%load_ext autoreload 
%autoreload 2

In [5]:
from transformers import PreTrainedTokenizerFast

In [6]:
tokenizer = PreTrainedTokenizerFast.from_pretrained('anakib1/sa-asr-0.1')

In [7]:
from mango.training.collators import SpeakerAttributionCollator

In [9]:
from transformers import WhisperFeatureExtractor
import torch

feature_extractor = WhisperFeatureExtractor.from_pretrained('openai/whisper-small')
inventory = torch.randn(30, 384)
collator = SpeakerAttributionCollator(tokenizer, feature_extractor, inventory)

In [10]:
import torch
from mango.training.DatasetMixer import DatasetMixerWrapper

wrapped = DatasetMixerWrapper(data)
loader = torch.utils.data.DataLoader(wrapped, batch_size=4, collate_fn=collator)

In [15]:
tokenizer.pad_token_id

0

In [13]:
from mango.models.sa_asr import SAASR, SAASRConfig

model = SAASR(SAASRConfig(vocab_size=tokenizer.vocab_size))
model(**batch)

{'asr_outputs': tensor([[[ 477.6021,  536.6523,   48.9734,  ...,  392.5928, -523.9099,
             85.6945],
          [ 477.8172,  536.3663,   49.5025,  ...,  392.6454, -523.8537,
             85.5838],
          [ 477.8195,  536.3677,   49.4888,  ...,  392.6532, -523.8458,
             85.5778],
          ...,
          [ 477.8139,  536.3461,   49.4363,  ...,  392.7038, -523.8139,
             85.5769],
          [ 477.8138,  536.3460,   49.4363,  ...,  392.7037, -523.8139,
             85.5768],
          [ 477.8138,  536.3461,   49.4362,  ...,  392.7039, -523.8139,
             85.5769]],
 
         [[ 241.8062,  320.1090,  -11.6099,  ...,  321.1734, -461.2775,
            -35.9849],
          [ 242.0220,  319.7439,  -11.2700,  ...,  321.2599, -461.4782,
            -36.4647],
          [ 242.0175,  319.7428,  -11.2770,  ...,  321.2614, -461.4699,
            -36.4609],
          ...,
          [ 241.9625,  319.7487,  -11.2995,  ...,  321.2768, -461.4478,
            -36.4751],
  

In [None]:
from mango.training.MangoTrainer import TrainerConfig, MangoTrainer

config = TrainerConfig(model_name='sa-asr-0.1', logs_frequency_batches=16, save_strategy='epoch',
                       mixed_precision='fp16', early_stopping_patience=6, gradient_accumulation_steps=2)

optim = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CyclicLR(optim, base_lr=1e-4, max_lr=1e-3, mode='exp_range', gamma=0.99,
                                              cycle_momentum=False, step_size_up=int(len(loader) * 1.75),
                                              step_size_down=int(len(loader) * 1.25))

trainer = MangoTrainer(model=model, train_loader=loader, eval_loader=loader, config=config, optmizer=optim,
                       scheduler=scheduler)

trainer.train()

In [25]:
a = torch.randint(0, 100, (4, 128))
b = torch.randint(0, 100, (4, 128))
((a==b).masked_fill(b==0, False).sum(dim=1)/((b!=0).sum(dim=1))).mean()

tensor(0.0099)

In [19]:
import evaluate
acc = evaluate.load('accuracy')
acc.compute(references=a, predictions=b)

ValueError: Predictions and/or references don't match the expected format.
Expected format: {'predictions': Value(dtype='int32', id=None), 'references': Value(dtype='int32', id=None)},
Input predictions: tensor([[ 7, 82, 87, 41, 55,  3, 94, 45, 46, 54, 77, 27, 55, 33, 71, 93, 78, 85,
         49, 57, 15, 17,  5, 73, 72, 71, 47,  6, 19, 97,  2, 87, 33, 72, 35, 71,
         89, 43, 51, 28,  7, 80, 85, 70, 73, 20, 57, 60, 92,  3, 73, 81, 25, 92,
         94, 40,  4, 14, 45, 96, 37, 47, 14, 37, 66,  5, 26, 28, 95, 30, 34, 13,
         75, 71, 28, 22, 93, 15, 41, 10, 86, 57, 27, 45, 86, 27, 47, 66, 49, 54,
         61, 63, 96,  5, 65, 62, 10, 46, 92,  4, 45, 66, 20,  8,  4, 65,  9, 62,
         23,  2, 83, 46, 69, 16, 23,  6, 27, 98,  1, 86, 74, 53, 90, 91, 89, 13,
         72, 78],
        [97, 76, 21, 82, 15, 27, 32, 51, 32, 96, 64, 61, 56, 38,  7, 70, 33, 17,
         10, 19, 65, 63, 91, 48, 83, 13, 12, 65, 40, 72, 93, 22, 24, 71, 74, 36,
         80, 77, 47, 89, 22, 75, 44, 40, 74, 70, 15, 60, 90, 48, 68, 28, 42,  9,
         62, 99, 19, 88, 82, 71, 43, 21, 54, 83, 13, 62, 21, 46, 76, 10, 41, 59,
         76, 26, 83, 97, 76,  7,  2, 78, 85, 41, 50, 24, 42, 44, 28, 39, 31, 25,
         56, 33, 25, 62, 98, 27, 21,  6, 67, 54, 83, 37, 74, 22, 92, 87, 41,  1,
         44, 40, 64, 12,  8, 39, 43,  1, 34, 31, 49,  1, 56, 19, 30, 27, 82, 36,
         14, 90],
        [24, 60, 46,  1, 65, 11, 30,  6, 91, 40, 95, 10, 69, 56, 18, 69,  9, 33,
         45, 11, 15, 44, 56, 34, 36, 53, 79, 86,  3, 31, 32, 71, 19, 28, 18, 15,
         79, 28,  9, 87, 62, 25, 26,  2, 97, 11, 81, 81, 82, 17,  3, 21, 23,  3,
         76, 80, 73, 12, 51, 63, 79, 93, 50, 80, 37, 45, 90, 39, 67, 31, 83, 86,
         58, 91, 56, 36, 26, 44, 30, 68, 63, 54,  8, 50, 23, 52, 47, 48, 27, 10,
         29, 88, 74, 26, 39, 25, 57,  6,  1, 75, 74, 94, 52, 20, 29, 61, 13, 26,
         95, 49, 22, 50, 74, 86, 86, 52,  3, 99, 88, 92,  6, 94, 99, 37, 83, 49,
         12, 44],
        [50,  4, 91, 83, 21, 23,  3, 40, 12,  6, 13, 72, 92, 61, 80, 97, 94, 36,
         23, 67, 36, 38, 85, 11, 51, 28, 85, 19, 41,  9, 61, 52, 96, 26, 60, 86,
         32,  0, 89, 55, 28, 12, 64, 62, 79,  0, 42, 10, 81, 70, 40, 48, 30,  0,
         51, 86, 96, 71,  3, 40, 70, 73, 96, 24, 13, 62, 90, 97, 19, 55, 56,  5,
         56, 27, 23, 93, 26, 31, 98, 58,  2, 37, 25, 32, 62, 62, 21, 92,  0, 93,
          7, 78, 80, 63, 46, 47, 24, 54, 37, 69, 95, 10, 25, 41, 59, 64, 90, 51,
         39, 88, 91,  7, 37, 27, 17, 29, 10, 39, 62, 44, 41, 69, 97, 37, 92,  1,
         18, 83]]),
Input references: tensor([[68, 97, 38, 12, 72, 55, 23,  9, 88,  4,  1, 53, 87, 58, 82, 18, 74, 18,
         57, 23, 13, 44, 34, 55, 10, 67, 20, 43, 69, 61, 35, 64, 98, 42, 17, 71,
         17, 99, 95, 56, 43, 50, 68, 87, 34,  2, 99, 62, 27, 82,  5, 97, 31, 16,
          9, 62, 58, 13, 34, 64, 37, 96, 31, 67, 64,  5, 93, 80, 26, 44, 60, 20,
         67, 97, 59, 32, 51, 27, 95, 79, 44, 95, 92, 97, 27, 50, 51,  1,  7, 30,
         51, 97, 38, 74, 24, 59, 35, 42, 17, 97, 95, 32, 91, 34, 46, 67, 49, 83,
         12, 15, 55, 64, 85,  0, 97, 30, 68, 51, 26, 28, 16, 15, 56, 64,  1, 68,
         22, 24],
        [34, 25, 33, 18, 88, 53, 78, 87, 65, 12, 24, 31,  4,  0, 44, 82, 86, 79,
         56, 28, 75, 67, 66, 48, 17, 12, 52, 66, 55,  4, 72, 16, 39, 44, 34, 59,
         47, 68, 84, 24, 44, 65, 46, 34, 13, 15,  4, 53, 57, 45, 61, 82, 49, 23,
         84, 32,  2, 23, 37, 19, 81, 67, 70, 25, 54, 89, 31,  6, 25,  4, 58, 43,
         55, 78, 26, 72, 76, 30, 95, 90, 16,  1, 97, 26, 52, 38, 84, 29, 58, 51,
         36, 38, 67, 46, 10, 16, 53, 35, 71, 39, 47, 41, 23, 81, 43, 84, 13, 32,
         97, 45, 39, 70, 78, 21, 22, 98,  7, 50, 77, 29, 82, 48, 72, 83,  6, 64,
         38, 36],
        [22, 37, 27,  3,  5, 27, 99, 77, 31, 77, 79, 34, 67, 62, 64, 77, 84, 21,
          3, 58, 39, 71, 61, 66, 43, 81, 51, 10, 52, 38, 29, 47, 26, 16, 40, 40,
         35, 33, 69, 69, 92,  1, 54, 78, 78, 82, 18, 84, 68, 33, 39, 66, 58, 40,
         37,  1, 36, 69, 22, 55, 25, 84, 15,  0, 71, 51, 23, 91, 26, 39, 69, 28,
         27,  6, 54, 94,  2,  5,  8, 51,  9, 68, 48, 89, 61, 62, 99, 33, 55, 34,
         57, 88, 96, 45, 88, 64, 30, 21, 56,  5, 81, 72, 83, 10, 26, 33, 33, 63,
          5, 31, 48, 19, 81, 99, 93, 52, 20, 88, 89, 61, 34, 22, 53, 62, 57, 64,
         78, 96],
        [54, 32, 97, 39,  6, 12, 46, 56, 49, 92, 90,  5, 28, 77, 14,  3, 93, 49,
         75, 47,  5, 44, 96, 19,  4, 79, 64, 36, 46, 55, 88, 39, 70, 85, 44, 99,
         69, 58, 16, 86, 22, 56, 87, 91, 10,  8, 65, 66, 83,  0, 52, 53, 61, 73,
         50, 62, 13, 67, 61, 43, 77,  5, 25, 46, 10, 36,  0, 71, 41, 40, 89, 39,
         19, 11, 18, 14,  4, 86, 72,  9, 14, 31, 92, 93, 93, 52, 68, 65, 20, 64,
         96, 28, 42,  7, 26, 59, 36, 23, 29, 34, 22, 74, 55, 42, 45, 39, 31, 62,
         47, 90, 93, 76, 84, 24, 61, 88, 68, 47, 17, 86, 12, 40, 33, 57, 32,  3,
          5, 14]])