In [None]:
import pandas as pd
import sys
import os
import warnings
from config import *
import pickle

import torch
from torch.utils.data import DataLoader


sys.path.append(f"{ROOT}")
PATH = "/home/malbranke/mutation_data"
warnings.filterwarnings("ignore")

In [None]:
from data_extraction import from_fasta_to_df, from_df_to_fasta, build_profiles
from ss_inference.data import SecondaryStructureRawDataset, collate_sequences
from ss_inference.model import NetSurfP2, ConvNet

In [3]:
meta_df = pd.read_excel(f"{PATH}/meta.xlsx", index_col = 0)

In [4]:
import re

name = "BLAT_ECOLX_Ostermeier2014"
metadata = meta_df.loc[name]

family = metadata.family
name_dataset = metadata.dataset
uniprotid = metadata.uniprot
exp_columns = re.findall(r"\w\w*", metadata.exp_columns)

In [6]:
from_fasta_to_df(f"{PATH}/{family}", f"{PATH}/{family}/{family}.fasta", chunksize=5000)

Processing 8402 sequences ...

## Uniformize Data

In [7]:
import numpy as np

def lcs(X, Y):
    m, n = len(X), len(Y)
    L = np.zeros((m + 1, n + 1))

    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if X[i - 1] == Y[j - 1] or Y[j-1] == "-":
                L[i, j] = L[i - 1, j - 1] + 1
            else:
                L[i, j] = 0

    i, j = np.unravel_index(L.argmax(), L.shape)
    posX, posY = [], []
    while i > 0 and j > 0:
        if L[i, j] == 0:
            break
        if X[i - 1] == Y[j - 1]:
            i -= 1
            j -= 1
            posX.append(i)
            posY.append(j)
            continue
    posX.sort(), posY.sort()
    return len(posX), (min(posX), max(posX), min(posY), max(posY)), L


def replace(seq, idx, end):
    i = int(idx)
    return seq[:i]+end+seq[i+1:]

In [9]:
nat_df = pd.read_csv(f"{PATH}/{family}/sequences.csv")
mut_df = pd.read_csv(f"{PATH}/{family}/{family}_{name_dataset}.csv", comment = "#", sep=";")

mut_df["origin"] = mut_df.mutant.apply(lambda x : x[0] if len(x) > 2 else None)
mut_df["end"] = mut_df.mutant.apply(lambda x : x[-1] if len(x) > 2 else None)
mut_df["idx"] = mut_df.mutant.apply(lambda x : int(x[1:-1])-1 if len(x) > 2 else -1)

seq_nat  = nat_df.loc[0].seq
seq_mut = "" 
for i in range(int(max(mut_df["idx"]))+1):
    try:
        seq_mut += mut_df[mut_df.idx == i].reset_index().origin[0]
    except:
        seq_mut += "-"

_, (m_nat, M_nat, m_mut, M_mut), _ = lcs(seq_nat, seq_mut)

mut_df = mut_df[(mut_df.idx <= M_mut) | (mut_df.mutant == "WT")]
mut_df = mut_df[(mut_df.idx >= m_mut) | (mut_df.mutant == "WT")]
mut_df.idx = mut_df.idx - m_mut + m_nat

seq_mut2 = "" 
for i in range(int(max(mut_df["idx"]))+1):
    try:
        seq_mut2 += mut_df[mut_df.idx == i].reset_index().origin[0]
    except:
        seq_mut2 += "-"
print(seq_mut2 == seq_nat[m_nat:M_nat+1] )
    
mut_df["name"] = mut_df.apply(lambda x : (x.origin + str(x.idx) + x.end) if x.mutant != "WT" else "WT", axis = 1)
mut_df["seq"] = seq_nat
mut_df["seq"] = mut_df.apply(lambda x : replace(x.seq, x.idx, x.end) if x.mutant != "WT" else x.seq, axis = 1)
mut_df["aligned_seq"] = mut_df["seq"]
mut_df = mut_df.set_index("name")

mut_df.to_csv(f"{PATH}/{family}/{name_dataset}_mutation_sequences.csv")

True


In [9]:
from_df_to_fasta(f"{PATH}/{family}", f"{PATH}/{family}/{name_dataset}_mutation_sequences.csv", prefix = f"{dataset}_")

Processing 4783 sequences ...

In [10]:
from_df_to_fasta(f"{PATH}/{family}", f"{PATH}/{family}/sequences.csv")

Processing 8353 sequences ...

In [17]:
print("make HMM profile")
subprocess.run(f'hhmake -i {PATH}/{family}/aligned.fasta -M 100', shell=True)

make HMM profile


CompletedProcess(args='hhmake -i /home/malbranke/mutation_data/BLAT_ECOLX/aligned.fasta -M 100', returncode=0)

In [11]:
build_profiles(f"{PATH}/{family}", prefix = f"{name_dataset}_")

100%|██████████| 4784/4784 [00:24<00:00, 192.84it/s]


In [10]:
dataset = SecondaryStructureRawDataset(f"{PATH}/{family}/{name_dataset}_hmm.pkl")

print("Secondary structure")
loader = DataLoader(dataset, batch_size = 1,
                        shuffle=False, drop_last=False, collate_fn = collate_sequences)

device = "cuda"

model_ss3 = NetSurfP2(50, "")
model_ss3 = model_ss3.to(device)
model_ss3.load_state_dict(torch.load(f"{DATA}/secondary_structure/lstm_50feats.h5"))
print(model_ss3)

#_, _, ss3 = model_ss3.predict(loader)
#pickle.dump(ss3, open(f"{PATH}/{family}/{dataset}_ss3.pkl", "wb"))

Secondary structure
Model -50


In [11]:
from pattern_matching.utils import *

import biotite
import biotite.structure as struc

import biotite.database.rcsb as rcsb
import biotite.structure.io.mmtf as mmtf

In [12]:
pdb_uniprot = pd.read_csv(f"{DATA}/cross/uniprot_pdb.csv", index_col=0)
pdb_uniprot[pdb_uniprot.uni == uniprotid].pdb.values

array(['1AXB', '1BT5', '1BTL', '1CK3', '1ERM', '1ERO', '1ERQ', '1ESU',
       '1FQG', '1JTD', '1JTG', '1JVJ', '1JWP', '1JWV', '1JWZ', '1LHY',
       '1LI0', '1LI9', '1M40', '1NXY', '1NY0', '1NYM', '1NYY', '1PZO',
       '1PZP', '1S0W', '1TEM', '1XPB', '1XXM', '1YT4', '1ZG4', '1ZG6',
       '2B5R', '2V1Z', '2V20', '3C7U', '3C7V', '3CMZ', '3DTM', '3JYI',
       '3TOI', '4DXB', '4DXC', '4GKU', '4IBR', '4IBX', '4ID4', '4MEZ',
       '4QY5', '4QY6', '4R4R', '4R4S', '4RVA', '4RX2', '4RX3', '4ZJ1',
       '4ZJ2', '4ZJ3', '5HVI', '5HW1', '5HW5', '5I52', '5I63', '5IQ8',
       '5KKF', '5KPU', '5NPO', '6APA', '6AYK', '6B2N'], dtype=object)

In [13]:
longest, patterns = 0, []
c = 'A'
for pdb in pdb_uniprot[pdb_uniprot.uni == uniprotid].pdb.values:
    file_name = rcsb.fetch(pdb, "mmtf", biotite.temp_dir())
    mmtf_file = mmtf.MMTFFile()
    mmtf_file.read(file_name)
    array = mmtf.get_structure(mmtf_file, model=1)
    # Transketolase homodimer
    tk_dimer = array[struc.filter_amino_acids(array)]
    # Transketolase monomer
    tk_mono = tk_dimer[tk_dimer.chain_id == c]

    chain_id_per_res = array.chain_id[struc.get_residue_starts(tk_dimer)]
    chain_idx = np.where(chain_id_per_res == c)[0]
    ss_seq = np.array(list(mmtf_file["entityList"][0]["sequence"]))[chain_idx]
    length, (m_nat, M_nat, m_mut, M_mut), _= lcs(seq_nat, "".join(ss_seq))
    if length < longest:
        continue
    if length > longest:
        longest = length
        patterns = []
    sse = mmtf_file["secStructList"]
    sse = sse[:chain_id_per_res.shape[0]][chain_id_per_res == c]
    sse = np.array(sse[m_mut: M_mut + 1])
    sse = np.array([sec_struct_codes[code] for code in sse], dtype="U1")
    sse = np.array([dssp_to_abc[e] for e in sse], dtype="U1")
    sse = to_onehot([abc_codes[x] for x in sse], (None, 3))
    dss = (sse[1:] - sse[:-1])
    cls = to_onehot(np.where(dss == -1)[1], (None, 3)).T
    bbox = np.array([np.where(dss == 1)[0], np.where(dss == -1)[0], *cls]).T
    pat = np.argmax(bbox[:, 2:], 1)

    patterns.append(pat)


In [14]:
c_patterns, n_patterns = [], []
for pat in patterns:
    char_pat = "".join(["abc"[x] for x in pat])
    if len(char_pat):
        c_patterns.append(char_pat)
        n_patterns.append(list(pat))
max_occ, c_pattern, n_pattern = 0, None, None
for c, n in zip(c_patterns, n_patterns):
    n_occ = c_patterns.count(c)
    if n_occ > max_occ:
        max_occ = n_occ
        c_pattern, n_pattern = c, n
c_patterns = [c_pattern]
n_patterns = [n_pattern]

In [15]:
from pattern_matching.utils import *
from pattern_matching.pattern import *
from pattern_matching.loss import *

from tqdm import tqdm

In [22]:
size = 263
Q = np.ones((3, size+1, size+1)) * (-np.inf)
e = size
for i in range(size+1):
    Q[:3, i, i+1:] = 0
Q = Q.reshape(1, *Q.shape)

regex = ([(i,None,None) for i in n_pattern])

seq_hmm = torch.load(f"{PATH}/{family}/hmm.pt")

matcher = PatternMatching(model_ss3, pattern = regex, Q = Q,
                          seq_hmm = seq_hmm, ss_hmm = None, 
                          size = size, name = c_patterns)

In [17]:
dataset = SecondaryStructureRawDataset(f"{PATH}/{family}/{name_dataset}_hmm.pkl")

print("Secondary structure")
loader = DataLoader(dataset, batch_size = 200,
                        shuffle=False, drop_last=False, collate_fn = collate_sequences)

Secondary structure


In [None]:
ls, M, L = [],[],[]
len_pat = len(matcher.pattern)
for batch_idx, data in tqdm(enumerate(loader)):
    x = data[0].permute(0,2,1).float()
    torch.cuda.empty_cache()
    m = Matching(x)
    matcher(m)
    L.append(m.L)
    ls.append(m.ls)
    M.append(m.M)
    del m
ls= torch.cat(ls,0)
M = torch.cat(M,0)
L = torch.cat(L,0)

8it [32:49, 246.13s/it]