In [None]:
# Import necessary libraries
import os
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from asr_model import ASRModel
from dataset import AISHELL1Dataset, PadCollate, PretrainedVGGExtractor

# Init Variables

In [None]:
# Define arguments
TRANSCRIPT_PATH = "path/to/transcript.txt"  # Replace with your transcript path
WAV_PATH = "path/to/wav"  # Replace with your wav directory path
CHECKPOINT_PATH = "path/to/checkpoint.pth"  # Replace with your checkpoint path
STRUCTURE = "A"  # Model structure (e.g., 'A', 'B', 'C')
BATCH_SIZE = 4  # Batch size for the demo
NUM_WORKERS = 2  # Number of workers for DataLoader
TOKENIZER_NAME = "bert-base-chinese"  # Tokenizer name
RESHAPE_VGG_OUTPUT = True  # Whether to reshape VGG output

# Load Dataset, DataLoader

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME)
PAD_IDX = tokenizer.pad_token_id

# Initialize VGG Feature Extractor
vgg_model = PretrainedVGGExtractor(freeze_features=True)

# Initialize PadCollate
pad_collate_instance = PadCollate(
    pad_idx=PAD_IDX,
    vgg_model=vgg_model,
    tokenizer=tokenizer,
    reshape_features=RESHAPE_VGG_OUTPUT
)

# Create Dataset and Dataloader
test_dataset = AISHELL1Dataset(
    TRANSCRIPT_PATH, WAV_PATH, split='test'
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=pad_collate_instance,
    num_workers=NUM_WORKERS
)

print(f"Test Dataloader Length: {len(test_dataloader)}")

# Load Model

In [None]:
# Initialize device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model
model = ASRModel(model_dim=768, mode=STRUCTURE).to(device)

# Load checkpoint
checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Checkpoint loaded from {CHECKPOINT_PATH}")

# Show model architecture
print(model)

# Show Samples

In [None]:
# Get one batch from the dataloader
batch = next(iter(test_dataloader))
print("Batch Keys:")
print(f'Text: {batch['original_transcript']}')

# Forward Pass

In [None]:
from evaluate import load
cer = load("cer")

## Teacher Forcing Generation

In [None]:
# Perform forward pass using teacher_forcing_generate_tokens
from utils import teacher_forcing_generate_tokens

# Move batch to device
batch = {key: value.to(device) if isinstance(value, torch.Tensor) else value for key, value in batch.items()}

# Compute predictions and CER score
result_teacher_forcing = teacher_forcing_generate_tokens(tokenizer, model, batch, cer)

# Display results
print("Teacher Forcing Results:")
print(f"Decoded Predictions: {result_teacher_forcing['decoded_predictions']}")
print(f"Decoded References: {result_teacher_forcing['decoded_references']}")
print(f"CER Score: {result_teacher_forcing['cer_score']:.4f}")

## Normal Generation

In [None]:
# Perform forward pass using generate_tokens
from utils import generate_tokens

# Compute predictions and CER score
result_generate_tokens = generate_tokens(tokenizer, model, batch, cer)

# Display results
print("Generate Tokens Results:")
print(f"Decoded Predictions: {result_generate_tokens['decoded_predictions']}")
print(f"Decoded References: {result_generate_tokens['decoded_references']}")
print(f"CER Score: {result_generate_tokens['cer_score']:.4f}")