In [3]:
from collections import ChainMap

import yaml
import torch



In [4]:
import fairseq_mod
#import fairseq
#from wav2vec2_inference_pipeline import inference_pipeline

In [5]:
import sys
sys.path.append("../..")

#from wav2vec2_inference_pipeline import inference_pipeline
from data_loader import LibriSpeechDataLoader
from knowledge_distillation.kd_training import KnowledgeDistillationTraining


In [6]:
from fairseq_mod.models.wav2vec.teacher_wav2vec2 import TeacherWav2Vec2Model
from fairseq_mod.models.wav2vec.student_wav2vec2 import StudentWav2Vec2Model

In [7]:
from fairseq_mod.data.dictionary import Dictionary as D

### Load configurations and create letter dictionary

In [8]:
config = yaml.load(open('demo_config.yaml','r'), Loader=yaml.FullLoader)
target_dict = D.load('ltr_dict.txt')

In [9]:
import torch
import torchaudio
from torch.utils.data import DataLoader
import torch.nn.functional as F

def postprocess_features(feats, sample_rate):
    if feats.dim() == 2: feats = feats.mean(-1)
    assert feats.dim() == 1, feats.dim()
    with torch.no_grad():
        feats = F.layer_norm(feats, feats.shape)
    return feats

def get_feature(batch_sample):
    return postprocess_features(batch_sample[0][0], batch_sample[1])

def get_padding_mask(batch_sample):
    return torch.BoolTensor(batch_sample[0].size(1)).fill_(False)

def get_batch_encoder_input(batch_samples):
    features = [get_feature(batch_sample) for batch_sample in batch_samples]
    features = torch.nn.utils.rnn.pad_sequence(features, batch_first=True, padding_value=0)
    padding_masks = [get_padding_mask(batch_sample) for batch_sample in batch_samples]
    padding_masks = torch.nn.utils.rnn.pad_sequence(padding_masks, batch_first=True, padding_value=True)
    mask = False
    features_only = True
    return features, padding_masks, mask, features_only

class LibriSpeechDataLoader:

    """
    Data loaders for the LibriSpeech dataset.

    Arguments:
        train_batch_size (int): batch size for the training data loader
        val_batch_size (int): batch size for the validation data loader
        num_workers (int): number of workers for training and validation data loaders
        train_data_path (str): Path to training data
        val_data_path (str): Path to validation data
        train_on_dev_clean (bool): Set to True if you want to train on parts of the dev-clean dataset and validate on the other part. This is useful when testing ideas
        use_train_clean_100 (bool): Set to True if using LibriSpeech's train-clean-100 dataset during training
        use_train_clean_360 (bool): Set to True if using LibriSpeech's train-clean-360 dataset during training
        use_train_other_500 (bool): Set to True if using LibriSpeech's train-other-500 dataset during training
    """

    def __init__(self,
                 download,
                 train_batch_size,
                 val_batch_size,
                 num_workers,
                 train_data_path,
                 val_data_path,
                 train_on_dev_clean,
                 use_train_clean_100,
                 use_train_clean_360,
                 use_train_other_500,
                 ):

        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        self.num_workers = num_workers
        self.download  = download
        
        dev_clean_dataset = torchaudio.datasets.LIBRISPEECH(val_data_path, url='dev-clean', download=False)
        dev_other_dataset = torchaudio.datasets.LIBRISPEECH(val_data_path, url='dev-other', download=False)
        dev_other_data_loader = DataLoader(dev_other_dataset,
                                           batch_size = val_batch_size,
                                           shuffle = False,
                                           num_workers = num_workers)

        if train_on_dev_clean:
            train_data_loader, dev_train_data_loader, dev_clean_data_loader = self.create_data_loaders_from_dev_clean(dev_clean_dataset,
                                                                                                                      train_batch_size,
                                                                                                                      val_batch_size,
                                                                                                                      num_workers)
        else:
            train_data_loader, dev_train_data_loader = self.create_data_loaders_from_train_dataset(train_data_path,
                                                                                                   train_batch_size,
                                                                                                   val_batch_size,
                                                                                                   num_workers,
                                                                                                   use_train_clean_100,
                                                                                                   use_train_clean_360,
                                                                                                   use_train_other_500,)
            dev_clean_data_loader = DataLoader(dev_clean_dataset,
                                               batch_size = val_batch_size,
                                               shuffle = False,
                                               num_workers = num_workers)

        self.train_data_loader = train_data_loader
        self.val_data_loaders = {
                                 #"dev_train": dev_train_data_loader,
                                 "dev_clean": dev_clean_data_loader,
                                 #"dev_other": dev_other_data_loader
                                }

    def create_data_loaders_from_dev_clean(self,
                                           dev_clean_dataset,
                                           train_batch_size,
                                           val_batch_size,
                                           num_workers):

        """
        Create train_data_loader and dev_train_data_loader from dev_clean_dataset.
        Parts of dev_clean_dataset will be used for training, and the other part will be used for validating.

        Arguments:
            dev_clean_dataset (torchaudio.datasets.LIBRISPEECH): dev-clean data set from LibriSpeech
            train_batch_size (int): batch size for the training data loader
            val_batch_size (int): batch size for the validation data loader
            num_workers (int): number of workers for the data loaders

        Returns:
            train_data_loader (torch.utils.data.DataLoader): data loader for training created from the dev clean dataset
            dev_train_data_loader (torch.utils.data.DataLoader): data loader for validating created from the dev clean dataset
        """

        train_dataset, val_dataset = torch.utils.data.random_split(dev_clean_dataset,
                                                                   [2203,500],
                                                                   generator=torch.Generator().manual_seed(42))

        train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                                        batch_size=train_batch_size,
                                                        shuffle=False,
                                                        num_workers=num_workers,
                                                        collate_fn=get_batch_encoder_input)
        dev_train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                                            batch_size=val_batch_size,
                                                            shuffle=False,
                                                            num_workers=num_workers,
                                                            sampler=torch.utils.data.sampler.SubsetRandomSampler(torch.randint(high=2203, size=(500,))),)
        dev_clean_data_loader = torch.utils.data.DataLoader(val_dataset,
                                                            batch_size=val_batch_size,
                                                            shuffle=False,
                                                            num_workers=num_workers)
        return train_data_loader, dev_train_data_loader, dev_clean_data_loader


    def create_data_loaders_from_train_dataset(self,
                                               train_data_path,
                                               train_batch_size,
                                               val_batch_size,
                                               num_workers,
                                               use_train_clean_100,
                                               use_train_clean_360,
                                               use_train_other_500):
        """
        Create train_data_loader and dev_train_data_loader from training datasets of LibriSpeech.
        Create the joint training dataset based on user's selections.

        Arguments:
            train_data_path (str): path to LibriSpeech training data
            train_batch_size (int): batch size for train_data_loader
            val_batch_size (int): batch size for dev_traiin_data_loader
            num_workers (int): number of workers for data loaders
            use_train_clean_100 (bool): Set to True if using LibriSpeech's train-clean-100 dataset during training
            use_train_clean_360 (bool): Set to True if using LibriSpeech's train-clean-360 dataset during training
            use_train_other_500 (bool): Set to True if using LibriSpeech's train-other-500 dataset during training

        Returns:
            train_data_loader (torch.utils.data.DataLoader): data loader for training created from LibriSpeech training datasets
            dev_train_data_loader (torch.utils.data.DataLoader): data loader for validating created from LibriSpeech training datasets
        """
        selected_datasets = []
        if use_train_clean_100: selected_datasets.append(torchaudio.datasets.LIBRISPEECH(train_data_path, url='train-clean-100', download=self.download))
        if use_train_clean_360: selected_datasets.append(torchaudio.datasets.LIBRISPEECH(train_data_path, url='train-clean-360', download=self.download))
        if use_train_other_500: selected_datasets.append(torchaudio.datasets.LIBRISPEECH(train_data_path, url='train-other-500', download=self.download))
        train_dataset = torch.utils.data.ConcatDataset(selected_datasets)
        train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                                        batch_size=train_batch_size,
                                                        shuffle=True,
                                                        num_workers=num_workers,
                                                        collate_fn=get_batch_encoder_input)
        dev_train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                                            batch_size=val_batch_size,
                                                            shuffle=False,
                                                            num_workers=num_workers,
                                                            sampler=torch.utils.data.sampler.SubsetRandomSampler(torch.randint(high=len(train_data_loader), size=(2000,))),)
        return train_data_loader, dev_train_data_loader

    def get_train_data_loader(self):
        return self.train_data_loader

    def get_val_data_loaders(self):
        return self.val_data_loaders


In [12]:
config["knowledge_distillation"]["student_model"]

{'num_trans_layer_student_model': 4,
 'num_trans_layer_student_init_model': 12,
 'student_init_model_path': '/common/home/vk405/Projects/Data/Data_wav2vec/wav2vec_big_960h.pt',
 'student_init_model_type': 'fairseq_pretrained',
 'student_trans_layer_init_method': 'every_k',
 'change_conv_layers': True,
 'conv_groups': 2}

### Create data loaders for training and validation

In [8]:
libriSpeech_data_loader = LibriSpeechDataLoader(**config["data_loader"])


In [9]:
train_data_loader = libriSpeech_data_loader.get_train_data_loader()
val_data_loaders = libriSpeech_data_loader.get_val_data_loaders()

In [14]:
import numpy as np

In [13]:
batch = next(iter(train_data_loader))

In [15]:
batch[0].shape


torch.Size([1, 241440])

In [19]:
batch[-1]

True

### Create inference pipeline for validating the student model

In [15]:
#inference_pipeline_example = inference_pipeline(target_dict, use_cuda=True, input_half=False)

In [13]:
student_model = StudentWav2Vec2Model.create_student_model(target_dict=target_dict,
                                                          fairseq_pretrained_model_path=config["knowledge_distillation"]["general"]["fairseq_pretrained_model_path"],
                                                          **config["knowledge_distillation"]["student_model"])


2022-04-24 15:01:01 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX


Error: Canceled future for execute_request message before replies were done

In [11]:
'feature_extractor.conv_layers.0.0.weight'.split('.')[2] != 0

True

In [47]:
student_model.state_dict().keys()

odict_keys(['mask_emb', 'feature_extractor.conv_layers.0.0.weight', 'feature_extractor.conv_layers.0.2.weight', 'feature_extractor.conv_layers.0.2.bias', 'feature_extractor.conv_layers.1.0.weight', 'feature_extractor.conv_layers.2.0.weight', 'feature_extractor.conv_layers.3.0.weight', 'feature_extractor.conv_layers.4.0.weight', 'feature_extractor.conv_layers.5.0.weight', 'feature_extractor.conv_layers.6.0.weight', 'post_extract_proj.weight', 'post_extract_proj.bias', 'quantizer.vars', 'quantizer.weight_proj.weight', 'quantizer.weight_proj.bias', 'project_q.weight', 'project_q.bias', 'encoder.pos_conv.0.bias', 'encoder.pos_conv.0.weight_g', 'encoder.pos_conv.0.weight_v', 'encoder.layers.0.self_attn.k_proj.weight', 'encoder.layers.0.self_attn.k_proj.bias', 'encoder.layers.0.self_attn.v_proj.weight', 'encoder.layers.0.self_attn.v_proj.bias', 'encoder.layers.0.self_attn.q_proj.weight', 'encoder.layers.0.self_attn.q_proj.bias', 'encoder.layers.0.self_attn.out_proj.weight', 'encoder.layers.0

In [29]:
out = list(children)

In [34]:
feature_extrac_mod = out[0]

In [36]:
feat_mod = list(feature_extrac_mod.children())

In [41]:
list(feat_mod[0].children())

[Sequential(
   (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
   (1): Dropout(p=0.0, inplace=False)
   (2): Fp32GroupNorm(512, 512, eps=1e-05, affine=True)
   (3): GELU()
 ),
 Sequential(
   (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
   (1): Dropout(p=0.0, inplace=False)
   (2): GELU()
 ),
 Sequential(
   (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
   (1): Dropout(p=0.0, inplace=False)
   (2): GELU()
 ),
 Sequential(
   (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
   (1): Dropout(p=0.0, inplace=False)
   (2): GELU()
 ),
 Sequential(
   (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
   (1): Dropout(p=0.0, inplace=False)
   (2): GELU()
 ),
 Sequential(
   (0): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
   (1): Dropout(p=0.0, inplace=False)
   (2): GELU()
 ),
 Sequential(
   (0): Conv1d(512, 512, kernel_size=(2,), stride=(2,), bias=False)
   (1): Dropout(p=0.0, inplace=False

### Create student and teacher model

In [1]:
teacher_model = TeacherWav2Vec2Model.create_teacher_model(target_dict=target_dict,
                                                          fairseq_pretrained_model_path=config["knowledge_distillation"]["general"]["fairseq_pretrained_model_path"])

NameError: name 'TeacherWav2Vec2Model' is not defined

In [62]:
for module in student_model.modules():
    print(module)
    break

StudentWav2Vec2Model(
  (feature_extractor): ConvFeatureExtractionModel(
    (conv_layers): ModuleList(
      (0): Sequential(
        (0): Conv1d(1, 512, kernel_size=(10,), stride=(5,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): Fp32GroupNorm(512, 512, eps=1e-05, affine=True)
        (3): GELU()
      )
      (1): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): GELU()
      )
      (2): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): GELU()
      )
      (3): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): GELU()
      )
      (4): Sequential(
        (0): Conv1d(512, 512, kernel_size=(3,), stride=(2,), bias=False)
        (1): Dropout(p=0.0, inplace=False)
        (2): GELU()
      )
 

In [27]:
import torch.nn as nn

l = nn.Sequential(nn.Conv1d(1,512,5),nn.Conv1d(512,512,5),nn.Conv1d(512,512,5),\
    nn.Conv1d(512,512,5),nn.Conv1d(512,512,5),nn.Conv1d(512,512,5))

In [28]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

In [40]:
l = nn.Conv1d(20,10,5,bias=False)
dl = nn.Conv1d(20,10,5,groups=2,bias=False)
count_parameters(l)

+---------+------------+
| Modules | Parameters |
+---------+------------+
|  weight |    1000    |
+---------+------------+
Total Trainable Params: 1000


1000

In [46]:
l.state_dict()['weight'].shape

torch.Size([10, 20, 5])

In [48]:
dl.state_dict()['weight'].shape

torch.Size([10, 10, 5])

In [55]:
lyr_ids = np.arange(0,20,2).tolist()

In [56]:
l.state_dict()['weight'][:,lyr_ids,:]

torch.Size([10, 10, 5])

In [59]:
#dl.state_dict()['weight'].copy_(l.state_dict()['weight'][:,lyr_ids,:])

In [60]:
dl.state_dict()['weight']

torch.Size([10, 10, 5])

In [35]:
count_parameters(dl)

+---------+------------+
| Modules | Parameters |
+---------+------------+
|  weight |    250     |
+---------+------------+
Total Trainable Params: 250


250

In [31]:
count_parameters(student_model)

+----------------------------------------------+------------+
|                   Modules                    | Parameters |
+----------------------------------------------+------------+
|                   mask_emb                   |    1024    |
|   feature_extractor.conv_layers.0.0.weight   |    5120    |
|   feature_extractor.conv_layers.0.2.weight   |    512     |
|    feature_extractor.conv_layers.0.2.bias    |    512     |
|   feature_extractor.conv_layers.1.0.weight   |   786432   |
|   feature_extractor.conv_layers.2.0.weight   |   786432   |
|   feature_extractor.conv_layers.3.0.weight   |   786432   |
|   feature_extractor.conv_layers.4.0.weight   |   786432   |
|   feature_extractor.conv_layers.5.0.weight   |   524288   |
|   feature_extractor.conv_layers.6.0.weight   |   524288   |
|           post_extract_proj.weight           |   524288   |
|            post_extract_proj.bias            |    1024    |
|                quantizer.vars                |   245760   |
|       

65456384

### Set the projection layer (which outputs probability distributions over tokens) for student and teacher model

In [19]:
def get_proj_layer(fairseq_pretrained_model_path):
    """
    Get projection layer's weights and biases of wav2vec 2.0 pre-trained model
    """
    w2v = torch.load(fairseq_pretrained_model_path)
    return w2v["model"]["w2v_encoder.proj.weight"], w2v["model"]["w2v_encoder.proj.bias"]

In [20]:
proj_layer_weight, proj_layer_bias = get_proj_layer(fairseq_pretrained_model_path=config["knowledge_distillation"]["general"]["fairseq_pretrained_model_path"])
student_model.init_proj_layer_to_decoder(proj_layer_weight, proj_layer_bias)
teacher_model.init_proj_layer_to_decoder(proj_layer_weight, proj_layer_bias)

### Train a student model with knowledge distillation and get its performance on dev set

In [21]:
#inference_pipeline_example
KD_wav2vec2 = KnowledgeDistillationTraining(train_data_loader = train_data_loader,
                                            val_data_loaders = val_data_loaders,
                                            inference_pipeline = None,
                                            student_model = student_model,
                                            teacher_model = teacher_model,
                                            num_gpu_used = config["knowledge_distillation"]["general"]["num_gpu_used"],
                                            temperature = config["knowledge_distillation"]["general"]["temperature"],
                                            final_loss_coeff_dict = config["knowledge_distillation"]["final_loss_coeff"],
                                            logging_param = ChainMap(config["knowledge_distillation"]["general"], config["knowledge_distillation"]["optimization"],
                                                                     config["knowledge_distillation"]["final_loss_coeff"], config["knowledge_distillation"]["student_model"],
                                                                     config["knowledge_distillation"]["pytorch_lightning_trainer"]),
                                            **ChainMap(config["knowledge_distillation"]["optimization"],
                                                       config["knowledge_distillation"]["pytorch_lightning_trainer"],
                                                       config["knowledge_distillation"]["comet_info"])
                                            )
KD_wav2vec2.start_kd_training()

Global seed set to 42
  rank_zero_deprecation(


MisconfigurationException: `Trainer(strategy='ddp')` or `Trainer(accelerator='ddp')` is not compatible with an interactive environment. Run your code as a script, or choose one of the compatible backends: dp, ddp_spawn, ddp_sharded_spawn, tpu_spawn. In case you are spawning processes yourself, make sure to include the Trainer creation inside the worker function.

In [9]:
student_model = KD_wav2vec2.get_student_model()
val_result = inference_pipeline_example.run_inference_pipeline(student_model.cuda(), val_data_loaders["dev_clean"])

In [10]:
print("final WER is {:.2f}".format(val_result["inference_result"]*100))

final WER is 48.40


#### As the output above shows, WER has decreased from 99 to 48 after 5 epochs of training.