# Test MonoDecoders: Sequence and Geometry
This notebook replicates the training logic from `learn.py` using the decoder in `mono_decoders.py` for amino acid and geometry prediction.

In [3]:
#use autoreload
%load_ext autoreload
%autoreload 2

In [4]:
cd /home/dmoi/projects/foldtree2/

/home/dmoi/projects/foldtree2


In [8]:
# Imports
import torch
from torch_geometric.data import DataLoader
import numpy as np
from src import pdbgraph
from src import foldtree2_ecddcd as ft2
from src.mono_decoders import MultiMonoDecoder
import os
import tqdm
import random
import torch.nn.functional as F

In [9]:
# Set seeds for reproducibility
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Data setup
datadir = '../../datasets/foldtree2/'
dataset_path = 'structs_traininffttest.h5'
converter = pdbgraph.PDB2PyG(aapropcsv='config/aaindex1.csv')
struct_dat = pdbgraph.StructureDataset(dataset_path)
train_loader = DataLoader(struct_dat, batch_size=4, shuffle=True, num_workers=0)
data_sample = next(iter(train_loader))



In [10]:
print('Data sample:', data_sample)

Data sample: HeteroDataBatch(
  identifier=[4],
  AA={
    x=[1185, 20],
    batch=[1185],
    ptr=[5],
  },
  R_true={
    x=[1185, 3, 3],
    batch=[1185],
    ptr=[5],
  },
  bondangles={
    x=[1185, 3],
    batch=[1185],
    ptr=[5],
  },
  coords={
    x=[1185, 3],
    batch=[1185],
    ptr=[5],
  },
  fourier1di={
    x=[1185, 80],
    batch=[1185],
    ptr=[5],
  },
  fourier1dr={
    x=[1185, 80],
    batch=[1185],
    ptr=[5],
  },
  fourier2di={
    x=[4, 1300],
    batch=[4],
    ptr=[5],
  },
  fourier2dr={
    x=[4, 1300],
    batch=[4],
    ptr=[5],
  },
  godnode={
    x=[4, 5],
    batch=[4],
    ptr=[5],
  },
  godnode4decoder={
    x=[4, 5],
    batch=[4],
    ptr=[5],
  },
  plddt={
    x=[1185, 1],
    batch=[1185],
    ptr=[5],
  },
  positions={
    x=[1185, 256],
    batch=[1185],
    ptr=[5],
  },
  res={
    x=[1185, 857],
    batch=[1185],
    ptr=[5],
  },
  t_true={
    x=[1185, 3],
    batch=[1185],
    ptr=[5],
  },
  (godnode4decoder, informs, res)={ edg

In [11]:
# Model setup
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
ndim = data_sample['res'].x.shape[1]
ndim_godnode = data_sample['godnode'].x.shape[1]
ndim_fft2i = data_sample['fourier2di'].x.shape[1]
ndim_fft2r = data_sample['fourier2dr'].x.shape[1]

num_embeddings = 40
embedding_dim = 20
hidden_size = 100

encoder = ft2.mk1_Encoder(
	in_channels=ndim,
	hidden_channels=[hidden_size, hidden_size],
	out_channels=embedding_dim,
	metadata={'edge_types': [('res','contactPoints','res'), ('res','hbond','res')]},
	num_embeddings=num_embeddings,
	commitment_cost=0.9,
	edge_dim=1,
	encoder_hidden=hidden_size,
	EMA=False,
	nheads=5,
	dropout_p=0.005,
	reset_codes=False,
	flavor='transformer',
	fftin=False
)

# MultiMonoDecoder for sequence and geometry
mono_configs = {
	'sequence_transformer': {
		'in_channels': {'res': embedding_dim},
		'xdim': 20,
		'concat_positions': True,
		'hidden_channels': {('res','backbone','res'): [hidden_size]*3 , ('res','backbonerev','res'): [hidden_size]*3},
		'layers': 3,
		'AAdecoder_hidden': [hidden_size, hidden_size//2, hidden_size//2],
		'amino_mapper': converter.aaindex,
		'flavor': 'sage',
		'dropout': 0.005,
		'normalize': True,
		'residual': True
	},
	
	'contacts': {
		'in_channels': {'res': embedding_dim, 'godnode4decoder': ndim_godnode, 'foldx': 23 ,  'fft2r': ndim_fft2r, 'fft2i': ndim_fft2i},
		'concat_positions': True,
		'hidden_channels': {('res','backbone','res'): [hidden_size]*3, ('res','backbonerev','res'): [hidden_size]*3, ('res','informs','godnode4decoder'): [hidden_size]*3},
		'layers': 2,
		'FFT2decoder_hidden': [hidden_size//2, hidden_size//2, hidden_size//2],
		'contactdecoder_hidden': [hidden_size//2, hidden_size//2],
		'nheads': 2,
		'Xdecoder_hidden': [hidden_size, hidden_size//2,  hidden_size//2 ],
		'metadata': converter.metadata,
		'flavor': 'sage',
		'dropout': 0.005,
		'output_fft': True,
        'output_rt':False,
		'normalize': True,
		'residual': False,
		'contact_mlp': True
	}
}
decoder = MultiMonoDecoder(tasks=['sequence_transformer', 'contacts'], configs=mono_configs)
encoder = encoder.to(device)
decoder = decoder.to(device)

print(decoder)

Seed set to 42
Seed set to 42
Seed set to 42


100 4 3 0.005
MultiMonoDecoder(
  (decoders): ModuleDict(
    (sequence_transformer): Transformer_AA_Decoder(
      (input_proj): Sequential(
        (0): Linear(in_features=276, out_features=100, bias=True)
        (1): GELU(approximate='none')
        (2): Dropout(p=0.005, inplace=False)
        (3): Linear(in_features=100, out_features=100, bias=True)
        (4): DynamicTanh(normalized_shape=100, alpha_init_value=0.5, channels_last=True)
      )
      (transformer_encoder): TransformerEncoder(
        (layers): ModuleList(
          (0-2): 3 x TransformerEncoderLayer(
            (self_attn): MultiheadAttention(
              (out_proj): NonDynamicallyQuantizableLinear(in_features=100, out_features=100, bias=True)
            )
            (linear1): Linear(in_features=100, out_features=2048, bias=True)
            (dropout): Dropout(p=0.005, inplace=False)
            (linear2): Linear(in_features=2048, out_features=100, bias=True)
            (norm1): LayerNorm((100,), eps=1e-05,

In [None]:
# Training loop (demo, similar to learn.py)
import time
from collections import defaultdict

num_epochs = 5  # For demonstration, keep small
optimizer = torch.optim.AdamW(list(encoder.parameters()) + list(decoder.parameters()), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2)

edgeweight = 0.01
xweight = 1
fft2weight = 0.01
vqweight = 0.0001
clip_grad = True

encoder.train()
decoder.train()

for epoch in range(num_epochs):
	total_loss_x = 0
	total_loss_edge = 0
	total_vq = 0
	for data in tqdm.tqdm(train_loader, desc=f"Epoch {epoch}"):
		data = data.to(device)
		optimizer.zero_grad()
		z, vqloss = encoder(data)
		data['res'].x = z
		# For demonstration, only sequence and contacts tasks
		out = decoder(data, None)
		recon_x = out['aa'] if isinstance(out, dict) and 'aa' in out else out[0] if isinstance(out, (list, tuple)) else None
		fft2_x = out['fft2pred'] if isinstance(out, dict) and 'fft2pred' in out else out[1] if isinstance(out, (list, tuple)) else None
		# Edge loss: use contactPoints if available
		edge_index = data.edge_index_dict['res', 'contactPoints', 'res'] if hasattr(data, 'edge_index_dict') and ('res', 'contactPoints', 'res') in data.edge_index_dict else None
		if edge_index is not None:
			edgeloss, _ = ft2.recon_loss(data, edge_index, decoder, plddt=False, offdiag=False , key = 'edge_probs')
		else:
			edgeloss = torch.tensor(0.0, device=device)
		xloss = ft2.aa_reconstruction_loss(data['AA'].x, recon_x)
		fft2loss = F.smooth_l1_loss(torch.cat( [ data['fourier2dr'].x ,data['fourier2di'].x ] ,axis = 1 ) , fft2_x )
		loss = xweight * xloss + edgeweight * edgeloss + vqweight * vqloss + fft2loss* fft2weight

		loss.backward()
		if clip_grad:
			torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=1.0)
			torch.nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=1.0)
		optimizer.step()
		total_loss_x += xloss.item()
		total_loss_edge += edgeloss.item()
		total_vq += vqloss.item() if isinstance(vqloss, torch.Tensor) else float(vqloss)
	scheduler.step(total_loss_x)
	
	print(f"Epoch {epoch}: AA Loss: {total_loss_x/len(train_loader):.4f}, Edge Loss: {total_loss_edge/len(train_loader):.4f}, VQ Loss: {total_vq/len(train_loader):.4f} , FFT2 Loss: {fft2loss.item()/len(train_loader):.4f}")

Epoch 0: 100%|████████████████████████████████| 1250/1250 [04:01<00:00,  5.18it/s]


Epoch 0: AA Loss: 2.3597, Edge Loss: 1.1617, VQ Loss: 0.2810 , FFT2 Loss: 30.1284


Epoch 1:  32%|██████████▋                      | 405/1250 [01:20<03:16,  4.30it/s]