In [28]:
import jax
import jax.numpy as jnp

def forced_align_impl_jnp(logProbs, targets, blank, paths):
    """
    This function performs forced alignment implementation using jax.numpy.
    
    Args:
        logProbs (jnp.ndarray): The log probabilities.
        targets (jnp.ndarray): The target values.
        blank (int): The blank label.
        paths (jnp.ndarray): The output paths.
    """
    kNegInfinity = -jnp.inf
    T = logProbs.shape[0]
    L = targets.shape[0]
    S = 2 * L + 1
    alphas = jnp.full((2, S), kNegInfinity)
    backPtr = jnp.full((T, S), -1, dtype=jnp.int8)
    R = 0
    for i in range(1, L):
        if targets[i] == targets[i - 1]:
            R += 1
    if T < L + R:
        raise ValueError(f"targets length is too long for CTC. Found targets length: {T}, log_probs length: {L}, and number of repeats: {R}")
    start = 0 if T - (L + R) > 0 else 1
    end = 1 if S == 1 else 2
    for i in range(start, end):
        labelIdx = blank if i % 2 == 0 else targets[i // 2]
        alphas = alphas.at[0, i].set(logProbs[0, labelIdx])
    for t in range(1, T):
        if T - t <= L + R:
            if start % 2 == 1 and targets[start // 2] != targets[start // 2 + 1]:
                start += 1
            start += 1
        if t <= L + R:
            if end % 2 == 0 and end < 2 * L and targets[end // 2 - 1] != targets[end // 2]:
                end += 1
            end += 1
        startloop = start
        curIdxOffset = t % 2
        prevIdxOffset = (t - 1) % 2
        alphas = alphas.at[curIdxOffset, :].set(kNegInfinity)
        if start == 0:
            alphas = alphas.at[curIdxOffset, 0].set(alphas[prevIdxOffset, 0] + logProbs[t, blank])
            backPtr = backPtr.at[t, 0].set(0)
            startloop += 1
        for i in range(startloop, end):
            x0 = alphas[prevIdxOffset, i]
            x1 = alphas[prevIdxOffset, i - 1]
            x2 = kNegInfinity
            labelIdx = blank if i % 2 == 0 else targets[i // 2]
            if i % 2 != 0 and i != 1 and targets[i // 2] != targets[i // 2 - 1]:
                x2 = alphas[prevIdxOffset, i - 2]
            result = 0.0
            if x2 > x1 and x2 > x0:
                result = x2
                backPtr = backPtr.at[t, i].set(2)
            elif x1 > x0 and x1 > x2:
                result = x1
                backPtr = backPtr.at[t, i].set(1)
            else:
                result = x0
                backPtr = backPtr.at[t, i].set(0)
            alphas = alphas.at[curIdxOffset, i].set(result + logProbs[t, labelIdx])
    idx1 = (T - 1) % 2
    ltrIdx = S - 1 if alphas[idx1, S - 1] > alphas[idx1, S - 2] else S - 2
    for t in range(T - 1, -1, -1):
        lbl_idx = blank if ltrIdx % 2 == 0 else targets[ltrIdx // 2]
        paths = paths.at[t].set(lbl_idx)
        ltrIdx -= backPtr[t, ltrIdx]
    return paths, alphas, backPtr

def compute_jnp(logProbs, targets, inputLengths, targetLengths, blank):
    """
    This function performs computation using jax.numpy.
    
    Args:
        logProbs (jnp.ndarray): The log probabilities.
        targets (jnp.ndarray): The target values.
        inputLengths (int): The input lengths.
        targetLengths (int): The target lengths.
        blank (int): The blank label.
        
    Returns:
        tuple: The output paths and the log probabilities for those paths.
    """
    if not isinstance(logProbs, jnp.ndarray):
        raise ValueError("log_probs must be a jax numpy array")
    if not isinstance(targets, jnp.ndarray):
        raise ValueError("targets must be a jax numpy array")
    if not jnp.issubdtype(logProbs.dtype, jnp.floating):
        raise ValueError("log_probs must be float64, float32 or float16 (half) type")
    if not jnp.issubdtype(targets.dtype, jnp.integer):
        raise ValueError("targets must be int32 or int64 type")
    if len(logProbs.shape) != 2:
        raise ValueError("log_probs must be 2-D (input length, num classes)")
    if len(targets.shape) != 1:
        raise ValueError("targets must be 1-D (target length,)")
    if jnp.ndim(inputLengths) != 0:
        raise ValueError("input_lengths must be 0-D")
    if jnp.ndim(targetLengths) != 0:
        raise ValueError("target_lengths must be 0-D")
    if blank < 0 or blank >= logProbs.shape[-1]:
        raise ValueError("blank must be within [0, num classes)")
    if logProbs.shape[0] != inputLengths:
        raise ValueError("input length mismatch")
    if targets.shape[0] != targetLengths:
        raise ValueError("target length mismatch")
    T = logProbs.shape[0]
    paths = jnp.zeros(T, dtype=targets.dtype)
    paths, alphas, backPtr = forced_align_impl_jnp(logProbs, targets, blank, paths)
    return paths, logProbs[jnp.arange(T), paths]


In [2]:
# Let's test this with some random data
logProbs = jnp.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]])
targets = jnp.array([0, 2, 3])
blank = 1

In [9]:
logProbs = jax.nn.log_softmax(logProbs, axis=-1)

In [11]:
input_lengths = jnp.array(logProbs.shape[0])
target_lengths = jnp.array(targets.shape[0])

In [12]:
compute_jnp(logProbs, targets, input_lengths, target_lengths, blank)

(Array([0, 2, 3], dtype=int32),
 Array([-1.5425494, -1.3425493, -1.2425493], dtype=float32))

## Numpy Implementation

In [13]:
import numpy as np

def forced_align_impl_np(logProbs, targets, blank, paths):
    kNegInfinity = -np.inf
    T = logProbs.shape[0]
    L = targets.shape[0]
    S = 2 * L + 1
    alphas = np.full((2, S), kNegInfinity)
    backPtr = np.full((T, S), -1, dtype=np.int8)
    R = 0
    for i in range(1, L):
        if targets[i] == targets[i - 1]:
            R += 1
    if T < L + R:
        raise ValueError(f"targets length is too long for CTC. Found targets length: {T}, log_probs length: {L}, and number of repeats: {R}")
    start = 0 if T - (L + R) > 0 else 1
    end = 1 if S == 1 else 2
    for i in range(start, end):
        labelIdx = blank if i % 2 == 0 else targets[i // 2]
        alphas[0, i] = logProbs[0, labelIdx]
    for t in range(1, T):
        if T - t <= L + R:
            if start % 2 == 1 and targets[start // 2] != targets[start // 2 + 1]:
                start += 1
            start += 1
        if t <= L + R:
            if end % 2 == 0 and end < 2 * L and targets[end // 2 - 1] != targets[end // 2]:
                end += 1
            end += 1
        startloop = start
        curIdxOffset = t % 2
        prevIdxOffset = (t - 1) % 2
        alphas[curIdxOffset, :] = kNegInfinity
        if start == 0:
            alphas[curIdxOffset, 0] = alphas[prevIdxOffset, 0] + logProbs[t, blank]
            backPtr[t, 0] = 0
            startloop += 1
        for i in range(startloop, end):
            x0 = alphas[prevIdxOffset, i]
            x1 = alphas[prevIdxOffset, i - 1]
            x2 = kNegInfinity
            labelIdx = blank if i % 2 == 0 else targets[i // 2]
            if i % 2 != 0 and i != 1 and targets[i // 2] != targets[i // 2 - 1]:
                x2 = alphas[prevIdxOffset, i - 2]
            result = 0.0
            if x2 > x1 and x2 > x0:
                result = x2
                backPtr[t, i] = 2
            elif x1 > x0 and x1 > x2:
                result = x1
                backPtr[t, i] = 1
            else:
                result = x0
                backPtr[t, i] = 0
            alphas[curIdxOffset, i] = result + logProbs[t, labelIdx]
    idx1 = (T - 1) % 2
    ltrIdx = S - 1 if alphas[idx1, S - 1] > alphas[idx1, S - 2] else S - 2
    for t in range(T - 1, -1, -1):
        lbl_idx = blank if ltrIdx % 2 == 0 else targets[ltrIdx // 2]
        paths[t] = lbl_idx
        ltrIdx -= backPtr[t, ltrIdx]

def compute_np(logProbs, targets, inputLengths, targetLengths, blank):
    if not isinstance(logProbs, np.ndarray):
        raise ValueError("log_probs must be a numpy array")
    if not isinstance(targets, np.ndarray):
        raise ValueError("targets must be a numpy array")
    if not np.issubdtype(logProbs.dtype, np.floating):
        raise ValueError("log_probs must be float64, float32 or float16 (half) type")
    if not np.issubdtype(targets.dtype, np.integer):
        raise ValueError("targets must be int32 or int64 type")
    if len(logProbs.shape) != 2:
        raise ValueError("log_probs must be 2-D (input length, num classes)")
    if len(targets.shape) != 1:
        raise ValueError("targets must be 1-D (target length,)")
    if np.ndim(inputLengths) != 0:
        raise ValueError("input_lengths must be 0-D")
    if np.ndim(targetLengths) != 0:
        raise ValueError("target_lengths must be 0-D")
    if blank < 0 or blank >= logProbs.shape[-1]:
        raise ValueError("blank must be within [0, num classes)")
    if logProbs.shape[0] != inputLengths:
        raise ValueError("input length mismatch")
    if targets.shape[0] != targetLengths:
        raise ValueError("target length mismatch")
    T = logProbs.shape[0]
    paths = np.zeros(T, dtype=targets.dtype)
    forced_align_impl_np(logProbs, targets, blank, paths)
    return paths, logProbs[np.arange(T), paths]

In [17]:
compute_np(np.asarray(logProbs), np.asarray(targets), np.asarray(input_lengths), np.asarray(target_lengths), blank)

(array([0, 2, 3], dtype=int32),
 array([-1.5425494, -1.3425493, -1.2425493], dtype=float32))

## Torch implementation

In [20]:
from torchaudio.functional import forced_align
import torch

In [22]:
forced_align(torch.from_numpy(np.asarray(logProbs)), torch.from_numpy(np.asarray(targets)), torch.from_numpy(np.asarray(input_lengths)), torch.from_numpy(np.asarray(target_lengths)), blank)

  forced_align(torch.from_numpy(np.asarray(logProbs)), torch.from_numpy(np.asarray(targets)), torch.from_numpy(np.asarray(input_lengths)), torch.from_numpy(np.asarray(target_lengths)), blank)


(tensor([0, 2, 3], dtype=torch.int32), tensor([-1.5425, -1.3425, -1.2425]))

## Parellized jax implementation

In [100]:
import jax
import jax.numpy as jnp

def inner_loop_fn(t, i, startloop, end, prevIdxOffset, curIdxOffset, alphas, backPtr, logProbs, targets, kNegInfinity, blank):
    """
    This function represents the logic inside the inner loop of the forced_align_impl_jnp function.
    
    Args:
        t (int): The outer loop variable.
        i (int): The inner loop variable.
        startloop (int): The starting index of the inner loop.
        end (int): The ending index of the inner loop.
        prevIdxOffset (int): The offset index of the previous timestep.
        curIdxOffset (int): The offset index of the current timestep.
        alphas (jnp.ndarray): The alpha values.
        backPtr (jnp.ndarray): The backpointer values.
        logProbs (jnp.ndarray): The log probabilities.
        targets (jnp.ndarray): The target values.
        kNegInfinity (float): A constant representing negative infinity.
        blank (int): The blank label.
        
    Returns:
        Tuple[jnp.ndarray, jnp.ndarray]: The updated alpha values and backpointer values.
    """
    print(f"Input Shape of backPtr: {backPtr.shape}")
    print(f"Input Shape of alphas: {alphas.shape}")
    x0 = alphas[prevIdxOffset, i]
    x1 = alphas[prevIdxOffset, i - 1]
    x2 = kNegInfinity
    labelIdx = jnp.where(i % 2 == 0, blank, targets[i // 2])
    condition = jnp.logical_and(jnp.logical_and(i % 2 != 0, i != 1), targets[i // 2] != targets[i // 2 - 1])
    x2 = jnp.where(condition, alphas[prevIdxOffset, i - 2], kNegInfinity)
    result = 0.0
    
    cond1 = jnp.logical_and(x2 > x1, x2 > x0)
    cond2 = jnp.logical_and(~cond1, jnp.logical_and(x1 > x0, x1 > x2))

    result = jnp.where(cond1, x2, jnp.where(cond2, x1, x0))
    backPtr_val = jnp.where(cond1, 2, jnp.where(cond2, 1, 0))

    print(f"Output Shape of backPtr: {backPtr.shape}")
    backPtr = backPtr.at[t, i].set(backPtr_val)
    
    print(f"Output Shape of alphas: {alphas.shape}")
    alphas = alphas.at[curIdxOffset, i].set(result + logProbs[t, labelIdx])
    
    return alphas, backPtr

inner_loop_vmap = jax.vmap(inner_loop_fn, in_axes=(None, 0, None, None, None, None, None, None, None, None, None, None), out_axes=(0, 0))


def forced_align_impl_jnp(logProbs, targets, blank, paths):
    """
    This function performs forced alignment implementation using jax.numpy.
    
    Args:
        logProbs (jnp.ndarray): The log probabilities.
        targets (jnp.ndarray): The target values.
        blank (int): The blank label.
        paths (jnp.ndarray): The output paths.
    """
    kNegInfinity = -jnp.inf
    T = logProbs.shape[0]
    L = targets.shape[0]
    S = 2 * L + 1
    alphas = jnp.full((2, S), kNegInfinity)
    backPtr = jnp.full((T, S), -1, dtype=jnp.int8)
    R = 0
    for i in range(1, L):
        if targets[i] == targets[i - 1]:
            R += 1
    if T < L + R:
        raise ValueError(f"targets length is too long for CTC. Found targets length: {T}, log_probs length: {L}, and number of repeats: {R}")
    start = 0 if T - (L + R) > 0 else 1
    end = 1 if S == 1 else 2
    for i in range(start, end):
        labelIdx = blank if i % 2 == 0 else targets[i // 2]
        alphas = alphas.at[0, i].set(logProbs[0, labelIdx])
    for t in range(1, T):
        if T - t <= L + R:
            if start % 2 == 1 and targets[start // 2] != targets[start // 2 + 1]:
                start += 1
            start += 1
        if t <= L + R:
            if end % 2 == 0 and end < 2 * L and targets[end // 2 - 1] != targets[end // 2]:
                end += 1
            end += 1
        startloop = start
        curIdxOffset = t % 2
        prevIdxOffset = (t - 1) % 2
        alphas = alphas.at[curIdxOffset, :].set(kNegInfinity)
        if start == 0:
            alphas = alphas.at[curIdxOffset, 0].set(alphas[prevIdxOffset, 0] + logProbs[t, blank])
            backPtr = backPtr.at[t, 0].set(0)
            startloop += 1
        alphas, backPtr = inner_loop_vmap(t, jnp.arange(startloop, end), startloop, end, prevIdxOffset, curIdxOffset, alphas, backPtr, logProbs, targets, kNegInfinity, blank)
        
        # find fix for these two lines
        alphas = jnp.squeeze(alphas)
        backPtr = jnp.squeeze(backPtr)

    idx1 = (T - 1) % 2
    ltrIdx = jnp.where(alphas[idx1, S - 1] > alphas[idx1, S - 2], S - 1, S - 2)
    for t in range(T - 1, -1, -1):
        lbl_idx = jnp.where(ltrIdx % 2 == 0, blank, targets[ltrIdx // 2])
        paths = paths.at[t].set(jnp.array(lbl_idx))
        ltrIdx -= backPtr[t, ltrIdx]
    return paths, alphas, backPtr

def compute_jnp(logProbs, targets, inputLengths, targetLengths, blank):
    """
    This function performs computation using jax.numpy.
    
    Args:
        logProbs (jnp.ndarray): The log probabilities.
        targets (jnp.ndarray): The target values.
        inputLengths (int): The input lengths.
        targetLengths (int): The target lengths.
        blank (int): The blank label.
        
    Returns:
        tuple: The output paths and the log probabilities for those paths.
    """
    if not isinstance(logProbs, jnp.ndarray):
        raise ValueError("log_probs must be a jax numpy array")
    if not isinstance(targets, jnp.ndarray):
        raise ValueError("targets must be a jax numpy array")
    if not jnp.issubdtype(logProbs.dtype, jnp.floating):
        raise ValueError("log_probs must be float64, float32 or float16 (half) type")
    if not jnp.issubdtype(targets.dtype, jnp.integer):
        raise ValueError("targets must be int32 or int64 type")
    if len(logProbs.shape) != 2:
        raise ValueError("log_probs must be 2-D (input length, num classes)")
    if len(targets.shape) != 1:
        raise ValueError("targets must be 1-D (target length,)")
    if jnp.ndim(inputLengths) != 0:
        raise ValueError("input_lengths must be 0-D")
    if jnp.ndim(targetLengths) != 0:
        raise ValueError("target_lengths must be 0-D")
    if blank < 0 or blank >= logProbs.shape[-1]:
        raise ValueError("blank must be within [0, num classes)")
    if logProbs.shape[0] != inputLengths:
        raise ValueError("input length mismatch")
    if targets.shape[0] != targetLengths:
        raise ValueError("target length mismatch")
    T = logProbs.shape[0]
    paths = jnp.zeros(T, dtype=targets.dtype)
    paths, alphas, backPtr = forced_align_impl_jnp(logProbs, targets, blank, paths)
    return paths, logProbs[jnp.arange(T), paths]


In [82]:
# Let's test this with some random data
logProbs = jnp.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], [0.9, 1.0, 1.1, 1.2]])
targets = jnp.array([0, 2, 3])
blank = 1

logProbs = jax.nn.log_softmax(logProbs, axis=-1)

input_lengths = jnp.array(logProbs.shape[0])
target_lengths = jnp.array(targets.shape[0])

compute_jnp(logProbs, targets, input_lengths, target_lengths, blank)

Input Shape of alphas: (2, 7), value: [[      -inf -1.5425494       -inf       -inf       -inf       -inf
        -inf]
 [      -inf       -inf       -inf       -inf       -inf       -inf
        -inf]]
Output Shape of alphas: (2, 7), value: Traced<ShapedArray(float32[2,7], weak_type=True)>with<BatchTrace(level=1/0)> with
  val = Array([[[      -inf, -1.5425494,       -inf,       -inf,       -inf,
               -inf,       -inf],
        [      -inf,       -inf,       -inf, -2.8850987,       -inf,
               -inf,       -inf]]], dtype=float32, weak_type=True)
  batch_dim = 0
Input Shape of alphas: (2, 7), value: [[      -inf       -inf       -inf       -inf       -inf       -inf
        -inf]
 [      -inf       -inf       -inf -2.8850987       -inf       -inf
        -inf]]
Output Shape of alphas: (2, 7), value: Traced<ShapedArray(float32[2,7], weak_type=True)>with<BatchTrace(level=1/0)> with
  val = Array([[[      -inf,       -inf,       -inf,       -inf,       -inf,
         -4.

(Array([0, 2, 3], dtype=int32),
 Array([-1.5425494, -1.3425493, -1.2425493], dtype=float32))

In [44]:
targets[0].item()

0

In [83]:
import numpy as np
import zipfile
import os
import torch

# Unzip the provided file
with zipfile.ZipFile("forced_align_input_outputs.zip", 'r') as zip_ref:
    zip_ref.extractall("forced_align_input_outputs")

# Check the contents of the directory
os.listdir("forced_align_input_outputs")


['targets.pt',
 'target_lengths.pt',
 'paths.pt',
 'input_lengths.pt',
 'emissions.pt']

In [84]:
# Load the provided tensors
logProbs = torch.load("forced_align_input_outputs/emissions.pt")
targets = torch.load("forced_align_input_outputs/targets.pt")
inputLengths = torch.load("forced_align_input_outputs/input_lengths.pt")
targetLengths = torch.load("forced_align_input_outputs/target_lengths.pt")
paths = torch.load("forced_align_input_outputs/paths.pt")

# Convert to numpy arrays
logProbs_np = logProbs.numpy()
targets_np = targets.numpy()
inputLengths_np = inputLengths.numpy()
targetLengths_np = targetLengths.numpy()
paths_np = paths.numpy()

blank = 0

logProbs_np, targets_np, inputLengths_np, targetLengths_np, paths_np


(array([[-7.3520811e-03, -2.7093369e+01, -2.7509064e+01, ...,
         -8.7403002e+00, -1.1601903e+01, -1.1159496e+01],
        [-1.3696853e-02, -2.7963503e+01, -2.8412813e+01, ...,
         -9.0036783e+00, -1.1766139e+01, -1.0991915e+01],
        [-3.1517904e-02, -2.9013901e+01, -2.9537155e+01, ...,
         -9.1752405e+00, -1.2215663e+01, -1.1296188e+01],
        ...,
        [-8.4026694e-02, -2.5163204e+01, -2.5194981e+01, ...,
         -7.7668262e+00, -8.7488461e+00, -8.5354090e+00],
        [-8.3044618e-02, -2.5117912e+01, -2.5149914e+01, ...,
         -7.6841059e+00, -8.7871141e+00, -8.6122017e+00],
        [-9.2383891e-02, -2.5157015e+01, -2.5237972e+01, ...,
         -7.7255855e+00, -8.7317905e+00, -8.4821301e+00]], dtype=float32),
 array([12,  4,  4, ..., 19,  4,  4], dtype=int32),
 array(10658),
 array(1648),
 array([0, 0, 0, ..., 0, 0, 0], dtype=int32))

In [97]:
targets_np.shape[0] * 2

3296

In [87]:
paths[415:515]

tensor([ 0,  0,  0,  0,  0,  0,  0,  4,  0,  0,  0,  0,  0, 13,  0,  0,  5,  0,
         0,  5,  0,  0,  7,  0,  0,  0,  9,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0, 16,  0,  0,  0,  0,  0,  0,  0,  0, 16,  0,  0,  4,  0,  0,
         9,  0,  0, 15,  0,  0,  0,  4,  0,  0,  4,  0,  0,  0,  0, 21,  0,  0,
         0,  0,  4,  0,  0,  0,  0,  4,  0,  0, 12,  0,  0,  0,  0,  0, 14,  0,
         0,  4,  0,  0,  0,  0,  0,  0,  0,  0], dtype=torch.int32)

In [101]:
paths_result, _ = compute_jnp(jnp.asarray(logProbs_np), jnp.asarray(targets_np), jnp.asarray(inputLengths_np), jnp.asarray(targetLengths_np), blank)

Input Shape of backPtr: (10658, 3297)
Input Shape of alphas: (2, 3297)
Output Shape of backPtr: (10658, 3297)
Output Shape of alphas: (2, 3297)
Input Shape of backPtr: (3, 10658, 3297)
Input Shape of alphas: (3, 2, 3297)
Output Shape of backPtr: (3, 10658, 3297)
Output Shape of alphas: (3, 2, 3297)
Input Shape of backPtr: (4, 3, 10658, 3297)
Input Shape of alphas: (4, 3, 2, 3297)
Output Shape of backPtr: (4, 3, 10658, 3297)


ValueError: Incompatible shapes for broadcasting: (2, 3297) and requested shape (10658, 3297)

In [None]:
jnp.sum()