In [1]:
import numpy as np
from scipy.stats import dirichlet, poisson
import pickle
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
from scipy.special import logsumexp
from scipy.optimize import linear_sum_assignment
from scipy.ndimage import gaussian_filter1d

In [2]:
#ensembl_trx = data.ensembl_trx()
#pickle.dump(ensembl_trx, open('data/ensembl_trx.pkl', 'wb'))
ensembl_trx = pickle.load(open('ensembl_trx.pkl', 'rb'))

In [3]:
#trx_orfs = data.trx_orfs(ensembl_trx)
#pickle.dump(trx_orfs, open('data/trx_orfs.pkl', 'wb'))
trx_orfs = pickle.load(open('trx_orfs.pkl', 'rb'))

In [4]:
for trx, attrs in tqdm(ensembl_trx.items()):
    for orf, info in trx_orfs[trx].items():
        if orf.startswith('ENSP'):
            zero_arr = np.zeros(len(ensembl_trx[trx]['sequence']))
            start, stop = info['start'], info['stop']
            zero_arr[start:stop] = 1
            ensembl_trx[trx]['target'] = zero_arr

100%|████████████████████████████████████████████████████████████████████████| 206497/206497 [00:11<00:00, 18433.94it/s]


In [5]:
# Helper functions
def base_to_index(base):
    return {'A': 0, 'C': 1, 'G': 2, 'T': 3}[base]

def sequence_to_indices(sequence):
    return np.array([base_to_index(b) for b in sequence])

# Prior distributions
def sample_r_prior():
    return max(1, poisson.rvs(3, size=1)[0] % (MAX_STATES + 1))

def sample_q_prior():
    return poisson.rvs(1, size=1)[0] % (MAX_ORDER + 1)

# Initialize model parameters
def initialize_parameters(n_states, n_bases, q):
    Lambda = dirichlet.rvs(np.ones(n_states), size=n_states)
    P = dirichlet.rvs(np.ones(n_bases**(q+1)), size=n_states)
    return Lambda, P

def forward_multi(sequences, Lambda, P, q):
    n_states, n_emissions = P.shape
    total_log_likelihood = 0
    all_forward = []
    all_scale = []
    
    for sequence in sequences:
        n_obs = len(sequence)
        F = np.zeros((n_states, n_obs))
        scale = np.zeros(n_obs)
        
        # Initialize
        initial_state = tuple(sequence[:q])
        for i in range(n_states):
            emission_index = np.ravel_multi_index(initial_state + (sequence[q],), (4,)*(q+1))
            F[i, q-1] = P[i, emission_index] / n_states
        scale[q-1] = np.sum(F[:, q-1])
        F[:, q-1] /= scale[q-1]
        
        # Recurse
        for t in range(q, n_obs):
            state = tuple(sequence[t-q:t+1])
            for j in range(n_states):
                emission_index = np.ravel_multi_index(state, (4,)*(q+1))
                F[j, t] = P[j, emission_index] * np.sum(F[:, t-1] * Lambda[:, j])
            scale[t] = np.sum(F[:, t])
            F[:, t] /= scale[t]
        
        total_log_likelihood += np.sum(np.log(scale))
        all_forward.append(F)
        all_scale.append(scale)
    
    return total_log_likelihood, all_forward, all_scale

def backward_multi(sequences, Lambda, P, q, all_scale):
    n_states, n_emissions = P.shape
    all_backward = []
    
    for sequence, scale in zip(sequences, all_scale):
        n_obs = len(sequence)
        B = np.zeros((n_states, n_obs))
        
        # Initialize
        B[:, -1] = 1 / scale[-1]
        
        # Recurse
        for t in range(n_obs-2, q-2, -1):
            state = tuple(sequence[t-q+1:t+2])
            for i in range(n_states):
                emission_index = np.ravel_multi_index(state, (4,)*(q+1))
                B[i, t] = np.sum(B[:, t+1] * Lambda[i, :] * P[:, emission_index])
            B[:, t] /= scale[t]
        
        all_backward.append(B)
    
    return all_backward

def calculate_posterior_probs(all_forward, all_backward):
    all_posterior_probs = []
    for F, B in zip(all_forward, all_backward):
        posterior_probs = F * B
        posterior_probs /= np.sum(posterior_probs, axis=0)
        all_posterior_probs.append(posterior_probs)
    return all_posterior_probs

def update_q_multi(sequences, r, q, Lambda, P):
    log_probs = np.zeros(MAX_ORDER + 1)
    for new_q in range(MAX_ORDER + 1):
        new_P = dirichlet.rvs(np.ones(4**(new_q+1)), size=r)
        log_probs[new_q], _, _ = forward_multi(sequences, Lambda, new_P, new_q)
    
    new_q = np.random.choice(MAX_ORDER + 1, p=np.exp(log_probs - logsumexp(log_probs)))
    if new_q != q:
        P = dirichlet.rvs(np.ones(4**(new_q+1)), size=r)
    
    return new_q, P

def rjmcmc_r_multi(sequences, r, q, Lambda, P):
    if np.random.random() < 0.5:  # Attempt birth
        if r < MAX_STATES:
            new_r = r + 1
            new_Lambda = np.zeros((new_r, new_r))
            new_Lambda[:r, :r] = Lambda
            new_Lambda[r, :] = dirichlet.rvs(np.ones(new_r))
            new_Lambda[:, r] = dirichlet.rvs(np.ones(new_r))
            new_P = np.vstack((P, dirichlet.rvs(np.ones(4**(q+1)))))
            
            # Calculate acceptance probability
            log_A, _, _ = forward_multi(sequences, new_Lambda, new_P, q)
            log_A -= forward_multi(sequences, Lambda, P, q)[0]
            log_A += np.log(new_r) - np.log(r + 1)  # Prior ratio and proposal ratio
            
            if np.log(np.random.random()) < log_A:
                return new_r, new_Lambda, new_P
    else:  # Attempt death
        if r > 1:
            new_r = r - 1
            j = np.random.randint(r)
            new_Lambda = np.delete(np.delete(Lambda, j, axis=0), j, axis=1)
            new_P = np.delete(P, j, axis=0)
            
            # Calculate acceptance probability
            log_A, _, _ = forward_multi(sequences, new_Lambda, new_P, q)
            log_A -= forward_multi(sequences, Lambda, P, q)[0]
            log_A += np.log(r) - np.log(new_r + 1)  # Prior ratio and proposal ratio
            
            if np.log(np.random.random()) < log_A:
                return new_r, new_Lambda, new_P
    
    return r, Lambda, P

def run_mcmc_multi(sequences, n_iter, n_burn_in, thin=1):
    r = sample_r_prior()
    q = sample_q_prior()
    Lambda, P = initialize_parameters(r, 4, q)
    
    samples = []
    for i in tqdm(range(n_iter + n_burn_in)):
        # Update q and P
        q, P = update_q_multi(sequences, r, q, Lambda, P)
        
        # Update Lambda
        for j in range(r):
            Lambda[j, :] = dirichlet.rvs(Lambda[j, :] + 1)
        
        # Reversible jump for r
        r, Lambda, P = rjmcmc_r_multi(sequences, r, q, Lambda, P)
        
        # Store samples after burn-in period and apply thinning
        if i >= n_burn_in and (i - n_burn_in) % thin == 0:
            log_likelihood, all_forward, all_scale = forward_multi(sequences, Lambda, P, q)
            all_backward = backward_multi(sequences, Lambda, P, q, all_scale)
            all_posterior_probs = calculate_posterior_probs(all_forward, all_backward)
            viterbi_paths = viterbi_multi(sequences, Lambda, P, q)
            samples.append((r, q, Lambda.copy(), P.copy(), all_posterior_probs, viterbi_paths))
    
    return samples

def viterbi_multi(sequences, Lambda, P, q):
    segmentations = []
    for sequence in sequences:
        n_states, n_emissions = P.shape
        n_obs = len(sequence)
        V = np.zeros((n_states, n_obs))
        path = np.zeros((n_states, n_obs), dtype=int)
        
        # Initialize
        initial_state = tuple(sequence[:q])
        for i in range(n_states):
            emission_index = np.ravel_multi_index(initial_state + (sequence[q],), (4,)*(q+1))
            V[i, q-1] = np.log(P[i, emission_index] / n_states)
        
        # Recurse
        for t in range(q, n_obs):
            state = tuple(sequence[t-q:t+1])
            for j in range(n_states):
                emission_index = np.ravel_multi_index(state, (4,)*(q+1))
                prob = V[:, t-1] + np.log(Lambda[:, j]) + np.log(P[j, emission_index])
                V[j, t] = np.max(prob)
                path[j, t] = np.argmax(prob)
        
        # Backtrack
        best_path = np.zeros(n_obs, dtype=int)
        best_path[-1] = np.argmax(V[:, -1])
        for t in range(n_obs-2, q-2, -1):
            best_path[t] = path[best_path[t+1], t+1]
        
        segmentations.append(best_path)
    
    return segmentations


In [6]:
def plot_state_probabilities_multi(samples, sequences, original_sequences, window_size=5, sigma=2):
    r_max = max(s[0] for s in samples)
    n_sequences = len(sequences)
    
    fig, axes = plt.subplots(n_sequences, 1, figsize=(15, 4*n_sequences), sharex=True)
    fig.suptitle('Posterior State Probabilities', fontsize=16)
    
    for seq_idx, (sequence, original_sequence) in enumerate(zip(sequences, original_sequences)):
        ax = axes[seq_idx] if n_sequences > 1 else axes
        n_obs = len(sequence)
        avg_probs = np.zeros((r_max, n_obs))
        
        for r, q, Lambda, P, all_posterior_probs, _ in samples:
            posterior_probs = all_posterior_probs[seq_idx]
            avg_probs[:r, :] += posterior_probs
        
        avg_probs /= len(samples)
        
        # Apply Gaussian smoothing
        smoothed_probs = np.zeros_like(avg_probs)
        for i in range(r_max):
            smoothed_probs[i, :] = gaussian_filter1d(avg_probs[i, :], sigma=sigma, mode='nearest')
        
        # Plot smoothed probabilities
        x = np.arange(n_obs)
        for i in range(r_max):
            ax.plot(x, smoothed_probs[i, :], '-', label=f'State {i+1}', alpha=0.7)
            ax.fill_between(x, 0, smoothed_probs[i, :], alpha=0.2)
        
        ax.set_ylim(0, 1)
        ax.set_ylabel(f'Sequence {seq_idx + 1}')
        ax.set_yticks([0, 0.5, 1])
        
        # Add sequence at the bottom of each subplot
        for j, base in enumerate(original_sequence):
            ax.text(j, -0.15, base, ha='center', va='center', fontsize=8)
        
        if seq_idx == 0:
            ax.legend(loc='upper right', bbox_to_anchor=(1.15, 1))
    
    axes[-1].set_xlabel('Position in Sequence')
    plt.tight_layout()
    plt.show()

In [7]:
N = 50

trxps = [x for x,y in ensembl_trx.items() if y['biotype'] == 'protein_coding'][:N]

sequences = [x['sequence'] for x in ensembl_trx.values() if x['biotype'] == 'protein_coding'][:N]
#sequences = ''.join(sequences)

targets = [x['target'] for x in ensembl_trx.values() if x['biotype'] == 'protein_coding'][:N]
#targets = np.concatenate(targets)

# Constants
MAX_STATES = 2  # r_max
MAX_ORDER = 2    # q_max

n_iter = 500  # Number of iterations after burn-in
n_burn_in = 100  # Number of burn-in iterations
thin = 2  # Store every 10th sample

In [9]:
# Example usage
sequences_indices = [sequence_to_indices(seq) for seq in sequences]
samples = run_mcmc_multi(sequences_indices, n_iter, n_burn_in, thin)

# Plot the results
#plot_state_probabilities_multi(samples, sequences_indices, sequences)

# Analyze results
r_posterior = [s[0] for s in samples]
q_posterior = [s[1] for s in samples]
print(f"Posterior mode for r: {max(set(r_posterior), key=r_posterior.count)}")
print(f"Posterior mode for q: {max(set(q_posterior), key=q_posterior.count)}")

  total_log_likelihood += np.sum(np.log(scale))
  2%|█▍                                                                              | 11/600 [01:24<1:15:32,  7.69s/it]


KeyboardInterrupt: 