In [28]:
cd /home/dmoi/projects/foldtree2/

/home/dmoi/projects/foldtree2


# Protein Structure Analysis and Prediction Workflow

This notebook demonstrates a workflow for protein structure analysis and prediction using deep learning models.

## Overview

The workflow includes the following steps:

- **Loading and encoding protein structures** into graph representations.
- **Using a trained encoder and decoder** to generate sequence and contact predictions.
- **Visualizing protein structures and contact maps** for qualitative assessment.
- **Evaluating model performance** with metrics such as ROC, precision-recall, RMSD, and lDDT.
- **Utilities for handling PDB files**, extracting features, and plotting results.

## Technologies Used

The workflow leverages **PyTorch**, **PyTorch Geometric**, and custom modules for advanced protein modeling tasks.

In [29]:
#use autoreload
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [30]:
import sys
sys.path.append('/home/dmoi/projects/foldtree2')
#read the afdb clusters file
import pandas as pd
import numpy as np
import glob
import os
#autoreload
import pickle
from src import AFDB_tools
import toytree
import tqdm
from matplotlib import pyplot as plt
import torch
from src import pdbgraph
converter = pdbgraph.PDB2PyG( aapropcsv = 'config/aaindex1.csv',)
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')


In [31]:
modelname = 'hetero_500diagplddt'
modeldir = './models/'
with open( modeldir + modelname + '.pkl', 'rb') as f:
	encoder, decoder = pickle.load(f)
print('Loaded model:', modelname)
print('Encoder:', encoder)
print('Decoder:', decoder)
encoder.eval()
decoder.eval()
encoder.to(device)
encoder.device = device
decoder.to(device)

Loaded model: hetero_500diagplddt
Encoder: mk1_Encoder(
  (convs): ModuleList(
    (0-1): 2 x ModuleDict(
      (res_contactPoints_res): TransformerConv(100, 100, heads=5)
      (res_hbond_res): TransformerConv(100, 100, heads=5)
    )
  )
  (norms): ModuleList(
    (0-1): 2 x GraphNorm(100)
  )
  (bn): BatchNorm1d(857, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.005, inplace=False)
  (jk): JumpingKnowledge(cat)
  (ffin): Sequential(
    (0): Linear(in_features=1017, out_features=200, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=200, out_features=100, bias=True)
    (3): GELU(approximate='none')
    (4): DynamicTanh(normalized_shape=100, alpha_init_value=0.5, channels_last=True)
  )
  (lin): Sequential(
    (0): DynamicTanh(normalized_shape=200, alpha_init_value=0.5, channels_last=True)
    (1): Linear(in_features=200, out_features=100, bias=True)
    (2): GELU(approximate='none')
    (3): Linear(in_features=100, 

MultiMonoDecoder(
  (decoders): ModuleDict(
    (sequence): HeteroGAE_AA_Decoder(
      (convs): ModuleList(
        (0-4): 5 x HeteroConv(num_relations=2)
      )
      (norms): ModuleList(
        (0-4): 5 x GraphNorm(100)
      )
      (dropout): Dropout(p=0.005, inplace=False)
      (jk): JumpingKnowledge(cat)
      (lin): Sequential(
        (0): Dropout(p=0.005, inplace=False)
        (1): DynamicTanh(normalized_shape=500, alpha_init_value=0.5, channels_last=True)
        (2): Linear(in_features=500, out_features=100, bias=True)
        (3): GELU(approximate='none')
        (4): Linear(in_features=100, out_features=50, bias=True)
        (5): GELU(approximate='none')
        (6): Linear(in_features=50, out_features=50, bias=True)
        (7): GELU(approximate='none')
        (8): DynamicTanh(normalized_shape=50, alpha_init_value=0.5, channels_last=True)
      )
      (aadecoder): Sequential(
        (0): Dropout(p=0.005, inplace=False)
        (1): DynamicTanh(normalized_shape=32

In [32]:
from torch_geometric.data import DataLoader , HeteroData
struct_dat = pdbgraph.StructureDataset('structs_traininffttest.h5')
encoder_loader = DataLoader(struct_dat, batch_size=1, shuffle=False)
#encode the structures
encode_alns = True
if encode_alns == True:
	def databatch2list(loader , limit = 10):
		for i,data in enumerate(loader):
			if i > limit:
				break
			data = data.to_data_list()
			for d in data:
				d = d.to(device)
				yield d
	encoder_loader = databatch2list(encoder_loader)
	encoder.encode_structures_fasta(encoder_loader , modeldir + modelname+'_aln_encoded_test.fasta' , debug = True, verbose = True)
struct_dat = pdbgraph.StructureDataset('structs_traininffttest.h5')
encoder_loader = DataLoader(struct_dat, batch_size=1, shuffle=False)
#grab the length of the input vectors
datalen = {}
for i, data in enumerate(encoder_loader):
	datalen[data.identifier[0]] = data['res'].x.shape[0]
	if i > 10:
		break
print('Data length:', datalen)

4it [00:00,  7.81it/s]

res shape 633
len outstring 633
res shape 630
len outstring 630
res shape 152
len outstring 152
res shape 205
len outstring 205
res shape 395


11it [00:00, 18.04it/s]

len outstring 395
res shape 144
len outstring 144
res shape 85
len outstring 85
res shape 31
len outstring 31
res shape 415
len outstring 415
res shape 119
len outstring 119
res shape 198
len outstring 198


11it [00:00, 12.56it/s]


Data length: {'A0A010QA53': 633, 'A0A015IT50': 630, 'A0A016SEG2': 152, 'A0A016V570': 205, 'A0A016VTG2': 395, 'A0A017SRI7': 144, 'A0A017SZM7': 85, 'A0A022PJY8': 31, 'A0A022QJ04': 415, 'A0A022RZG1': 119, 'A0A022Y6E1': 198, 'A0A023BAQ9': 466}


In [78]:
encoded_fasta =  modeldir + modelname+'_aln_encoded_test.fasta' 
seqstr = ''
ID = ''
seqdict = {}
with open(encoded_fasta, 'r') as f:
	#read all chars of file into a string
	for line in tqdm.tqdm(f):
		if line[0] == '>':
			seqdict[ID] = seqstr[:-1]
			ID = line[1:].strip()
			seqstr = ''
		else:
			seqstr += line
	del seqdict['']
encoded_df = pd.DataFrame( seqdict.items() , columns=['protid', 'seq'] )
encoded_df['seqlen'] = encoded_df.seq.map( lambda x: len(x) )

#change index to protid
encoded_df.index = encoded_df.protid
encoded_df = encoded_df.drop( 'protid', axis=1 )
encoded_df['ord'] = encoded_df.seq.map( lambda x: [ ord(c) for c in x] )
#hex starts at 1
encoded_df['hex2'] = encoded_df.ord.map( lambda x: [ hex(c) for c in x] )
encoded_df['ordlength'] = encoded_df.ord.map( lambda x: len(x) )
encoded_df['datalength'] = encoded_df.index.map( lambda x: datalen[x] )

106it [00:00, 649045.58it/s]


In [79]:
print(encoded_df)

                                                          seq  seqlen  \
protid                                                                  
A0A010QA53    \t\n \t$%...     633   
A0A015IT50    %\n%$ %%%  !%  ...     630   
A0A016SEG2    !&"\n#"($\n%((\n\n\n$  ...     152   
A0A016V570     $\t !\n\n!\n! \n...     205   
A0A016VTG2    #! ''...     395   
A0A017SRI7    %!!\n\n#&\n\n...     144   
A0A017SZM7    (&!!!( ...      85   
A0A022PJY8                      ! "!      31   
A0A022QJ04     $!!#'%&&...     415   
A0A022RZG1    %%%\n$%\t...     119   

                                                          ord  \
protid                                                          
A0A010QA53 

In [80]:
charset = 249
#make fasta is shifted by 1 and goes from 1-248 included
#0x01 – 0xFF excluding > (0x3E), = (0x3D), < (0x3C), - (0x2D), Space (0x20), Carriage Return (0x0d) and Line Feed (0x0a)
#replace 0x22 or " which is necesary for nexus files and 0x23 or # which is also necesary
replace_dict = {chr(0):chr(246) , '"':chr(248) , '#':chr(247), '>' : chr(249), '=' : chr(250), '<' : chr(251), '-' : chr(252), ' ' : chr(253) , '\r' : chr(254), '\n' : chr(255) }
rev_replace_dict = { v:k for k,v in replace_dict.items() }
replace_dict_ord = { ord(k):ord(v) for k,v in replace_dict.items() }
rev_replace_dict_ord = { ord(v):ord(k) for k,v in replace_dict.items() }
print(replace_dict)
print(rev_replace_dict)
print(replace_dict_ord)
print(rev_replace_dict_ord)

{'\x00': 'ö', '"': 'ø', '#': '÷', '>': 'ù', '=': 'ú', '<': 'û', '-': 'ü', ' ': 'ý', '\r': 'þ', '\n': 'ÿ'}
{'ö': '\x00', 'ø': '"', '÷': '#', 'ù': '>', 'ú': '=', 'û': '<', 'ü': '-', 'ý': ' ', 'þ': '\r', 'ÿ': '\n'}
{0: 246, 34: 248, 35: 247, 62: 249, 61: 250, 60: 251, 45: 252, 32: 253, 13: 254, 10: 255}
{246: 0, 248: 34, 247: 35, 249: 62, 250: 61, 251: 60, 252: 45, 253: 32, 254: 13, 255: 10}


In [81]:
encoded_df.seq = encoded_df.seq.map(lambda x : ''.join([ c if c not in replace_dict else replace_dict[c] for c in x]))
encoded_df['delta_datalen'] = encoded_df.datalength - encoded_df.ord.map( lambda x: len(x) )
encoded_df['delta_seqlen'] = encoded_df.seq.map( lambda x: len(x) ) - encoded_df.ord.map( lambda x:len(x) )
print(encoded_df)
assert (encoded_df.delta_datalen == 0).all(), "There are sequences with different lengths after encoding!"
assert (encoded_df.delta_seqlen == 0).all(), "There are sequences with different lengths from the ords after encoding!"
print(encoded_df)

                                                          seq  seqlen  \
protid                                                                  
A0A010QA53  ýý\tÿý\t$%...     633   
A0A015IT50  ýý%ÿ%$ý%%%ýý!%ýý...     630   
A0A016SEG2  ýý!&øÿ÷ø($ÿ%((ÿÿÿ$ýý!...     152   
A0A016V570  ýýý$\tý!ÿÿ!ÿ!ýÿ...     205   
A0A016VTG2  ýý÷!ý''...     395   
A0A017SRI7  ýý%!!ÿÿ÷&ÿÿ...     144   
A0A017SZM7  ýý(&!!!(ý...      85   
A0A022PJY8                    ýý!ýø!      31   
A0A022QJ04  ýýý$!!÷'%&&...     415   
A0A022RZG1  ýý%%%ÿ$%\t...     119   

                                                          ord  \
protid                                                          
A0A010QA53 

In [88]:
from Bio.PDB import PDBParser
from Bio import PDB
def getCAatoms(pdb_file):
	parser = PDBParser(QUIET=True)
	# Parse the structure
	structure = parser.get_structure('structure', pdb_file)
	ca_atoms = {}
	for model in structure:
		for chain in model:
			if chain.id not in ca_atoms:
				ca_atoms[chain.id] = []
			for residue in chain :
				if 'CA' in residue and PDB.is_aa(residue) :
					ca_atoms[chain.id].append(residue['CA'])
	return ca_atoms

In [89]:
from src.AFDB_tools import grab_struct
for i in range(len(encoded_df)):
	identifier = encoded_df.index[i]
	pdb_file = 'tmp/' + identifier + '.pdb'
	if not os.path.exists(pdb_file):
		grab_struct(str(identifier), structfolder='tmp/')
	ca_atoms = getCAatoms(pdb_file)
	total_residues = sum(len(atoms) for atoms in ca_atoms.values())
	encoded_df.at[identifier, 'total_residues'] = total_residues

In [92]:
#check difference between total_residues and length
encoded_df['residue_diff'] = encoded_df['total_residues'] - encoded_df['seqlen']
print(encoded_df[['total_residues', 'seqlen', 'residue_diff']])
assert (encoded_df.residue_diff == 0).all(), "There are sequences with different lengths after encoding!"
print(encoded_df[['total_residues', 'seqlen', 'residue_diff']])


            total_residues  seqlen  residue_diff
protid                                          
A0A010QA53           633.0     633           0.0
A0A015IT50           630.0     630           0.0
A0A016SEG2           152.0     152           0.0
A0A016V570           205.0     205           0.0
A0A016VTG2           395.0     395           0.0
A0A017SRI7           144.0     144           0.0
A0A017SZM7            85.0      85           0.0
A0A022PJY8            31.0      31           0.0
A0A022QJ04           415.0     415           0.0
A0A022RZG1           119.0     119           0.0
            total_residues  seqlen  residue_diff
protid                                          
A0A010QA53           633.0     633           0.0
A0A015IT50           630.0     630           0.0
A0A016SEG2           152.0     152           0.0
A0A016V570           205.0     205           0.0
A0A016VTG2           395.0     395           0.0
A0A017SRI7           144.0     144           0.0
A0A017SZM7          

In [93]:
print(decoder.decoders.keys())

dict_keys(['sequence', 'contacts'])


In [94]:
import py3Dmol

def view_custom_pdb(pdb_file, chain='A'):
	"""
	Visualize a PDB structure from a file, highlight a specific residue with thicker sticks,
	show the backbone for the selected residue, and add an arrow-like indicator.
	
	Parameters:
	- pdb_file: str, path to the local PDB file.
	- residue_num: int, residue number to highlight.
	- chain: str, the chain identifier (default 'A').
	"""
	# Read the PDB file content
	with open(pdb_file, 'r') as f:
		pdb_content = f.read()
	# Initialize py3Dmol viewer
	viewer = py3Dmol.view(width=800, height=600)
	# Add the PDB structure to the viewer
	viewer.addModel(pdb_content, 'pdb')
	# Apply cartoon style for the overall structure
	viewer.setStyle({'cartoon': {'color': 'cyan'}})
	# Add molecular surface for the rest of the structure
	viewer.addSurface(py3Dmol.VDW, {'opacity': 0.5})
	# Set zoom and background color
	viewer.zoomTo()
	viewer.setBackgroundColor('white')
	# Show the structure
	return viewer.show()



In [None]:
from Bio import PDB
from Bio.PDB import PDBParser
from sklearn.metrics import roc_curve, auc , precision_recall_curve, average_precision_score
from scipy import sparse
import torch_geometric
import losses

def get_alpha_carbon_distance_matrix(pdb_file):
	"""
	Extracts the alpha carbon (Cα) atoms from a PDB structure
	and computes the distance matrix in numpy format.

	Parameters:
	pdb_file (str): Path to the PDB file.

	Returns:
	numpy.ndarray: A matrix of distances between all Cα atoms.
	"""
	# Initialize the PDB parser
	parser = PDBParser(QUIET=True)
	
	# Parse the structure
	structure = parser.get_structure('structure', pdb_file)
	
	# Extract alpha carbon (Cα) atoms
	ca_atoms = []
	for model in structure:
		for chain in model:
			for residue in chain :
				if 'CA' in residue and PDB.is_aa(residue) :
					ca_atoms.append(residue['CA'])
	
	# Get the number of Cα atoms
	num_atoms = len(ca_atoms)

	# Initialize the distance matrix
	dist_matrix = np.zeros((num_atoms, num_atoms))
	
	# Compute the distances between all pairs of Cα atoms
	for i, atom_i in enumerate(ca_atoms):
		for j, atom_j in enumerate(ca_atoms):
			if i < j:
				dist_matrix[i, j] = atom_i - atom_j
	dist_matrix += dist_matrix.T
	return dist_matrix

#get aa and contacts
def get_backbone(naa):
	backbone_mat = np.zeros((naa, naa))
	backbone_rev_mat = np.zeros((naa, naa))
	np.fill_diagonal(backbone_mat[1:], 1)
	np.fill_diagonal(backbone_rev_mat[:, 1:], 1)
	return backbone_mat, backbone_rev_mat

def sparse2pairs(sparsemat):
	sparsemat = sparse.find(sparsemat)
	return np.vstack([sparsemat[0],sparsemat[1]])

def decoder_reconstruction2aa( ords , device, verbose = False):
	print(len(ords))
	decoder.eval()
	z = encoder.vector_quantizer.embeddings( ords  ).to('cpu')
	edge_index = torch.tensor( [ [i,j] for i in range(z.shape[0]) for j in range(z.shape[0]) ]  , dtype = torch.long).T
	godnode_index = np.vstack([np.zeros(z.shape[0]), [ i for i in range(z.shape[0]) ] ])
	godnode_rev = np.vstack([ [ i for i in range(z.shape[0]) ] , np.zeros(z.shape[0]) ])
	#generate a backbone for the decoder
	data = HeteroData()
	data['res'].x = z
	backbone, backbone_rev = get_backbone( z.shape[0] )
	backbone = sparse.csr_matrix(backbone)
	backbone_rev = sparse.csr_matrix(backbone_rev)
	backbone = sparse2pairs(backbone)
	backbone_rev = sparse2pairs(backbone_rev)
	positional_encoding = converter.get_positional_encoding( z.shape[0] , 256 )
	print( 'positional encoding shape:', positional_encoding.shape )
	data['res'].batch = torch.tensor([0 for i in range(z.shape[0])], dtype=torch.long)
	data['positions'].x = torch.tensor( positional_encoding, dtype=torch.float32)
	data['res','backbone','res'].edge_index = torch.tensor(backbone,  dtype=torch.long )
	data[ 'res' , 'backbone_rev' , 'res'].edge_index = torch.tensor(backbone_rev, dtype=torch.long)
	print( data['res'].x.shape )
	#add the godnode
	data['godnode'].x = torch.tensor(np.ones((1,5)), dtype=torch.float32)
	data['godnode4decoder'].x = torch.tensor(np.ones((1,5)), dtype=torch.float32)
	data['godnode4decoder', 'informs', 'res'].edge_index = torch.tensor(godnode_index, dtype=torch.long)
	data['res', 'informs', 'godnode4decoder'].edge_index = torch.tensor(godnode_rev, dtype=torch.long)
	data['res', 'informs', 'godnode'].edge_index = torch.tensor(godnode_rev, dtype=torch.long)
	edge_index = edge_index.to( device )
	print( data )
	data = data.to( device )
	allpairs = torch.tensor( [ [i,j] for i in range(z.shape[0]) for j in range(z.shape[0]) ]  , dtype = torch.long).T
	out = decoder( data , allpairs ) 

	recon_x = out['aa'] if 'aa' in out else None
	edge_probs = out['edge_probs'] if 'edge_probs' in out else None

	print( edge_probs.shape)
	amino_map = decoder.decoders['sequence'].amino_acid_indices
	revmap_aa = { v:k for k,v in amino_map.items() }
	edge_probs = edge_probs.reshape((z.shape[0], z.shape[0]))
	if verbose == True:
		print( recon_x )
		print( edge_probs )
	aastr = ''.join(revmap_aa[int(idx.item())] for idx in recon_x.argmax(dim=1) )
	return aastr ,edge_probs

postives = []
predictions = []
for ex in range( encoded_df.shape[0] ):
	os.makedirs('tmp', exist_ok=True)
	example = encoded_df.iloc[ex]
	print( example )
	protid = encoded_df.index[ex]
	ords = example['ord']
	print( str(protid))
	AFDB_tools.grab_struct(str(protid) , structfolder='tmp/')
	
	#show struct
	view_custom_pdb('tmp/' + protid + '.pdb', chain='A')

	#get alpha carbon distmat with biopython
	s = get_alpha_carbon_distance_matrix( 'tmp/' + protid + '.pdb')
	
	ords = torch.tensor([ c-1 if c not in rev_replace_dict_ord else rev_replace_dict_ord[c]-1 for c in ords] , dtype=torch.long)
	print( 'ord length', len(ords))
	ords = ords.to(device)	
	aa, edgeprobs = decoder_reconstruction2aa( ords , device = device, verbose = True)
	fig, axs = plt.subplots(1, 2, figsize=(12, 5))
	im0 = axs[0].imshow(s)
	axs[0].set_title('Alpha Carbon Distance Matrix')
	plt.colorbar(im0, ax=axs[0])
	probs = 1 - edgeprobs.detach().cpu()
	probs = (probs + probs.T) / 2
	im1 = axs[1].imshow(probs)
	axs[1].set_title('Predicted Contact Probabilities')
	plt.colorbar(im1, ax=axs[1])
	plt.tight_layout()
	plt.show()

	#normalize the distance matrix
	s = s / np.max(s)
	plt.imshow( np.abs(s - np.array(probs)))
	plt.colorbar()
	plt.title('Normalized Distance Matrix - Predicted Contact Probabilities')
	plt.show()

	#these are the embedding indices. shift by 1 to get the hex indices
	
	print(aa)
	#change range of image to 0-1
	
	print( probs.shape)
	print( s.shape)
	#output ROC curve for contact prediction

	pos = np.zeros( s.shape )
	pos[ s < 10] = 1
	pos = pos[ 0:probs.shape[0], 0:probs.shape[1]]
	plt.imshow(pos)
	plt.show()

	postives.append(pos.flatten())
	predictions.append(probs.flatten())

	#flatten the matrices
	pos = pos.flatten()
	probs = probs.flatten()
	fpr, tpr, thresholds = roc_curve(pos, 1-probs)
	roc_auc = auc(fpr, tpr)
	plt.plot(fpr, tpr, lw=1, alpha=0.9, label='ROC foldtree2 (AUC = %0.2f)' % (roc_auc))

	#output precision recall curve
	precision, recall, thresholds = precision_recall_curve(pos, 1-probs)
	average_precision = average_precision_score(pos, 1-probs)
	plt.plot(recall, precision, lw=1, alpha=0.9, label='PR foldtree2 (AP = %0.2f)' % (average_precision))
	plt.legend()
	plt.show()

#output ROC curve for contact prediction
postives = np.concatenate(postives)
predictions = np.concatenate(predictions)
fpr, tpr, thresholds = roc_curve(postives, 1-predictions)
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, lw=1, alpha=0.9, label='ROC foldtree2 (AUC = %0.2f)' % (roc_auc))
plt.legend()
plt.show()

#output precision recall curve
precision, recall, thresholds = precision_recall_curve(postives, 1-predictions)
average_precision = average_precision_score(postives, 1-predictions)
plt.plot(recall, precision, lw=1, alpha=0.9, label='PR foldtree2 (AP = %0.2f)' % (average_precision))
plt.legend()
plt.show()


seq               ýý\tÿý\t$%...
seqlen                                                          633
ord               [32, 32, 9, 5, 6, 5, 5, 20, 5, 27, 10, 5, 16, ...
hex2              [0x20, 0x20, 0x9, 0x5, 0x6, 0x5, 0x5, 0x14, 0x...
ordlength                                                       633
datalength                                                      633
delta_datalen                                                     0
delta_seqlen                                                      0
total_residues                                                633.0
residue_diff                                                    0.0
Name: A0A010QA53, dtype: object
A0A010QA53


ord length 633
633


In [None]:
# Function to compute Root Mean Square Deviation (RMSD)
def compute_rmsd(coords1, coords2):
    """
    Compute the Root Mean Square Deviation (RMSD) between two sets of coordinates.
    
    Parameters:
    - coords1: Nx3 NumPy array of true coordinates
    - coords2: Nx3 NumPy array of predicted coordinates
    
    Returns:
    - RMSD value
    """
    assert coords1.shape == coords2.shape, "Coordinate arrays must have the same shape"
    
    diff = coords1 - coords2
    rmsd = np.sqrt(np.mean(np.sum(diff**2, axis=1)))
    return rmsd

# Function to compute Local Distance Difference Test (lDDT)
def compute_lddt(true_coords, pred_coords, cutoff=15.0):
    """
    Compute the Local Distance Difference Test (lDDT) score.
    
    Parameters:
    - true_coords: Nx3 NumPy array of true coordinates
    - pred_coords: Nx3 NumPy array of predicted coordinates
    - cutoff: Distance threshold for considering a pair of residues
    
    Returns:
    - lDDT score (0 to 1)
    """
    assert true_coords.shape == pred_coords.shape, "Coordinate arrays must have the same shape"

    num_residues = true_coords.shape[0]
    num_pairs = 0
    num_matching_pairs = 0

    for i in range(num_residues):
        for j in range(i + 1, num_residues):
            true_dist = np.linalg.norm(true_coords[i] - true_coords[j])
            pred_dist = np.linalg.norm(pred_coords[i] - pred_coords[j])

            if true_dist < cutoff:
                num_pairs += 1
                if abs(true_dist - pred_dist) < 0.5 * true_dist:
                    num_matching_pairs += 1

    lddt_score = num_matching_pairs / num_pairs if num_pairs > 0 else 0
    return lddt_score

# Compute RMSD and lDDT
rmsd_value = compute_rmsd(true_coords_from_rt, pred_coords_from_rt)
lddt_value = compute_lddt(true_coords_from_rt, pred_coords_from_rt)

rmsd_value, lddt_value


NameError: name 'true_coords_from_rt' is not defined

In [None]:

#plot the true and predicted structures
def transform_rt_to_coordinates(rotations, translations):
    """
    Given a list of rotation matrices (R) and translation vectors (t),
    generate the global 3D coordinates of the protein backbone.
    
    Parameters:
    - rotations: List of 3x3 rotation matrices
    - translations: List of 3x1 translation vectors
    
    Returns:
    - coords: Nx3 NumPy array representing the backbone in 3D space
    """
    num_residues = len(rotations)
    assert num_residues == len(translations), "Rotation and translation lists must be the same length"
    
    # Initialize the first coordinate at the origin
    coords = [np.array([0, 0, 0, 1])]  # Homogeneous coordinates

    # Apply transformations iteratively
    current_transform = np.eye(4)  # Identity matrix as starting point

    for R, t in zip(rotations, translations):
        # Construct the transformation matrix (4x4)
        T = np.eye(4)
        T[:3, :3] = R
        T[:3, 3] = t
        
        # Update the cumulative transformation
        current_transform = current_transform @ T
        
        # Transform the point and store the new coordinates
        new_point = current_transform @ np.array([0, 0, 0, 1])  # Homogeneous coordinates
        coords.append(new_point)

    return np.array(coords)[:, :3]  # Convert back to 3D coordinates (drop the homogeneous coordinate)


def plot_protein_structures_with_thicker_lines(true_coords, pred_coords):
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d')

    # Plot true structure with thicker lines and transparency
    ax.plot(true_coords[:, 0], true_coords[:, 1], true_coords[:, 2], 
            'bo-', alpha=0.5, label='True Structure', linewidth=3, markersize=6)

    # Plot predicted structure with thicker lines and transparency
    ax.plot(pred_coords[:, 0], pred_coords[:, 1], pred_coords[:, 2], 
            'ro-', alpha=0.5, label='Predicted Structure', linewidth=3, markersize=6)

    # Adjusted arrow size
    arrow_size = 0.15  # Keep small arrows

    # Add arrowheads to indicate backbone direction
    for i in range(len(true_coords) - 1):
        ax.quiver(true_coords[i, 0], true_coords[i, 1], true_coords[i, 2],
                  true_coords[i+1, 0] - true_coords[i, 0], 
                  true_coords[i+1, 1] - true_coords[i, 1], 
                  true_coords[i+1, 2] - true_coords[i, 2],
                  color='blue', alpha=0.5, arrow_length_ratio=arrow_size)

    for i in range(len(pred_coords) - 1):
        ax.quiver(pred_coords[i, 0], pred_coords[i, 1], pred_coords[i, 2],
                  pred_coords[i+1, 0] - pred_coords[i, 0], 
                  pred_coords[i+1, 1] - pred_coords[i, 1], 
                  pred_coords[i+1, 2] - pred_coords[i, 2],
                  color='red', alpha=0.5, arrow_length_ratio=arrow_size)

    # Labels and legend
    ax.set_xlabel('X-axis')
    ax.set_ylabel('Y-axis')
    ax.set_zlabel('Z-axis')
    ax.set_title('True vs Predicted Protein Backbone Structure with Thicker Lines')
    ax.legend()

    plt.show()

# Plot with thicker lines
plot_protein_structures_with_thicker_lines(true_coords, noisy_pred_coords)


NameError: name 'true_coords' is not defined