In [1]:
import argparse
import copy
import os
import random
import time


import torch
import numpy as np
from common import helpers

from common.dataset import AudioDataset, get_data_loader
from common.features import BaseFeatures, FilterbankFeatures
from common.helpers import (Checkpointer, greedy_wer, num_weights, print_once,
                            process_evaluation_epoch)
from common.tb_dllogger import flush_log, init_log, log
from jasper import config
from transducer.model import Transducer



In [10]:
def parse_args():
    
    model_str = "./configs/transducer_asr.yaml"
    dataset_str = "/Users/madhuhegde/work/berkeley/ASR/SpeechRecognition/datasets/LibriSpeech/"
    traindata_str = [dataset_str+"librispeech-train-clean-100-wav.json"]
    valdata_str = [dataset_str+"librispeech-dev-clean-wav.json"]
    out_str = "./results/"
   
    
    parser = argparse.ArgumentParser(description='Jasper')

    training = parser.add_argument_group('training setup')
    training.add_argument('--epochs', default=10, type=int,
                          help='Number of epochs for the entire training; influences the lr schedule')
   
    training.add_argument('--seed', default=42, type=int, help='Random seed')
   
    optim = parser.add_argument_group('optimization setup')
    optim.add_argument('--batch_size', default=32, type=int,
                       help='Global batch size')
    optim.add_argument('--lr', default=1e-4, type=float,
                       help='Peak learning rate')
    optim.add_argument("--lr_exp_gamma", default=0.99, type=float,
                       help='gamma factor for exponential lr scheduler')
   
    io = parser.add_argument_group('feature and checkpointing setup')
   
    io.add_argument('--model_config', type=str, default = model_str,
                    help='Path of the model configuration file')
    io.add_argument('--train_manifests', type=str, default=traindata_str, nargs='+',
                    help='Paths of the training dataset manifest file')
    io.add_argument('--val_manifests', type=str, default=valdata_str, nargs='+',
                    help='Paths of the evaluation datasets manifest files')
    io.add_argument('--max_duration', type=float,
                    help='Discard samples longer than max_duration')
    io.add_argument('--pad_to_max_duration', action='store_true', default=False,
                    help='Pad training sequences to max_duration')
    io.add_argument('--dataset_dir', default=dataset_str, type=str,
                    help='Root dir of dataset')
    io.add_argument('--output_dir', type=str, default=out_str,
                    help='Directory for logs and checkpoints')
    io.add_argument('--log_file', type=str, default=None,
                    help='Path to save the training logfile.')
    return parser.parse_args("")




In [11]:
print(args.lr, args.lr_exp_gamma)

0.0001 0.99


In [15]:
def data_generator(args, config, symbols):
    
    print('Setting up datasets...')
    cfg = config.load(args.model_config)
    config.apply_duration_flags(cfg, args.max_duration, args.pad_to_max_duration)
    
    train_dataset_kw, train_features_kw = config.input(cfg, 'train')
    train_dataset = AudioDataset(args.dataset_dir,
                                 args.train_manifests,
                                 symbols,
                                 **train_dataset_kw)
    
    train_loader = get_data_loader(train_dataset,
                                       args.batch_size,
                                       multi_gpu=0,
                                       shuffle=True,
                                       num_workers=0)
    
    train_feat_proc = FilterbankFeatures(**train_features_kw)
    
    val_dataset_kw, val_features_kw = config.input(cfg, 'val')
    val_dataset = AudioDataset(args.dataset_dir,
                                   args.val_manifests,
                                   symbols,
                                   **val_dataset_kw)
    val_loader = get_data_loader(val_dataset,
                                     args.batch_size,
                                     multi_gpu=0,
                                     shuffle=False,
                                     num_workers=0,
                                     drop_last=False)
    
    val_feat_proc = FilterbankFeatures(**val_features_kw)
    
    return train_loader, train_feat_proc, val_loader, val_feat_proc
    
if 1:
    args = parse_args()
    multi_gpu = 0
    args.amp = False
    torch.manual_seed(args.seed + 1)
    np.random.seed(args.seed + 2)
    random.seed(args.seed + 3)
    cfg = config.load(args.model_config)

    symbols = cfg['labels'] + ['<BLANK>']
    
    
    train_loader, train_feat_proc, val_loader, val_feat_proc = data_generator(args, config, symbols)

    
    #Configure model and optimizer
    
    num_inputs = train_features_kw['n_filt']
    model = Transducer(num_inputs, 32)
    lr = args.lr
    lr_gamma = args.lr_exp_gamma
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=lr_gamma)
    batch_size = args.batch_size    
        
    num_epochs = 1 #args.epochs
    while(num_epochs>0):
        model.train()
        train_loss = 0
        num_samples = 0
        test_loss = 0
        for idx, batch in enumerate(train_loader):
            audio, audio_lens, txt, txt_lens = batch
            #feat, feat_lens = train_feat_proc(audio, audio_lens, args.amp)
            feat, feat_lens = audio, audio_lens
            if(idx % 100 == 0):
                print(idx)
            #feat = feat.transpose(1, 2)
            #feat = feat.to(model.device)
            #txt = txt.to(model.device)
            batch_size = feat.shape[0]
            #loss = model.compute_loss(feat,txt,feat_lens,txt_lens)
            loss = feat.shape[0]
            num_samples += batch_size
            #optimizer.zero_grad()
            #loss.backward()
            #optimizer.step()
            train_loss += loss #.item() * batch_size
        train_loss /= num_samples
        num_epochs = num_epochs - 1    

def temp():       
        lr_scheduler.step()  
        num_samples = 0
        model.eval()
        for batch in val_loader:
            audio, audio_lens, txt, txt_lens = batch
            feat, feat_lens = val_feat_proc(audio, audio_lens, args.amp)
            feat = feat.transpose(1, 2)
            
            feat = feat.to(model.device)
            txt = txt.to(model.device)
            batch_size = feat.shape[0]
            loss = model.compute_loss(feat,txt,feat_lens,txt_lens)
            
            num_samples += batch_size
            test_loss += loss.item() * batch_size
            test_loss /= num_samples
            break
            
        num_epochs -= 1    
     
   

Setting up datasets...
0
100
200
300
400
500
600
700
800


In [14]:
 print(train_loss)

0.0011338270661171077


In [None]:
import matplotlib.pyplot as plt
plt.plot(a[31][:,0])

In [None]:
print(config.encoder(cfg))

In [None]:
import torch
import math
x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)
p = torch.tensor([1, 2, 3])
xx = x.unsqueeze(-1)
y = x.unsqueeze(-1).pow(p)

In [None]:
print(xx, x.shape)

In [None]:
shape = (2,2)
x = torch.ones(shape)
xx =  x.unsqueeze(0)
print(x)
print(xx)