In [2]:
import math
import numpy as np

# This is prob_grad.py

def compute_log_p(X, y, W, T):
    """
    Computes the log probability of a sequence of labels given inputs X and parameters W, T.
    
    Parameters:
    X : 2D array where each row is the feature vector for one observation.
    y : 1D array of labels corresponding to the observations in X.
    W : Weight matrix where each row corresponds to the weights for one label.
    T : Transition matrix where T[i, j] is the transition weight from label i to label j.
    
    Returns:
    log probability of the label sequence given the inputs and parameters.
    """
    # Initialize the log probability sum
    sum_num = np.dot(W[y[0]], X[0])  # Contribution from the first label and feature
    
    # Add contributions from the rest of the labels and transitions
    for i in range(1, X.shape[0]):
        sum_num += np.dot(W[y[i]], X[i]) + T[y[i-1], y[i]]
    
    # Compute forward probabilities for log partition function (normalizing constant)
    alpha_len = W.shape[0]  # Number of possible labels (alphabet size)
    trellisfw = np.zeros((X.shape[0], alpha_len))
    for i in range(1, X.shape[0]):
        np.matmul(W, X[i-1], out=trellisfw[i-1])
        T_ext = T + trellisfw[i-1][:, np.newaxis]  # Add the transition scores
        log_sum = np.log(np.sum(np.exp(T_ext - np.max(T_ext)), axis=0))
        trellisfw[i] = log_sum + np.max(T_ext)

    # Log partition function is the log sum of the last column of forward probabilities
    log_z = np.max(trellisfw[-1]) + np.log(np.sum(np.exp(trellisfw[-1] - np.max(trellisfw[-1]))))
    
    # Return log probability of the sequence
    return sum_num - log_z

def fb_prob(X, W, T):
    """
    Computes forward and backward probabilities for all labels given inputs X and parameters W, T.
    
    Parameters:
    X : 2D array where each row is the feature vector for one observation.
    W : Weight matrix where each row corresponds to the weights for one label.
    T : Transition matrix where T[i, j] is the transition weight from label i to label j.
    
    Returns:
    Tuple of (forward probabilities, backward probabilities, log partition function)
    """
    alpha_len = 26  # Number of labels = W.shape[0]
    trellisfw = np.zeros((X.shape[0], alpha_len))  # Forward probabilities
    trellisbw = np.zeros((X.shape[0], alpha_len))  # Backward probabilities

    # Forward pass
    for i in range(1, X.shape[0]):
        
        # Matrix multiplication of W and X[i-1] (input features)
        np.matmul(W, X[i-1], out=trellisfw[i-1])
        
        # Adding transition scores
        np.add(T, trellisfw[i-1][:, np.newaxis], out=trellisfw[i])
        
        # Log-sum-exp trick for numerical stability
        max_val = np.max(trellisfw[i], axis=0)
        np.subtract(trellisfw[i], max_val, out=trellisfw[i])
        np.exp(trellisfw[i], out=trellisfw[i])
        sum_exp = np.sum(trellisfw[i], axis=0)
        log_sum_exp = np.log(sum_exp)
        trellisfw[i] = log_sum_exp + max_val

    # Debug: print the shape of the forward trellis
    print("Forward trellis shape:", trellisfw.shape)

    # Backward pass (compute backward probabilities)
    trellisbw[-1, :] = 0  # Initialize the last row of the backward trellis with zeros (log(1))
    for i in range(X.shape[0] - 2, -1, -1):
        # Compute the weighted features for the next time step
        weighted_features = np.dot(W, X[i + 1])
        # Compute the backward messages for each label
        for label in range(alpha_len):
            trellisbw[i, label] = np.log(np.sum(
                np.exp(
                    T[label, :] + trellisbw[i + 1, :] + weighted_features - log_z
                )
            ))

    # Debug: print the shape of the backward trellis
    print("Backward trellis shape:", trellisbw.shape)

    # Log partition function (computed using the forward trellis)
    log_z = np.max(trellisfw[-1]) + np.log(np.sum(np.exp(trellisfw[-1] - np.max(trellisfw[-1]))))

    # Debug: print the value of the log partition function
    print("Log partition function:", log_z)

    return trellisfw, trellisbw, log_z

# The following functions compute gradients for the weight matrix W and transition matrix T respectively
# given a single example (X, y), where X is the feature matrix for the sequence and y is the corresponding label sequence.

def log_p_wgrad(W, X, y, T):
    """
    Computes the gradient of the log probability with respect to the weight matrix W.
    
    Parameters:
    W : Weight matrix where each row corresponds to the weights for one label.
    X : 2D array where each row is the feature vector for one observation.
    y : 1D array of labels corresponding to the observations in X.
    T : Transition matrix where T[i, j] is the transition weight from label i to label j.
    
    Returns:
    Gradient of the log probability with respect to W.
    """
    grad_W = np.zeros(W.shape)  # Gradient matrix for W
    trellisfw, trellisbw, log_z = fb_prob(X, W, T)

    # Iterate over the sequence
    for i in range(X.shape[0]):
        # Combine forward and backward messages
        marginal = trellisfw[i] + trellisbw[i]
        # Incorporate the evidence from input features
        evidence = np.matmul(W, X[i])
        # Subtract the log partition function
        marginal -= log_z
        # Normalize to get probabilities
        marginal = np.exp(marginal)

        # Calculate the gradient for the current position
        for j in range(26):  # Iterate over all possible labels
            if j == y[i]:
                grad_W[j] += X[i]  # Add the feature vector for the true label
            grad_W[j] -= marginal[j] * X[i]  # Subtract the expected feature vector

    return grad_W

def log_p_tgrad(T, X, y, W):
    """
    Computes the gradient of the log probability with respect to the transition matrix T.
    
    Parameters:
    T : Transition matrix where T[i, j] is the transition weight from label i to label j.
    X : 2D array where each row is the feature vector for one observation.
    y : 1D array of labels corresponding to the observations in X.
    W : Weight matrix where each row corresponds to the weights for one label.
    
    Returns:
    Gradient of the log probability with respect to T.
    """
    grad_T = np.zeros(T.shape)  # Gradient matrix for T
    trellisfw, trellisbw, log_z = fb_prob(X, W, T)

    # Iterate over the transitions in the sequence
    for i in range(X.shape[0] - 1):
        # Calculate the potential for all label transitions
        potential = np.outer(np.matmul(W, X[i]), np.matmul(W, X[i + 1]))
        # Add transition scores
        potential += T
        # Incorporate forward and backward messages
        potential += trellisfw[i, :, np.newaxis] + trellisbw[i + 1]
        # Subtract the log partition function
        potential -= log_z
        # Normalize to get joint probabilities of label transitions
        potential = np.exp(potential)

        # Calculate the gradient for the current transition
        for j in range(26):
            for k in range(26):
                if j == y[i] and k == y[i + 1]:
                    grad_T[j, k] += 1  # Increment for the true transition
                grad_T[j, k] -= potential[j, k]  # Subtract the expected count

    return grad_T

# Example usage:
# Assuming X, y, W, and T are already loaded
# grad_W = log_p_wgrad(W, X, y, T)
# grad_T = log_p_tgrad(T, X, y, W)
# The gradients can be used in an optimization algorithm to update W and T


In the backward pass, we go backwards through the time steps and calculate the backward probabilities, which, along with the forward probabilities, are used to compute the marginal probabilities. The **log_z** value is the log partition function computed from the forward trellis, which serves as the normalization constant to ensure that the probabilities sum to one.

The above functions **log_p_wgrad** and **log_p_tgrad** compute the gradients of the log probability with respect to the weights matrix W and the transition matrix T, respectively. They utilize the **fb_prob** function, which performs the forward-backward algorithm to compute the necessary probabilities. Debugging statements and comprehensive comments have been added for clarity. Please ensure that you have the correct data structures and that the matrices W and T are properly initialized before calling these functions.
