In [1]:
try:
    import Bio
except:
    #for rhofold+ #####################
    !pip install biopython
    !pip install ml-collections
    !pip install python-box
    !pip install dm-tree
    !pip install openmm[cuda12]





from copy import deepcopy

import pandas as pd
from Bio.PDB import Atom, Model, Chain, Residue, Structure, PDBParser
from Bio import SeqIO
import os, sys
import re
import numpy as np

import matplotlib
import matplotlib.pyplot as plt

import subprocess

print('IMPORT OK !!!!')

IMPORT OK !!!!


In [2]:
PYTHON = sys.executable
print('PYTHON',PYTHON)

USALIGN = \
'kaggle/working//USalign'
#'<your us align path>/USalign'

os.system('cp kaggle/input/usalign/USalign /kaggle/working/')
os.system('sudo chmod u+x kaggle/working//USalign')

RHONET_DIR=\
'kaggle/input/data-for-demo-for-rhofold-plus-with-kaggle-msa/RhoFold-main'

DATA_KAGGLE_DIR = 'kaggle/input/rnafold'
SEQ_DF = pd.read_csv(f'{DATA_KAGGLE_DIR}/train_sequences.csv')
LABEL_DF = pd.read_csv(f'{DATA_KAGGLE_DIR}/train_labels.csv')
LABEL_DF['target_id'] = LABEL_DF['ID'].apply(lambda x: '_'.join(x.split('_')[:-1]))


# helper ----
class dotdict(dict):
	__setattr__ = dict.__setitem__
	__delattr__ = dict.__delitem__

	def __getattr__(self, name):
		try:
			return self[name]
		except KeyError:
			raise AttributeError(name)

# visualisation helper ----
def set_aspect_equal(ax):
	x_limits = ax.get_xlim()
	y_limits = ax.get_ylim()
	z_limits = ax.get_zlim()

	# Compute the mean of each axis
	x_middle = np.mean(x_limits)
	y_middle = np.mean(y_limits)
	z_middle = np.mean(z_limits)

	# Compute the max range across all axes
	max_range = max(x_limits[1] - x_limits[0],
					y_limits[1] - y_limits[0],
					z_limits[1] - z_limits[0]) / 2.0

	# Set the new limits to ensure equal scaling
	ax.set_xlim(x_middle - max_range, x_middle + max_range)
	ax.set_ylim(y_middle - max_range, y_middle + max_range)
	ax.set_zlim(z_middle - max_range, z_middle + max_range)




# xyz df helper --------------------
def get_truth_df(target_id):
    truth_df = LABEL_DF[LABEL_DF['target_id'] == target_id]
    truth_df = truth_df.reset_index(drop=True)
    return truth_df

def parse_pdb_to_df(pdb_file, target_id):
    parser = PDBParser()
    structure = parser.get_structure('', pdb_file)

    df = []
    for model in structure:
        for chain in model:
            print(chain)
            chain_data = []
            for residue in chain:
                # print(residue)
                if residue.get_resname() in ['A', 'U', 'G', 'C']:
                    # Check if the residue has a C1' atom
                    if 'C1\'' in residue:
                        atom = residue['C1\'']
                        xyz = atom.get_coord()
                        resname = residue.get_resname()
                        resid = residue.get_id()[1]

                        #todo detect discontinous: resid = prev_resid+1
                        #ID	resname	resid	x_1	y_1	z_1
                        chain_data.append(dict(
                            ID = target_id+'_'+str(resid),
                            resname=resname,
                            resid=resid,
                            x_1=xyz[0],
                            y_1=xyz[1],
                            z_1=xyz[2],
                        ))
                        ##print(f"Residue {resname} {resid}, Atom: {atom.get_name()}, xyz: {xyz}")

            if len(chain_data)!=0:
                chain_df = pd.DataFrame(chain_data)
                df.append(chain_df)
                ##print(chain_df)
    return df

# usalign helper --------------------
def write_target_line(
    atom_name, atom_serial, residue_name, chain_id, residue_num, x_coord, y_coord, z_coord, occupancy=1.0, b_factor=0.0, atom_type='P'
):
    """
    Writes a single line of PDB format based on provided atom information.

    Args:
        atom_name (str): Name of the atom (e.g., "N", "CA").
        atom_serial (int): Atom serial number.
        residue_name (str): Residue name (e.g., "ALA").
        chain_id (str): Chain identifier.
        residue_num (int): Residue number.
        x_coord (float): X coordinate.
        y_coord (float): Y coordinate.
        z_coord (float): Z coordinate.
        occupancy (float, optional): Occupancy value (default: 1.0).
        b_factor (float, optional): B-factor value (default: 0.0).

    Returns:
        str: A single line of PDB string.
    """
    return f'ATOM  {atom_serial:>5d}  {atom_name:<5s} {residue_name:<3s} {residue_num:>3d}    {x_coord:>8.3f}{y_coord:>8.3f}{z_coord:>8.3f}{occupancy:>6.2f}{b_factor:>6.2f}           {atom_type}\n'

def write_xyz_to_pdb(df, pdb_file, xyz_id = 1):
    resolved_cnt = 0
    with open(pdb_file, 'w') as target_file:
        for _, row in df.iterrows():
            x_coord = row[f'x_{xyz_id}']
            y_coord = row[f'y_{xyz_id}']
            z_coord = row[f'z_{xyz_id}']

            if x_coord > -1e17 and y_coord > -1e17 and z_coord > -1e17:
                resolved_cnt += 1
                target_line = write_target_line(
                    atom_name="C1'",
                    atom_serial=int(row['resid']),
                    residue_name=row['resname'],
                    chain_id='0',
                    residue_num=int(row['resid']),
                    x_coord=x_coord,
                    y_coord=y_coord,
                    z_coord=z_coord,
                    atom_type='C',
                )
                target_file.write(target_line)
    return resolved_cnt

def parse_usalign_for_tm_score(output):
    # Extract TM-score based on length of reference structure (second)
    tm_score_match = re.findall(r'TM-score=\s+([\d.]+)', output)
    if len(tm_score_match) < 2:
        raise ValueError(f"Expected at least 2 TM-score matches, but found {len(tm_score_match)}. Output:\n{output}")
    if not tm_score_match:
        raise ValueError('No TM score found')
    return float(tm_score_match)

def parse_usalign_for_transform(output):
    # Locate the rotation matrix section
    matrix_lines = []
    found_matrix = False

    for line in output.splitlines():
        if "The rotation matrix to rotate Structure_1 to Structure_2" in line:
            found_matrix = True
        elif found_matrix and re.match(r'^\d+\s+[-\d.]+\s+[-\d.]+\s+[-\d.]+\s+[-\d.]+$', line):
            matrix_lines.append(line)
        elif found_matrix and not line.strip():
            break  # Stop parsing if an empty line is encountered after the matrix

    # Parse the rotation matrix values
    rotation_matrix = []
    for line in matrix_lines:
        parts = line.split()
        row_values = list(map(float, parts[1:]))  # Skip the first column (index)
        rotation_matrix.append(row_values)

    return np.array(rotation_matrix)

def call_usalign(predict_df, truth_df, verbose=1):
    truth_pdb = '~truth.pdb'
    predict_pdb = '~predict.pdb'
    write_xyz_to_pdb(predict_df, predict_pdb, xyz_id=1)
    write_xyz_to_pdb(truth_df, truth_pdb, xyz_id=1)

    command = f'{USALIGN} {predict_pdb} {truth_pdb} -atom " C1\'" -m -'
    output = os.popen(command).read()
    if verbose==1:
        print(output)
    tm_score = parse_usalign_for_tm_score(output)
    transform = parse_usalign_for_transform(output)
    return tm_score, transform


# msa helper --------------------
def read_msa(msa_file):
    f = open(msa_file, 'r')
    line = f.readlines()

    msa = []
    for i in range(0, len(line),2):
        m = dotdict(
            comment =line[i],
            seqence =line[i+1],
        )
        assert(m.comment[0]=='>')
        msa.append(m)
    return msa


def write_msa(msa_file, msa):
    line=[]
    for m in msa:
        line .append(m.comment)
        line .append(m.seqence)

    f = open(msa_file, 'wt')
    f.writelines(line)
    return msa
 
def msa_to_rhonet_file(msa_file, num_msa=5, out_dir='',target_id='xxx'):
    msa = read_msa(msa_file)
    msa0 = deepcopy(msa[0])
    msa0.comment =f'>{target_id}\n'
    msa0 = [msa0]

    a3m_file = f'{out_dir}/{target_id}.a3m'
    fasta_file = f'{out_dir}/{target_id}.fasta'
    os.makedirs(out_dir, exist_ok=True)

    write_msa(fasta_file, msa0)
    write_msa(a3m_file, msa[:num_msa])

print('HELPER OK!!!')

PYTHON c:\Users\admin\AppData\Local\Programs\Python\Python310\python.exe
HELPER OK!!!


In [3]:
out_dir   ='kaggle/working/'
target_id = '1EIY_C'.upper()
sequence  = 'GCCGAGGUAGCUCAGUUGGUAGAGCAUGCGACUGAAAAUCGCAGUGUCCGCGGUUCGAUUCCGCGCCUCGGCACCA'
print('len(sequence):',len(sequence))


#1. prepare input
msa_file = f'{DATA_KAGGLE_DIR}/MSA/1EIY_C.MSA.fasta'
msa_to_rhonet_file(msa_file, num_msa=5, out_dir=out_dir,target_id=target_id)

# 2. Use absolute paths to avoid cwd confusion
fasta_path = os.path.abspath(os.path.join(out_dir, f'{target_id}.fasta'))
a3m_path   = os.path.abspath(os.path.join(out_dir, f'{target_id}.a3m'))
out_path   = os.path.abspath(out_dir)

# Debug check
print("FASTA:", fasta_path, os.path.exists(fasta_path))
print("A3M:", a3m_path, os.path.exists(a3m_path))

abs_fasta = os.path.abspath(os.path.join(out_dir, f'{target_id}.fasta'))
abs_a3m = os.path.abspath(os.path.join(out_dir, f'{target_id}.a3m'))

cmd = [
    PYTHON,
    'inference.py',
    '--input_fas', abs_fasta,
    '--input_a3m', abs_a3m,
    '--output_dir', os.path.abspath(out_dir),
    '--ckpt', './pretrained/model_20221010_params.pt'
]

# 3. Run inference inside RHONET_DIR
result = subprocess.run(
    cmd,
    cwd=RHONET_DIR,
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE,
    text=True
)

# 4. Show output
print("STDOUT:\n", result.stdout)
print("STDERR:\n", result.stderr)

len(sequence): 76


FileNotFoundError: [Errno 2] No such file or directory: 'kaggle/input/rnafold/MSA/1EIY_C.MSA.fasta'

In [None]:
#visualise prediction and compute tm score

predict_relax_df = parse_pdb_to_df(f'{out_dir}/relaxed_1000_model.pdb', target_id)
predict_unrelax_df = parse_pdb_to_df(f'{out_dir}/unrelaxed_model.pdb', target_id)

assert len(predict_relax_df)==1
assert len(predict_unrelax_df)==1
predict_relax_df = predict_relax_df[0]
predict_unrelax_df = predict_unrelax_df[0]

print(predict_relax_df)
print(predict_unrelax_df)

truth_df = get_truth_df(target_id)
print(truth_df)

tm_score_relax, transform_relax = call_usalign(predict_relax_df, truth_df, verbose=1)
tm_score_unrelax, transform_unrelax= call_usalign(predict_unrelax_df, truth_df, verbose=0)

print('tm_score_relax', tm_score_relax)
print('tm_score_unrelax', tm_score_unrelax)
print('transform_relax\n', transform_relax)
print('transform_unrelax\n', transform_unrelax)
zz=0

if 1:
    COLOR = ['red', 'blue', 'green', 'black', 'yellow', 'cyan', 'magenta']
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')
    # ax.clear()

    #unrelax
    coord = predict_unrelax_df[['x_1', 'y_1', 'z_1']].to_numpy().astype('float32')
    coord = coord@transform_unrelax[:,1:].T + transform_unrelax[:,[0]].T
    x, y, z = coord[:, 0], coord[:, 1], coord[:, 2]
    ax.scatter(x, y, z, c='red', s=30, alpha=1)
    ax.plot(x, y, z, color='red', linewidth=1, alpha=1, label=f'unrelax (tm:{tm_score_unrelax:0.3f})')


    #relax
    coord = predict_relax_df[['x_1', 'y_1', 'z_1']].to_numpy().astype('float32')
    coord = coord@transform_relax[:,1:].T + transform_relax[:,[0]].T
    x, y, z = coord[:, 0], coord[:, 1], coord[:, 2]
    ax.scatter(x, y, z, c='orange', s=30, alpha=1)
    ax.plot(x, y, z, color='orange', linewidth=1, alpha=1, label=f'relax (tm:{tm_score_relax:0.3f})')

    # truth
    truth = truth_df[['x_1', 'y_1', 'z_1']].to_numpy().astype('float32')
    x, y, z = truth[:, 0], truth[:, 1], truth[:, 2]
    ax.scatter(x, y, z, c='black', s=30, alpha=1)
    ax.plot(x, y, z, color='black', linewidth=1, alpha=1, label=f'truth')

    set_aspect_equal(ax)
    plt.legend()
    plt.show()
    # plt.waitforbuttonpress()
    plt.close()

<Chain id=A>
<Chain id=0>
Relaxed model exists? True
Unrelaxed model exists? True
           ID resname  resid        x_1     y_1     z_1
0    1A1T_B_1       G      1  -0.018000  12.493   3.507
1    1A1T_B_2       G      2  -3.082000   8.008   4.823
2    1A1T_B_3       A      3  -6.739000   3.337   4.565
3    1A1T_B_4       C      4  -7.467000  -1.309   1.838
4    1A1T_B_5       U      5  -5.626000  -5.275  -1.976
5    1A1T_B_6       A      6  -2.183000  -7.754  -4.991
6    1A1T_B_7       G      7   2.981000  -9.737  -6.649
7    1A1T_B_8       C      8   8.807000  -9.341  -5.615
8    1A1T_B_9       G      9  11.720000  -9.878  -1.140
9   1A1T_B_10       G     10  19.273001  -8.522   1.211
10  1A1T_B_11       A     11  10.909000  -6.152   4.310
11  1A1T_B_12       G     12  10.971000   1.215   8.197
12  1A1T_B_13       G     13  10.225000  -0.906   1.053
13  1A1T_B_14       C     14   9.389000  -1.175  -4.991
14  1A1T_B_15       U     15   5.456000  -1.349  -8.445
15  1A1T_B_16       A 

ValueError: Expected at least 2 TM-score matches, but found 0. Output:
