Preprocessing
- Decapitalise everything. Either:
    - lowercase every token after a '.' token or
    - lowercase everything
    - I go with 2nd option because there are other punctuation marks (some of which are quotes) and I don't want to do SBD
    - I need to distinguish between UNKNOWNs and NAMEs which have tag Propernoun

The plan:

- Write script to list out all tokens with less than 0.0001N counts
- Write script to replace all tokens with less than 0.0001N counts with <UNKNOWN>
- Separate the sentences
- Get the counts (according to slide 14/39). Sum over each sentence
    - Transition count of each state pair. Emission count of each state-token pair.
- Estimate original values using slide 14/39
** Separate the sentences into 90-10 split for training and evaluation
- Online EM learn
    - Stepwise EM or whatever the fuck it's called
    - μ=(1−ηk​)μ+ηk​μ′
    - ηk is simply a step statistic: 1/(k+1)^a for iteration k
    - I guess we set alpha to 0.7
** Viterbi EM may not even be necessary? But we can implement it afterwards if it takes too long

In [None]:
import pandas as pd

# Path to your CoNLL-U file
file_path = 'ptb-train.conllu'

# Initialize list to store rows, sentence counter, and token counter
data = []
sentence_id = 0
token_id = 0

with open(file_path, 'r', encoding='utf-8') as file:
    for line in file:
        line = line.strip()
        if not line:  # Empty line indicates new sentence
            sentence_id += 1
            token_id = 0  # Reset token counter for the new sentence
        else:
            token_id += 1  # Increment token id for each token in a sentence
            parts = line.split('\t')
            if len(parts) == 10:
                data.append([sentence_id] + parts)

# Create a DataFrame
df = pd.DataFrame(data, columns=['sentence_id', 'id', 'form', 'blank', 'upos', 'xpos', 'blank', 'head', 'deprel', 'deps', 'blank'])

# Save the DataFrame to a CSV file
csv_path = 'ptb-train.csv'
df.to_csv(csv_path, index=False)

print(f"Data saved to {csv_path}")


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

df = pd.read_csv('ptb-train.csv')  
column_name = 'form'
forms = df[column_name]

# Calculate frequency distribution of words
frequency = forms.value_counts()

# Plot the frequency distribution
plt.figure(figsize=(10, 6))
frequency.plot(kind='line', logy=True)  # log scale for better visibility
plt.title('Frequency Distribution of Words')
plt.xlabel('Words')
plt.ylabel('Frequency (log scale)')
plt.grid(True)
plt.show()

# Output the frequency distribution
print(frequency)

# Check the result or save the updated DataFrame
print(df.head())  # Prints the first few rows of the updated DataFrame
# df.to_csv('updated_file.csv', index=False)  # Uncomment to save the updated DataFrame


In [None]:
import pandas as pd

# Load the CSV file
df = pd.read_csv('ptb-train.csv', keep_default_na=False, na_values=[''])
column_name = 'form'
upos_column = 'upos'

# Lowercase everything
df[column_name] = df[column_name].str.lower()

# Calculate frequency distribution of words
counts = df[column_name].value_counts()

# Apply 'UNKNOWN' or 'NAME' based on frequency and whether the word is a proper noun
threshold = 4
df[column_name] = df.apply(
    lambda row: 'NUM' if row[upos_column] == 'NUM' else
                ('NAME' if row[upos_column] == 'PROPN' else 
                ('UNKNOWN' if counts.get(row[column_name], 0) < threshold else row[column_name])),
    axis=1
)


output = f'ptb-train-{threshold}-all-lower.csv'
# Save to new csv
df.to_csv(output, index=False)
print(f"Updated DataFrame (all lowercase) saved to {output}")


In [2]:
import pandas as pd

def process_sentences(df: pd.DataFrame, id_col = 'id', forms_col = 'form', pos_tags_col = 'upos'):
    output = []
    current_sentence = []
    current_pos = []

    for index, row in df.iterrows():
        if row[id_col] == 1 and current_sentence:
            output.append((current_sentence, current_pos))
            current_sentence, current_pos = [row[forms_col]], [row[pos_tags_col]]
        else:
            current_sentence.append(row[forms_col])
            current_pos.append(row[pos_tags_col])

    # Append the last sentence if it's not empty
    if current_sentence:
        output.append((current_sentence, current_pos))

    return output


def write_sentences_to_file(sentences, filename):
    with open(filename, 'w', encoding='utf-8') as file:
        for sentence, pos_tags in sentences:
            # Writing words and POS tags on separate lines
            file.write('Words: ' + ' '.join(sentence) + '\n')
            file.write('Tags: ' + ' '.join(pos_tags) + '\n')
            file.write('\n')  # Adding a blank line between sentences for clarity

df = pd.read_csv('ptb-train-4-all-lower.csv')
processed = process_sentences(df)
write_sentences_to_file(processed, 'output_sentences.txt')



In [3]:
import numpy as np
import pandas as pd

df = pd.read_csv('ptb-train-4-all-lower.csv')
processed_sentences, processed_tags = [x[0] for x in processed], [x[1] for x in processed]

unique_words = sorted(df['form'].unique())
unique_words_dict = {w: i for (i, w) in enumerate(unique_words)}
unique_upos = sorted(df['upos'].unique())
unique_upos_dict = {t: i for (i, t) in enumerate(unique_upos)}

M = len(unique_words)
N = len(unique_upos)

print(unique_words)
print(unique_upos)


def make_transition_matrix():

    transition_matrix = np.zeros((N, N))

    transition_counts = {key: {k: 0 for k in unique_upos_dict.keys()} for key in unique_upos_dict.keys()}

    for sequence in processed_tags:
        for i in range(len(sequence) - 1):
            current_tag = sequence[i]
            next_tag = sequence[i + 1]

            if current_tag in unique_upos_dict and next_tag in unique_upos_dict:
                transition_counts[current_tag][next_tag] += 1

    for a in unique_upos_dict:
        total_transitions_from_a = sum(transition_counts[a].values())

        if total_transitions_from_a > 0:
            for b in unique_upos_dict:
                transition_matrix[unique_upos_dict[a], unique_upos_dict[b]] = transition_counts[a][b] / total_transitions_from_a

    print(transition_matrix)
    return transition_matrix

def make_emission_matrix():

    emission_matrix = np.zeros((N, M))

    emission_counts = {key: {tag: 0 for tag in unique_words_dict.keys()} for key in unique_upos_dict.keys()}

    for j in range(len(processed_tags)):
        sentence = processed_sentences[j]
        sequence = processed_tags[j]
        for i in range(len(sequence) - 1):
            current_tag = sequence[i]
            current_emission = sentence[i]

            if current_tag in unique_upos_dict and current_emission in unique_words_dict:
                emission_counts[current_tag][current_emission] += 1

    for a in unique_upos_dict:
        total_emissions_from_a = sum(emission_counts[a].values())

        if total_emissions_from_a > 0:
            for b in unique_words_dict:
                emission_matrix[unique_upos_dict[a], unique_words_dict[b]] = emission_counts[a][b] / total_emissions_from_a

    print(emission_matrix)
    return emission_matrix

def make_initial():
    initial_probabilities = np.zeros(len(unique_upos_dict))
    initial_counts = {tag: 0 for tag in unique_upos_dict.keys()}

    # Total number of sentences
    total_sentences = len(processed_tags)

    # Count each tag appearing as the first element in the tag sequences
    for tags in processed_tags:
        if tags:  # Check it exists
            initial_tag = tags[0]
            if initial_tag in initial_counts:
                initial_counts[initial_tag] += 1

    # Convert counts to probabilities
    for tag, index in unique_upos_dict.items():
        initial_probabilities[index] = initial_counts[tag] / total_sentences if total_sentences > 0 else 0

    return initial_probabilities

t_matrix, e_matrix, initial = make_transition_matrix(), make_emission_matrix(), make_initial()


['ADJ', 'ADP', 'ADV', 'AUX', 'CONJ', 'DET', 'INTJ', 'NOUN', 'NUM', 'PART', 'PRON', 'PROPN', 'PUNCT', 'SCONJ', 'SYM', 'VERB', 'X']
[[7.50636324e-02 7.03194421e-02 4.92492131e-03 1.64164044e-03
  1.71393286e-02 4.05138786e-03 1.50609214e-05 6.65557179e-01
  1.66573791e-02 1.76965827e-02 2.42480835e-03 3.71552932e-02
  6.51987289e-02 1.08438634e-02 3.02724521e-03 8.16301941e-03
  1.20487371e-04]
 [1.03085662e-01 2.52069652e-02 1.50520119e-02 3.81455097e-04
  1.58767797e-03 3.31226739e-01 2.06191944e-05 1.70376403e-01
  7.79921028e-02 1.10312690e-03 5.68058806e-02 1.52293370e-01
  1.24024454e-02 9.17554151e-04 4.01765003e-02 1.11446746e-02
  2.26811138e-04]
 [1.42092442e-01 1.24975960e-01 7.67677415e-02 3.60279505e-02
  1.06737611e-02 6.29207000e-02 0.00000000e+00 2.16680556e-02
  3.90409642e-02 1.68921085e-02 3.36560036e-02 1.40393615e-02
  1.64561831e-01 2.79505096e-02 1.25649080e-02 2.16103596e-01
  6.41066735e-05]
 [2.23435948e-03 3.01461200e-03 7.15704355e-02 8.64661654e-02
  3.901262

In [4]:
print(t_matrix[unique_upos_dict['NOUN']][unique_upos_dict['ADV']])

0.027215504655354057


In [7]:
# No end probabilities, so I just don't calculate them
def forward(tags, word_sequence, initial, words_dict, transition, emission):
    # node values stored during forward algorithm
    node_values_fwd = np.zeros((len(tags), len(word_sequence)))

    # i is index of observed sequence, k is across hidden states
    for i, word in enumerate(word_sequence):
        for j in range(len(tags)):
            # if first sequence value
            if (i == 0):
                node_values_fwd[j, i] = initial[j] * emission[j, words_dict[word]]
            else:
                values = [node_values_fwd[k, i - 1] * emission[j, words_dict[word]] 
                          * transition[k, j] for k in range(len(tags))]
                node_values_fwd[j, i] = sum(values)

    forward_val = sum(node_values_fwd[:, -1])
    return node_values_fwd, forward_val


# No end probabilities, so I assume the prob of landing on any hidden state last is 1
def backward(tags, word_sequence, initial, words_dict, transition, emission):
    # node values stored during forward algorithm
    backward_vals = np.zeros((len(tags), len(word_sequence)))

    #for i, sequence_val in enumerate(test_sequence):
    for i in range(1,len(word_sequence)+1):
        for j in range(len(tags)):
            # if first sequence value then do this
            if (-i == -1):
                backward_vals[j, -i] = 1
            else:
                values = [backward_vals[k, -i+1] * emission[k, words_dict[word_sequence[-i+1]]] * transition[j, k] for k in range(len(tags))]
                backward_vals[j, -i] = sum(values)

    start_state = [backward_vals[m,0] * emission[m, words_dict[word_sequence[0]]] for m in range(len(tags))]
    start_state = np.multiply(start_state, initial)
    backward_val = sum(start_state)
    return backward_vals, backward_val


#function to find si probabilities
def si_probs(states, word_sequence, forward, backward, forward_val, words_dict, transition, emission):

    si_probabilities = np.zeros((len(states), len(word_sequence)-1, len(states)))

    # i is observed state index
    # We are going from hidden state indexes j to k, at times i to i+1
    for i in range(len(word_sequence)-1):
        for j in range(len(states)):
            for k in range(len(states)):
                si_probabilities[j,i,k] = ( forward[j,i] * backward[k,i+1] * transition[j,k] * emission[k,words_dict[word_sequence[i+1]]] ) / forward_val
    return si_probabilities

#function to find gamma probabilities
# forward_val 
def gamma_probs(tags, test_sequence, forward, backward, forward_val):

    gamma_probabilities = np.zeros((len(tags), len(test_sequence)))

    for i in range(len(test_sequence)):
        for j in range(len(tags)):
            gamma_probabilities[j, i] = (forward[j, i] * backward[j, i]) / forward_val

    return gamma_probabilities


def baum(target_sequence, transition, emission, initial, tags, tags_dict, words, words_dict):

    fwd_probs, fwd_val = forward(tags, target_sequence, initial, words_dict, transition, emission)
    bwd_probs, bwd_val = backward(tags, target_sequence, initial, words_dict, transition, emission)
    si_probabilities = si_probs(tags, target_sequence, fwd_probs, bwd_probs, fwd_val, words_dict, transition, emission)
    gamma_probabilities = gamma_probs(tags, target_sequence, fwd_probs, bwd_probs, fwd_val)
    word_indices = np.array([words_dict[word] for word in target_sequence])

    #caclculating 'a' and 'b' matrices
    a = np.zeros((len(tags), len(tags)))
    b = np.zeros((len(tags), len(words_dict)))

    #'a' matrix
    for j in range(len(tags)):
        for i in range(len(tags)):
            for t in range(len(target_sequence)-1):
                a[j,i] = a[j,i] + si_probabilities[j,t,i]

            denom_a = [si_probabilities[j, t_x, i_x] for t_x in range(len(target_sequence) - 1) for i_x in range(len(tags))]
            denom_a = sum(denom_a)

            if (denom_a == 0):
                a[j,i] = 0
            else:
                a[j,i] = a[j,i]/denom_a


    #'b' matrix
    for j in range(len(tags)):
        for i in range(len(words)): 
            indices = [idx for idx, val in enumerate(target_sequence) if val == words[i]]
            numerator_b = sum( gamma_probabilities[j,indices] )
            denomenator_b = sum( gamma_probabilities[j,:] )

            if (denomenator_b == 0):
                b[j,i] = 0
            else:
                b[j, i] = numerator_b / denomenator_b


    print('\nMatrix a:\n')
    print(np.matrix(a.round(decimals=4)))
    print('\nMatrix b:\n')
    print(np.matrix(b.round(decimals=4)))

    return a, b

In [None]:
# import numpy as np

# def baum_welch_stepwise(sentences, initial_transition, initial_emission, initial_prob, tags, tags_dict, words, words_dict, alpha=0.7, max_iterations=100, convergence_threshold=0.01):
#     transition = np.copy(initial_transition)
#     emission = np.copy(initial_emission)
    
#     iteration = 0
#     converged = False
    
#     while iteration < max_iterations and not converged:
#         total_trans_update = np.zeros_like(transition)
#         total_emiss_update = np.zeros_like(emission)
#         learning_rate = 1 / ((iteration + 1) ** alpha)
        
#         for sentence in sentences:
#             a, b = baum(sentence, transition, emission, initial_prob, tags, tags_dict, words, words_dict)
            
#             # Stepwise EM update
#             total_trans_update += a
#             total_emiss_update += b
        
#         # Averaging updates over all sentences
#         avg_trans_update = total_trans_update / len(sentences)
#         avg_emiss_update = total_emiss_update / len(sentences)
        
#         # Update matrices
#         new_transition = (1 - learning_rate) * transition + learning_rate * avg_trans_update
#         new_emission = (1 - learning_rate) * emission + learning_rate * avg_emiss_update
        
#         # Check for convergence (Frobenius norm of difference)
#         if np.linalg.norm(new_transition - transition) < convergence_threshold and \
#            np.linalg.norm(new_emission - emission) < convergence_threshold:
#             converged = True
        
#         transition = new_transition
#         emission = new_emission
#         iteration += 1

#     return transition, emission


# def baum_welch_stepwise(sentences, initial_transition, initial_emission, initial_prob, tags, tags_dict, words, words_dict, alpha=0.7, max_iterations=100, convergence_threshold=0.01):
#     transition = np.copy(initial_transition)
#     emission = np.copy(initial_emission)
    
#     iteration = 0
#     converged = False
    
#     while iteration < max_iterations and not converged:
#         prev_transition = np.copy(transition)
#         prev_emission = np.copy(emission)

#         for sentence in sentences:
#             a, b = baum(sentence, transition, emission, initial_prob, tags, tags_dict, words, words_dict)
            
#             learning_rate = 1 / ((iteration + 1) ** alpha)
#             # Immediate EM update per sentence
#             transition = (1 - learning_rate) * transition + learning_rate * a
#             emission = (1 - learning_rate) * emission + learning_rate * b
        
#         # Check for convergence (Frobenius norm of difference)
#         if np.linalg.norm(transition - prev_transition) < convergence_threshold and \
#            np.linalg.norm(emission - prev_emission) < convergence_threshold:
#             converged = True

#         iteration += 1

#     return transition, emission

# log_transition, log_emission = baum_welch_stepwise(processed_sentences, t_matrix, e_matrix, log_initial, unique_upos, unique_upos_dict, unique_words, unique_words_dict)

In [12]:
# Transition and Emission are already log
# def forward_log(tags, word_sequence, initial, words_dict, transition, emission):
#     node_values_fwd = np.full((len(tags), len(word_sequence)), -np.inf)  # Use -inf for log(0)

#     for i, word in enumerate(word_sequence):
#         for j in range(len(tags)):
#             if i == 0:
#                 node_values_fwd[j, i] = initial[j] + emission[j, words_dict[word]]
#             else:
#                 log_values = [node_values_fwd[k, i - 1] + emission[j, words_dict[word]] + transition[k, j]
#                               for k in range(len(tags))]
#                 node_values_fwd[j, i] = log_sum_exp(np.array(log_values))

#     forward_val = log_sum_exp(node_values_fwd[:, -1])
#     return node_values_fwd, forward_val

# def backward_log(tags, word_sequence, initial, words_dict, transition, emission):
#     backward_vals = np.full((len(tags), len(word_sequence)), -np.inf)  # Use -inf for log(0)

#     for i in range(1, len(word_sequence) + 1):
#         for j in range(len(tags)):
#             if i == 1:
#                 backward_vals[j, -i] = 0  # log(1) is 0
#             else:
#                 log_values = [backward_vals[k, -i + 1] + emission[k, words_dict[word_sequence[-i + 1]]] + transition[j, k]
#                               for k in range(len(tags))]
#                 backward_vals[j, -i] = log_sum_exp(np.array(log_values))

#     start_state = np.array([backward_vals[m, 0] + emission[m, words_dict[word_sequence[0]]] + initial[m]
#                             for m in range(len(tags))])
#     backward_val = log_sum_exp(start_state)
#     return backward_vals, backward_val


# def log_si_probs_vec(tags, word_sequence, log_forward, log_backward, log_forward_val, words_dict, transition, emission):
#     si_probabilities = np.full((len(tags), len(word_sequence)-1, len(tags)), -np.inf)
#     log_transition = np.log(transition)
#     log_emission = np.log(emission)

#     for i in range(len(word_sequence)-1):
#         for j in range(len(tags)):
#             for k in range(len(tags)):
#                 si_probabilities[j, i, k] = (
#                     log_forward[j, i] +
#                     log_backward[k, i+1] +
#                     log_transition[j, k] +
#                     log_emission[k, words_dict[word_sequence[i+1]]] -
#                     log_forward_val
#                 )
#     return si_probabilities



# def log_gamma_probs(tags, word_sequence, log_forward, log_backward, log_forward_val):
#     gamma_probabilities = np.full((len(tags), len(word_sequence)), -np.inf)

#     for i in range(len(word_sequence)):
#         for j in range(len(tags)):
#             gamma_probabilities[j, i] = log_forward[j, i] + log_backward[j, i] - log_forward_val

#     return gamma_probabilities





# def baum_log(sentence, log_transition, log_emission, log_initial_prob, tags, tags_dict, words, words_dict):
#     # Calculate forward and backward probabilities in log space
#     log_fwd_probs, log_fwd_val = forward_log(tags, sentence, log_initial_prob, words_dict, log_transition, log_emission)
#     log_bwd_probs, log_bwd_val = backward_log(tags, sentence, log_initial_prob, words_dict, log_transition, log_emission)

#     # Calculate Si and Gamma probabilities in log space
#     log_si_probabilities = log_si_probs_vec(tags, sentence, log_fwd_probs, log_bwd_probs, log_fwd_val, words_dict, np.exp(log_transition), np.exp(log_emission))
#     log_gamma_probabilities = log_gamma_probs_vec(tags, sentence, log_fwd_probs, log_bwd_probs, log_fwd_val)

#     # Calculate updated transition and emission matrices in log space
#     log_a = np.full((len(tags), len(tags)), -np.inf)
#     log_b = np.full((len(tags), len(words_dict)), -np.inf)

#     # Update 'a' matrix in log space
#     for j in range(len(tags)):
#         for k in range(len(tags)):
#             sum_si = -np.inf  # Log of zero for initialization
#             for t in range(len(sentence) - 1):
#                 sum_si = np.logaddexp(sum_si, log_si_probabilities[j, t, k])
#             denom_a = np.logaddexp.reduce(log_gamma_probabilities[j, :-1])  # Skip last observation
#             log_a[j, k] = sum_si - denom_a if denom_a > -np.inf else -np.inf

    # Update 'b' matrix in log space
    for j in range(len(tags)):
        for word, idx in words_dict.items():
            log_numerator = -np.inf  # Log of zero for initialization
            indices = [i for i, x in enumerate(sentence) if x == word]
            if indices:
                log_numerator = np.logaddexp.reduce([log_gamma_probabilities[j, i] for i in indices])
            log_denominator = np.logaddexp.reduce(log_gamma_probabilities[j, :])
            log_b[j, idx] = log_numerator - log_denominator if log_denominator > -np.inf else -np.inf

#     return log_a, log_b

# def safe_log(x):
#     return np.log(x + 1e-10)  # Adding a small constant to avoid log(0)

# def baum_welch_stepwise_loggers(sentences, initial_transition, initial_emission, initial_prob, tags, tags_dict, words, words_dict, alpha=0.7, max_iterations=100, convergence_threshold=0.01):
#     log_transition = safe_log(initial_transition)
#     log_emission = safe_log(initial_emission)
#     log_initial_prob = safe_log(initial_prob)

#     iteration = 0
#     converged = False

#     while iteration < max_iterations and not converged:
#         # print(log_transition)
#         # print(log_emission)
#         prev_transition = np.copy(log_transition)
#         prev_emission = np.copy(log_emission)

#         for sentence in sentences:
#             # print(sentence)
#             log_a, log_b = baum_log(sentence, log_transition, log_emission, log_initial_prob, tags, tags_dict, words, words_dict)
#             learning_rate = 1 / ((iteration + 1) ** alpha)
#             # Immediate EM update per sentence using weighted average in log space
#             log_transition = np.logaddexp(log_transition, np.log(learning_rate) + log_a + np.log(1 - learning_rate))
#             log_emission = np.logaddexp(log_emission, np.log(learning_rate) + log_b + np.log(1 - learning_rate))
        
#         # Check for convergence (Frobenius norm of difference)
#         if np.linalg.norm(np.exp(log_transition) - np.exp(prev_transition)) < convergence_threshold and \
#            np.linalg.norm(np.exp(log_emission) - np.exp(prev_emission)) < convergence_threshold:
#             converged = True

#         iteration += 1

#     # Convert back to probabilities if needed for interpretation
#     transition = np.exp(log_transition)
#     emission = np.exp(log_emission)

#     return transition, emission

# log_transition, log_emission = baum_welch_stepwise_loggers(processed_sentences, t_matrix, e_matrix, log_initial, unique_upos, unique_upos_dict, unique_words, unique_words_dict)

In [5]:
from scipy.special import logsumexp


# def log_sum_exp(log_probs):
#     max_log_prob = np.max(log_probs)
#     if max_log_prob == -np.inf:
#         return -np.inf  # Return -inf if all values were -inf
    
#     sum_exp = np.sum(np.exp(log_probs - max_log_prob))
#     if sum_exp == 0:
#         return -np.inf  # Return -inf if the sum of exponentials is zero (should not happen unless inputs are incorrect)
    
#     return max_log_prob + np.log(sum_exp)

# Transition and Emission are already log
def forward_log_vec(num_tags, word_indices, log_initial, log_transition, log_emission_sentence):
    num_words = len(word_indices)
    log_probs = np.full((num_tags, num_words), -np.inf)  # log(0) = -inf
    
    log_probs[:, 0] = log_initial + log_emission_sentence[0]
    for t in range(1, num_words):
        log_probs[:, t] = logsumexp(log_probs[:, t-1].reshape(-1, 1) + log_transition, axis=0) + log_emission_sentence[t]

    forward_val = logsumexp(log_probs[:, -1])
    return log_probs, forward_val


def backward_log_vec(num_tags, word_indices, log_transition, log_emission_sentence):
    num_words = len(word_indices)
    backward_vals = np.full((num_tags, num_words), -np.inf)
    
    backward_vals[:, -1] = 0  
    for t in range(num_words-2, -1, -1):
        # m = backward_vals[:, t + 1] + log_emission_sentence[t + 1]
        # for j in range(num_tags):
        #     # log_values = (
        #     #     backward_vals[:, t + 1] +
        #     #     emission[:, word_indices[t + 1]] +
        #     #     transition[j, :]
        #     # )
        #     log_values = m + log_transition[j, :]
        #     backward_vals[j, t] = logsumexp(log_values)
        backward_vals[:, t] = logsumexp(
        backward_vals[:, t + 1] + log_emission_sentence[t + 1] + log_transition.T,
        axis=1
        )
    return backward_vals


def log_si_probs_vec(word_indices, log_forward, log_backward, log_forward_val, log_transition, log_emission):
    next_word_indices = word_indices[1:]

    return (
        log_forward[:, :-1, np.newaxis] +
        log_backward[:, 1:].T[np.newaxis, :] +
        log_transition[:, np.newaxis, :] + 
        log_emission[:, next_word_indices].T[np.newaxis, :] -
        log_forward_val
    )



def baum_log(target_sentence, log_transition, log_emission, log_initial, tags, words_dict):

    word_indices = np.array([words_dict[word] for word in target_sentence])
    num_tags = len(tags)
    log_emission_sentence = log_emission[:, word_indices]
    
    log_fwd_probs, log_fwd_val = forward_log_vec(num_tags, word_indices, log_initial, log_transition, log_emission_sentence)
    log_bwd_probs = backward_log_vec(num_tags, word_indices, log_transition, log_emission_sentence)

    log_si_probabilities = log_si_probs_vec(word_indices, log_fwd_probs, log_bwd_probs, log_fwd_val, log_transition, log_emission)
    # No more function for gamma, too simple
    log_gamma_probabilities = log_fwd_probs + log_bwd_probs - log_fwd_val

    # Calculate updated transition and emission matrices in log space
    log_a = np.full((num_tags, num_tags), -np.inf)
    log_b = np.full((num_tags, len(word_indices)), -np.inf)

    # # Update 'a' matrix in log space
    # for j in range(num_tags):
    #     for k in range(num_tags):
    #         sum_si = -np.inf  # Log of zero for initialization
    #         for t in range(len(word_indices) - 1):
    #             sum_si = np.logaddexp(sum_si, log_si_probabilities[j, t, k])
    #         denom_a = np.logaddexp.reduce(log_gamma_probabilities[j, :-1])  # Skip last observation
    #         log_a[j, k] = sum_si - denom_a if denom_a > -np.inf else -np.inf

    # # Update 'b' matrix in log space
    # for j in range(num_tags):
    #     for word, idx in words_dict.items():
    #         log_numerator = -np.inf  # Log of zero for initialization
    #         indices = [i for i, x in enumerate(word_indices) if x == word]
    #         if indices:
    #             log_numerator = np.logaddexp.reduce([log_gamma_probabilities[j, i] for i in indices])
    #         log_denominator = np.logaddexp.reduce(log_gamma_probabilities[j, :])
    #         log_b[j, idx] = log_numerator - log_denominator if log_denominator > -np.inf else -np.inf

    # a matrix
    sum_si_matrix = np.apply_along_axis(logsumexp, 1, log_si_probabilities)
    denom_a_vector = np.apply_along_axis(logsumexp, 1, log_gamma_probabilities[:, :-1])
    log_a = sum_si_matrix - denom_a_vector[:, np.newaxis]  # Broadcast denom_a across axis for each k
    log_a[denom_a_vector <= -np.inf, :] = -np.inf  # Handle cases where denom_a is -inf

    # b matrix
    word_indices_array = np.array(word_indices)
    log_b = np.full((num_tags, len(words_dict)), -np.inf)  # Initialize log_b with -inf

    
    for word, idx in words_dict.items():
        # Find indices where the word matches in the sentence
        indices = np.where(np.array(target_sentence) == word)[0]
        if indices.size > 0:
            log_numerator = np.array([
                logsumexp(row[indices]) for row in log_gamma_probabilities
            ])
        else:
            log_numerator = np.full(len(tags), -np.inf)  # Log of zero if no indices match
        
        log_denominator = np.array([logsumexp(row) for row in log_gamma_probabilities])
        
        # Compute log_b for the current word across all tags
        log_b[:, idx] = np.where(
            log_denominator > -np.inf,  # Only compute where denominator is valid
            log_numerator - log_denominator,
            -np.inf  # Handle invalid cases where denominator is -inf
        )
    
    # for word, idx in words_dict.items():
    #     # print(word, idx)
    #     indices = np.where(word_indices_array == idx)[0]
    #     if indices.size > 0:
    #         log_numerator = log_sum_exp(log_gamma_probabilities[:, indices])
    #         log_denominator = log_sum_exp(log_gamma_probabilities[:, :])
    #         print(log_numerator)
    #         print(log_denominator)
    #         print(log_b.shape)
    #         log_b[:, idx] = log_numerator - log_denominator
    #         log_b[log_denominator <= -np.inf, idx] = -np.inf  # Handle cases where log_denominator is -inf
    
    return log_a, log_b

# def safe_log(x):
#     return np.log(x + 1e-10)  # Adding a small constant to avoid log(0)

def baum_welch_stepwise_loggers(sentences, initial_transition, initial_emission, initial_prob, tags, words_dict, alpha=0.7, max_iterations=100, convergence_threshold=0.01):
    log_transition = np.log(initial_transition)
    log_emission = np.log(initial_emission)
    log_initial = np.log(initial_prob)

    iteration = 0
    converged = False

    while iteration < max_iterations and not converged:
        # print(log_transition)
        # print(log_emission)
        prev_transition = np.copy(log_transition)
        prev_emission = np.copy(log_emission)

        for sentence in sentences:
            print(sentence)
            log_a, log_b = baum_log(sentence, log_transition, log_emission, log_initial, tags, words_dict)
            learning_rate = 1 / ((iteration + 1) ** alpha)
            # Immediate EM update per sentence using weighted average in log space
            log_transition = np.logaddexp(log_transition, np.log(learning_rate) + log_a + np.log(1 - learning_rate))
            log_emission = np.logaddexp(log_emission, np.log(learning_rate) + log_b + np.log(1 - learning_rate))
        
        # Check for convergence (Frobenius norm of difference)
        if np.linalg.norm(np.exp(log_transition) - np.exp(prev_transition)) < convergence_threshold and \
           np.linalg.norm(np.exp(log_emission) - np.exp(prev_emission)) < convergence_threshold:
            converged = True

        iteration += 1

    # Convert back to probabilities if needed for interpretation
    transition = np.exp(log_transition)
    emission = np.exp(log_emission)

    return transition, emission

transition, emission = baum_welch_stepwise_loggers(processed_sentences[:4000], t_matrix, e_matrix, initial, unique_upos, unique_words_dict)

  log_transition = np.log(initial_transition)
  log_emission = np.log(initial_emission)


['in', 'an', 'NAME', 'NUM', 'review', 'of', '``', 'the', 'UNKNOWN', "''", 'at', 'NAME', "'s", 'NAME', 'NAME', '-lrb-', '``', 'UNKNOWN', 'UNKNOWN', 'take', 'the', 'stage', 'in', 'NAME', 'NAME', ',', "''", 'leisure', '&', 'arts', '-rrb-', ',', 'the', 'role', 'of', 'NAME', ',', 'played', 'by', 'NAME', 'NAME', ',', 'was', 'mistakenly', 'attributed', 'to', 'NAME', 'NAME', '.']


  log_transition = np.logaddexp(log_transition, np.log(learning_rate) + log_a + np.log(1 - learning_rate))
  log_emission = np.logaddexp(log_emission, np.log(learning_rate) + log_b + np.log(1 - learning_rate))


['NAME', 'NAME', 'plays', 'NAME', '.']
['NAME', 'NAME', 'NAME', 'NAME', 'said', 'it', 'expects', 'its', 'NAME', 'sales', 'to', 'remain', 'steady', 'at', 'about', 'NUM', 'cars', 'in', 'NUM', '.']
['the', 'luxury', 'auto', 'maker', 'last', 'year', 'sold', 'NUM', 'cars', 'in', 'the', 'NAME']
['NAME', 'NAME', ',', 'president', 'and', 'chief', 'executive', 'officer', ',', 'said', 'he', 'anticipates', 'growth', 'for', 'the', 'luxury', 'auto', 'maker', 'in', 'NAME', 'and', 'NAME', ',', 'and', 'in', 'far', 'eastern', 'markets', '.']
['NAME', 'NAME', 'NAME', 'increased', 'its', 'quarterly', 'to', 'NUM', 'cents', 'from', 'NUM', 'cents', 'a', 'share', '.']
['the', 'new', 'rate', 'will', 'be', 'payable', 'NAME', 'NUM', '.']
['a', 'record', 'date', 'has', "n't", 'been', 'set', '.']


KeyboardInterrupt: 