In [149]:
import os
import pandas as pd
import numpy as np
from Bio import SeqIO
import pickle
from multiprocessing import Pool

In [5]:
%config Completer.use_jedi = False
%config IPCompleter.greedy=True

In [4]:
data_root = '/data1/APA/Paul_ALS_Data/bams_in/subscelltype_bamfiles/Mapper_outs/'
os.listdir(data_root)

['IN-SST_C9ALSvsCTRL',
 'L4_C9ALSvsCTRL',
 'Inhibitory_sALSvsCTRL',
 'L5-6-CC_C9ALSvsCTRL',
 'Excitatory_sALSvsCTRL',
 'L5-6_C9ALSvsCTRL',
 'Excitatory_C9ALSvsCTRL',
 'L2-3_sALSvsCTRL',
 'AST-FB_sALSvsCTRL',
 'Oligodendrocytes_C9ALSvsCTRL',
 'get_fasta_for_switches.sh',
 'OPC_C9ALSvsCTRL',
 'IN-VIP_C9ALSvsCTRL',
 'L4_sALSvsCTRL',
 'AST-PP_C9ALSvsCTRL',
 'IN-VIP_sALSvsCTRL',
 'AST-PP_sALSvsCTRL',
 'L5-6-CC_sALSvsCTRL',
 'IN-PV_C9ALSvsCTRL',
 'Oligodendrocytes_sALSvsCTRL',
 'IN-SST_sALSvsCTRL',
 'OPC_sALSvsCTRL',
 'L2-3_C9ALSvsCTRL',
 'Microglia_sALSvsCTRL',
 'Inhibitory_C9ALSvsCTRL',
 'Microglia_C9ALSvsCTRL',
 'Endothelial_ALSvsCTRL',
 'IN-PV_sALSvsCTRL',
 'AST-FB_C9ALSvsCTRL']

In [5]:
## lets focus on C9ALS
ct = [ e for e in os.listdir(data_root) if 'C9ALS' in e]
ct

['IN-SST_C9ALSvsCTRL',
 'L4_C9ALSvsCTRL',
 'L5-6-CC_C9ALSvsCTRL',
 'L5-6_C9ALSvsCTRL',
 'Excitatory_C9ALSvsCTRL',
 'Oligodendrocytes_C9ALSvsCTRL',
 'OPC_C9ALSvsCTRL',
 'IN-VIP_C9ALSvsCTRL',
 'AST-PP_C9ALSvsCTRL',
 'IN-PV_C9ALSvsCTRL',
 'L2-3_C9ALSvsCTRL',
 'Inhibitory_C9ALSvsCTRL',
 'Microglia_C9ALSvsCTRL',
 'AST-FB_C9ALSvsCTRL']

In [6]:
sequences_dict = {}
for ct_cn in ct:
    inp_fa = data_root + "/{}/switch_DNA_sequence.fa".format(ct_cn)
    inp_fa = SeqIO.parse(inp_fa, "fasta")
    for rec in inp_fa:
        if rec.id not in sequences_dict:
            sequences_dict[rec.id] = str(rec.seq)
        else:
            continue  

In [7]:
len(sequences_dict)

107427

In [8]:
list(sequences_dict.keys())[10]

'chr2:AAK1:69461526:69472042:-'

In [9]:
celltypes = [e.split('_')[0] for e in ct]
celltypes = sorted(celltypes)
celltypes

['AST-FB',
 'AST-PP',
 'Excitatory',
 'IN-PV',
 'IN-SST',
 'IN-VIP',
 'Inhibitory',
 'L2-3',
 'L4',
 'L5-6',
 'L5-6-CC',
 'Microglia',
 'OPC',
 'Oligodendrocytes']

In [76]:
celltypes_order = {}
for e,i in enumerate(celltypes):
    celltypes_order[i] = e
celltypes_order

{'AST-FB': 0,
 'AST-PP': 1,
 'Excitatory': 2,
 'IN-PV': 3,
 'IN-SST': 4,
 'IN-VIP': 5,
 'Inhibitory': 6,
 'L2-3': 7,
 'L4': 8,
 'L5-6': 9,
 'L5-6-CC': 10,
 'Microglia': 11,
 'OPC': 12,
 'Oligodendrocytes': 13}

In [122]:
labels = {}
for key in sequences_dict.keys():
    labels[key] = np.zeros((14,2), dtype=float)
len(labels)

107427

In [123]:
def get_lfc_logp(df, name):
    tmp_df =  df.loc[df['switch_name'] == name]
    lfc = round(float(tmp_df['LFC_PA_Usage']),4)
    nlogp = round(float(tmp_df['negative_logFDR']),4)
    return([lfc,nlogp])


In [124]:
for ct in celltypes:
    df_name = data_root + ct + "_C9ALSvsCTRL/APAlog_res_metadata_added.tsv"
    inp_df = pd.read_csv(df_name, sep='\t')
    ct_idx = celltypes_order[ct]
    for key in list(labels.keys()):
        if key in inp_df['switch_name'].values:
            res = get_lfc_logp(inp_df, key)
            labels[key][ct_idx] = res
        else:
            continue

In [131]:
dict(list(labels.items())[0:5])

{'chr12:AACS:125140928:125143316:+': array([[ 0.    ,  0.    ],
        [ 0.0197,  0.0595],
        [-0.7202,  7.8307],
        [ 0.    ,  0.    ],
        [ 0.7973,  4.3538],
        [ 0.    ,  0.    ],
        [ 0.4728,  3.7107],
        [-0.7088,  5.2996],
        [ 0.    ,  0.    ],
        [ 0.    ,  0.    ],
        [ 0.    ,  0.    ],
        [ 0.    ,  0.    ],
        [ 0.4067,  1.5954],
        [ 0.5094,  8.9315]]),
 'chr4:AADAT:170060271:170060673:-': array([[ 0.    ,  0.    ],
        [-0.0885,  0.3384],
        [ 0.    ,  0.    ],
        [ 1.3189,  5.5945],
        [ 1.1698,  6.8452],
        [ 2.0254, 11.2648],
        [ 1.5136, 23.2604],
        [ 0.    ,  0.    ],
        [ 0.    ,  0.    ],
        [ 0.    ,  0.    ],
        [ 0.    ,  0.    ],
        [ 0.    ,  0.    ],
        [ 0.1105,  0.3281],
        [ 0.7901, 25.9075]]),
 'chr2:AAK1:69457997:69461526:-': array([[ 0.    ,  0.    ],
        [ 0.6892, 13.9665],
        [ 0.1218,  0.8157],
        [ 0.9573, 15.66

## ok we have the sequences and labels dictionaries.
## next lets translate the sequences to the sense RNAs 
## then make all data dictionary

In [141]:
def transcribe_positive_strand(seq):
    """ input is the 5' to 3' coding squence
        so the RNA will be exact sequence except
        U instead of T
    """
    return(seq.replace('T','U'))

def transcribe_negative_strand(seq):
    """ input is the 5' to 3' template squence
        so the function complement and returns
        the reverse of sequence
    """
    complement = {'A': 'U', 'C': 'G', 'G': 'C', 'T': 'A'}
    return "".join(complement.get(base, base) for base in reversed(seq))

In [None]:
transcribed_sequences = {}
for key,value in sequences_dict.items():
    strand = key.split(':')[-1]
    if strand == '+':
        transcribed_sequences[key] = transcribe_positive_strand(value)
    else:
        transcribed_sequences[key] = transcribe_negative_strand(value)

In [145]:
all_data_dict = {}
for key,value in transcribed_sequences.items():
    all_data_dict[value] = labels[key]

In [146]:
dict(list(all_data_dict.items())[0:2])

{'GUGAGGCGGGACAAACUUGUCUUCCUCACACCCAUCUUACUUCCUCUUAUGAGGAAACCCAGAGAGAUGAGGGGUCUUGCCCAAGGAAGGGGUGUCCAUAGUCAGCUCUGCCUUCUGCUCACCCAGAAUAAAGACCUGGGGACCCCGCGAGGGUCAUGGCCAAGUGGAAUGGACUCCUGGCAUUUGAGGGCUUCCCGACUGCAGCCCUCAGGCAGCCAUGGCUGUCCCAAGUCCAGCGGGCCUUUGCUCGGGUCAUGGCUGGGAUGUCUGGCCCUUCCUGACAGGAGGCUGCUGGGCUCCUGUCUACUUGGGGACGCCUCAUGCAGGAGCUGGUGUGGGGGUGGGCAGGGGGGCGGUGGCUUCUUCCUUUCUCUUUCCCUUUCCUCUACCUUUUCCCCUCUCCCCAGAGGAAAUGGUAGCAGGAUUUCUUUUAAGAGGAUGCUGCUGUAUUUUGCCAGCGGGUGGAAGGUGGCGGUAUUAGCUCCCGUGAGCUGCACGUGGACCCCUGUGUGAAGCGUAGCAGGGCACAGAGCAGGCGAGACGUUUGCAUCUCACAGCGGGAGGGCCGGCGACAUCACAUGAAGUGACAGGCAGGCCCUUGGAAGCCGGUGCUUAGAUCCUUAAUUAGUUCACACGUCGACUGAAUUUUCAAGUGAAUGAAUUUUAAUUACAUCUCAGGUUAAAAAAAAAAAAAGGCGCCAGUGAUCGAGGACUCGUCACUGGGCUCUGUUGCUCCUGAAGUUUCCUAGCCCACAACACACCAACACUGCCAAGGGCUCUUCUGGAUUCAAGGUGAAACACAUGUGCCAUAAAUCUUGGAGCUCUGAAUGUUUGGAAAGGGCCCGACUGUGAGAAGAAGUAACACACCGUCCCGUGCAGAUGGCUGGCUCUGAGGAGGAGUUCAUGGGAGCUUGGGGACACUCUUGCCUCUAGUUCUAGGAAGCUGGGCCACUUCUGAAGUAAUGGCAAUAUCAAUAAAGUAAUGGUCUUUAUCAUAG

In [150]:
outname = data_root + 'C9ALS_ALL_training_test_data.pkl'
with open(outname, 'wb') as f:
    pickle.dump(all_data_dict, f)

## Okk lets do the same thing for sALS and we are done with data processing :) 

In [152]:
ct = [ e for e in os.listdir(data_root) if 'sALS' in e]
ct

['Inhibitory_sALSvsCTRL',
 'Excitatory_sALSvsCTRL',
 'L2-3_sALSvsCTRL',
 'AST-FB_sALSvsCTRL',
 'L4_sALSvsCTRL',
 'IN-VIP_sALSvsCTRL',
 'AST-PP_sALSvsCTRL',
 'L5-6-CC_sALSvsCTRL',
 'Oligodendrocytes_sALSvsCTRL',
 'IN-SST_sALSvsCTRL',
 'OPC_sALSvsCTRL',
 'Microglia_sALSvsCTRL',
 'IN-PV_sALSvsCTRL']

In [153]:
len(ct)

13

In [154]:
sequences_dict = {}
for ct_cn in ct:
    inp_fa = data_root + "/{}/switch_DNA_sequence.fa".format(ct_cn)
    inp_fa = SeqIO.parse(inp_fa, "fasta")
    for rec in inp_fa:
        if rec.id not in sequences_dict:
            sequences_dict[rec.id] = str(rec.seq)
        else:
            continue  

In [155]:
len(sequences_dict)

107834

In [157]:
celltypes = [e.split('_')[0] for e in ct]
celltypes = sorted(celltypes)
celltypes

['AST-FB',
 'AST-PP',
 'Excitatory',
 'IN-PV',
 'IN-SST',
 'IN-VIP',
 'Inhibitory',
 'L2-3',
 'L4',
 'L5-6-CC',
 'Microglia',
 'OPC',
 'Oligodendrocytes']

In [158]:
celltypes_order = {}
for e,i in enumerate(celltypes):
    celltypes_order[i] = e
celltypes_order

{'AST-FB': 0,
 'AST-PP': 1,
 'Excitatory': 2,
 'IN-PV': 3,
 'IN-SST': 4,
 'IN-VIP': 5,
 'Inhibitory': 6,
 'L2-3': 7,
 'L4': 8,
 'L5-6-CC': 9,
 'Microglia': 10,
 'OPC': 11,
 'Oligodendrocytes': 12}

In [159]:
labels = {}
for key in sequences_dict.keys():
    labels[key] = np.zeros((13,2), dtype=float)
len(labels)

107834

In [160]:
for ct in celltypes:
    df_name = data_root + ct + "_sALSvsCTRL/APAlog_res_metadata_added.tsv"
    inp_df = pd.read_csv(df_name, sep='\t')
    ct_idx = celltypes_order[ct]
    for key in list(labels.keys()):
        if key in inp_df['switch_name'].values:
            res = get_lfc_logp(inp_df, key)
            labels[key][ct_idx] = res
        else:
            continue

In [161]:
transcribed_sequences = {}
for key,value in sequences_dict.items():
    strand = key.split(':')[-1]
    if strand == '+':
        transcribed_sequences[key] = transcribe_positive_strand(value)
    else:
        transcribed_sequences[key] = transcribe_negative_strand(value)
all_data_dict = {}
for key,value in transcribed_sequences.items():
    all_data_dict[value] = labels[key]

In [162]:
dict(list(all_data_dict.items())[0:2])

{'GUGAGGCGGGACAAACUUGUCUUCCUCACACCCAUCUUACUUCCUCUUAUGAGGAAACCCAGAGAGAUGAGGGGUCUUGCCCAAGGAAGGGGUGUCCAUAGUCAGCUCUGCCUUCUGCUCACCCAGAAUAAAGACCUGGGGACCCCGCGAGGGUCAUGGCCAAGUGGAAUGGACUCCUGGCAUUUGAGGGCUUCCCGACUGCAGCCCUCAGGCAGCCAUGGCUGUCCCAAGUCCAGCGGGCCUUUGCUCGGGUCAUGGCUGGGAUGUCUGGCCCUUCCUGACAGGAGGCUGCUGGGCUCCUGUCUACUUGGGGACGCCUCAUGCAGGAGCUGGUGUGGGGGUGGGCAGGGGGGCGGUGGCUUCUUCCUUUCUCUUUCCCUUUCCUCUACCUUUUCCCCUCUCCCCAGAGGAAAUGGUAGCAGGAUUUCUUUUAAGAGGAUGCUGCUGUAUUUUGCCAGCGGGUGGAAGGUGGCGGUAUUAGCUCCCGUGAGCUGCACGUGGACCCCUGUGUGAAGCGUAGCAGGGCACAGAGCAGGCGAGACGUUUGCAUCUCACAGCGGGAGGGCCGGCGACAUCACAUGAAGUGACAGGCAGGCCCUUGGAAGCCGGUGCUUAGAUCCUUAAUUAGUUCACACGUCGACUGAAUUUUCAAGUGAAUGAAUUUUAAUUACAUCUCAGGUUAAAAAAAAAAAAAGGCGCCAGUGAUCGAGGACUCGUCACUGGGCUCUGUUGCUCCUGAAGUUUCCUAGCCCACAACACACCAACACUGCCAAGGGCUCUUCUGGAUUCAAGGUGAAACACAUGUGCCAUAAAUCUUGGAGCUCUGAAUGUUUGGAAAGGGCCCGACUGUGAGAAGAAGUAACACACCGUCCCGUGCAGAUGGCUGGCUCUGAGGAGGAGUUCAUGGGAGCUUGGGGACACUCUUGCCUCUAGUUCUAGGAAGCUGGGCCACUUCUGAAGUAAUGGCAAUAUCAAUAAAGUAAUGGUCUUUAUCAUAG

In [164]:
outname = data_root + 'sALS_ALL_training_test_data.pkl'
with open(outname, 'wb') as f:
    pickle.dump(all_data_dict, f)

In [163]:
len(all_data_dict)

107457

## sanity checks :) 

In [167]:
with open(data_root + 'C9ALS_ALL_training_test_data.pkl', 'rb') as inputFile:
        C9ALS_data = pickle.load(inputFile)
        
with open(data_root + 'sALS_ALL_training_test_data.pkl', 'rb') as inputFile:
        sALS_data = pickle.load(inputFile)

In [171]:
len(C9ALS_data)

107051

In [172]:
len(sALS_data)

107457