# RNAlight
Analyze the predictions file from the RNAlight repo: [Whole_genome_lncRNA_predict_df.tsv](https://github.com/YangLab/RNAlight/blob/main/Light_score_diverse_RNA/lncRNA_whole_genome/Whole_genome_lncRNA_predict_df.tsv)   
Compare the RNAlight predictions to labels derived from lncATLAS.    
Not all sequences in the predictions file are found in lncATLAS.   
We need a threshold for nuclear-vs-cytoplasmic.    
Threshold= -1 makes sense since RNAlight excluded -2 to 0, but we tried other values.    
Threshold= -1 also gave the highest accuracy. 

This notebook measures accuracy on mean of 14 cell lines, excluding H1.hESC as in the RNAlight paper.

In [1]:
import sys, traceback
import sklearn
from sklearn.metrics import accuracy_score
import numpy as np

# Assume these files or links are in the current directory.
# This is from the RNAlight repo on github
RNALIGHT = 'Whole_genome_lncRNA_predict_df.tsv'
# GenCode fasta reduced to csv
GENCODE = 'gencode.v44.lncRNA_transcripts.fa'
# This is from lncATLAS downloads
LNCATLAS = 'lncATLAS_all_data_RCI.csv'

In [2]:
# Load the RNAlight file of predictions
def load_predictions(filename):
    '''
    Parse the predictions file from the RNAlight repository.
    The format is tsv with a header line.
    After splitting each line on tab, each line yields
    transcript ID and its prediction class & probability.
    Split the ID on period to remove the version number. 
    '''
    predictions = dict()
    with open (filename, 'r') as fin:
        header = None
        for line in fin:
            try:
                fields = line.strip().split('\t')
                if header is None:
                    header = fields
                    continue
                tid = fields[0].split('.')[0]
                pred_class = int(fields[3])
                pred_prob = float(fields[4])
                predictions[tid]=(pred_class,pred_prob)
            except Exception as e:
                print(line)
                traceback.print_exc()
                raise(e)
    return predictions

In [3]:
# Load IDs from the GenCode FASTA file
def load_conversions(filename):
    '''
    Load the sequences from the GENCODE sequence file.
    The format is FASTA: each defline followed by multiple sequence lines.
    Split deflines on vertical bar to parse out the gene and transcript IDs.
    Split IDs on period to remove the version numbers. 
    '''
    conversions = dict()
    with open (filename, 'r') as fin:
        for line in fin:
            # Parse deflines like...
            # >ENST00000456328.2|ENSG00000290825.1|-|OTTHUMT00000362751.1|DDX11L2-202|DDX11L2|1657|
            try:
                if line.startswith('>'):
                    fields = line[1:].rstrip().split('|')
                    tid = fields[0].split('.')[0]
                    gid = fields[1].split('.')[0]
                    conversions[tid]=gid
            except Exception as e:
                print(line)
                traceback.print_exc()
                raise(e)
    return conversions

In [4]:
# Load the lncATLAS file of CNRCI
def load_truth(filename,/,include=None,exclude=None):
    '''
    Parse the datafile from lncATLAS.
    The format is csv with a header line.
    Filter 'nc' to exclude coding genes.
    Filter 'CNRCI' to exclude RCIs besides cytoplasmic-to-nuclear.
    Filter 'NA' to exclude lines with value not available. 
    Optionally filter 'H1.hESC' to exclude this cell line, as RNAlight did.
    Return a mapping of gene to its mean CNRCI.
    '''
    mean_rci = dict()
    all_rci = dict()
    with open (filename, 'r') as fin:
        header = None
        for line in fin:
            try:
                fields = line.strip().split(',')
                if header is None:
                    header = fields
                    continue
                gid = fields[0]
                cell_type = fields[1]
                rci_type = fields[2]
                rci_value = fields[3]
                gene_type = fields[6]
                if (include is not None and cell_type==include) \
                    or (exclude is not None and cell_type!=exclude):
                    if gene_type=='nc' and\
                        rci_type=='CNRCI' and\
                        rci_value!='NA':
                        rci_value=float(rci_value)
                        if gid not in all_rci.keys():
                            all_rci[gid]=[]
                        all_rci[gid].append(rci_value)
            except Exception as e:
                print(line)
                traceback.print_exc()
                raise(e)
    for gid in all_rci.keys():
        values = all_rci[gid]
        mean = np.mean(values)
        #print(gid,mean,values)
        mean_rci[gid] = mean
    return mean_rci

In [5]:
GID,TID,PRED_CLASS,PRED_PROB,CNRCI = 0,1,2,3,4
def load_integrated_data(rnalight,gencode,lncatlas,/,include=None,exclude=None,verbose=True):
    '''
    Integrate the three databases: RNAlight, GENCODE, lncATLAS.
    Pass in three database filenames.
    Return one mega mapping of transcript ID to its data.
    '''
    predictions = load_predictions(rnalight)
    tid_to_gid = load_conversions(gencode)
    cnrci = load_truth(lncatlas, include=include, exclude=exclude)
    print('Input size:',len(predictions),len(tid_to_gid),len(cnrci))
    not_in_gencode=0
    not_in_lncatlas=0
    comparables=0
    database = []
    for tid in predictions.keys():
        if tid not in tid_to_gid.keys():
            not_in_gencode += 1
            continue
        gid = tid_to_gid[tid]
        if gid not in cnrci.keys():
            not_in_lncatlas += 1
            continue
        pred_class,pred_prob = predictions[tid]
        rci = cnrci[gid]
        comparables += 1
        data_row = (gid,tid,pred_class,pred_prob,rci)
        database.append(data_row)
    if verbose:
        print('# Comparable values:', comparables)
        print('# Transcripts not in GenCode:', not_in_gencode)
        print('# Transcripts with no RCI in lncATLAS:', not_in_lncatlas)
    return database

In [6]:
dataset = load_integrated_data(RNALIGHT,GENCODE,LNCATLAS,exclude='H1.hESC')
print('Records:', len(dataset))

Input size: 16153 58246 5760
# Comparable values: 5317
# Transcripts not in GenCode: 195
# Transcripts with no RCI in lncATLAS: 10641
Records: 5317


In [7]:
# This is a preliminary test to confirm that, in RNAlight output, 1=cytoplasmic and 0=nuclear. 
positive_rci_counts=[0,0]
negative_rci_counts=[0,0]
for row in dataset:
    pred = row[PRED_CLASS]
    rci = row[CNRCI]
    if rci>= 4:
        positive_rci_counts[pred] += 1
    if rci<= -4:
        negative_rci_counts[pred] += 1
print('RNAlight predicts 0 or 1. What does that mean?')
print('We assume RNAlight is correct most of the time, especially on extreme values.')
print('For postive CNRCI (cytosol), RNAlight mostly predicts 0:')
print(' Predictions for RNA with very positive RCI: [0,1]=', positive_rci_counts)
print('For negative CNRCI (nucleus), RNAlight mostly predicts 1:')
print(' Predictions for RNA with very negative RCI: [0,1]=', negative_rci_counts)
CLASS_CYTOSOL = 0
CLASS_NUCLEUS = 1

RNAlight predicts 0 or 1. What does that mean?
We assume RNAlight is correct most of the time, especially on extreme values.
For postive CNRCI (cytosol), RNAlight mostly predicts 0:
 Predictions for RNA with very positive RCI: [0,1]= [2, 0]
For negative CNRCI (nucleus), RNAlight mostly predicts 1:
 Predictions for RNA with very negative RCI: [0,1]= [7, 428]


In [8]:
def make_two_arrays(data,threshold):
    '''
    This is a subroutine for show_accuracy().
    Pass in the mapping of each transcript ID to all its data.
    Extract and return two lists: predictions and ground truth.
    The predictions are straight from the RNAlight predictions file.
    The ground truth depends on the CNRCI compared to a given threshold.
    '''
    y_pred = []
    y_true = []
    for row in data:
        pred = row[PRED_CLASS]
        rci =  row[CNRCI]
        if rci >= threshold:
            actual = CLASS_CYTOSOL
        else:
            actual = CLASS_NUCLEUS
        y_pred.append(pred)
        y_true.append(actual)
    return y_true,y_pred

In [9]:
def show_accuracy(data,rci_threshold):
    '''
    Pass in the mapping of each transcript to all its data.
    Pass in some desired threshold T, where CNRCI>T implies cytoplasmic, else nuclear.
    For each transcript, extract RNAlight predictions from the data.
    Finally, print one sentence showing accuracy of predictions at this threshold.
    '''
    y_true,y_pred = make_two_arrays(data,rci_threshold)
    if len(y_true)==0 or len(y_pred)==0:
        acc = 0.0
    else:
        acc = accuracy_score(y_true,y_pred) * 100
    print('At RCI threshold %3.1f, %d predictions have accuracy=%6.2f%%' \
          % (rci_threshold, len(y_pred), acc))

### Overall accuracy, no middle exclusion
These statistics are sensitive to what portion of genes are in the middle.

In [10]:
show_accuracy(dataset,-2)
show_accuracy(dataset,-1.5)
show_accuracy(dataset,-1)
show_accuracy(dataset,-0.5)
show_accuracy(dataset,0)
show_accuracy(dataset,0.5)
show_accuracy(dataset,1)
show_accuracy(dataset,2)

At RCI threshold -2.0, 5317 predictions have accuracy= 78.58%
At RCI threshold -1.5, 5317 predictions have accuracy= 79.14%
At RCI threshold -1.0, 5317 predictions have accuracy= 79.88%
At RCI threshold -0.5, 5317 predictions have accuracy= 79.31%
At RCI threshold 0.0, 5317 predictions have accuracy= 76.92%
At RCI threshold 0.5, 5317 predictions have accuracy= 69.14%
At RCI threshold 1.0, 5317 predictions have accuracy= 62.69%
At RCI threshold 2.0, 5317 predictions have accuracy= 54.88%


In [11]:
def middle_exclusion(data,lower_bound,upper_bound,/,inverse=False):
    '''
    Pass in the mapping of each transcript to all its data.
    Returns a subset of maps, filtered by the given CNRCI thresholds.
    This emulates the middle exclusion filter used in the RNAlight study.
    '''
    include = []
    exclude = []
    for row in data:
        rci = row[CNRCI]
        if rci < lower_bound or rci > upper_bound:
            include.append(row)
        else:
            exclude.append(row)
    if inverse:
        return exclude
    return include

### Accuracy on the extremes, after middle exclusion

In [12]:
reduced_set = middle_exclusion(dataset,-2,0)
print('Number of genes left after middle exclusion', len(reduced_set))
show_accuracy(reduced_set,-2)
show_accuracy(reduced_set,-1.5)
show_accuracy(reduced_set,-1)
show_accuracy(reduced_set,-0.5)
show_accuracy(reduced_set,0)
show_accuracy(reduced_set,0.5)
show_accuracy(reduced_set,1)
show_accuracy(reduced_set,2)

Number of genes left after middle exclusion 3226
At RCI threshold -2.0, 3226 predictions have accuracy= 95.94%
At RCI threshold -1.5, 3226 predictions have accuracy= 95.94%
At RCI threshold -1.0, 3226 predictions have accuracy= 95.94%
At RCI threshold -0.5, 3226 predictions have accuracy= 95.94%
At RCI threshold 0.0, 3226 predictions have accuracy= 95.94%
At RCI threshold 0.5, 3226 predictions have accuracy= 82.70%
At RCI threshold 1.0, 3226 predictions have accuracy= 72.07%
At RCI threshold 2.0, 3226 predictions have accuracy= 59.21%


### Accuracy on just the middle part, excluding the extremes

In [13]:
removed_set = middle_exclusion(dataset,-2,0,inverse=True)
print('Number of genes in the middle', len(removed_set))
show_accuracy(removed_set,-2)
show_accuracy(removed_set,-1.5)
show_accuracy(removed_set,-1)
show_accuracy(removed_set,-0.5)
show_accuracy(removed_set,0)
show_accuracy(removed_set,0.5)
show_accuracy(removed_set,1)
show_accuracy(removed_set,2)

Number of genes in the middle 2091
At RCI threshold -2.0, 2091 predictions have accuracy= 51.79%
At RCI threshold -1.5, 2091 predictions have accuracy= 53.23%
At RCI threshold -1.0, 2091 predictions have accuracy= 55.09%
At RCI threshold -0.5, 2091 predictions have accuracy= 53.66%
At RCI threshold 0.0, 2091 predictions have accuracy= 47.58%
At RCI threshold 0.5, 2091 predictions have accuracy= 48.21%
At RCI threshold 1.0, 2091 predictions have accuracy= 48.21%
At RCI threshold 2.0, 2091 predictions have accuracy= 48.21%
