# Test MonoDecoders: Sequence and Geometry
This notebook replicates the training logic from `learn.py` using the decoder in `mono_decoders.py` for amino acid and geometry prediction.

In [1]:
#use autoreload
%load_ext autoreload
%autoreload 2

In [2]:
cd /home/dmoi/projects/foldtree2/

/home/dmoi/projects/foldtree2


In [3]:
# Imports
import torch
from torch_geometric.data import DataLoader
import numpy as np
from src import pdbgraph
from src import foldtree2_ecddcd as ft2
from src.mono_decoders import MultiMonoDecoder
import os
import tqdm
import random
import torch.nn.functional as F

In [48]:
# Set seeds for reproducibility
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Data setup
datadir = '../../datasets/foldtree2/'
dataset_path = 'structs_traininffttest.h5'
converter = pdbgraph.PDB2PyG(aapropcsv='config/aaindex1.csv')
struct_dat = pdbgraph.StructureDataset(dataset_path)
train_loader = DataLoader(struct_dat, batch_size=5, shuffle=True, num_workers=4)
data_sample = next(iter(train_loader))



In [49]:
print('Data sample:', data_sample)

Data sample: HeteroDataBatch(
  identifier=[5],
  AA={
    x=[1323, 20],
    batch=[1323],
    ptr=[6],
  },
  R_true={
    x=[1323, 3, 3],
    batch=[1323],
    ptr=[6],
  },
  bondangles={
    x=[1323, 3],
    batch=[1323],
    ptr=[6],
  },
  coords={
    x=[1323, 3],
    batch=[1323],
    ptr=[6],
  },
  fourier1di={
    x=[1323, 80],
    batch=[1323],
    ptr=[6],
  },
  fourier1dr={
    x=[1323, 80],
    batch=[1323],
    ptr=[6],
  },
  fourier2di={
    x=[5, 1300],
    batch=[5],
    ptr=[6],
  },
  fourier2dr={
    x=[5, 1300],
    batch=[5],
    ptr=[6],
  },
  godnode={
    x=[5, 5],
    batch=[5],
    ptr=[6],
  },
  godnode4decoder={
    x=[5, 5],
    batch=[5],
    ptr=[6],
  },
  plddt={
    x=[1323, 1],
    batch=[1323],
    ptr=[6],
  },
  positions={
    x=[1323, 256],
    batch=[1323],
    ptr=[6],
  },
  res={
    x=[1323, 857],
    batch=[1323],
    ptr=[6],
  },
  t_true={
    x=[1323, 3],
    batch=[1323],
    ptr=[6],
  },
  (godnode4decoder, informs, res)={ edg

In [50]:
# Model setup

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
ndim = data_sample['res'].x.shape[1]
ndim_godnode = data_sample['godnode'].x.shape[1]
ndim_fft2i = data_sample['fourier2di'].x.shape[1]
ndim_fft2r = data_sample['fourier2dr'].x.shape[1]

num_embeddings = 40
embedding_dim = 20
hidden_size = 100
se3transfomer = False  # Set to True for SE3Transformer, False for GNN

if se3transfomer == True:

	encoder = se3e.se3_Encoder(
		in_channels=ndim,
		hidden_channels=[hidden_size//2, hidden_size//2],
		out_channels=embedding_dim,
		metadata={'edge_types': [('res','contactPoints','res'), ('res','hbond','res')]},
		num_embeddings=num_embeddings,
		commitment_cost=0.9,
		edge_dim=1,
		encoder_hidden=hidden_size,
		EMA=True,
		nheads=5,
		dropout_p=0.005,
		reset_codes=False,
		flavor='transformer',
		fftin=True
	)			
else:
	encoder = ft2.mk1_Encoder(
		in_channels=ndim,
		hidden_channels=[hidden_size, hidden_size],
		out_channels=embedding_dim,
		metadata={'edge_types': [('res','contactPoints','res'), ('res','hbond','res')]},
		num_embeddings=num_embeddings,
		commitment_cost=0.9,
		edge_dim=1,
		encoder_hidden=hidden_size,
		EMA=True,
		nheads=5,
		dropout_p=0.005,
		reset_codes=False,
		flavor='transformer',
		fftin=True
	)


print(encoder)
encoder = encoder.to(device)

# MultiMonoDecoder for sequence and geometry
mono_configs = {
	'sequence_transformer': {
		'in_channels': {'res': embedding_dim},
		'xdim': 20,
		'concat_positions': True,
		'hidden_channels': {('res','backbone','res'): [hidden_size]*3 , ('res','backbonerev','res'): [hidden_size]*3},
		'layers': 3,
		'AAdecoder_hidden': [hidden_size, hidden_size, hidden_size//2],
		'amino_mapper': converter.aaindex,
		'flavor': 'sage',
		'dropout': 0.005,
		'normalize': True,
		'residual': False
	},
	
	'contacts': {
		'in_channels': {'res': embedding_dim, 'godnode4decoder': ndim_godnode, 'foldx': 23 ,  'fft2r': ndim_fft2r, 'fft2i': ndim_fft2i},
		'concat_positions': True,
		'hidden_channels': {('res','backbone','res'): [hidden_size]*3, ('res','backbonerev','res'): [hidden_size]*3, ('res','informs','godnode4decoder'): [hidden_size]*3 , ('godnode4decoder','informs','res'): [hidden_size]*3},
		'layers': 3,
		'FFT2decoder_hidden': [hidden_size, hidden_size, hidden_size],
		'contactdecoder_hidden': [hidden_size//2, hidden_size//2],
		'nheads': 2,
		'Xdecoder_hidden': [hidden_size, hidden_size,  hidden_size ],
		'metadata': converter.metadata,
		'flavor': 'sage',
		'dropout': 0.005,
		'output_fft': True,
        'output_rt':False,
		'normalize': True,
		'residual': False,
		'contact_mlp': True
	}
}
decoder = MultiMonoDecoder( configs=mono_configs)
encoder = encoder.to(device)
decoder = decoder.to(device)

print(decoder)

Seed set to 42
Seed set to 42
Seed set to 42


mk1_Encoder(
  (convs): ModuleList(
    (0): ModuleDict(
      (res_contactPoints_res): TransformerConv(100, 100, heads=5)
      (res_hbond_res): TransformerConv(100, 100, heads=5)
    )
  )
  (norms): ModuleList(
    (0): GraphNorm(100)
  )
  (bn): BatchNorm1d(857, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.005, inplace=False)
  (jk): JumpingKnowledge(cat)
  (ffin): Sequential(
    (0): Linear(in_features=1017, out_features=200, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=200, out_features=100, bias=True)
    (3): GELU(approximate='none')
    (4): DynamicTanh(normalized_shape=100, alpha_init_value=0.5, channels_last=True)
  )
  (lin): Sequential(
    (0): DynamicTanh(normalized_shape=100, alpha_init_value=0.5, channels_last=True)
    (1): Linear(in_features=100, out_features=100, bias=True)
    (2): GELU(approximate='none')
    (3): Linear(in_features=100, out_features=100, bias=True)
    (4): GELU(approximate=

In [51]:
# Training loop (demo, similar to learn.py)
import time
from collections import defaultdict

num_epochs = 20  # For demonstration, keep small
optimizer = torch.optim.AdamW(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2)

edgeweight = 0.01
xweight = 1
fft2weight = 0.01
vqweight = 0.0001
clip_grad = True
encoder.device = device
encoder.train()
decoder.train()


MultiMonoDecoder(
  (decoders): ModuleDict(
    (sequence_transformer): Transformer_AA_Decoder(
      (input_proj): Sequential(
        (0): DynamicTanh(normalized_shape=276, alpha_init_value=0.5, channels_last=True)
        (1): Linear(in_features=276, out_features=100, bias=True)
        (2): GELU(approximate='none')
        (3): Dropout(p=0.005, inplace=False)
        (4): Linear(in_features=100, out_features=100, bias=True)
      )
      (transformer_encoder): TransformerEncoder(
        (layers): ModuleList(
          (0-2): 3 x TransformerEncoderLayer(
            (self_attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=100, out_features=100, bias=True)
            )
            (linear1): Linear(in_features=100, out_features=2048, bias=True)
            (dropout): Dropout(p=0.005, inplace=False)
            (linear2): Linear(in_features=2048, out_features=100, bias=True)
            (norm1): LayerNorm((100,), eps=1e-05, elementwise_a

In [52]:
#get one sample from the dataloader
train_loader = DataLoader(struct_dat, batch_size=20, shuffle=True, num_workers=4)
data_sample = next(iter(train_loader))
data = data_sample.to(device)
optimizer.zero_grad()
z, vqloss = encoder(data , debug = True)
print('Encoded z shape:', z.shape)
print( data_sample.identifier )

torch.Size([6329, 857]) x_dict[res] shape
z_quantized shape: torch.Size([6329, 20])
vq_loss: tensor(0.7659, device='cuda:1', grad_fn=<SubBackward0>)
x shape: torch.Size([6329, 20])
x_dict keys: dict_keys(['AA', 'R_true', 'bondangles', 'coords', 'fourier1di', 'fourier1dr', 'fourier2di', 'fourier2dr', 'godnode', 'godnode4decoder', 'plddt', 'positions', 'res', 't_true'])
edge_index_dict keys: dict_keys([('godnode4decoder', 'informs', 'res'), ('godnode', 'informs', 'res'), ('res', 'backbone', 'res'), ('res', 'backbonerev', 'res'), ('res', 'contactPoints', 'res'), ('res', 'hbond', 'res'), ('res', 'informs', 'godnode'), ('res', 'informs', 'godnode4decoder'), ('res', 'window', 'res'), ('res', 'windowrev', 'res')])
Encoded z shape: torch.Size([6329, 20])
['R5D7F0', 'A0A7S4UUN0', 'A0A0F8XUB3', 'A0A3E4NY07', 'A0A1D1W005', 'A0A4R8G4N7', 'A0A1E5XW89', 'A0A7V2CVU6', 'A0A182QEX7', 'A0A0A9S6W1', 'A0A6J8AXJ0', 'A0A6V7WLY8', 'A0A5N5KV62', 'A0A7J8DMC4', 'J5TFK9', 'A0A853HWF1', 'A0A7Y6IRL5', 'K1ISW4', 'Q

In [53]:
from Bio import PDB
from Bio.PDB import PDBParser
def getCAatoms(pdb_file):
	parser = PDBParser(QUIET=True)
	# Parse the structure
	structure = parser.get_structure('structure', pdb_file)
	ca_atoms = {}
	for model in structure:
		for chain in model:
			if chain.id not in ca_atoms:
				ca_atoms[chain.id] = []
			for residue in chain :
				if 'CA' in residue and PDB.is_aa(residue) :
					ca_atoms[chain.id].append(residue['CA'])
	return ca_atoms

from src.AFDB_tools import 	grab_struct
totals = []
for identifier in  list(data.identifier):
	print('Processing identifier:', identifier)
	grab_struct(str(identifier) , structfolder='tmp/')
	#find the total number of residues
	ca_atoms = getCAatoms( 'tmp/' + str(identifier) + '.pdb')
	total_residues = sum(len(atoms) for atoms in ca_atoms.values())
	print('Total number of CA atoms:', total_residues)
	totals.append(total_residues)
print('Total residues for all identifiers:', sum(totals))


Processing identifier: R5D7F0
Total number of CA atoms: 153
Processing identifier: A0A7S4UUN0
Total number of CA atoms: 563
Processing identifier: A0A0F8XUB3
Total number of CA atoms: 82
Processing identifier: A0A3E4NY07
Total number of CA atoms: 81
Processing identifier: A0A1D1W005
Total number of CA atoms: 758
Processing identifier: A0A4R8G4N7
Total number of CA atoms: 73
Processing identifier: A0A1E5XW89
Total number of CA atoms: 1096
Processing identifier: A0A7V2CVU6
Total number of CA atoms: 626
Processing identifier: A0A182QEX7
Total number of CA atoms: 281
Processing identifier: A0A0A9S6W1
Total number of CA atoms: 326
Processing identifier: A0A6J8AXJ0
Total number of CA atoms: 329
Processing identifier: A0A6V7WLY8
Total number of CA atoms: 74
Processing identifier: A0A5N5KV62
Total number of CA atoms: 106
Processing identifier: A0A7J8DMC4
Total number of CA atoms: 338
Processing identifier: J5TFK9
Total number of CA atoms: 475
Processing identifier: A0A853HWF1
Total number of C

In [54]:
import pandas as pd
train_loader = DataLoader(struct_dat, batch_size=1, shuffle=True, num_workers=4)
def databatch2list(loader , limit = 10):
	for i,data in enumerate(loader):
		if i > limit:
			break
		data = data.to_data_list()
		for d in data:
			d = d.to(device)
			yield d
encoder_loader = databatch2list(train_loader, limit=20)
encoder.encode_structures_fasta(encoder_loader , './aln_encoded_test.fasta' )
#read the test fasta file
encoded_fasta =  './aln_encoded_test.fasta' 
seqstr = ''
ID = ''
seqdict = {}
with open(encoded_fasta, 'r') as f:
	#read all chars of file into a string
	for line in tqdm.tqdm(f):
		if line[0] == '>' and line[-1] == '\n':
			seqdict[ID] = seqstr
			ID = line[1:].strip()
			seqstr = ''
		else:
			seqstr += line.strip()
del seqdict['']
encoded_df = pd.DataFrame( seqdict.items() , columns=['protid', 'seq'] )
#change index to protid
encoded_df.index = encoded_df.protid
encoded_df = encoded_df.drop( 'protid', axis=1 )
encoded_df['ord'] = encoded_df.seq.map( lambda x: [ ord(c) for c in x] )
#hex starts at 1
encoded_df['hex2'] = encoded_df.ord.map( lambda x: [ hex(c) for c in x] )
encoded_df['length_ord'] = encoded_df.ord.map( lambda x: len(x) )
encoded_df['length_hex'] = encoded_df.hex2.map( lambda x: len(x) )
encoded_df['length_seq'] = encoded_df.seq.map( lambda x: len(x) )
encoded_df = encoded_df.sort_values('length_ord', ascending=False)
print(encoded_df.head(10))
print(encoded_df.tail(10))

21it [00:01, 11.07it/s]
42it [00:00, 123534.90it/s]

                                                          seq  \
protid                                                          
A0A813AF04  ...   
A0A1I0BAB0  ...   
A0A6C0JB65  ...   
A0A1H8XK78  ...   
A0A849FTN9  ...   
A0A812B6T8  ...   
A0A1Y4RGK8  ...   
A0A815TLP4  ...   
A0A850LSD1  ...   
G3N7K7      ...   

                                                          ord  \
protid                                                          
A0A813AF04  [27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 2...   
A0A1I0BAB0  [27, 27, 27,




In [55]:
print(encoded_df.head())
#grab the pdbs and get the total number of residues
#ensure encoder produces fastas with the same residues as the pdbs
pdbs = []
for identifier in encoded_df.index:
	pdbs.append( str(identifier) + '.pdb' )
encoded_df['pdb'] = pdbs
for i in range(len(encoded_df)):
	identifier = encoded_df.index[i]
	pdb_file = 'tmp/' + identifier + '.pdb'
	if not os.path.exists(pdb_file):
		grab_struct(str(identifier), structfolder='tmp/')
	ca_atoms = getCAatoms(pdb_file)
	total_residues = sum(len(atoms) for atoms in ca_atoms.values())
	encoded_df.at[identifier, 'total_residues'] = total_residues
encoded_df['delta'] = encoded_df['length_ord'] - encoded_df['total_residues']
assert (encoded_df['delta'] == 0).all(), "Mismatch between sequence length and total residues in PDB files"
assert (encoded_df['length_ord'] == encoded_df['length_hex']).all(), "Mismatch between ord and hex lengths"
assert (encoded_df['length_ord'] == encoded_df['length_seq']).all(), "Mismatch between ord and sequence lengths"
assert (encoded_df['length_hex'] == encoded_df['length_seq']).all(), "Mismatch between hex and sequence lengths"
print("All checks passed. Sequence lengths match PDB total residues and ord/hex lengths.")
print("Encoded sequences DataFrame:")
encoded_df.to_csv('encoded_sequences.csv')

                                                          seq  \
protid                                                          
A0A813AF04  ...   
A0A1I0BAB0  ...   
A0A6C0JB65  ...   
A0A1H8XK78  ...   
A0A849FTN9  ...   

                                                          ord  \
protid                                                          
A0A813AF04  [27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 2...   
A0A1I0BAB0  [27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 2...   
A0A6C0JB65  [27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 2...   
A0A1H8XK78  [27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 2...   
A0A849FTN9  [27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 2...   

                                                         hex2  length_ord  \
protid     

In [None]:

from losses.losses import recon_loss_diag , aa_reconstruction_loss
for epoch in range(num_epochs):
	total_loss_x = 0
	total_loss_edge = 0
	total_vq = 0
	total_loss_fft2 = 0
	for data in tqdm.tqdm(train_loader, desc=f"Epoch {epoch}"):
		data = data.to(device)
		optimizer.zero_grad()
		z, vqloss = encoder(data)
		data['res'].x = z
		# For demonstration, only sequence and contacts tasks
		out = decoder(data, None)
		recon_x = out['aa'] if isinstance(out, dict) and 'aa' in out else out[0] if isinstance(out, (list, tuple)) else None
		fft2_x = out['fft2pred'] if isinstance(out, dict) and 'fft2pred' in out else out[1] if isinstance(out, (list, tuple)) else None
		# Edge loss: use contactPoints if available
		edge_index = data.edge_index_dict['res', 'contactPoints', 'res'] if hasattr(data, 'edge_index_dict') and ('res', 'contactPoints', 'res') in data.edge_index_dict else None
		if edge_index is not None:
			edgeloss, _ = recon_loss_diag(data, edge_index, decoder, plddt=False, offdiag=False , key = 'edge_probs')
		else:
			edgeloss = torch.tensor(0.0, device=device)
		xloss = aa_reconstruction_loss(data['AA'].x, recon_x)
		fft2loss = F.smooth_l1_loss(torch.cat( [ data['fourier2dr'].x ,data['fourier2di'].x ] ,axis = 1 ) , fft2_x )
		loss = xweight * xloss + edgeweight * edgeloss + vqweight * vqloss + fft2loss* fft2weight

		loss.backward()
		if clip_grad:
			torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=1.0)
			torch.nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=1.0)
		optimizer.step()
		total_loss_x += xloss.item()
		total_loss_edge += edgeloss.item()
		total_loss_fft2 += fft2loss.item()
		total_vq += vqloss.item() if isinstance(vqloss, torch.Tensor) else float(vqloss)
	scheduler.step(total_loss_x)
	print(f"Epoch {epoch}: AA Loss: {total_loss_x/len(train_loader):.4f}, Edge Loss: {total_loss_edge/len(train_loader):.4f}, VQ Loss: {total_vq/len(train_loader):.4f} , FFT2 Loss: {total_loss_fft2/len(train_loader):.4f}")

Epoch 0:   3%|▍              | 31/1000 [00:17<08:55,  1.81it/s]


KeyboardInterrupt: 