# Test MonoDecoders: Sequence and Geometry
This notebook replicates the training logic from `learn.py` using the decoder in `mono_decoders.py` for amino acid and geometry prediction.

In [8]:
#use autoreload
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
cd /home/dmoi/projects/foldtree2/

/home/dmoi/projects/foldtree2


In [10]:
# Imports
import torch
from torch_geometric.data import DataLoader
import numpy as np
from src import pdbgraph
from src import foldtree2_ecddcd as ft2
from src.mono_decoders import MultiMonoDecoder
import os
import tqdm
import random
import torch.nn.functional as F

In [11]:
# Set seeds for reproducibility
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Data setup
datadir = '../../datasets/foldtree2/'
dataset_path = 'structs_traininffttest.h5'
converter = pdbgraph.PDB2PyG(aapropcsv='config/aaindex1.csv')
struct_dat = pdbgraph.StructureDataset(dataset_path)
train_loader = DataLoader(struct_dat, batch_size=5, shuffle=True, num_workers=4)
data_sample = next(iter(train_loader))



In [12]:
print('Data sample:', data_sample)

Data sample: HeteroDataBatch(
  identifier=[5],
  AA={
    x=[1323, 20],
    batch=[1323],
    ptr=[6],
  },
  R_true={
    x=[1323, 3, 3],
    batch=[1323],
    ptr=[6],
  },
  bondangles={
    x=[1323, 3],
    batch=[1323],
    ptr=[6],
  },
  coords={
    x=[1323, 3],
    batch=[1323],
    ptr=[6],
  },
  fourier1di={
    x=[1323, 80],
    batch=[1323],
    ptr=[6],
  },
  fourier1dr={
    x=[1323, 80],
    batch=[1323],
    ptr=[6],
  },
  fourier2di={
    x=[5, 1300],
    batch=[5],
    ptr=[6],
  },
  fourier2dr={
    x=[5, 1300],
    batch=[5],
    ptr=[6],
  },
  godnode={
    x=[5, 5],
    batch=[5],
    ptr=[6],
  },
  godnode4decoder={
    x=[5, 5],
    batch=[5],
    ptr=[6],
  },
  plddt={
    x=[1323, 1],
    batch=[1323],
    ptr=[6],
  },
  positions={
    x=[1323, 256],
    batch=[1323],
    ptr=[6],
  },
  res={
    x=[1323, 857],
    batch=[1323],
    ptr=[6],
  },
  t_true={
    x=[1323, 3],
    batch=[1323],
    ptr=[6],
  },
  (godnode4decoder, informs, res)={ edg

In [15]:
# Model setup
import se3encoder as se3e

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
ndim = data_sample['res'].x.shape[1]
ndim_godnode = data_sample['godnode'].x.shape[1]
ndim_fft2i = data_sample['fourier2di'].x.shape[1]
ndim_fft2r = data_sample['fourier2dr'].x.shape[1]

num_embeddings = 40
embedding_dim = 20
hidden_size = 100
se3transfomer = False  # Set to True for SE3Transformer, False for GNN

if se3transfomer == True:

	encoder = se3e.se3_Encoder(
		in_channels=ndim,
		hidden_channels=[hidden_size//2, hidden_size//2],
		out_channels=embedding_dim,
		metadata={'edge_types': [('res','contactPoints','res'), ('res','hbond','res')]},
		num_embeddings=num_embeddings,
		commitment_cost=0.9,
		edge_dim=1,
		encoder_hidden=hidden_size,
		EMA=True,
		nheads=5,
		dropout_p=0.005,
		reset_codes=False,
		flavor='transformer',
		fftin=True
	)			
else:
	encoder = ft2.mk1_Encoder(
		in_channels=ndim,
		hidden_channels=[hidden_size, hidden_size],
		out_channels=embedding_dim,
		metadata={'edge_types': [('res','contactPoints','res'), ('res','hbond','res')]},
		num_embeddings=num_embeddings,
		commitment_cost=0.9,
		edge_dim=1,
		encoder_hidden=hidden_size,
		EMA=True,
		nheads=5,
		dropout_p=0.005,
		reset_codes=False,
		flavor='transformer',
		fftin=True
	)


print(encoder)
encoder = encoder.to(device)

# MultiMonoDecoder for sequence and geometry
mono_configs = {
	'sequence_transformer': {
		'in_channels': {'res': embedding_dim},
		'xdim': 20,
		'concat_positions': True,
		'hidden_channels': {('res','backbone','res'): [hidden_size]*3 , ('res','backbonerev','res'): [hidden_size]*3},
		'layers': 3,
		'AAdecoder_hidden': [hidden_size, hidden_size, hidden_size//2],
		'amino_mapper': converter.aaindex,
		'flavor': 'sage',
		'dropout': 0.005,
		'normalize': True,
		'residual': False
	},
	
	'contacts': {
		'in_channels': {'res': embedding_dim, 'godnode4decoder': ndim_godnode, 'foldx': 23 ,  'fft2r': ndim_fft2r, 'fft2i': ndim_fft2i},
		'concat_positions': True,
		'hidden_channels': {('res','backbone','res'): [hidden_size]*3, ('res','backbonerev','res'): [hidden_size]*3, ('res','informs','godnode4decoder'): [hidden_size]*3 , ('godnode4decoder','informs','res'): [hidden_size]*3},
		'layers': 3,
		'FFT2decoder_hidden': [hidden_size, hidden_size, hidden_size],
		'contactdecoder_hidden': [hidden_size//2, hidden_size//2],
		'nheads': 2,
		'Xdecoder_hidden': [hidden_size, hidden_size,  hidden_size ],
		'metadata': converter.metadata,
		'flavor': 'sage',
		'dropout': 0.005,
		'output_fft': True,
        'output_rt':False,
		'normalize': True,
		'residual': False,
		'contact_mlp': True
	}
}
decoder = MultiMonoDecoder(tasks=['sequence_transformer', 'contacts'], configs=mono_configs)
encoder = encoder.to(device)
decoder = decoder.to(device)

print(decoder)

Seed set to 42
Seed set to 42


mk1_Encoder(
  (convs): ModuleList(
    (0): ModuleDict(
      (res_contactPoints_res): TransformerConv(100, 100, heads=5)
      (res_hbond_res): TransformerConv(100, 100, heads=5)
    )
  )
  (norms): ModuleList(
    (0): GraphNorm(100)
  )
  (bn): BatchNorm1d(857, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.005, inplace=False)
  (jk): JumpingKnowledge(cat)
  (ffin): Sequential(
    (0): Linear(in_features=1017, out_features=200, bias=True)
    (1): GELU(approximate='none')
    (2): Linear(in_features=200, out_features=100, bias=True)
    (3): GELU(approximate='none')
    (4): DynamicTanh(normalized_shape=100, alpha_init_value=0.5, channels_last=True)
  )
  (lin): Sequential(
    (0): DynamicTanh(normalized_shape=100, alpha_init_value=0.5, channels_last=True)
    (1): Linear(in_features=100, out_features=100, bias=True)
    (2): GELU(approximate='none')
    (3): Linear(in_features=100, out_features=100, bias=True)
    (4): GELU(approximate=

Seed set to 42


MultiMonoDecoder(
  (decoders): ModuleDict(
    (sequence_transformer): Transformer_AA_Decoder(
      (input_proj): Sequential(
        (0): Linear(in_features=276, out_features=100, bias=True)
        (1): GELU(approximate='none')
        (2): Dropout(p=0.005, inplace=False)
        (3): Linear(in_features=100, out_features=100, bias=True)
        (4): DynamicTanh(normalized_shape=100, alpha_init_value=0.5, channels_last=True)
      )
      (transformer_encoder): TransformerEncoder(
        (layers): ModuleList(
          (0-2): 3 x TransformerEncoderLayer(
            (self_attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=100, out_features=100, bias=True)
            )
            (linear1): Linear(in_features=100, out_features=2048, bias=True)
            (dropout): Dropout(p=0.005, inplace=False)
            (linear2): Linear(in_features=2048, out_features=100, bias=True)
            (norm1): LayerNorm((100,), eps=1e-05, elementwise_a

In [None]:
# Training loop (demo, similar to learn.py)
import time
from collections import defaultdict

num_epochs = 20  # For demonstration, keep small
optimizer = torch.optim.AdamW(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2)

edgeweight = 0.01
xweight = 1
fft2weight = 0.01
vqweight = 0.0001
clip_grad = True

encoder.train()
decoder.train()


Epoch 0: 100%|████████████████████████████████| 1000/1000 [06:15<00:00,  2.66it/s]


Epoch 0: AA Loss: 2.4984, Edge Loss: 1.1541, VQ Loss: -0.6879 , FFT2 Loss: 101.5385


Epoch 1: 100%|████████████████████████████████| 1000/1000 [06:29<00:00,  2.57it/s]


Epoch 1: AA Loss: 1.5485, Edge Loss: 0.8114, VQ Loss: -1.6587 , FFT2 Loss: 37.0629


Epoch 2: 100%|████████████████████████████████| 1000/1000 [06:47<00:00,  2.45it/s]


Epoch 2: AA Loss: 1.1679, Edge Loss: 0.6661, VQ Loss: -2.0465 , FFT2 Loss: 5.5672


Epoch 3: 100%|████████████████████████████████| 1000/1000 [06:42<00:00,  2.48it/s]


Epoch 3: AA Loss: 0.7867, Edge Loss: 0.5999, VQ Loss: -2.4561 , FFT2 Loss: 1.7800


Epoch 4: 100%|████████████████████████████████| 1000/1000 [06:46<00:00,  2.46it/s]


Epoch 4: AA Loss: 0.5590, Edge Loss: 0.5556, VQ Loss: -2.5729 , FFT2 Loss: 51.1382


Epoch 5: 100%|████████████████████████████████| 1000/1000 [06:15<00:00,  2.66it/s]


Epoch 5: AA Loss: 0.3943, Edge Loss: 0.5271, VQ Loss: -2.6148 , FFT2 Loss: 16.6045


Epoch 6: 100%|████████████████████████████████| 1000/1000 [05:45<00:00,  2.89it/s]


Epoch 6: AA Loss: 0.4700, Edge Loss: 0.5066, VQ Loss: -2.5625 , FFT2 Loss: 8.4572


Epoch 7: 100%|████████████████████████████████| 1000/1000 [05:48<00:00,  2.87it/s]


Epoch 7: AA Loss: 0.3841, Edge Loss: 0.5044, VQ Loss: -2.5833 , FFT2 Loss: 2.0891


Epoch 8: 100%|████████████████████████████████| 1000/1000 [03:16<00:00,  5.08it/s]


Epoch 8: AA Loss: 0.4011, Edge Loss: 0.4917, VQ Loss: -2.5842 , FFT2 Loss: 7.4210


Epoch 9: 100%|████████████████████████████████| 1000/1000 [03:28<00:00,  4.79it/s]


Epoch 9: AA Loss: 0.4184, Edge Loss: 0.4864, VQ Loss: -2.5857 , FFT2 Loss: 3.8743


Epoch 10: 100%|███████████████████████████████| 1000/1000 [03:16<00:00,  5.08it/s]


Epoch 10: AA Loss: 0.2971, Edge Loss: 0.4804, VQ Loss: -2.6449 , FFT2 Loss: 2.6341


Epoch 11: 100%|███████████████████████████████| 1000/1000 [03:20<00:00,  4.99it/s]


Epoch 11: AA Loss: 0.3500, Edge Loss: 0.4755, VQ Loss: -2.7058 , FFT2 Loss: 6.7716


Epoch 12: 100%|███████████████████████████████| 1000/1000 [03:18<00:00,  5.04it/s]


Epoch 12: AA Loss: 0.2674, Edge Loss: 0.4691, VQ Loss: -2.7613 , FFT2 Loss: 115.1710


Epoch 13: 100%|███████████████████████████████| 1000/1000 [03:20<00:00,  4.99it/s]


Epoch 13: AA Loss: 0.3381, Edge Loss: 0.4594, VQ Loss: -2.7169 , FFT2 Loss: 33.2982


Epoch 14: 100%|███████████████████████████████| 1000/1000 [03:19<00:00,  5.00it/s]


Epoch 14: AA Loss: 0.2588, Edge Loss: 0.4633, VQ Loss: -2.7325 , FFT2 Loss: 6.7632


Epoch 15: 100%|███████████████████████████████| 1000/1000 [03:23<00:00,  4.92it/s]


Epoch 15: AA Loss: 0.2028, Edge Loss: 0.4541, VQ Loss: -2.7619 , FFT2 Loss: 27.1181


Epoch 16: 100%|███████████████████████████████| 1000/1000 [03:19<00:00,  5.01it/s]


Epoch 16: AA Loss: 0.3053, Edge Loss: 0.4369, VQ Loss: -2.7444 , FFT2 Loss: 22.4614


Epoch 17: 100%|███████████████████████████████| 1000/1000 [03:25<00:00,  4.88it/s]


Epoch 17: AA Loss: 0.2896, Edge Loss: 0.4330, VQ Loss: -2.7370 , FFT2 Loss: 8.0427


Epoch 18: 100%|███████████████████████████████| 1000/1000 [03:23<00:00,  4.92it/s]


Epoch 18: AA Loss: 0.2819, Edge Loss: 0.4256, VQ Loss: -2.7175 , FFT2 Loss: 127.2558


Epoch 19: 100%|███████████████████████████████| 1000/1000 [03:26<00:00,  4.84it/s]

Epoch 19: AA Loss: 0.1433, Edge Loss: 0.4171, VQ Loss: -2.7424 , FFT2 Loss: 20.7313





In [None]:

for epoch in range(num_epochs):
	total_loss_x = 0
	total_loss_edge = 0
	total_vq = 0
	total_loss_fft2 = 0
	for data in tqdm.tqdm(train_loader, desc=f"Epoch {epoch}"):
		data = data.to(device)
		optimizer.zero_grad()
		z, vqloss = encoder(data)
		data['res'].x = z
		# For demonstration, only sequence and contacts tasks
		out = decoder(data, None)
		recon_x = out['aa'] if isinstance(out, dict) and 'aa' in out else out[0] if isinstance(out, (list, tuple)) else None
		fft2_x = out['fft2pred'] if isinstance(out, dict) and 'fft2pred' in out else out[1] if isinstance(out, (list, tuple)) else None
		# Edge loss: use contactPoints if available
		edge_index = data.edge_index_dict['res', 'contactPoints', 'res'] if hasattr(data, 'edge_index_dict') and ('res', 'contactPoints', 'res') in data.edge_index_dict else None
		if edge_index is not None:
			edgeloss, _ = ft2.recon_loss(data, edge_index, decoder, plddt=False, offdiag=False , key = 'edge_probs')
		else:
			edgeloss = torch.tensor(0.0, device=device)
		xloss = ft2.aa_reconstruction_loss(data['AA'].x, recon_x)
		fft2loss = F.smooth_l1_loss(torch.cat( [ data['fourier2dr'].x ,data['fourier2di'].x ] ,axis = 1 ) , fft2_x )
		loss = xweight * xloss + edgeweight * edgeloss + vqweight * vqloss + fft2loss* fft2weight

		loss.backward()
		if clip_grad:
			torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=1.0)
			torch.nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=1.0)
		optimizer.step()
		total_loss_x += xloss.item()
		total_loss_edge += edgeloss.item()
		total_loss_fft2 += fft2loss.item()
		total_vq += vqloss.item() if isinstance(vqloss, torch.Tensor) else float(vqloss)
	scheduler.step(total_loss_x)
	print(f"Epoch {epoch}: AA Loss: {total_loss_x/len(train_loader):.4f}, Edge Loss: {total_loss_edge/len(train_loader):.4f}, VQ Loss: {total_vq/len(train_loader):.4f} , FFT2 Loss: {total_loss_fft2/len(train_loader):.4f}")

Epoch 0: 100%|████████████████████████████████| 1000/1000 [03:19<00:00,  5.01it/s]


Epoch 0: AA Loss: 0.2636, Edge Loss: 0.4102, VQ Loss: -2.7085 , FFT2 Loss: 8.2469


Epoch 1: 100%|████████████████████████████████| 1000/1000 [03:26<00:00,  4.85it/s]


Epoch 1: AA Loss: 0.2948, Edge Loss: 0.4082, VQ Loss: -2.6526 , FFT2 Loss: 13.7172


Epoch 2: 100%|████████████████████████████████| 1000/1000 [03:19<00:00,  5.01it/s]


Epoch 2: AA Loss: 0.2488, Edge Loss: 0.4113, VQ Loss: -2.6605 , FFT2 Loss: 15.4840


Epoch 3: 100%|████████████████████████████████| 1000/1000 [03:15<00:00,  5.10it/s]


Epoch 3: AA Loss: 0.2524, Edge Loss: 0.4061, VQ Loss: -2.6550 , FFT2 Loss: 5.5670


Epoch 4: 100%|████████████████████████████████| 1000/1000 [03:17<00:00,  5.07it/s]


Epoch 4: AA Loss: 0.2171, Edge Loss: 0.4042, VQ Loss: -2.6700 , FFT2 Loss: 19.8529


Epoch 5: 100%|████████████████████████████████| 1000/1000 [03:22<00:00,  4.94it/s]


Epoch 5: AA Loss: 0.2403, Edge Loss: 0.4021, VQ Loss: -2.6580 , FFT2 Loss: 129.2435


Epoch 6: 100%|████████████████████████████████| 1000/1000 [03:19<00:00,  5.01it/s]


Epoch 6: AA Loss: 0.2684, Edge Loss: 0.4057, VQ Loss: -2.6343 , FFT2 Loss: 49.8008


Epoch 7: 100%|████████████████████████████████| 1000/1000 [03:24<00:00,  4.88it/s]


Epoch 7: AA Loss: 0.1928, Edge Loss: 0.4034, VQ Loss: -2.6721 , FFT2 Loss: 17.4037


Epoch 8: 100%|████████████████████████████████| 1000/1000 [03:25<00:00,  4.88it/s]


Epoch 8: AA Loss: 0.2023, Edge Loss: 0.4028, VQ Loss: -2.6709 , FFT2 Loss: 12.7929


Epoch 9: 100%|████████████████████████████████| 1000/1000 [03:23<00:00,  4.91it/s]


Epoch 9: AA Loss: 0.1792, Edge Loss: 0.4003, VQ Loss: -2.6781 , FFT2 Loss: 1.3679


Epoch 10: 100%|███████████████████████████████| 1000/1000 [03:20<00:00,  4.99it/s]


Epoch 10: AA Loss: 0.1892, Edge Loss: 0.4017, VQ Loss: -2.6693 , FFT2 Loss: 5.1052


Epoch 11: 100%|███████████████████████████████| 1000/1000 [03:24<00:00,  4.88it/s]


Epoch 11: AA Loss: 0.2264, Edge Loss: 0.3995, VQ Loss: -2.6547 , FFT2 Loss: 69.2878


Epoch 12: 100%|███████████████████████████████| 1000/1000 [03:23<00:00,  4.92it/s]


Epoch 12: AA Loss: 0.2670, Edge Loss: 0.4003, VQ Loss: -2.6204 , FFT2 Loss: 106.0918


Epoch 13: 100%|███████████████████████████████| 1000/1000 [03:15<00:00,  5.11it/s]


Epoch 13: AA Loss: 0.2108, Edge Loss: 0.4033, VQ Loss: -2.6581 , FFT2 Loss: 14.0107


Epoch 14:  98%|███████████████████████████████▏| 975/1000 [03:11<00:05,  4.84it/s]