In [1]:
import pickle
import torch
from torch import nn
import torchvision
import torch.optim as optim
from tqdm import tqdm_notebook as tqdm

In [2]:
from modules.utils import *

In [3]:
ensembl_trx = pickle.load(open('data/ensembl_trx.pkl', 'rb'))

In [4]:
trx_orfs = pickle.load(open('data/trx_orfs.pkl', 'rb'))

In [12]:
dataset = dict()
for trx, orfs in tqdm(trx_orfs.items()):
    biotype = ensembl_trx[trx]['biotype']
    seq, seq_len = ensembl_trx[trx]['sequence'], len(ensembl_trx[trx]['sequence'])
    for orf, attrs in orfs.items():
        if seq_len < 30000:
            if biotype == 'protein_coding' and orf.startswith('ENSP'):
                dataset[trx] = seq
            elif biotype == 'pseudogene':
                dataset[trx] = seq

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for trx, orfs in tqdm(trx_orfs.items()):


  0%|          | 0/227492 [00:00<?, ?it/s]

In [13]:
len([x for x in ensembl_trx.values() if x["biotype"] == "pseudogene"])

17081

In [14]:
len(dataset)

66296

In [15]:
dataset

{'ENST00000426406': 'AGCCCAGTTGGCTGGACCAATGGATGGAGAGAATCACTCAGTGGTATCTGAGTTTTTGTTTCTGGGACTCACTCATTCATGGGAGATCCAGCTCCTCCTCCTAGTGTTTTCCTCTGTGCTCTATGTGGCAAGCATTACTGGAAACATCCTCATTGTGTTTTCTGTGACCACTGACCCTCACTTACACTCCCCCATGTACTTTCTACTGGCCAGTCTCTCCTTCATTGACTTAGGAGCCTGCTCTGTCACTTCTCCCAAGATGATTTATGACCTGTTCAGAAAGCGCAAAGTCATCTCCTTTGGAGGCTGCATCGCTCAAATCTTCTTCATCCACGTCGTTGGTGGTGTGGAGATGGTGCTGCTCATAGCCATGGCCTTTGACAGATATGTGGCCCTATGTAAGCCCCTCCACTATCTGACCATTATGAGCCCAAGAATGTGCCTTTCATTTCTGGCTGTTGCCTGGACCCTTGGTGTCAGTCACTCCCTGTTCCAACTGGCATTTCTTGTTAATTTAGCCTTCTGTGGCCCTAATGTGTTGGACAGCTTCTACTGTGACCTTCCTCGGCTTCTCAGACTAGCCTGTACCGACACCTACAGATTGCAGTTCATGGTCACTGTTAACAGTGGGTTTATCTGTGTGGGTACTTTCTTCATACTTCTAATCTCCTACGTCTTCATCCTGTTTACTGTTTGGAAACATTCCTCAGGTGGTTCATCCAAGGCCCTTTCCACTCTTTCAGCTCACAGCACAGTGGTCCTTTTGTTCTTTGGTCCACCCATGTTTGTGTATACACGGCCACACCCTAATTCACAGATGGACAAGTTTCTGGCTATTTTTGATGCAGTTCTCACTCCTTTTCTGAATCCAGTTGTCTATACATTCAGGAATAAGGAGATGAAGGCAGCAATAAAGAGAGTATGCAAACAGCTAGTGATTTACAAGAGGATCTCATAAATGATATAATAAGCCCTTCTC

In [6]:
data = [(x,y) for x,y in dataset.items()]

In [9]:
def build_fasta(data, filename):
    fasta_sequences = []dataset
    for i, trx in enumerate(data):
        trx_id, seq = trx[0], trx[1]
        header = f'>{i} {trx_id}'
        fasta_sequence = header + '\n' + seq.upper()
        fasta_sequences.append(fasta_sequence)
    
    fasta_text = '\n'.join(fasta_sequences)
    
    with open(filename + '.fasta', 'w') as fasta_file:
        fasta_file.write(fasta_text)

In [10]:
build_fasta(data, 'mydb')

In [5]:
psicube_trx = dict()
with open('psicube_trx.txt', 'r') as f:
    reader = csv.reader(f, delimiter='\t')
    for n, row in enumerate(reader):
        if n == 0:
            cols = row
            continue
        line = dict(zip(cols, row))
        pseudo_trx = line["ID"].split(".")[0]
        if "Parent transcript" not in line:
            continue
        parent_trx = line["Parent transcript"].split(".")[0]
        psicube_trx[pseudo_trx] = parent_trx

In [6]:
len(psicube_trx)

10370

In [7]:
cluster_dict = {}
current_cluster = ''
with open('output.txt', 'r') as f:
    for line in f:
        if line.startswith('#'):
            current_cluster = line.strip()
            cluster_dict[current_cluster] = []
        else:
            _, enst = line.strip().split()
            cluster_dict[current_cluster].append(enst)

In [8]:
len(cluster_dict)

16353

In [16]:
count = 0
check = 0
for cluster, trxps in cluster_dict.items():
    for trx in trxps:
        if trx in psicube_trx and psicube_trx[trx] in dataset:
            check += 1
            parent = psicube_trx[trx]
            if parent in trxps:
                count += 1
        else:
            continue

In [17]:
count

4974

In [18]:
check

6599

In [20]:
strings = [x for x in cluster_dict.keys()]
random.seed(5)
random.shuffle(strings)
bins = np.array_split(strings, 5)
split_dict = dict()
for idx, bin_ in enumerate(bins):
    for cluster in bin_:
        for trx in cluster_dict[cluster]:
            split_dict[trx] = idx

In [21]:
from collections import Counter

In [22]:
Counter([x for x in split_dict.values()])

Counter({0: 13488, 1: 13018, 2: 12822, 3: 13678, 4: 13290})

In [66]:
pickle.dump(split_dict, open('data/split_dict.pkl', 'wb'))