# FoldTree2 Model Training and Analysis

This notebook trains a protein structure prediction model using FoldTree2's encoder-decoder architecture. The model learns to encode protein structures into discrete embeddings and decode them back to predict amino acid sequences and structural contacts.

## Training Process
The notebook demonstrates:
- **Vector Quantized Encoding**: Proteins are encoded into discrete embedding sequences using a transformer-based encoder
- **Multi-task Decoding**: The decoder predicts amino acid sequences, contact maps, and geometric properties
- **Progressive Learning**: Training occurs over multiple epochs with various loss components (reconstruction, contact prediction, VQ regularization)

## Training Visualizations
During training, the notebook generates comprehensive analysis plots showing:
- **Contact Prediction**: Predicted vs. true contact maps for protein residue interactions
- **Distance Analysis**: True distance matrices and binary contact classifications
- **Performance Metrics**: ROC curves and precision-recall analysis for contact prediction accuracy
- **Sequence Embedding**: Color-coded visualization of the discrete embedding alphabet learned by the model
- **3D Structure**: Interactive molecular visualization colored by embedding states
- **Bond Angles**: Comparison of predicted vs. true backbone bond angles

This provides real-time feedback on model performance across sequence, contact, and geometric prediction tasks.

In [1]:
#use autoreload
%load_ext autoreload
%autoreload 2

In [2]:
cd /home/dmoi/projects/foldtree2/

/home/dmoi/projects/foldtree2


In [3]:
# Imports
import torch
from torch_geometric.data import DataLoader
import numpy as np
from foldtree2.src import pdbgraph
from foldtree2.src import foldtree2_ecddcd as ft2
from foldtree2.src.mono_decoders import MultiMonoDecoder
from foldtree2.src.losses.losses import recon_loss_diag , aa_reconstruction_loss

import os
import tqdm
import random
import torch.nn.functional as F

  Jd = torch.load(str(path))


In [4]:
# Set seeds for reproducibility
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Data setup
datadir = '../../datasets/foldtree2/'
dataset_path = 'structs_traininffttest.h5'
converter = pdbgraph.PDB2PyG(aapropcsv='./foldtree2/config/aaindex1.csv')
struct_dat = pdbgraph.StructureDataset(dataset_path)
train_loader = DataLoader(struct_dat, batch_size=5, shuffle=True, num_workers=4)

data_sample = next(iter(train_loader))



In [5]:
# Model setup

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
ndim = data_sample['res'].x.shape[1]
ndim_godnode = data_sample['godnode'].x.shape[1]
ndim_fft2i = data_sample['fourier2di'].x.shape[1]
ndim_fft2r = data_sample['fourier2dr'].x.shape[1]
print(data_sample)

HeteroDataBatch(
  identifier=[5],
  AA={
    x=[1323, 20],
    batch=[1323],
    ptr=[6],
  },
  R_true={
    x=[1323, 3, 3],
    batch=[1323],
    ptr=[6],
  },
  bondangles={
    x=[1323, 3],
    batch=[1323],
    ptr=[6],
  },
  coords={
    x=[1323, 3],
    batch=[1323],
    ptr=[6],
  },
  fourier1di={
    x=[1323, 80],
    batch=[1323],
    ptr=[6],
  },
  fourier1dr={
    x=[1323, 80],
    batch=[1323],
    ptr=[6],
  },
  fourier2di={
    x=[5, 1300],
    batch=[5],
    ptr=[6],
  },
  fourier2dr={
    x=[5, 1300],
    batch=[5],
    ptr=[6],
  },
  godnode={
    x=[5, 5],
    batch=[5],
    ptr=[6],
  },
  godnode4decoder={
    x=[5, 5],
    batch=[5],
    ptr=[6],
  },
  plddt={
    x=[1323, 1],
    batch=[1323],
    ptr=[6],
  },
  positions={
    x=[1323, 256],
    batch=[1323],
    ptr=[6],
  },
  res={
    x=[1323, 857],
    batch=[1323],
    ptr=[6],
  },
  t_true={
    x=[1323, 3],
    batch=[1323],
    ptr=[6],
  },
  (godnode4decoder, informs, res)={ edge_index=[2, 1

In [6]:
#alphabet params
num_embeddings = 40
embedding_dim = 128

#net params
hidden_size = 256
batch_size = 20

#loss weights
edgeweight = 0.05
logitweight = 0.08
xweight = 0.1
fft2weight = 0.01
vqweight = 0.001
angles_weight = 0.001
clip_grad = True

num_epochs = 100  # For demonstration, keep small



In [7]:

encoder = ft2.mk1_Encoder(
	in_channels=ndim,
	hidden_channels=[hidden_size, hidden_size],
	out_channels=embedding_dim,
	metadata={'edge_types': [('res','contactPoints','res') ]},
	num_embeddings=num_embeddings,
	commitment_cost=0.9,
	edge_dim=1,
	encoder_hidden=hidden_size,
	EMA=True,
	nheads=8,
	dropout_p=0.01,
	reset_codes=False,
	flavor='transformer',
	fftin=True
)


print(encoder)
encoder = encoder.to(device)




Seed set to 42


mk1_Encoder(
  (convs): ModuleList(
    (0): ModuleDict(
      (res_contactPoints_res): TransformerConv(256, 256, heads=8)
    )
  )
  (norms): ModuleList(
    (0): GraphNorm(256)
  )
  (bn): BatchNorm1d(857, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.01, inplace=False)
  (jk): JumpingKnowledge(cat)
  (ffin): Sequential(
    (0): Linear(in_features=1017, out_features=512, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): GELU(approximate='none')
  )
  (lin): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): GELU(approximate='none')
  )
  (out_dense): Sequential(
    (0): Linear(in_features=276, out_features=256, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): GELU(approximate='none'

In [17]:
use_monodecoder = True  # Set to True to use MultiMonoDecoder, False for Single Decoder
if use_monodecoder == True:
	# MultiMonoDecoder for sequence and geometry
	mono_configs = {
		'sequence_transformer': {
			'in_channels': {'res': embedding_dim},
			'xdim': 20,
			'concat_positions': True,
			'hidden_channels': {('res','backbone','res'): [hidden_size*2] , ('res','backbonerev','res'): [hidden_size*2]},
			'layers': 2,
			'AAdecoder_hidden': [hidden_size, hidden_size, hidden_size//2],
			'amino_mapper': converter.aaindex,
			'flavor': 'sage',
			'nheads': 4,
			'dropout': 0.005,
			'normalize': False,
			'residual': False
		},
		
		'contacts': {
			'in_channels': {'res': embedding_dim, 'godnode4decoder': ndim_godnode, 'foldx': 23 ,  'fft2r': ndim_fft2r, 'fft2i': ndim_fft2i},
			'concat_positions': True,
			'hidden_channels': {('res','backbone','res'): [hidden_size]*8, ('res','backbonerev','res'): [hidden_size]*8, ('res','informs','godnode4decoder'): [hidden_size]*8 , ('godnode4decoder','informs','res'): [hidden_size]*8 },
			'layers': 4,
			'FFT2decoder_hidden': [hidden_size, hidden_size, hidden_size],
			'contactdecoder_hidden': [hidden_size//4, hidden_size//8],
			'anglesdecoder_hidden': [hidden_size//2, hidden_size//2],
			'nheads': 1,
			'Xdecoder_hidden': [hidden_size, hidden_size,  hidden_size ],
			'metadata': converter.metadata,
			#'flavor': 'sage',
			'flavor': 'cheb',
			'dropout': 0.005,
			'output_fft': False,
			'output_rt':False,
			'output_angles': False,
			'normalize': True,
			'residual': False,
			'contact_mlp': False,
			'ncat': 16,
			'output_edge_logits': True
		},
	}
	decoder = MultiMonoDecoder( configs=mono_configs)
else:
	# Single decoder 
	decoder = ft2.HeteroGAE_Decoder(
			in_channels={'res': embedding_dim , 'godnode4decoder': ndim_godnode, 'foldx': 23},
			concat_positions=True,
			hidden_channels={('res','backbone','res'): [hidden_size]*5, ('res','backbonerev','res'): [hidden_size]*5, ('res','informs','godnode4decoder'): [hidden_size]*5 , ('godnode4decoder','informs','res'): [hidden_size]*5},
			layers=3,
			AAdecoder_hidden=[hidden_size, hidden_size, hidden_size//2],
			Xdecoder_hidden=[hidden_size, hidden_size, hidden_size],
			contactdecoder_hidden=[hidden_size//2, hidden_size//2],
			anglesdecoder_hidden=[hidden_size//2, hidden_size//4],
			nheads=5,
			amino_mapper=converter.aaindex,
			flavor='sage',
			dropout=0.005,
			normalize=True,
			residual=False,
			contact_mlp=False
		)
decoder = decoder.to(device)
print(decoder)

Seed set to 42
Seed set to 42


Initializing decoder for task: sequence_transformer
False True False False False
512 4 2 0.005
Initializing decoder for task: contacts
False False True False False
MultiMonoDecoder(
  (decoders): ModuleDict(
    (sequence_transformer): Transformer_AA_Decoder(
      (input_proj): Sequential(
        (0): Dropout(p=0.005, inplace=False)
        (1): Linear(in_features=384, out_features=512, bias=True)
        (2): GELU(approximate='none')
        (3): Linear(in_features=512, out_features=512, bias=True)
        (4): Tanh()
      )
      (transformer_encoder): TransformerEncoder(
        (layers): ModuleList(
          (0-1): 2 x TransformerEncoderLayer(
            (self_attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
            )
            (linear1): Linear(in_features=512, out_features=2048, bias=True)
            (dropout): Dropout(p=0.005, inplace=False)
            (linear2): Linear(in_features=204

In [18]:
# Training loop (demo, similar to learn.py)
import time
from collections import defaultdict

encoder.device = device
encoder.train()
decoder.train()
optimizer = torch.optim.AdamW(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2)


In [19]:
#get one sample from the dataloader
train_loader = DataLoader(struct_dat, batch_size=1, shuffle=True, num_workers=4)
import random

randint = random.randint(0, len(train_loader) - 1)
print(f"Randomly selected batch index: {randint}")
data_sample = struct_dat[randint]
print(data_sample)
data = data_sample.to(device)
optimizer.zero_grad()
z, vqloss = encoder(data , debug = True)
print('Encoded z shape:', z.shape)

Randomly selected batch index: 912
HeteroData(
  identifier='A0A1I8D3W7',
  AA={ x=[624, 20] },
  R_true={ x=[624, 3, 3] },
  bondangles={ x=[624, 3] },
  coords={ x=[624, 3] },
  fourier1di={ x=[624, 80] },
  fourier1dr={ x=[624, 80] },
  fourier2di={ x=[1, 1300] },
  fourier2dr={ x=[1, 1300] },
  godnode={ x=[1, 5] },
  godnode4decoder={ x=[1, 5] },
  plddt={ x=[624, 1] },
  positions={ x=[624, 256] },
  res={ x=[624, 857] },
  t_true={ x=[624, 3] },
  (godnode4decoder, informs, res)={ edge_index=[2, 624] },
  (godnode, informs, res)={ edge_index=[2, 624] },
  (res, backbone, res)={
    edge_index=[2, 1247],
    edge_attr=[623],
  },
  (res, backbonerev, res)={
    edge_index=[2, 1247],
    edge_attr=[623],
  },
  (res, contactPoints, res)={
    edge_index=[2, 4552],
    edge_attr=[4552],
  },
  (res, hbond, res)={
    edge_index=[2, 632],
    edge_attr=[632],
  },
  (res, informs, godnode)={ edge_index=[2, 624] },
  (res, informs, godnode4decoder)={ edge_index=[2, 624] },
  (res, wi

In [20]:
from Bio import PDB
from Bio.PDB import PDBParser
from foldtree2.src.AFDB_tools import grab_struct

def getCAatoms(pdb_file):
	parser = PDBParser(QUIET=True)
	# Parse the structure
	structure = parser.get_structure('structure', pdb_file)
	ca_atoms = []
	for model in structure:
		for chain in model:
			for residue in chain :
				if 'CA' in residue and PDB.is_aa(residue) :
					ca_atoms.append(residue['CA'])
	return ca_atoms

In [21]:
#get aa and contacts

from torch_geometric.data import DataLoader , HeteroData
from scipy import sparse
from matplotlib import pyplot as plt
from sklearn.metrics import roc_curve, auc
#add precision and recall metrics
from sklearn.metrics import precision_recall_curve, average_precision_score

def get_backbone(naa):
	backbone_mat = np.zeros((naa, naa))
	backbone_rev_mat = np.zeros((naa, naa))
	np.fill_diagonal(backbone_mat[1:], 1)
	np.fill_diagonal(backbone_rev_mat[:, 1:], 1)
	return backbone_mat, backbone_rev_mat

def sparse2pairs(sparsemat):
	sparsemat = sparse.find(sparsemat)
	return np.vstack([sparsemat[0],sparsemat[1]])

def decoder_reconstruction2aa( ords , device, verbose = False):
	decoder.eval()
	print(ords)
	z = encoder.vector_quantizer.embeddings( ords  ).to('cpu')
	
	edge_index = torch.tensor( [ [i,j] for i in range(z.shape[0]) for j in range(z.shape[0]) ]  , dtype = torch.long).T
	godnode_index = np.vstack([np.zeros(z.shape[0]), [ i for i in range(z.shape[0]) ] ])
	godnode_rev = np.vstack([ [ i for i in range(z.shape[0]) ] , np.zeros(z.shape[0]) ])
	#generate a backbone for the decoder
	data = HeteroData()
	
	data['res'].x = z
	backbone, backbone_rev = get_backbone( z.shape[0] )
	backbone = sparse.csr_matrix(backbone)
	backbone_rev = sparse.csr_matrix(backbone_rev)
	backbone = sparse2pairs(backbone)
	backbone_rev = sparse2pairs(backbone_rev)
	positional_encoding = converter.get_positional_encoding( z.shape[0] , 256 )
	print( 'positional encoding shape:', positional_encoding.shape )
	data['res'].batch = torch.tensor([0 for i in range(z.shape[0])], dtype=torch.long)
	data['positions'].x = torch.tensor( positional_encoding, dtype=torch.float32)
	data['res','backbone','res'].edge_index = torch.tensor(backbone,  dtype=torch.long )
	data[ 'res' , 'backbone_rev' , 'res'].edge_index = torch.tensor(backbone_rev, dtype=torch.long)
	print( data['res'].x.shape )
	#add the godnode
	data['godnode'].x = torch.tensor(np.ones((1,5)), dtype=torch.float32)
	data['godnode4decoder'].x = torch.tensor(np.ones((1,5)), dtype=torch.float32)
	data['godnode4decoder', 'informs', 'res'].edge_index = torch.tensor(godnode_index, dtype=torch.long)
	data['res', 'informs', 'godnode4decoder'].edge_index = torch.tensor(godnode_rev, dtype=torch.long)
	data['res', 'informs', 'godnode'].edge_index = torch.tensor(godnode_rev, dtype=torch.long)
	edge_index = edge_index.to( device )
	print( data )
	data = data.to( device )
	allpairs = torch.tensor( [ [i,j] for i in range(z.shape[0]) for j in range(z.shape[0]) ]  , dtype = torch.long).T
	out = decoder( data , allpairs )
	recon_x = out['aa'] if 'aa' in out else None
	edge_probs = out['edge_probs'] if 'edge_probs' in out else None
	logits = out['edge_logits'] if 'edge_logits' in out else None

	print( edge_probs.shape)
	"""
	try:
		amino_map = decoder.decoders['sequence'].amino_acid_indices
	except:
		amino_map = decoder.decoders['sequence_transformer'].amino_acid_indices
		print('Using amino_acid_indices_dict instead of amino_acid_indices')
	revmap_aa = { v:k for k,v in amino_map.items() }
	aastr = ''.join(revmap_aa[int(idx.item())] for idx in recon_x.argmax(dim=1) )
	
	"""
	aastr = None

	edge_probs = edge_probs.reshape((z.shape[0], z.shape[0]))
	logits = torch.sum( logits , dim=1).squeeze()
	logits = logits.reshape((z.shape[0], z.shape[0]))
	if verbose == True:
		print( recon_x )
		print( edge_probs )
	return aastr ,edge_probs, logits , out

In [22]:
import numpy as np
import matplotlib.pyplot as plt
import colour

def plot_logits_sequence_on_ax(selected_indices, num_embeddings, ax, max_width=64, line_spacing=1, show_title=True, show_colorbar=False):
	"""
	Plots a wrapped embedding sequence using argmax from logits on a given matplotlib Axes.
	
	Parameters:
		logits (np.ndarray): Array of shape (sequence_length, num_embeddings).
		num_embeddings (int): Number of possible embeddings.
		ax (matplotlib.axes.Axes): Target axis to draw the plot.
		max_width (int): Max tokens per line.
		line_spacing (int): White lines between sequences.
		show_title (bool): If True, show a title on the subplot.
		show_colorbar (bool): If True, attach a colorbar to the plot.
	"""
	################ Process inputs
	selected_indices = np.array(selected_indices)
	# Create color mapping
	ord_colors = colour.Color("red").range_to(colour.Color("blue"), num_embeddings)
	ord_colors = np.array([c.get_rgb() for c in ord_colors])
	sequence_colors = ord_colors[selected_indices]

	# Compute rows and canvas
	total_len = len(selected_indices)
	rows = int(np.ceil(total_len / max_width))
	height = rows * (1 + line_spacing) - line_spacing
	canvas = np.ones((height, max_width, 3))  # White background

	for i in range(rows):
		start = i * max_width
		end = min((i + 1) * max_width, total_len)
		row_colors = sequence_colors[start:end]
		row_y = i * (1 + line_spacing)
		canvas[row_y, :len(row_colors), :] = row_colors

	# Plot on given axis
	ax.imshow(canvas, aspect='auto')
	ax.axis('off')
	if show_title:
		ax.set_title("Embedding Selection (argmax)")

	# Optional colorbar
	if show_colorbar:
		from matplotlib.colors import ListedColormap
		cmap = ListedColormap(ord_colors)
		norm = plt.Normalize(vmin=0, vmax=num_embeddings - 1)
		sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
		sm.set_array([])
		plt.colorbar(sm, ax=ax, orientation='vertical', fraction=0.02, pad=0.01)

In [23]:
def adaptive_clip_grad(model, clip_factor=0.01, eps=1e-3, exclude_bias_norm=True):
    """
    AGC (Adaptive Gradient Clipping).
    Scales gradients if grad_norm > clip_factor * param_norm.
    """
    for p in model.parameters():
        if p.grad is None:
            continue
        if exclude_bias_norm and p.ndim == 1:
            # skip bias/LayerNorm weights
            continue
        param_norm = torch.norm(p.detach())
        grad_norm = torch.norm(p.grad.detach())
        max_norm = clip_factor * (param_norm + eps)
        if grad_norm > max_norm:
            scale = max_norm / (grad_norm + eps)
            p.grad.detach().mul_(scale)


In [26]:
train_loader = DataLoader(struct_dat, batch_size=batch_size, shuffle=True, num_workers=4)
encoder.train()
decoder.train()
hammingdistances = []
figurestack = []

reload = False
import glob
if reload == True:
	models = glob.glob('models/test_encoder_epoch_*.pkl')
	if len(models) > 0:
		latest_model = max(models, key=os.path.getctime)
		print(f"Loading model from {latest_model}")
		import pickle
		with open(latest_model, 'rb') as f:
			encoder = pickle.load(f)
		#load the corresponding decoder
		decoder_model = latest_model.replace('encoder', 'decoder')
		with open(decoder_model, 'rb') as f:
			decoder = pickle.load(f)
		
		encoder = encoder.to(device)
		decoder = decoder.to(device)
		encoder.train()
		decoder.train()


In [28]:
encoder = encoder.to(device)
decoder = decoder.to(device)
for epoch in range(num_epochs):
	total_loss_x = 0
	total_loss_edge = 0
	total_vq = 0
	total_angles_loss = 0
	total_loss_fft2 = 0
	total_logit_loss = 0

	
	for data in tqdm.tqdm(train_loader, desc=f"Epoch {epoch}"):
		data = data.to(device)
		optimizer.zero_grad()
		z, vqloss = encoder(data)
		data['res'].x = z

		

		# For demonstration, only sequence and contacts tasks
		out = decoder(data, None)
		recon_x = out['aa'] if isinstance(out, dict) and 'aa' in out else out[0] if isinstance(out, (list, tuple)) else None
		fft2_x = out['fft2pred'] if isinstance(out, dict) and 'fft2pred' in out else out[1] if isinstance(out, (list, tuple)) else None
		# Edge loss: use contactPoints if available
		edge_index = data.edge_index_dict['res', 'contactPoints', 'res'] if hasattr(data, 'edge_index_dict') and ('res', 'contactPoints', 'res') in data.edge_index_dict else None
		logitloss = torch.tensor(0.0, device=device)
		if edge_index is not None:
			edgeloss , logitloss = recon_loss_diag(data, edge_index, decoder, plddt=False, offdiag=False , key = 'edge_probs')
		else:
			edgeloss = torch.tensor(0.0, device=device)
		
		xloss = aa_reconstruction_loss(data['AA'].x, recon_x)
		if fft2_x is not None:
			fft2loss = F.smooth_l1_loss(torch.cat( [ data['fourier2dr'].x ,data['fourier2di'].x ] ,axis = 1 ) , fft2_x )
		else:
			fft2loss = torch.tensor(0.0, device=device)
		
		angles_loss = torch.tensor(0.0, device=device)
		if out['angles'] is not None:
			angles = out['angles']
			angles_loss = F.smooth_l1_loss(angles, data['bondangles'].x)

		loss = xweight * xloss + edgeweight * edgeloss + vqweight * vqloss + fft2loss* fft2weight + angles_loss * angles_weight + logitloss * logitweight
		
		loss.backward()
		
		if clip_grad:
			torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=1.0)
			torch.nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=1.0)
			#adaptive_clip_grad(encoder, clip_factor=0.025)
			#adaptive_clip_grad(decoder, clip_factor=0.025)
		optimizer.step()
		total_loss_x += xloss.item()
		total_logit_loss += logitloss.item()
		total_loss_edge += edgeloss.item()
		total_loss_fft2 += fft2loss.item()
		total_angles_loss += angles_loss.item()
		total_vq += vqloss.item() if isinstance(vqloss, torch.Tensor) else float(vqloss)
	scheduler.step(total_loss_x)
	# save the model every 10 epochs
	if epoch % 10 == 0:
		torch.save(encoder.state_dict(), f'models/big_encoder_epoch_{epoch}.pt')
		torch.save(decoder.state_dict(), f'models/big_decoder_epoch_{epoch}.pt')
		#pickle the encoder and decoder
		#import pickle
		#with open(f'models/test_encoder_epoch_{epoch}.pkl', 'wb') as f:
		#	pickle.dump(encoder, f)
		#with open(f'models/test_decoder_epoch_{epoch}.pkl', 'wb') as f:
		#	pickle.dump(decoder, f)
		#save model checkpoint
		torch.save(encoder, f'models/big_encoder_epoch_{epoch}.pt' )
		torch.save(decoder, f'models/big_decoder_epoch_{epoch}.pt' )

	print(f"Epoch {epoch}: AA Loss: {total_loss_x/len(train_loader):.4f}, Edge Loss: {total_loss_edge/len(train_loader):.4f}, VQ Loss: {total_vq/len(train_loader):.4f} , FFT2 Loss: {total_loss_fft2/len(train_loader):.4f} , Angles Loss: {total_angles_loss/len(train_loader):.4f} , Logit Loss: {total_logit_loss/len(train_loader):.4f}")

	encoder.eval()
	decoder.eval()
	
	# predict all vs all contacts for the last sample
	data_sample = data_sample.to(device)
	z, vqloss = encoder(data_sample)
	print('Encoded z shape:', z.shape)
	ords = encoder.vector_quantizer.discretize_z(z.detach())
	zdiscrete = ords[0].detach()
	print('Encoded zdiscrete shape:', zdiscrete.shape)
	aastr, edge_probs , logits ,sample_out = decoder_reconstruction2aa( zdiscrete , device, verbose=True)
	#show the distance matrix
	grab_struct(str(data_sample.identifier) , structfolder='tmp/')
	#find the total number of residues
	ca_atoms = getCAatoms( 'tmp/' + str(data_sample.identifier) + '.pdb')
	dist_mat = np.zeros((len(ca_atoms), len(ca_atoms)))
	
	for i, res1 in enumerate(ca_atoms):
		for j, res2 in enumerate(ca_atoms):
			if i < j:
				dist_mat[i, j] = np.linalg.norm(res1.coord - res2.coord)
	dist_mat += dist_mat.T  # Make it symmetric
	np.fill_diagonal(dist_mat, 0)
	
	ndistmat = dist_mat.copy()
	ndistmat[dist_mat>10 ] = 0
	ndistmat[dist_mat<=10 ] = 1

	fig, axs = plt.subplots(3, 3, figsize=(20, 10))

	# Predicted Contacts
	im0 = axs[0, 0].imshow( 1- edge_probs.detach().cpu().numpy(), cmap='hot', interpolation='nearest')
	axs[0, 0].set_title(f"Epoch {epoch} - Predicted Contacts for {data.identifier[0]}")
	fig.colorbar(im0, ax=axs[0, 0], fraction=0.046, pad=0.04)

	# Distance Matrix
	im1 = axs[0, 1].imshow(dist_mat, cmap='hot', interpolation='nearest')
	axs[0, 1].set_title(f"Epoch {epoch} - Distance Matrix for {data.identifier[0]}")
	fig.colorbar(im1, ax=axs[0, 1], fraction=0.046, pad=0.04)

	# Distance Diff
	im2 = axs[0, 2].imshow( ndistmat , cmap='hot', interpolation='nearest')
	axs[0, 2].set_title(f"Epoch {epoch} - Distance Diff (Predicted vs True) for {data.identifier[0]}")
	fig.colorbar(im2, ax=axs[0, 2], fraction=0.046, pad=0.04)

	# exclude diagonal from the distance matrix in the ROC and Precision-Recall curves
	npdistmat = ndistmat.copy()
	edge_probs = edge_probs.detach().cpu().numpy()
	edge_probs = edge_probs.reshape((z.shape[0], z.shape[0]))
	#flatten both matrices for ROC and Precision-Recall curves
	ndistmat_flat = npdistmat.flatten()
	edge_probs_flat = edge_probs.flatten()
	# Remove NaN values from both arrays
	
	# ROC Curve
	fpr, tpr, _ = roc_curve(ndistmat_flat, edge_probs_flat)
	roc_auc = auc(fpr, tpr)
	axs[1, 0].plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
	axs[1, 0].plot([0, 1], [0, 1], color='red', lw=2, linestyle='--')
	axs[1, 0].set_xlim([0.0, 1.0])
	axs[1, 0].set_ylim([0.0, 1.05])

	# Precision-Recall Curve
	y_true = ndistmat_flat
	y_scores = edge_probs_flat
	precision, recall, _ = precision_recall_curve(y_true, y_scores)
	ap_score = average_precision_score(y_true, y_scores)
	axs[1, 1].plot(recall, precision, color='green', lw=2, label=f'AP = {ap_score:.2f}')
	
	# Mask out entries not separated by at least 10 residues
	mask = np.fromfunction(lambda i, j: np.abs(i - j) >= 10, dist_mat.shape)
	masked_ndistmat = np.where(mask, ndistmat, np.nan)
	masked_edge_probs = np.where(mask, edge_probs, np.nan)

	# Flatten and filter out nan values
	ndistmat_flat_masked = masked_ndistmat.flatten()
	edge_probs_flat_masked = masked_edge_probs.flatten()
	valid_mask = ~np.isnan(ndistmat_flat_masked) & ~np.isnan(edge_probs_flat_masked)
	ndistmat_flat_masked = ndistmat_flat_masked[valid_mask]
	edge_probs_flat_masked = edge_probs_flat_masked[valid_mask]

	y_true = ndistmat_flat
	y_scores = edge_probs_flat
	# ROC Curve for off-diagonal only
	fpr, tpr, _ = roc_curve(ndistmat_flat_masked, edge_probs_flat_masked)
	roc_auc = auc(fpr, tpr)
	axs[1, 0].plot(fpr, tpr, color='red', lw=2, label=f'ROC curve offdiag (area = {roc_auc:.2f})')
	
	axs[1, 0].set_xlabel('False Positive Rate')
	axs[1, 0].set_ylabel('True Positive Rate')
	axs[1, 0].set_title('Receiver Operating Characteristic')
	axs[1, 0].legend(loc='lower right')

	axs[1, 1].set_xlabel('Recall')
	axs[1, 1].set_ylabel('Precision')
	axs[1, 1].set_title('Precision-Recall Curve')
	axs[1, 1].legend(loc='lower left')

	dist_mat_flat = dist_mat.flatten()  # Ensure we use the same valid mask for distance matrix
	dist_mat_flat = dist_mat_flat/dist_mat_flat.max()  # Normalize the distance matrix for better visualization
	# Correlation
	corr = np.corrcoef( y_scores, dist_mat_flat)[0, 1]
	axs[1, 2].scatter( y_scores, dist_mat_flat, alpha=0.05 , s=.1  )
	axs[1, 2].set_xlabel('Predicted Contacts')
	axs[1, 2].set_ylabel('True Distance Matrix')
	axs[1, 2].set_title(f'Correlation: {corr:.2f}')


	if 'angles' in sample_out and sample_out['angles'] is not None:
		true_angles = data_sample['bondangles'].x.detach().cpu().numpy()
		pred_angles = sample_out['angles'].detach().cpu().numpy()
		angle_names = ['N-Ca-C', 'Ca-C-N', 'C-N-Ca']
		angle_colors = ['r', 'g', 'b']
		#plot the angles
		for i in range(3):
			axs[2,0].plot(true_angles[:,i], label='True ' + angle_names[i], color=angle_colors[i] , alpha = 0.5)
			axs[2,0].plot(pred_angles[:,i], label='Predicted ' + angle_names[i], color=angle_colors[i], linestyle='--'  , alpha = 0.5)		
		axs[2,0].legend()
		axs[2,0].set_title('Bond Angles')
	
	axs[2,1].imshow( logits.detach().cpu().numpy(), cmap='hot', interpolation='nearest' )
	axs[2,1].set_xlabel('Residue Index')
	axs[2,1].set_ylabel('Residue Index')
	#colorbar
	fig.colorbar(im0, ax=axs[2,1], fraction=0.046, pad=0.04)
	axs[2,1].set_title('Predicted Edge Logits')

	#show foldtree2 sequence on the last subplot
	plot_logits_sequence_on_ax( zdiscrete.detach().cpu().numpy() , num_embeddings , axs[2,2] , max_width=64 , show_colorbar=True)

	figurestack.append(fig)
	plt.tight_layout()
	#save each subplot within the last figure as a separate svg
	if not os.path.exists('figures'):
		os.makedirs('figures')
	for i, ax in enumerate(fig.axes):
		# Create a new figure with just this subplot
		new_fig, new_ax = plt.subplots(1, 1, figsize=(6, 6))
		
		# Copy the content from the original subplot to the new figure
		for artist in ax.get_children():
			if hasattr(artist, 'get_array') and hasattr(artist, 'get_extent'):  # Images
				new_ax.imshow(artist.get_array(), extent=artist.get_extent(), 
							 cmap=artist.get_cmap(), interpolation='nearest')
			elif hasattr(artist, 'get_xdata') and hasattr(artist, 'get_ydata'):  # Lines/plots
				new_ax.plot(artist.get_xdata(), artist.get_ydata(), 
						   color=artist.get_color(), linestyle=artist.get_linestyle(),
						   label=artist.get_label() if artist.get_label() and not artist.get_label().startswith('_') else None)
			elif hasattr(artist, 'get_offsets'):  # Scatter plots (PathCollection)
				continue  # Skip scatter plots for simplicity
		
		# Copy titles and labels
		new_ax.set_title(ax.get_title())
		new_ax.set_xlabel(ax.get_xlabel())
		new_ax.set_ylabel(ax.get_ylabel())
		
		# Copy legend if it exists
		if ax.get_legend():
			new_ax.legend()
		
		# Save the individual subplot
		new_fig.savefig(f"figures/subplot_{i}_ep{epoch}.svg", bbox_inches='tight')
		plt.close(new_fig)  # Close to free memory

	plt.show()
	encoder.train()
	decoder.train()

Epoch 0:   0%|                                                                                                                        | 0/250 [00:01<?, ?it/s]


AttributeError: 'tuple' object has no attribute 'size'

In [None]:
def load_and_visualize_pdb(pdb_path, encoder, decoder, converter, device, num_embeddings, 
                          save_path=None, figsize=(20, 12)):
    """
    Load any PDB file, process it through the encoder-decoder, and visualize results.
    
    Args:
        pdb_path: Path to PDB file
        encoder: Trained encoder model
        decoder: Trained decoder model
        converter: PDB2PyG converter
        device: PyTorch device
        num_embeddings: Number of embeddings in alphabet
        save_path: Optional path to save figure
        figsize: Figure size tuple
        
    Returns:
        fig: matplotlib figure
        metrics_dict: computed metrics
        zdiscrete: discrete embeddings
        data_sample: processed data
    """
    import os
    from foldtree2.src import pdbgraph
    
    print(f"Loading PDB file: {pdb_path}")
    
    # Check if file exists
    if not os.path.exists(pdb_path):
        raise FileNotFoundError(f"PDB file not found: {pdb_path}")
    
    # Convert PDB to PyG data
    try:
        # Use converter to process the PDB file
        data_sample = converter.pdb2pyg(pdb_path)
        print(f"Successfully converted PDB to PyG format")
        print(f"Protein has {data_sample['res'].x.shape[0]} residues")
        
    except Exception as e:
        print(f"Error converting PDB file: {e}")
        return None, None, None, None
    
    # Set identifier from filename if not present
    if not hasattr(data_sample, 'identifier'):
        filename = os.path.basename(pdb_path).replace('.pdb', '')
        data_sample.identifier = filename
    
    # Move to device and encode
    data_sample = data_sample.to(device)
    encoder.eval()
    decoder.eval()
    
    with torch.no_grad():
        try:
            z, vqloss = encoder(data_sample)
            ords = encoder.vector_quantizer.discretize_z(z.detach())
            zdiscrete = ords[0].detach()
            print(f"Encoded to {zdiscrete.shape[0]} discrete tokens")
            
        except Exception as e:
            print(f"Error during encoding: {e}")
            return None, None, None, None
    
    # Visualize using the existing function
    fig, metrics = visualize_decoder_reconstruction(
        encoder, decoder, data_sample, device, num_embeddings,
        converter, epoch=None, save_path=save_path, figsize=figsize
    )
    
    return fig, metrics, zdiscrete, data_sample

def load_pdb_from_identifier(identifier, encoder, decoder, converter, device, num_embeddings,
                            structfolder='tmp/', save_path=None, figsize=(20, 12)):
    """
    Load PDB from AlphaFold DB using identifier and visualize.
    
    Args:
        identifier: Protein identifier (e.g., 'AF-P12345-F1')
        encoder, decoder, converter, device, num_embeddings: Model components
        structfolder: Folder to download PDB to
        save_path, figsize: Visualization options
        
    Returns:
        Same as load_and_visualize_pdb
    """
    from foldtree2.src.AFDB_tools import grab_struct
    
    # Download structure
    grab_struct(identifier, structfolder=structfolder)
    pdb_path = os.path.join(structfolder, f"{identifier}.pdb")
    
    return load_and_visualize_pdb(pdb_path, encoder, decoder, converter, device, 
                                 num_embeddings, save_path, figsize)

def visualize_pdb_with_embeddings(pdb_path, zdiscrete, num_embeddings, width=800, height=400):
    """
    Visualize PDB structure colored by embedding states using py3Dmol.
    
    Args:
        pdb_path: Path to PDB file
        zdiscrete: Discrete embedding tensor
        num_embeddings: Number of embeddings
        width, height: Viewer dimensions
        
    Returns:
        py3Dmol view object
    """
    import py3Dmol
    import colour
    import numpy as np
    
    # Read PDB file
    with open(pdb_path, 'r') as f:
        pdb_data = f.read()
    
    # Create color mapping
    ord_colors = colour.Color("red").range_to(colour.Color("blue"), num_embeddings)
    ord_colors = np.array([c.get_rgb() for c in ord_colors])
    sequence_colors = ord_colors[zdiscrete.cpu().numpy()]
    sequence_colors_hex = [colour.Color(rgb=tuple(c)).hex for c in sequence_colors]
    
    # Create 3D viewer
    view = py3Dmol.view(width=width, height=height)
    view.addModel(pdb_data, 'pdb')
    
    # Color by embedding state
    for i, color in enumerate(sequence_colors_hex):
        view.setStyle({'chain': 'A', 'resi': i+1}, {'cartoon': {'color': color}})
    
    view.setBackgroundColor('0xeeeeee')
    view.zoomTo()
    
    return view

def batch_process_pdbs(pdb_folder, encoder, decoder, converter, device, num_embeddings,
                      output_folder='batch_results/', file_pattern='*.pdb'):
    """
    Process multiple PDB files in a folder.
    
    Args:
        pdb_folder: Folder containing PDB files
        encoder, decoder, converter, device, num_embeddings: Model components
        output_folder: Where to save results
        file_pattern: File pattern to match (e.g., '*.pdb')
        
    Returns:
        results_dict: Dictionary with results for each file
    """
    import glob
    import os
    
    # Create output folder
    os.makedirs(output_folder, exist_ok=True)
    
    # Find PDB files
    pdb_files = glob.glob(os.path.join(pdb_folder, file_pattern))
    print(f"Found {len(pdb_files)} PDB files to process")
    
    results = {}
    
    for pdb_file in pdb_files:
        filename = os.path.basename(pdb_file).replace('.pdb', '')
        print(f"\nProcessing: {filename}")
        
        try:
            # Process PDB
            save_path = os.path.join(output_folder, f"{filename}_reconstruction.png")
            fig, metrics, zdiscrete, data_sample = load_and_visualize_pdb(
                pdb_file, encoder, decoder, converter, device, num_embeddings,
                save_path=save_path
            )
            
            if fig is not None:
                results[filename] = {
                    'metrics': metrics,
                    'zdiscrete': zdiscrete,
                    'data_sample': data_sample,
                    'pdb_path': pdb_file,
                    'figure_path': save_path
                }
                plt.close(fig)  # Free memory
                
        except Exception as e:
            print(f"Error processing {filename}: {e}")
            results[filename] = {'error': str(e)}
    
    return results

# Usage examples:
"""
# Load a specific PDB file
fig, metrics, zdiscrete, data = load_and_visualize_pdb(
    '/path/to/protein.pdb', encoder, decoder, converter, device, num_embeddings
)

# Load from AlphaFold DB
fig, metrics, zdiscrete, data = load_pdb_from_identifier(
    'AF-P00520-F1', encoder, decoder, converter, device, num_embeddings
)

# Visualize with 3D coloring
view = visualize_pdb_with_embeddings('/path/to/protein.pdb', zdiscrete, num_embeddings)
view.show()

# Process multiple PDBs
results = batch_process_pdbs(
    '/path/to/pdb_folder/', encoder, decoder, converter, device, num_embeddings
)
"""

In [None]:
# Add this cell to test loading arbitrary PDB files
def interactive_pdb_loader():
    """Interactive widget to load and visualize PDB files"""
    import ipywidgets as widgets
    from IPython.display import display
    
    # File upload widget
    file_upload = widgets.FileUpload(
        accept='.pdb',
        multiple=False,
        description='Upload PDB'
    )
    
    # Identifier input
    identifier_input = widgets.Text(
        value='',
        placeholder='Enter AlphaFold ID (e.g., AF-P00520-F1)',
        description='AF ID:'
    )
    
    # Buttons
    upload_btn = widgets.Button(description="Process Uploaded PDB")
    download_btn = widgets.Button(description="Download and Process")
    
    # Output
    output = widgets.Output()
    
    def process_uploaded(b):
        with output:
            output.clear_output()
            if file_upload.value:
                # Save uploaded file
                uploaded_file = list(file_upload.value.values())[0]
                temp_path = f"tmp/{uploaded_file['metadata']['name']}"
                with open(temp_path, 'wb') as f:
                    f.write(uploaded_file['content'])
                
                # Process
                fig, metrics, zdiscrete, data = load_and_visualize_pdb(
                    temp_path, encoder, decoder, converter, device, num_embeddings
                )
                
                if fig:
                    plt.show()
                    print("Metrics:", metrics)
                    
                    # Show 3D structure
                    view = visualize_pdb_with_embeddings(temp_path, zdiscrete, num_embeddings)
                    view.show()
    
    def process_identifier(b):
        with output:
            output.clear_output()
            if identifier_input.value:
                fig, metrics, zdiscrete, data = load_pdb_from_identifier(
                    identifier_input.value, encoder, decoder, converter, device, num_embeddings
                )
                
                if fig:
                    plt.show()
                    print("Metrics:", metrics)
                    
                    # Show 3D structure  
                    pdb_path = f"tmp/{identifier_input.value}.pdb"
                    view = visualize_pdb_with_embeddings(pdb_path, zdiscrete, num_embeddings)
                    view.show()
    
    upload_btn.on_click(process_uploaded)
    download_btn.on_click(process_identifier)
    
    display(widgets.VBox([
        widgets.HTML("<h3>Load PDB File</h3>"),
        file_upload,
        upload_btn,
        widgets.HTML("<h3>Or Download from AlphaFold DB</h3>"),
        identifier_input,
        download_btn,
        output
    ]))

# Run the interactive loader
# interactive_pdb_loader()

In [None]:
print( figurestack)

[<Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure size 2000x1000 with 14 Axes>, <Figure siz

In [None]:
#save last figure in the figurestack as svg
figurestack[-1].savefig('test_monodecoder_last_figure.svg', format='svg')
#save each subplot within the last figure as a separate svg
for i, ax in enumerate(figurestack[-1].axes):
    ax.figure.savefig(f'test_monodecoder_last_figure_subplot_{i}.svg', format='svg')


In [None]:
#show the structure for the data sample
grab_struct(str(data_sample.identifier) , structfolder='tmp/')
import py3Dmol
with open('tmp/' + str(data_sample.identifier) + '.pdb', 'r') as f:
    pdb_data = f.read()
view = py3Dmol.view(width=800, height=400)
view.addModel(pdb_data, 'pdb')
view.setStyle({'cartoon': {'color': 'spectrum'}})
view.setBackgroundColor('0xeeeeee')
view.zoomTo()
view.show()

In [None]:
#color the structure by state of ft2 alphabet
#map the zdiscrete to colors


num_embeddings = encoder.num_embeddings
ord_colors = colour.Color("red").range_to(colour.Color("blue"), num_embeddings)
ord_colors = np.array([c.get_rgb() for c in ord_colors])
sequence_colors = ord_colors[zdiscrete.cpu().numpy()]
#convert to hex
sequence_colors_hex = [colour.Color(rgb=tuple(c)).hex for c in sequence_colors]
print(sequence_colors_hex)
#add the colors to the structure
for i, color in enumerate(sequence_colors_hex):
	view.setStyle({'chain': 'A', 'resi': i+1}, {'cartoon': {'color': color}})
view.zoomTo()
view.show()

['#ff3400', '#ff3400', '#ff1a00', '#ff4e00', '#ff4e00', '#5cff00', '#ff4e00', '#0fa', '#00ff0d', '#ffb700', '#c4ff00', '#ffb700', '#0fa', '#ff1a00', '#00ff27', '#ff4e00', '#ffb700', '#ffb700', '#ff9d00', '#f00', '#ff1a00', '#ffb700', '#00ff27', '#00ff5c', '#f00', '#00fff8', '#41ff00', '#5cff00', '#5cff00', '#ff1a00', '#0fa', '#ff1a00', '#00fff8', '#ff9d00', '#00fff8', '#00ff27', '#00ff76', '#f00', '#f00', '#00ff5c', '#5cff00', '#00fff8', '#f00', '#0fa', '#41ff00', '#c4ff00', '#f00', '#00ff5c', '#00ff41', '#00ffc4', '#41ff00', '#5cff00', '#00ff5c', '#ffb700', '#ff8300', '#41ff00', '#deff00', '#ff8300', '#00d1ff', '#ffb700', '#00ff41', '#00d1ff', '#ff8300', '#deff00', '#0fa', '#00ff41', '#41ff00', '#00ebff', '#90ff00', '#5cff00', '#5cff00', '#00ff27', '#00fff8', '#f00', '#41ff00', '#0fa', '#ff1a00', '#f00', '#f00', '#00ff5c', '#41ff00', '#ffb700', '#41ff00', '#deff00', '#5cff00', '#0034ff', '#00ff5c', '#deff00', '#00ff27', '#00ff5c', '#00ff27', '#ff8300', '#00ff41', '#00d1ff', '#c4ff00',