# RNA-seq EM Implementation

### Imports for the Project


The `utils.py` file contains classes designed to facilitate the reading process of files:

1. **Isoform**: Represents an isoform, and each instance has its own unique ID.
2. **Read**: Represents a read, and it has an attribute, an array of isoforms, to determine the isoforms related to the read.
3. **DataLoader**: This class reads data and creates an array of reads mapped to their determined isoforms, given a `transcrip.fasta` file and a `read.bam` file.


In [79]:
import numpy as np
from collections import defaultdict
from utils import *
from tqdm import tqdm
import csv
import plotly.graph_objects as go

### RSEM class
This class is the core part of the project, as it contains functions that read files, estimate abundances using the EM algorithm, and save and plot the results. 

In [80]:
class RSEM:
    def __init__(self, bam_file, isoforms_file):
        dataLoader = DataLoader(bam_file, isoforms_file)
        # Load reads from BAM file
        self.reads = dataLoader.get_mapped_reads()
        # Load isoforms from FASTA file
        self.isoforms = dataLoader.isoforms
        # Initialize parameters: expression levels (theta), RSPD, and error model
        self.theta = self._initialize_parameters()
        # Initialize posterior probabilities
        self.posterior_probs = defaultdict(dict)
        # Fragment length distribution parameters
        self.mean_fragment_length = 200
        self.sd_fragment_length = 10
        # Store log-likelihood values
        self.log_likelihoods = []

        self.median_diff = []

    # Initialize parameters: expression levels (theta), RSPD, and error model
    def _initialize_parameters(self):
        num_isoforms = len(self.isoforms)
        theta = np.ones(num_isoforms) / num_isoforms  # Initial expression levels
        return theta

    # Calculate effective length using fragment length distribution
    def _effective_length(self, transcript_length):
        # Effective length calculation considering fragment length distribution
        effective_length = transcript_length - self.mean_fragment_length + 1
        if effective_length < 1:
            effective_length = 1  # Ensure effective length is at least 1
        return effective_length

    # Calculate log-likelihood of the observed data
    def _calculate_log_likelihood(self):
        log_likelihood = 0
        for read in self.reads:
            denom = sum(self.theta[isoform.id] for isoform in read.isoforms)
            log_likelihood += np.log(denom)
        return log_likelihood

    # EM Algorithm: Iterate the E-step and M-step until convergence
    def em_algorithm(self, max_iter=100, tol=1e-6, verbose=False):
        for iteration in tqdm(range(max_iter)):
            new_theta = np.zeros_like(self.theta)
            N = len(self.reads)
            
            # E-step: Calculate expected counts and posterior probabilities
            log_likelihood = 0
            for read in self.reads:
                denom = sum(self.theta[isoform.id] for isoform in read.isoforms)
                log_likelihood += np.log(denom)
                for isoform in read.isoforms:
                    posterior_prob = self.theta[isoform.id] / denom
                    new_theta[isoform.id] += posterior_prob
                    self.posterior_probs[read.query_name][isoform.name] = posterior_prob
            
            # Normalize new_theta
            new_theta /= N
            
            # Store the log-likelihood
            self.log_likelihoods.append(log_likelihood)
            
            # Calculate the median of the difference between the last iteration and the current one
            if iteration > 0:
                diff = np.abs(new_theta - self.theta)
                median_diff = np.median(diff)
                self.median_diff.append(median_diff)
                if verbose:
                    print(f"Iteration {iteration}, Median diff: {median_diff:.6f}, Log-Likelihood: {log_likelihood:.6f}")
                
                # Check for convergence
                if median_diff < tol:
                    break
            
            self.theta = new_theta

    # Output the estimated expression levels
    def output_results(self):
        for isoform in self.isoforms.values():
            transcript_length = len(isoform)
            effective_length = self._effective_length(transcript_length)
            expected_counts = self.theta[isoform.id] * len(self.reads)
            tpm = (expected_counts / effective_length) * 1e6 / sum(self.theta)
            print(f"Transcript {isoform.name}: Length {transcript_length}, Effective Length {effective_length}, Expected Counts {expected_counts:.2f}, TPM {tpm:.2f}")

    # Output the posterior probabilities of each read mapping to each transcript
    def output_posterior_probabilities(self):
        for read_name, probs in self.posterior_probs.items():
            for isoform_name, prob in probs.items():
                print(f"Read {read_name} maps to Transcript {isoform_name} with posterior probability {prob:.4f}")

    # Save the estimated expression levels to a CSV file
    def save_results_to_csv(self, filename):
        with open(filename, 'w', newline='') as csvfile:
            csvwriter = csv.writer(csvfile)
            csvwriter.writerow(['Transcript', 'Length', 'Effective Length', 'Expected Counts', 'TPM'])
            for isoform in self.isoforms.values():
                transcript_length = len(isoform)
                effective_length = self._effective_length(transcript_length)
                expected_counts = self.theta[isoform.id] * len(self.reads)
                tpm = (expected_counts / effective_length) * 1e6
                row = (isoform.name, transcript_length, effective_length, expected_counts, tpm)
                csvwriter.writerow(row)

    # Save the posterior probabilities to a CSV file
    def save_posterior_probabilities_to_csv(self, filename):
        with open(filename, 'w', newline='') as csvfile:
            csvwriter = csv.writer(csvfile)
            csvwriter.writerow(['Read', 'Isoform', 'Probability'])
            for read_name, probs in self.posterior_probs.items():
                for isoform_name, prob in probs.items():
                    csvwriter.writerow([read_name, isoform_name, prob])

    # Plot the log-likelihood over EM iterations
    def plot_log_likelihood(self):
        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=list(range(len(self.log_likelihoods))),
            y=self.log_likelihoods,
            mode='lines+markers',
            marker=dict(symbol='circle', size=8),
            line=dict(shape='linear')
        ))
        
        fig.update_layout(
            title='Log-Likelihood over EM Iterations',
            xaxis_title='Iteration',
            yaxis_title='Log-Likelihood',
            template='plotly_white',
            xaxis=dict(showgrid=True),
            yaxis=dict(showgrid=True),
            width=960,  
            height=540  
        )
        
        fig.show()

    # Plot the median difference between the last and current iteration
    def plot_median_diff(self):
        fig = go.Figure()
        fig.add_trace(go.Scatter(
            x=list(range(len(self.median_diff))),
            y=self.median_diff,
            mode='lines+markers',
            marker=dict(symbol='circle', size=8),
            line=dict(shape='linear'), 
        ))
        
        fig.update_layout(
            title='Median Difference between the Last and Current Iteration',
            xaxis_title='Iteration',
            yaxis_title='Median Difference',
            template='plotly_white',
            xaxis=dict(showgrid=True),
            yaxis=dict(showgrid=True),
            width=960,  
            height=540  
        )
        
        fig.show()

    # Plot the posterior probabilities of a read mapping to each transcript
    def plot_posterior_probabilities(self, read_name):
        probs = self.posterior_probs[read_name]
        read_name = read_name.split('/')[0]
        labels = list(probs.keys())
        values = list(probs.values())
        
        fig = go.Figure(data=[go.Pie(labels=labels, values=values)])
        
        fig.update_layout(
            title_text=f'Posterior Probabilities of {read_name.capitalize()} mapping to Each Transcript',
            showlegend=True,
            width=800,  
            height=800  
        )
        
        fig.show()


### Estimation of the abundances

The following cells estimate the abundances and save them to a file for future use. Additionally, they also plot some interesting data during the iteration process.

In [81]:
bam_file = "aligned.bam"
isoforms_file = "chr11_transcriptome.fasta"
rsem = RSEM(bam_file, isoforms_file)

In [82]:
rsem.em_algorithm(tol=1e-10)

 21%|██        | 21/100 [01:23<05:12,  3.96s/it]


In [83]:
rsem.output_results()

Transcript ENST00000410108: Length 637, Effective Length 438, Expected Counts 92.78, TPM 211835.70
Transcript ENST00000325147: Length 2916, Effective Length 2717, Expected Counts 624.66, TPM 229907.01
Transcript ENST00000382762: Length 2792, Effective Length 2593, Expected Counts 971.35, TPM 374606.06
Transcript ENST00000529614: Length 560, Effective Length 361, Expected Counts 0.00, TPM 0.00
Transcript ENST00000332865: Length 533, Effective Length 334, Expected Counts 0.00, TPM 0.00
Transcript ENST00000486280: Length 665, Effective Length 466, Expected Counts 1.21, TPM 2586.06
Transcript ENST00000342878: Length 435, Effective Length 236, Expected Counts 0.00, TPM 0.00
Transcript ENST00000325113: Length 1284, Effective Length 1085, Expected Counts 0.00, TPM 0.00
Transcript ENST00000525282: Length 764, Effective Length 565, Expected Counts 0.00, TPM 0.00
Transcript ENST00000526104: Length 3506, Effective Length 3307, Expected Counts 1877.10, TPM 567612.96
Transcript ENST00000325207: Len

In [84]:
filename = "output_result.csv"
rsem.save_results_to_csv(filename)

In [85]:
filename = "posterior_probabilities.csv"
rsem.save_posterior_probabilities_to_csv(filename)

In [86]:
rsem.plot_log_likelihood()

In [87]:
rsem.plot_median_diff()

In [90]:
import random

reads = rsem.reads

for i in range(10):
    read = random.choice(reads)
    rsem.plot_posterior_probabilities(read.query_name)
