In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"  # Enable CPU fallback for MPS
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"  # Disable high watermark for MPS
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import random
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader
import logging
import matplotlib.pyplot as plt

from tqdm import tqdm

from config import Configurator
from data import NeuralDataset, collate_transducer, download_data, get_dataframes
from training import Trainer, EarlyStopping
from losses import RNNTLoss
from models import MyEncoder, MyPredictor, MyJoiner, RNNT, predict_on_sample

Config = Configurator()
path = Config.DATA_PATH
device = Config.DEVICE

speechbrain is not installed. FastRNNTLoss will not be available.


In [3]:
train_df, val_df = get_dataframes(path, debug=False)

Loading 45 train files in parallel...


Processing Train: 100%|██████████| 45/45 [00:23<00:00,  1.94it/s]


Loading 41 val files in parallel...


Processing Val: 100%|██████████| 41/41 [00:04<00:00,  8.97it/s]


In [7]:
train_df.head()

Unnamed: 0,neural_features,n_time_steps,seq_class_ids,seq_len,transcriptions,sentence_label,session,block_num,trial_num
0,"[[-0.23488846, 0.5014211, -0.75813776, -0.5213...",544,"[6, 40, 36, 17, 21, 40, 15, 25, 40, 12, 5, 23,...",14,"[73, 32, 119, 105, 108, 108, 32, 103, 111, 32,...",I will go around.,t15.2023.11.03,1,0
1,"[[-0.23703846, 0.3616628, -0.7747983, 0.937466...",641,"[6, 40, 2, 22, 40, 31, 4, 20, 17, 24, 40, 31, ...",24,"[73, 32, 97, 109, 32, 116, 97, 108, 107, 105, ...",I am talking to my family.,t15.2023.11.03,1,1
2,"[[-0.23900536, -0.8310198, -0.75461406, -0.489...",694,"[17, 31, 40, 17, 38, 40, 21, 33, 20, 17, 24, 4...",22,"[73, 116, 32, 105, 115, 32, 108, 111, 111, 107...",It is looking quite hard.,t15.2023.11.03,1,2
3,"[[-0.24192314, 0.2851043, -0.7366848, -0.48644...",691,"[36, 6, 40, 9, 25, 23, 31, 40, 37, 34, 40, 20,...",19,"[87, 104, 121, 32, 100, 111, 110, 39, 116, 32,...",Why don't you come here.,t15.2023.11.03,1,3
4,"[[-0.24300373, -0.8231351, -0.71545774, -0.482...",918,"[6, 40, 37, 34, 39, 3, 36, 3, 21, 18, 40, 15, ...",29,"[73, 32, 117, 115, 117, 97, 108, 108, 121, 32,...",I usually go home by this time.,t15.2023.11.03,1,4


In [12]:
# From train_df, find all unique id in column 'seq_class_ids'
unique_ids = set()
for seq in train_df['transcriptions']:
    unique_ids.update(seq)
unique_ids = sorted(list(unique_ids))

In [13]:
Config.set_vocabulary(unique_ids)

In [27]:
vocab_size = Config.VOCAB_SIZE
vocabulary = Config.VOCABULARY
id_to_char = Config.ID_TO_CHAR
char_to_id = Config.CHAR_TO_ID

print("Vocabulary Size:", vocab_size)
print("Vocabulary:", vocabulary)

Vocabulary Size: 63
Vocabulary: ['<blank>', ' ', '!', "'", ',', '-', '.', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'Y', 'Z', '[', ']', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '’']


In [31]:
sample_ids = train_df.iloc[0]['transcriptions']
# print("Sample IDs:", sample_ids)

# Map the ids to char using the id_to_token dictionary
sample_chars = [id_to_char[id_] for id_ in sample_ids]
print("Sample Characters:", sample_chars)

Sample Characters: ['I', ' ', 'w', 'i', 'l', 'l', ' ', 'g', 'o', ' ', 'a', 'r', 'o', 'u', 'n', 'd', '.', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<blank>', '<bl