<a href="https://colab.research.google.com/github/andanil/Z-RNA-prediction-tool/blob/colab/Z_RNA_prediction_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install dependecies and necessary files

In [None]:
! pip install pysam
! pip install -q condacolab

In [None]:
import condacolab
condacolab.install()

In [None]:
! conda install -y -c bioconda viennarna

In [None]:
! wget http://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz
! gzip -d hg38.fa.gz

In [None]:
! git clone https://github.com/andanil/Z-RNA-prediction-tool.git

In [None]:
from google.colab import files

import pysam
import numpy as np
import pandas as pd
from tqdm import tqdm

import sys
sys.path.insert(0, "/content/Z-RNA-prediction-tool")

In [None]:
from src.config import ZHUNT_FILE

! chmod 777 {ZHUNT_FILE}

In [None]:
genome = pysam.FastaFile('hg38.fa')

In [None]:
from zrna_pipeline.rna.rna_struct import RNA_structure
from zrna_pipeline.utils import from_chrom_to_seq_coord, get_seq_around_pos, reverse_region
from zrna_pipeline.rna_struct_analyser import get_rna_struct

## Read data

In [None]:
# upload file with Z-DNA regions
uploaded = files.upload()
filename = list(uploaded.keys())[0]

In [None]:
zdna_predicted_regions = pd.read_csv(filename, sep='\t', names=['chr', 'start', 'end'])

## Find regions that fall in stems

### Define some useful functions

In [None]:
def overlap(region_1, region_2):
    return max(min(region_1[1], region_2[1]) - max(region_1[0], region_2[0]) + 1, 0)

def is_overlap_n(zdna_region, stem_coord, n_bp):
    if zdna_region[0] > stem_coord[1]:
        n_overlap = overlap(zdna_region, [stem_coord[2], stem_coord[3]])
    else:
        n_overlap = overlap(zdna_region, [stem_coord[0], stem_coord[1]])
    if len(stem_coord) == 4:
        return (n_overlap >= n_bp, n_overlap, 0)

    n_unpaired = 0
    for bulge in stem_coord[4]:
        overlap_unpaired = overlap(zdna_region, bulge.unpaired)
        if overlap_unpaired > 0:
            n_unpaired += overlap_unpaired
            n_overlap -= overlap_unpaired
        elif overlap(zdna_region, bulge.opposite) == 2:
            n_unpaired += bulge.len_unpaired()
    return (n_overlap >= n_bp, n_overlap, n_unpaired)

def is_in_stem(stems_coord, zdna_region):
    res = (False, -1, -1)
    for stem_coord in stems_coord:
        if not (zdna_region[1] < stem_coord[0] or zdna_region[0] > stem_coord[3] or
                (zdna_region[0] > stem_coord[1] and zdna_region[1] < stem_coord[2])):
            interm_res = is_overlap_n(zdna_region, stem_coord, 6)
            if interm_res[1] > res[1]:
                res = interm_res
    return res

def analyse_struct(genome, window, chrom, zdna_region, union=False):
    middle = int((zdna_region[0] + zdna_region[1]) / 2)
    seq = get_seq_around_pos(genome, chrom, middle, window).upper()
    rna_struct = get_rna_struct(seq, -1)
    region_in_struct = from_chrom_to_seq_coord(zdna_region, middle, window[0], len(rna_struct))
    stems_coord = rna_struct.get_stems_union() if union else rna_struct.get_stems_coord()
    result = is_in_stem(stems_coord, region_in_struct)
    if not result[0]:
        rna_struct.reset_graph_type(False)
        stems_coord = rna_struct.get_stems_union() if union else rna_struct.get_stems_coord()
        reversed_region = reverse_region(region_in_struct, len(rna_struct))
        return (*is_in_stem(stems_coord, reversed_region), '-'), rna_struct
    return (*result, '+'), rna_struct

def iterate_through_windows(genome, windows, chrom, zdna_region, union=False):
    rna_structs = []
    results = []
    for window in windows:
        result, rna_struct = analyse_struct(genome, window, chrom, zdna_region, union)
        if result[0]:
            rna_structs += [rna_struct]
            results += [result]
    if not rna_structs:
        return result
    index = rna_structs.index(max(rna_structs))
    return (*results[index], rna_structs[index].z_score, len(rna_structs[index]), len(rna_structs))

def is_in_any_stem(genome, windows, chrom, zdna_region):
    result = iterate_through_windows(genome, windows, chrom, zdna_region)
    if not result[0]:
        return iterate_through_windows(genome, windows, chrom, zdna_region, union=True)
    return result

In [None]:
def run_search(data, k):
    windows = [[150, 150], [250, 250], [500, 500]]
    N = 5000
    zrna_in_stems = pd.DataFrame()
    for i, row in tqdm(data.iterrows(), total=N):
        result = is_in_any_stem(genome, windows, row['chr'], (row['start'], row['end']))
        if result[0]:
            zrna_in_stems = pd.concat([zrna_in_stems, pd.DataFrame([{'chr':row['chr'],
                                                                      'start':row['start'],
                                                                      'end':row['end'],
                                                                      'overlap': result[1],
                                                                      'zh-score': result[4],
                                                                      'not paired': result[2],
                                                                      'window': result[5],
                                                                      'strand': result[3],
                                                                      'n_structs': result[6]}])],
                                      ignore_index=True)
    zrna_in_stems.to_csv(f'in_stems_{k}.bed', sep='\t', index=False)

### Searching

In [None]:
from multiprocessing import Process

def run_parallel(function, data_batches, n):
    processes = []
    for i, data in enumerate(data_batches):
        proc = Process(target=function, args=(data, n + i))
        proc.start()
        processes.append(proc)
    for proc in processes:
        proc.join()

In [None]:
k = 0
N = 5000

data_list = []
while k * N < len(zdna_pred):
    data_list += [zdna_pred.iloc[k * N: min((k + 1) * N, len(zdna_pred))]]
    k += 1

In [None]:
k = 0
n = 10
while k * n < len(data_list):
    run_parallel(run_search, data_list[k * n:  min((k + 1) * n, len(data_list))], n*k)
    k += 1