In [8]:
import dataclasses as dataclass
import os

import jax
import jax.numpy as jnp
import pandas as pd

from jax import jit, vmap
from shapetune import _partition, _log_likelihood

_log_likelihood
jnp

############################################################
####### parse the m1psi SHAPE reactivity data   ############
############################################################

not_registered = True

@dataclass.dataclass
class SHAPE_DATA:
    tsv_list = [f for f in os.listdir("../data") if f.endswith(".tsv")]
    NAMES = [f.split(".")[0] for f in tsv_list]
    RAW_MOU = {
        name: pd.read_csv(f"../data/{f}", sep="\t")
        for f, name in zip(tsv_list, NAMES)
        if name[:3] == "moU"
    }
    RAW_MPU = {
        name: pd.read_csv(f"../data/{f}", sep="\t")
        for f, name in zip(tsv_list, NAMES)
        if name[:3] == "mpU"
    }
    RAW_U = {
        name: pd.read_csv(f"../data/{f}", sep="\t")
        for f, name in zip(tsv_list, NAMES)
        if name[:1] == "U"
    }

    def __post_init__(self) -> None:
        # concatenate the "Sequence" column for each dataframe
        self.SEQUENCES_MOU = {
            name: ("".join(df["Sequence"].values)).replace('T','U')
            for name, df in self.RAW_MOU.items()
        }
        self.SEQUENCES_MPU = {
            name: ("".join(df["Sequence"].values)).replace("T", "U")
            for name, df in self.RAW_MPU.items()
        }
        self.SEQUENCES_U = {
            name: ("".join(df["Sequence"].values)).replace("T", "U")
            for name, df in self.RAW_U.items()
        }
        self.REACTIVITY_MOU = {
            name: df["Normalized Reactivity"].values.tolist()
            for name, df in self.RAW_MOU.items()
        }
        self.REACTIVITY_MPU = {
            name: df["Normalized Reactivity"].values.tolist()
            for name, df in self.RAW_MPU.items()
        }
        self.REACTIVITY_U = {
            name: df["Normalized Reactivity"].values.tolist()
            for name, df in self.RAW_U.items()
        }

In [None]:
luc_seqs = SHAPE_DATA().SEQUENCES_U
luc_seqarrs = {
    key : jnp.frombuffer(
        seq.encode().translate(_partition.CONSTANTS.TRANSLATER), dtype=jnp.uint8
    ) for key, seq in luc_seqs.items()
}

luc_seqstack = jnp.vstack(tuple(luc_seqarrs[key] for key in luc_seqarrs.keys()))

@jax.jit
def shapetune_probvec(seqarr):
    bpp = _partition._partition_arr(seqarr, _partition.CONSTANTS.STACKING_PAIRS).bpp()
    prob_vec = (bpp + bpp.T).sum(axis=0)
    return prob_vec

: 

In [6]:
wow = _partition._partition_arr(luc_seqarrs['U_L1'], _partition.CONSTANTS.STACKING_PAIRS)
bpp  = wow.bpp()
bpp.nonzero()

(Array([], shape=(0,), dtype=int32), Array([], shape=(0,), dtype=int32))

In [7]:
luc_seqs['U_L1']

'GGGAAATAAGAGAGAAAAGAAGAGTAAGAAGAAATATAAGAGCCACCATGGAGGACGCGAAAAACATCAAAAAAGGGCCTGCGCCTTTTTACCCTCTGGAGGACGGGACCGCGGGGGAGCAACTGCACAAAGCGATGAAAAGGTACGCGCTGGTACCTGGGACCATCGCGTTTACCGACGCGCACATCGAGGTAGACATCACCTACGCGGAGTACTTTGAGATGAGCGTAAGGCTGGCGGAGGCGATGAAAAGGTACGGGCTGAACACCAACCACAGGATCGTAGTATGCAGCGAGAACAGCCTGCAATTTTTTATGCCTGTACTGGGGGCGCTGTTTATCGGGGTAGCGGTAGCGCCTGCGAACGACATCTACAACGAGAGGGAGCTGCTGAACAGCATGGGGATCAGCCAACCTACCGTAGTATTTGTAAGCAAAAAAGGGCTGCAAAAAATCCTGAACGTACAAAAAAAACTGCCTATCATCCAAAAAATCATCATCATGGACAGCAAAACCGACTACCAAGGGTTTCAAAGCATGTACACCTTTGTAACCAGCCACCTGCCTCCTGGGTTTAACGAGTACGACTTTGTACCTGAGAGCTTTGACAGGGACAAAACCATCGCGCTGATCATGAACAGCAGCGGGAGCACCGGGCTGCCTAAAGGGGTAGCGCTGCCTCACAGGACCGCGTGCGTAAGGTTTAGCCACGCGAGGGACCCTATCTTTGGGAACCAAATCATCCCTGACACCGCGATCCTGAGCGTAGTACCTTTTCACCACGGGTTTGGGATGTTTACCACCCTGGGGTACCTGATCTGCGGGTTTAGGGTAGTACTGATGTACAGGTTTGAGGAGGAGCTGTTTCTGAGGAGCCTGCAAGACTACAAAATCCAAAGCGCGCTGCTGGTACCTACCCTGTTTAGCTTTTTTGCGAAAAGCACCCTGATCGACAAATACGACCTGAGCAACCTGCACGAGATCGCGAGCGGGGGGGCGC

In [None]:
all_ensemble = wow.all_ensemble
jnp.exp(all_ensemble)

In [None]:
canon_reactivity = SHAPE_DATA().REACTIVITY_U

# filter out negative reactivity values

valid_reactivity = {
    name: [max(0, r) for r in reactivity]
    for name, reactivity in canon_reactivity.items()
}

In [1]:
sequence = "GCGCGC"


def can_pair(i, j):
    base1 = sequence[i]
    base2 = sequence[j]
    # G-C and C-G pairings
    return (base1 == "G" and base2 == "C") or (base1 == "C" and base2 == "G")


memo = {}


def get_structures(i, j):
    if (i, j) in memo:
        return memo[(i, j)]
    if i > j:
        return set([frozenset()])
    structures = set()
    # Case 1: Position i is unpaired
    for s in get_structures(i + 1, j):
        structures.add(s)
    # Case 2: Position i pairs with position k
    for k in range(i + 1, j + 1):
        if can_pair(i, k):
            for left in get_structures(i + 1, k - 1):
                for right in get_structures(k + 1, j):
                    s = set(left)
                    s.update(right)
                    s.add((i, k))
                    structures.add(frozenset(s))
    memo[(i, j)] = structures
    return structures


all_structures = get_structures(0, len(sequence) - 1)

# Convert frozensets to sorted lists and print the structures
for struct in all_structures:
    struct_list = list(struct)
    struct_list.sort()
    print(struct_list)

[(1, 4)]
[(1, 4), (2, 3)]
[(0, 5), (1, 4), (2, 3)]
[(0, 5), (1, 4)]
[(0, 1), (2, 3), (4, 5)]
[(0, 1), (3, 4)]
[(0, 1), (2, 3)]
[(0, 1)]
[(1, 2), (3, 4)]
[(1, 2), (4, 5)]
[(0, 3)]
[(2, 3), (4, 5)]
[(0, 3), (1, 2)]
[(1, 2)]
[(0, 1), (2, 5), (3, 4)]
[(0, 5), (3, 4)]
[(0, 3), (4, 5)]
[(3, 4)]
[(0, 5), (1, 2)]
[(0, 5), (1, 2), (3, 4)]
[(0, 5)]
[(0, 3), (1, 2), (4, 5)]
[(0, 5), (2, 3)]
[(4, 5)]
[(2, 5)]
[(0, 1), (4, 5)]
[(0, 1), (2, 5)]
[(2, 5), (3, 4)]
[]
[(2, 3)]
