In [40]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [41]:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"  # Enable CPU fallback for MPS
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"  # Disable high watermark for MPS
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

In [42]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from config import Configurator
from data import NeuralDataset, collate_transducer, load_h5py_file, ids_to_text

Config = Configurator(phoneme=False)
Config.Train = False
path = Config.DATA_PATH
device = Config.DEVICE

In [43]:
def get_dataframes(path):
    folders = os.listdir(path)
    train_files = []
    test_files = []
    val_files = []
    for i, files in enumerate(folders):
        if files.startswith("."):
            continue
        files = os.listdir(os.path.join(path, files))
        for file in files:
            if file.endswith("train.hdf5"):
                train_files.append(os.path.join(path, folders[i], file))
            elif file.endswith("val.hdf5"):
                val_files.append(os.path.join(path, folders[i], file))
            elif file.endswith("test.hdf5"):
                test_files.append(os.path.join(path, folders[i], file))

    train_df = pd.DataFrame()
    i = 0
    for file in tqdm(train_files, desc="Loading train files"):
        data = load_h5py_file(file)
        temp_df = pd.DataFrame(data)
        train_df = pd.concat([train_df, temp_df], ignore_index=True)
        if Config.DEBUG:
            i += 1
            if i >= 1:  # load only 4 files in debug mode
                break

    val_df = pd.DataFrame()
    i = 0
    for file in tqdm(val_files, desc="Loading val files"):
        data = load_h5py_file(file)
        temp_df = pd.DataFrame(data)
        val_df = pd.concat([val_df, temp_df], ignore_index=True)
        if Config.DEBUG:
            i += 1
            if i >= 1:  # load only 2 files in debug mode
                break

    return train_df, val_df, test_files

train_df, val_df, test_files = get_dataframes(path)

Loading train files:   0%|          | 0/45 [00:00<?, ?it/s]
Loading val files:   0%|          | 0/41 [00:00<?, ?it/s]


In [44]:
# ------------------------ Dataset and Dataloader ------------------------
train_dataset = NeuralDataset(train_df, augment=True, smoothing=True)
train_loader = DataLoader(
    train_dataset,
    batch_size=Config.BATCH_SIZE,
    shuffle=True,
    collate_fn=lambda b: collate_transducer(b, pad_id=Config.PAD_ID, batch_first=Config.BATCH_FIRST)
)

val_dataset = NeuralDataset(val_df, augment=True, smoothing=True)
val_loader = DataLoader(
    val_dataset,
    batch_size=Config.BATCH_SIZE,
    shuffle=True,
    collate_fn=lambda b: collate_transducer(b, pad_id=Config.PAD_ID, batch_first=Config.BATCH_FIRST)
)

In [45]:
from models import ConformerEncoderDecoder
from training import Trainer, EarlyStopping

In [46]:
# 1. Define your Simpler Model
model = ConformerEncoderDecoder(
    vocab_size=Config.VOCAB_SIZE,
).to(device)

In [47]:
num_epochs = 10
lr = 3e-3
model_checkpoint_path = "./best_model.pth"

In [48]:
loss_fn = nn.CrossEntropyLoss(ignore_index=Configurator.PAD_ID)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=lr,
    weight_decay=1e-2
)

lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=lr,
    steps_per_epoch=len(train_loader),
    epochs=num_epochs,
    anneal_strategy='cos'
)

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    scheduler=lr_scheduler,
    loss_fn=loss_fn,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    epochs=num_epochs,
    early_stop=EarlyStopping(patience=10, min_delta=1e-3, path=model_checkpoint_path),
    batch_first=Config.BATCH_FIRST,
)

In [49]:
# trainer.run()

In [50]:
def run_single_prediction(model, dataset, index=0):
    device = Config.DEVICE
    model.eval()

    # 1. Get one sample
    x, y_ids = dataset[index] # x is (T, 512), y_ids is (L,)

    # 2. Prepare inputs
    x = x.unsqueeze(0).to(device) # Add batch dim: (1, T, 512)
    x_len = torch.tensor([x.size(1)], dtype=torch.long).to(device)

    # 3. Predict
    print(f"Input Shape: {x.shape}")
    print("Generating text...")

    # Returns tensor of IDs including SOS and EOS
    predicted_ids = model.predict(
        x,
        x_len,
        max_len=200,
        sos_id=Config.SOS_ID,
        eos_id=Config.EOS_ID
    )

    # 4. Decode to String
    # Remove batch dim and convert to list
    pred_id_list = predicted_ids[0].cpu().tolist()

    # Convert IDs to Text (using your new helper in data.py)
    predicted_text = ids_to_text(pred_id_list)
    ground_truth = ids_to_text(y_ids.tolist())

    print("-" * 30)
    print(f"Ground Truth: {ground_truth}")
    print(f"Prediction:   {predicted_text}")
    print("-" * 30)

In [51]:
run_single_prediction(model, val_dataset, index=0)

Input Shape: torch.Size([1, 978, 512])
Generating text...
------------------------------
Ground Truth: Qebcolpqv^fom^ppbaqeolrdeqeb`l^q+
Prediction:   *i*F+*F&ti*F=7*	x*	x*	x*	x*	x*	x*	x*	x*	x*i*	x*i*	x*	x*i*	x*i*	x*	x*i*	x*i*	x*	x*i*	x*i*	x*	x*i*	x*i*	x*	x*i*	x*i*	x*	x*i*	x*i*	x*	x*i*	x*i*	x*	x*i*	x*i*	x*	x*i*	x*i*	x*	x*i*	x*i*	x*	x*i*	x*i*	x*	x*i*
------------------------------
