# RNN-T eval weights


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import pathlib
import typing

import torch
from google.protobuf import text_format

from myrtlespeech.model.rnn_t import RNNTEncoder, RNNT
from myrtlespeech.run.callbacks.csv_logger import CSVLogger
from myrtlespeech.run.callbacks.callback import Callback, ModelCallback
from myrtlespeech.run.callbacks.clip_grad_norm import ClipGradNorm
from myrtlespeech.run.callbacks.report_mean_batch_loss import ReportMeanBatchLoss
from myrtlespeech.run.callbacks.stop_epoch_after import StopEpochAfter
from myrtlespeech.run.callbacks.mixed_precision import MixedPrecision
from myrtlespeech.builders.task_config import build
from myrtlespeech.run.train import fit
from myrtlespeech.protos import task_config_pb2


In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = ""

In [None]:
weights_fp = "/home/julian/models/rnnt-53.pt"

model = torch.load(weights_fp)

In [None]:
# parse example config file
with open("../src/myrtlespeech/configs/rnn_t_en_ds_int.config") as f:
    task_config = text_format.Merge(f.read(), task_config_pb2.TaskConfig())

task_config

### Build all components required for training:

In [None]:
# FYI: if using train-clean-100 & dev-clean this cell takes O(60s) 
seq_to_seq, epochs, train_loader, eval_loader = build(task_config)


This returns a 'sequence to sequence' (or `SeqToSeq` model) which inherits from `torch.nn.Module`. A `SeqToSeq` defines the `model`, `loss` and `post_process` (i.e. decoder). For example:


In [None]:
for k, v in seq_to_seq.model.named_parameters():
    print(k, v.shape)

In [None]:
seq_to_seq

## Callbacks
* Use callbacks to inject features into training loop. 
* It is necessary (for now) to use the `RNNTTraining()` callback but the others are optional


In [None]:
from myrtlespeech.run.callbacks.rnn_t_training import RNNTTraining
from myrtlespeech.run.run import ReportRNNTDecoder

rnnt_decoder_cb  = ReportRNNTDecoder(seq_to_seq.post_process, seq_to_seq.alphabet)

callbacks = [RNNTTraining(),
            ReportMeanBatchLoss(), 
            StopEpochAfter(epoch_batches=1), 
            rnnt_decoder_cb,] 


In [None]:
fit(
    seq_to_seq, 
    epochs=epochs,
    train_loader=train_loader, 
    eval_loader=eval_loader,
    callbacks=callbacks,
)
