In [1]:
import pandas as pd

import torch
from torch import optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool, GATConv, global_add_pool, global_max_pool
from torch_geometric.data import Data, Batch

from rdkit.Chem.rdchem import ChiralType
from rdkit.Chem.rdchem import BondType
from rdkit.Chem.rdchem import BondStereo

import numpy as np
from scipy.stats import pearsonr
import rdkit.Chem as Chem
from rdkit.Chem import rdFingerprintGenerator
from scipy.spatial.distance import pdist, squareform
from Bio.SeqUtils.ProtParam import ProteinAnalysis
from sklearn.cluster import KMeans

from typing import Tuple, List, Dict, Optional, Union

from lightning.pytorch.utilities.combined_loader import CombinedLoader

from tqdm import tqdm

conda install lightning -c conda-forge

## Our Dataset train-val set split

In [2]:
random_seed = 42
torch.manual_seed(random_seed)

device = torch.device("cpu") # "cuda")

In [3]:

def cluster_glycans(glycans, radius, fp_size, n_clusters):

    def get_morgan_count_fingerprint(smiles, radius, fp_size):
        
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return {f"mf_{i}": 0 for i in range(fp_size)} 


        #The useChirality parameter in Morgan fingerprints determines whether chirality is considered when encoding a molecule.
        #includeChirality=True = Differentiates between enantiomers (model will treat mirror-image molecules as different)
        #includeChirality=False = Ignores chirality (model will treat mirror-image molecules as the same)
        kid_named_morgan_finger = rdFingerprintGenerator.GetMorganGenerator(radius=radius,fpSize=fp_size, includeChirality=True)

        cfp = kid_named_morgan_finger.GetCountFingerprint(mol)  
        bit_counts = cfp.GetNonzeroElements()  

        # Convert to a full fp_size-length feature vector
        fingerprint_vector = {f"mf_{i}": bit_counts.get(i, 0) for i in range(fp_size)}
        return fingerprint_vector

    fingerprint_df = glycans['SMILES'].apply(lambda x: get_morgan_count_fingerprint(x, radius, fp_size)).apply(pd.Series)
    
    glycans = pd.concat([glycans, fingerprint_df], axis=1)
    
    # matrix version of fingerprint features. Each row is a glycan, each column is a fingerprint component shape: (611, 2048)
    finger_counts_matrix = fingerprint_df.values
    # pdist calculates the euclidean distance between the combination of each glycan with every other glycan. Then squareform() turns this into a matrix representation where each row is a glycan and each column is the same list of glycans so we can have a comparison matrix. Shape: (611, 611)
    dist_matrix = squareform(pdist(finger_counts_matrix, metric="euclidean"))
    

    kmeans = KMeans(n_clusters=n_clusters, random_state=0)
    labels = kmeans.fit_predict(dist_matrix)
    
    glycans['cluster_label'] = labels
    
    return glycans

def cluster_proteins(proteins, n_clusters):
    
    
    def compute_protein_features(seq):

        # Add reasoning for feature vectors
        
        # Protein Analysis is a Tool from Biopython
        analysis = ProteinAnalysis(seq)
        features = {}
        
        # The following are Basic Features
        features['length'] = len(seq)
        features['mw'] = analysis.molecular_weight()
        features['instability_index'] = analysis.instability_index()

        features['net_charge_pH7'] = analysis.charge_at_pH(7.0)

        aa_percent = analysis.get_amino_acids_percent()

        # Prompted ChatGPT to ask how to parse a
        # N, Q, S, T: Polar Amino Acids, often involved in hydrogen bonding with glycans
        # K, R: Basic Amino Acids, can form hydrogen bonds and electrostatic bonds
        # D, E: Acidic Amino Acids, can interact with positively charged groups of glycans
        for aa in ['N', 'Q', 'S', 'T', 'K', 'R', 'D', 'E']:
            features[f'frac_{aa}'] = aa_percent.get(aa, 0.0)

    
    # F, Y, W are aromatic amino acids which bind with glycans
        for aa in ['F', 'Y', 'W']:
            features[f'frac_{aa}'] = aa_percent.get(aa, 0.0)
            features['aromatic_binding_score'] = (
            aa_percent.get('F', 0.0) +
            aa_percent.get('Y', 0.0) +
            aa_percent.get('W', 0.0)
        )

        features['aromaticity'] = analysis.aromaticity()

        features['hydrophobicity'] = analysis.gravy()

        return features

    feature_dicts = proteins['Amino Acid Sequence'].apply(compute_protein_features)
    features_df = pd.DataFrame(list(feature_dicts))

    proteins = pd.concat([proteins, features_df], axis=1)
    
    # Select the feature columns (all columns from the feature extraction)
    feature_columns = features_df.columns.tolist()
    feature_data = proteins[feature_columns].values

    # apply k means clustering
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    proteins['cluster_label'] = kmeans.fit_predict(feature_data)
    
    return proteins

def stratified_train_test_split(fractions_df, glycans_df, proteins_df, test_size, random_state, mode='AND'):
    """
    Create a stratified train-test split where:
    1. Test set has unique GlycanIDs and ProteinGroups not seen in training
    2. Distribution of cluster_labels for both glycans and proteins is maintained
    
    Parameters:
    -----------
    fractions_df : pandas.DataFrame
        DataFrame containing ['ObjId', 'ProteinGroup', 'Concentration', 'GlycanID', 'f']
    glycans_df : pandas.DataFrame
        DataFrame containing ['Name', 'cluster_label'] where Name maps to GlycanID
    proteins_df : pandas.DataFrame
        DataFrame containing ['ProteinGroup', 'cluster_label']
    test_size : float, default=0.1
        Proportion of data to include in the test set
    random_state : int, default=42
        Random seed for reproducibility
    
    Returns:
    --------
    train_indices : numpy.ndarray
        Indices of fractions_df that belong to the training set
    test_indices : numpy.ndarray
        Indices of fractions_df that belong to the test set
    """
    # Set random seed
    np.random.seed(random_state)
    
    # Merge cluster labels from glycans and proteins into fractions
    fractions_with_clusters = fractions_df.copy()
    
    # Map glycan cluster labels
    glycan_cluster_map = dict(zip(glycans_df['Name'], glycans_df['cluster_label']))
    fractions_with_clusters['glycan_cluster'] = fractions_with_clusters['GlycanID'].map(glycan_cluster_map)
    
    # Map protein cluster labels
    protein_cluster_map = dict(zip(proteins_df['ProteinGroup'], proteins_df['cluster_label']))
    fractions_with_clusters['protein_cluster'] = fractions_with_clusters['ProteinGroup'].map(protein_cluster_map)
    
    # Get unique glycans and proteins with their cluster labels
    unique_glycans = glycans_df[['Name', 'cluster_label']].drop_duplicates()
    unique_proteins = proteins_df[['ProteinGroup', 'cluster_label']].drop_duplicates()
    
    # Calculate target counts for each cluster in test set
    glycan_cluster_counts = unique_glycans['cluster_label'].value_counts().to_dict()
    protein_cluster_counts = unique_proteins['cluster_label'].value_counts().to_dict()
    
    glycan_test_counts = {cluster: max(1, int(np.ceil(count * test_size))) 
                         for cluster, count in glycan_cluster_counts.items()}
    protein_test_counts = {cluster: max(1, int(np.ceil(count * test_size))) 
                          for cluster, count in protein_cluster_counts.items()}
    
    # Select glycans and proteins for test set while respecting cluster distributions
    test_glycans = []
    for cluster, target_count in glycan_test_counts.items():
        cluster_glycans = unique_glycans[unique_glycans['cluster_label'] == cluster]['Name'].tolist()
        selected = np.random.choice(cluster_glycans, size=min(target_count, len(cluster_glycans)), replace=False)
        test_glycans.extend(selected)
    
    test_proteins = []
    for cluster, target_count in protein_test_counts.items():
        cluster_proteins = unique_proteins[unique_proteins['cluster_label'] == cluster]['ProteinGroup'].tolist()
        selected = np.random.choice(cluster_proteins, size=min(target_count, len(cluster_proteins)), replace=False)
        test_proteins.extend(selected)
        
        
    if mode == 'AND':
        
        is_test = ((fractions_df['GlycanID'].isin(test_glycans)) & 
                (fractions_df['ProteinGroup'].isin(test_proteins)))

        is_train = ((~fractions_df['GlycanID'].isin(test_glycans)) & 
                        (~fractions_df['ProteinGroup'].isin(test_proteins)))
                
        test_indices = fractions_df[is_test].index

        train_indices = fractions_df[is_train].index
        
        print(f'-------------Test size (% of glycans and proteins as combinations in test set): {test_size*100}% -------------')

        print(f'train size: {len(train_indices)}, test size: {len(test_indices)}, total: {len(fractions_df)}')
                
        print(f'train size: {round((len(train_indices)/len(fractions_df))*100, 2)}%, test size: {round((len(test_indices)/len(fractions_df))*100, 2)}%')
        
        print(f'test size % in terms of test/(training+test) size: {round((len(test_indices)/(len(train_indices)+len(test_indices)))*100, 2)}%')
        
        print(f'Total % of dataset used: {round(((len(train_indices)+len(test_indices))/len(fractions_df))*100, 2)}%\n')
    
    else:
    
        # Create train and test masks
        is_test = ((fractions_with_clusters['GlycanID'].isin(test_glycans)) | 
                (fractions_with_clusters['ProteinGroup'].isin(test_proteins)))
        
        test_indices = fractions_with_clusters[is_test].index
        train_indices = fractions_with_clusters[~is_test].index
    
    
    return train_indices, test_indices

def batch_encode(encoder, data_list, device, batch_size):
    """Process data in batches to avoid CUDA memory overflow"""
    all_encodings = []
    total_items = len(data_list)
    
    for i in range(0, total_items, batch_size):
        # Get current batch
        batch = data_list[i:min(i+batch_size, total_items)]
        
        # Encode batch
        batch_encodings = encoder.encode_batch(batch, device)
        all_encodings.append(batch_encodings)
        
        # Print progress
        print(f'Progress: {min(i+batch_size, total_items)}/{total_items}')
        
        # Optional: clear CUDA cache to prevent memory fragmentation
        if device.type == 'cuda':
            torch.cuda.empty_cache()
    
    # Concatenate all batches
    return torch.cat(all_encodings, dim=0)

def prepare_train_val_datasets(
    fractions_df: pd.DataFrame,
    glycans_df: pd.DataFrame,
    proteins_df: pd.DataFrame,
    glycan_encoder,
    protein_encoder,
    glycan_type: str,
    random_state: int,
    split_mode: str,
    use_kfolds: bool,
    k_folds: float,
    val_split: float,
    device: torch.device
) -> Tuple[Dataset, Dataset]:
    """
    Prepare train and validation datasets
    
    Args:
        df: Full dataset DataFrame
        val_split: Fraction of data to use for validation
        glycan_encoder: Encoder for glycans
        protein_encoder: Encoder for proteins
    
    Returns:
        Tuple of train and validation datasets
    """
    
    # for each glycan create a glycan_encoding feature where we use glycan_encoder to encode the SMILES
    # for each protein create a protein_encoding feature where we use protein_encoder to encode the aminoacids
    #glycan_encodings = glycan_encoder.encode_batch(glycans_df[glycan_type].tolist(), device)
    #protein_encodings = protein_encoder.encode_batch(proteins_df['Amino Acid Sequence'].tolist(), device)

    # only do batch to not overload RAM of GPU
    if device.type == 'cuda':
        batch_size = 100  # Adjust based on your GPU memory

        # Encode glycans in batches
        glycan_encodings = batch_encode(
            glycan_encoder, 
            glycans_df[glycan_type].tolist(), 
            device, 
            batch_size=batch_size
        )

        # Encode proteins in batches
        protein_encodings = batch_encode(
            protein_encoder, 
            proteins_df['Amino Acid Sequence'].tolist(), 
            device, 
            batch_size=batch_size
        )
    else:
        glycan_encodings = glycan_encoder.encode_batch(glycans_df[glycan_type].tolist(), device)
        protein_encodings = protein_encoder.encode_batch(proteins_df['Amino Acid Sequence'].tolist(), device)
    
    
    # Might move to config but leave for now as our train and test are clusterd and stratified using these parameters
    radius = 3
    fp_size = 1024
    n_clusters = 3
    glycans_df = cluster_glycans(glycans_df, radius, fp_size, n_clusters)
    
    n_protein_clusters = 3
    proteins_df = cluster_proteins(proteins_df, n_protein_clusters)
    
    
    
    train_indicies, test_indicies = stratified_train_test_split(fractions_df, glycans_df, proteins_df, val_split, random_state, split_mode)
    # convert to kfold format so we can use the same code
    full_indicies = [(train_indicies, test_indicies)]
    
    return full_indicies, glycan_encodings, protein_encodings

## Glycan Encoder classes

In [4]:
class GNNGlycanEncoder(nn.Module):
    def __init__(self, embedding_dim: int = 256, hidden_channels: int = 128):
        super().__init__()
        
        # Node features (9-dimensional as shown in the figure)
        self.node_features = [
            'atomic_num',      # Atomic number
            'chirality',       # Chirality (important for glycans)
            'degree',          # Degree (number of bonds)
            'formal_charge',   # Formal charge
            'numH',            # Number of hydrogens
            'number_radical_e', # Number of radical electrons
            'hybridization',   # Hybridization type
            'is_aromatic',     # Is the atom aromatic (boolean)
            'is_in_ring'       # Is the atom in a ring (boolean)
        ]
        
        # Edge features (3-dimensional as shown in the figure)
        self.edge_features = [
            'bond_type',         # Type of bond (single, double, etc.)
            'stereo_configuration', # Stereo configuration
            'is_conjugated'      # Is the bond conjugated (boolean)
        ]
        
        # Define normalization parameters (to be populated during preprocessing)
        self.scalers = {}
        
        # Define GNN layers
        self.conv1 = GCNConv(9, hidden_channels//2)  # 9 is the expanded node features
        self.conv2 = GCNConv(hidden_channels//2, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels*2)
        
        # Output layer
        self.linear = torch.nn.Linear(hidden_channels*2, embedding_dim)
        
        self._embedding_dim = embedding_dim
    
    def _get_atom_features(self, atom) -> List[float]:
        """Extract atom features according to the predefined list"""
        features = []
        
        # atomic_num
        features.append(atom.GetAtomicNum())
        
        # chirality
        chirality_type = int(atom.GetChiralTag())
        features.append(chirality_type)
        
        # degree
        features.append(atom.GetDegree())
        
        # formal_charge
        features.append(atom.GetFormalCharge())
        
        # numH
        features.append(atom.GetTotalNumHs())
        
        # number_radical_e
        features.append(atom.GetNumRadicalElectrons())
        
        # hybridization
        hybridization_type = int(atom.GetHybridization())
        features.append(hybridization_type)
        
        # is_aromatic
        features.append(1 if atom.GetIsAromatic() else 0)
        
        # is_in_ring
        features.append(1 if atom.IsInRing() else 0)
        
        return features
    
    def _get_bond_features(self, bond) -> List[float]:
        """Extract bond features according to the predefined list"""
        features = []
        
        # bond_type
        bond_type = int(bond.GetBondType())
        features.append(bond_type)
        
        # stereo_configuration
        stereo = int(bond.GetStereo())
        features.append(stereo)
        
        # is_conjugated
        features.append(1 if bond.GetIsConjugated() else 0)
        
        return features
    
    def _mol_to_graph_data(self, mol) -> Data:
        """Convert an RDKit molecule to a PyTorch Geometric Data object"""
        # Get atom features
        node_features = []
        for atom in mol.GetAtoms():
            node_features.append(self._get_atom_features(atom))
        x = torch.tensor(node_features, dtype=torch.float)
        
        # Get edge indices and features
        edge_indices = []
        edge_features = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            
            edge_indices.append([i, j])
            edge_indices.append([j, i])  # Add reverse edge for undirected graph
            
            features = self._get_bond_features(bond)
            edge_features.append(features)
            edge_features.append(features)  # Duplicate for reverse edge
        
        if len(edge_indices) > 0:
            edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
            edge_attr = torch.tensor(edge_features, dtype=torch.float)
        else:
            # Handle molecules with no bonds (rare case)
            edge_index = torch.zeros((2, 0), dtype=torch.long)
            edge_attr = torch.zeros((0, 3), dtype=torch.float)
        
        # Create and normalize features
        x_norm = self._normalize_node_features(x)
        edge_attr_norm = self._normalize_edge_features(edge_attr)
        
        # Create the PyTorch Geometric Data object
        data = Data(
            x=x,
            edge_index=edge_index,
            edge_attr=edge_attr,
            x_norm=x_norm,
            edge_attr_norm=edge_attr_norm
        )
        
        return data
    
    def _normalize_node_features(self, x):
        """Apply normalization and one-hot encoding to node features"""
        # This is a placeholder - in production you'd use the scalers and encoding logic
        # from your preprocessing code
        # For simplicity, we're just returning the raw features
        return x
    
    def _normalize_edge_features(self, edge_attr):
        """Apply normalization and one-hot encoding to edge features"""
        # This is a placeholder - in production you'd use the scalers and encoding logic
        # from your preprocessing code
        # For simplicity, we're just returning the raw features
        return edge_attr
    
    def forward(self, data):
        """Process a batch of molecular graphs through the GNN"""
        x, edge_index, batch = data.x_norm, data.edge_index, data.batch
        
        # Apply GNN layers
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        
        # Global pooling to get a graph-level representation
        x = global_mean_pool(x, batch)
        
        x = F.dropout(x, p=0.5)
        
        # Final projection to embedding dimension
        x = self.linear(x)
        
        return x
    
    
    def encode_iupac(self, iupac_str: str, device: torch.device) -> torch.Tensor:
        """aaaaaa"""
        pass
    
    def encode_smiles(self, smiles: str, device: torch.device) -> torch.Tensor:
        """Convert a SMILES string to a graph embedding"""
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            raise ValueError(f"Could not parse SMILES: {smiles}")
        
        # Optionally add hydrogen atoms
        mol = Chem.AddHs(mol)
        
        # Convert to a graph data object
        data = self._mol_to_graph_data(mol)
        
        # Create a batch with just this single molecule
        data.batch = torch.zeros(data.x.size(0), dtype=torch.long)
        
        # Move to device
        data = data.to(device)
        
        # Get embedding
        with torch.no_grad():
            embedding = self.forward(data)
        
        return embedding
    
    def encode_batch(self, batch_data: List[str], device: torch.device) -> torch.Tensor:
        """Convert a batch of SMILES strings to graph embeddings"""
        # Process each molecule individually
        batch_embeddings = []
        count = 0
        for smiles in batch_data:
            count += 1
            print(f'glycan batch encoder progress: {count}/{len(batch_data)}')
            embedding = self.encode_smiles(smiles, device)
            batch_embeddings.append(embedding)
        #for iupac in batch_data:
            #embedding = self.encode_iupac(iupac, device)
            #batch_embeddings.append(embedding)
        
        # Stack all embeddings
        batch = torch.cat(batch_embeddings, dim=0)
        
        return batch
    
    def preprocess_dataset(self, smiles_list: List[str]):
        """Precompute normalization parameters for the dataset"""
        # Convert all molecules to graphs and collect statistics
        all_node_features = []
        all_edge_features = []
        
        for smiles in smiles_list:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                continue
                
            # Add hydrogens
            mol = Chem.AddHs(mol)
            
            # Collect node features
            for atom in mol.GetAtoms():
                features = self._get_atom_features(atom)
                all_node_features.append(features)
            
            # Collect edge features
            for bond in mol.GetBonds():
                features = self._get_bond_features(bond)
                all_edge_features.append(features)
        
        # Convert to tensors
        all_node_features = torch.tensor(all_node_features, dtype=torch.float)
        all_edge_features = torch.tensor(all_edge_features, dtype=torch.float)
        
        # Compute normalization parameters
        for i, feature_name in enumerate(self.node_features):
            feature_values = all_node_features[:, i]
            self.scalers[feature_name] = {
                'min': float(feature_values.min()),
                'max': float(feature_values.max()),
                'mean': float(feature_values.mean()),
                'std': float(feature_values.std())
            }
        
        for i, feature_name in enumerate(self.edge_features):
            feature_values = all_edge_features[:, i]
            self.scalers[feature_name] = {
                'min': float(feature_values.min()),
                'max': float(feature_values.max()),
                'mean': float(feature_values.mean()),
                'std': float(feature_values.std())
            }
        
        return self.scalers
    
    @property
    def embedding_dim(self) -> int:
        return self._embedding_dim

In [5]:
class MPNNGlycanEncoder(nn.Module):
    def __init__(self, embedding_dim: int = 128, pos_emb_dim: int = 2,
                 hidden_state_size: int = 128, n_layers: int = 3):
        super().__init__()

        # Node features (118 + 1 + 1 + 1 + 4 = 125 dimensions)
        self.node_features = [
            'atomic_num',
            'mass',
            'row',
            'column',
            'chirality',
        ]
        
        # Edge features (4 + 4 + 1 + 1 + 1 = 11 dimensions)
        self.edge_features = [
            'bond_type',
            'stero_configuration',
            'is_in_ring',
            'is_conjugated',
            'is_aromatic',
        ]
        
        self._embedding_dim = embedding_dim
        self.pos_emb_dim = pos_emb_dim
        self.hidden_state_size = hidden_state_size
        self.n_layers = n_layers

        # Assume base node features have 125 dimensions.
        # (For example: one-hot atomic number (118) + mass (1) + row (1) + column (1) + chirality one-hot (4))
        self.base_node_feature_dim = 125
        # After concatenating positional embeddings, node feature dim becomes:
        self.node_feature_dim = self.base_node_feature_dim + self.pos_emb_dim + 2

        # Edge features dimension (example): 11.
        # (For example: bond type one-hot (4) + stereo configuration one-hot (4) + is_in_ring (1) + is_conjugated (1) + is_aromatic (1))
        self.edge_feature_dim = 11

         # Initial projection to hidden state.
        self.initial_linear = nn.Linear(self.node_feature_dim, self.hidden_state_size)

        # Message passing function
        self.f_message = nn.Sequential(
            nn.Linear(self.hidden_state_size + self.edge_feature_dim, self.hidden_state_size),
            nn.ReLU(),
            nn.Linear(self.hidden_state_size, self.hidden_state_size)
        )
        
        # Update function
        self.f_update = nn.Sequential(
            nn.Linear(2 * self.hidden_state_size, self.hidden_state_size),
            nn.ReLU(),
            nn.Linear(self.hidden_state_size, self.hidden_state_size)
        )

        # Final readout projection.
        self.f_readout = nn.Linear(self.hidden_state_size, self._embedding_dim)
        
    def _get_random_walk_stats(self, adj: torch.Tensor, k_steps: int = 6) -> torch.Tensor:
        """
        Compute a k-step random walk bias matrix R = T^k (with T = D^-1 * A),
        and then derive per-node statistics (mean and std) for each node.
        Returns a tensor of shape (N, 2) where the two columns are mean and std.
        """
        deg = torch.sum(adj, dim=1, keepdim=True) + 1e-6
        T = adj / deg
        R = T.clone()
        for _ in range(k_steps - 1):
            R = R @ T
        # For each node, compute mean and standard deviation across the row.
        r_mean = R.mean(dim=1, keepdim=True)  # (N, 1)
        r_std = R.std(dim=1, keepdim=True)    # (N, 1)
        return torch.cat([r_mean, r_std], dim=1)  # (N, 2)


    def _get_positional_embeddings(self, adj: torch.Tensor, k: int) -> torch.Tensor:
        """
        Compute k-dimensional positional embeddings using the Laplacian eigenvectors.
        adj: (N, N) adjacency matrix.
        Returns: Tensor of shape (N, k)
        """
        # Compute degree vector and construct degree matrix D.
        deg = torch.sum(adj, dim=1)
        D = torch.diag(deg)
        # Compute Laplacian: L = D - A.
        L = D - adj
        # Compute eigen-decomposition (eigenvalues in ascending order).
        eigenvalues, eigenvectors = torch.linalg.eigh(L)
        # Skip the first eigenvector (trivial constant vector) if possible.
        if k < L.size(0):
            pos_emb = eigenvectors[:, 1:k+1]
        else:
            pos_emb = eigenvectors[:, :k]
        return pos_emb
    
    def _one_hot_atomic_number(self, atom):
        one_hot = [0] * 118
        atomic_num = atom.GetAtomicNum()
        one_hot[atomic_num - 1] = 1
        return one_hot
    
    def _one_hot_chirality(self, atom):
        chiral_tag = atom.GetChiralTag()
        possible_tags = [
            ChiralType.CHI_UNSPECIFIED, 
            ChiralType.CHI_TETRAHEDRAL_CW, 
            ChiralType.CHI_TETRAHEDRAL_CCW, 
            ChiralType.CHI_OTHER,
        ]
        one_hot = [1 if chiral_tag == tag else 0 for tag in possible_tags]
        return one_hot

    def _one_hot_bond_type(self, bond):
        bond_type = bond.GetBondType()
        possible_types = [
            BondType.SINGLE,
            BondType.DOUBLE,
            BondType.TRIPLE,
            BondType.AROMATIC,
        ]
        one_hot = [1 if bond_type == type else 0 for type in possible_types]
        return one_hot
    
    def _one_hot_stereo_configuration(self, bond):
        stereo = bond.GetStereo()
        possible_configurations = [
            BondStereo.STEREOANY,
            BondStereo.STEREOZ,
            BondStereo.STEREOE,
            BondStereo.STEREONONE,
        ]
        one_hot = [1 if stereo == config else 0 for config in possible_configurations]
        return one_hot
    
    def _get_atom_features(self, atom) -> List[float]:
        """Extract atom features according to the predefined list"""
        features = []

        # atomic number one hot encoding
        features += self._one_hot_atomic_number(atom)

        # atomic mass
        features.append(atom.GetMass())

        # row in periodic table / period
        features.append(element_row[atom.GetSymbol()])

        # column in periodic table / group
        features.append(element_col[atom.GetSymbol()])

        # chirality one hot encoding
        features += self._one_hot_chirality(atom)

        return features
    
    def _get_bond_features(self, bond) -> List[float]:
        """Extract bond features according to the predefined list"""
        features = []

        # bond type one hot encoding
        features += self._one_hot_bond_type(bond)

        # stereo configuration one hot encoding
        features += self._one_hot_stereo_configuration(bond)

        # is in ring
        features.append(bond.IsInRing())

        # is conjugated
        features.append(bond.GetIsConjugated())

        # is aromatic
        features.append(bond.GetIsAromatic())

        return features
    
    def _mol_to_graph_data(self, mol) -> dict:
        """
        Convert an RDKit molecule to a graph data dictionary containing:
          - x: node feature matrix (N x node_feature_dim)
          - adj: adjacency matrix (N x N)
          - edge_attr: edge feature tensor (N x N x edge_feature_dim)
          - batch: tensor indicating graph membership (for a single graph, all zeros)
        """
        # Build node features.
        node_features = [self._get_atom_features(atom) for atom in mol.GetAtoms()]
        x_raw = torch.tensor(node_features, dtype=torch.float)  # Shape: (N, node_feature_dim)
        N = x_raw.size(0)
        
        # Initialize dense adjacency and edge feature matrices.
        adj = torch.zeros((N, N), dtype=torch.float)
        edge_attr = torch.zeros((N, N, self.edge_feature_dim), dtype=torch.float)
        
        # Populate matrices for each bond.
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            adj[i, j] = 1.0
            adj[j, i] = 1.0
            bf = self._get_bond_features(bond)
            bf_tensor = torch.tensor(bf, dtype=torch.float)
            edge_attr[i, j] = bf_tensor
            edge_attr[j, i] = bf_tensor
        
        # Compute positional embeddings from the Laplacian.
        pos_emb = self._get_positional_embeddings(adj, self.pos_emb_dim)  # Shape: (N, pos_emb_dim)

        rw_stats = self._get_random_walk_stats(adj)  # Shape: (N, 2)
        # Concatenate positional embeddings to raw node features.
        x = torch.cat([x_raw, pos_emb, rw_stats], dim=1)  # Shape: (N, base_node_feature_dim + pos_emb_dim)
        
        # Optionally, apply normalization (here we pass features through).
        x_norm = self._normalize_node_features(x)
        edge_attr_norm = self._normalize_edge_features(edge_attr)
        
        # For a single graph, assign all nodes to batch 0.
        batch = torch.zeros(N, dtype=torch.long)
        
        data = {
            'x': x,
            'adj': adj,
            'edge_attr': edge_attr,
            'x_norm': x_norm,
            'edge_attr_norm': edge_attr_norm,
            'batch': batch
        }
        return data 
    
    def _normalize_node_features(self, x):
        """Placeholder for node feature normalization."""
        return x

    def _normalize_edge_features(self, edge_attr):
        """Placeholder for edge feature normalization."""
        return edge_attr
    
    def encode_smiles(self, smiles: str, device: torch.device) -> torch.Tensor:
        """Convert a SMILES string to a graph embedding"""
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            raise ValueError(f"Could not parse SMILES: {smiles}")
        
        # Optionally add hydrogen atoms
        mol = Chem.AddHs(mol)
        
        # Convert to a graph data object
        data = self._mol_to_graph_data(mol)
        
        # Create a batch with just this single molecule
        data['batch'] = torch.zeros(data['x'].size(0), dtype=torch.long)

        # Move all tensor entries to device.
        for key in data:
            if isinstance(data[key], torch.Tensor):
                data[key] = data[key].to(device)
        with torch.no_grad():
            embedding = self.forward(data)
        return embedding
    
    def encode_batch(self, batch_data: List[str], device: torch.device) -> torch.Tensor:
        """Convert a batch of SMILES strings to graph embeddings"""
        # Process each molecule individually
        count = 0
        batch_embeddings = []
        for smiles in batch_data:
            count+=1
            print(f'glycan encoding progress: {count}/{len(batch_data)}')
            embedding = self.encode_smiles(smiles, device)
            batch_embeddings.append(embedding)
        
        # Stack all embeddings
        batch = torch.cat(batch_embeddings, dim=0)
        
        return batch

    def preprocess_dataset(self, smiles_list: List[str]):
        """Precompute normalization parameters for the dataset"""
        # Convert all molecules to graphs and collect statistics
        all_node_features = []
        all_edge_features = []
        
        for smiles in smiles_list:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                continue
                
            # Add hydrogens
            #mol = Chem.AddHs(mol)
            
            # Collect node features
            for atom in mol.GetAtoms():
                features = self._get_atom_features(atom)
                all_node_features.append(features)
            
            # Collect edge features
            for bond in mol.GetBonds():
                features = self._get_bond_features(bond)
                all_edge_features.append(features)
        
        # Convert to tensors
        all_node_features = torch.tensor(all_node_features, dtype=torch.float)
        all_edge_features = torch.tensor(all_edge_features, dtype=torch.float)
        
        # Compute normalization parameters
        for i, feature_name in enumerate(self.node_features):
            feature_values = all_node_features[:, i]
            self.scalers[feature_name] = {
                'min': float(feature_values.min()),
                'max': float(feature_values.max()),
                'mean': float(feature_values.mean()),
                'std': float(feature_values.std())
            }
        
        for i, feature_name in enumerate(self.edge_features):
            feature_values = all_edge_features[:, i]
            self.scalers[feature_name] = {
                'min': float(feature_values.min()),
                'max': float(feature_values.max()),
                'mean': float(feature_values.mean()),
                'std': float(feature_values.std())
            }
        
        return self.scalers
    
    @property
    def embedding_dim(self) -> int:
        return self._embedding_dim
    
    def forward(self, data: dict) -> torch.Tensor:
        """
        Perform message passing to produce a graph-level embedding.
        data: Dictionary containing keys 'x_norm', 'adj', 'edge_attr_norm', 'batch'.
        """
        # Extract inputs.
        x = data['x_norm']        # (N, node_feature_dim) where node_feature_dim = base (125) + pos_emb_dim
        adj = data['adj']         # (N, N)
        edge_attr = data['edge_attr_norm']  # (N, N, edge_feature_dim)
        N = x.size(0)

        # Initial projection to hidden state.
        h = F.relu(self.initial_linear(x))  # (N, hidden_state_size)

        # Message passing iterations.
        for t in range(self.n_layers):
            h_neighbors = h.unsqueeze(0).expand(N, N, self.hidden_state_size)  # (N, N, hidden_state_size)
            msg_input = torch.cat([h_neighbors, edge_attr], dim=2)  # (N, N, hidden_state_size + edge_feature_dim)
            msg_input_flat = msg_input.view(-1, self.hidden_state_size + self.edge_feature_dim)
            messages_flat = self.f_message(msg_input_flat)  # (N*N, hidden_state_size)
            messages = messages_flat.view(N, N, self.hidden_state_size)  # (N, N, hidden_state_size)
            messages = messages * adj.unsqueeze(2)  # mask non-existent edges
            m = messages.sum(dim=1)  # aggregate messages by mean: (N, hidden_state_size)
            h = F.relu(self.f_update(torch.cat([h, m], dim=1)))  # update node states: (N, hidden_state_size)

        # Global mean pooling.
        graph_repr = h.sum(dim=0, keepdim=True)  # (1, hidden_state_size)
        out = self.f_readout(graph_repr)  # (1, embedding_dim)
        return out
    
# Hashmap for periodic table row (period) (used RdKit pt.GetRow())
element_row = {
    'H': 1,
    'He': 1,
    'Li': 2,
    'Be': 2,
    'B': 2,
    'C': 2,
    'N': 2,
    'O': 2,
    'F': 2,
    'Ne': 2,
    'Na': 3,
    'Mg': 3,
    'Al': 3,
    'Si': 3,
    'P': 3,
    'S': 3,
    'Cl': 3,
    'Ar': 3,
    'K': 4,
    'Ca': 4,
    'Sc': 4,
    'Ti': 4,
    'V': 4,
    'Cr': 4,
    'Mn': 4,
    'Fe': 4,
    'Co': 4,
    'Ni': 4,
    'Cu': 4,
    'Zn': 4,
    'Ga': 4,
    'Ge': 4,
    'As': 4,
    'Se': 4,
    'Br': 4,
    'Kr': 4,
    'Rb': 5,
    'Sr': 5,
    'Y': 5,
    'Zr': 5,
    'Nb': 5,
    'Mo': 5,
    'Tc': 5,
    'Ru': 5,
    'Rh': 5,
    'Pd': 5,
    'Ag': 5,
    'Cd': 5,
    'In': 5,
    'Sn': 5,
    'Sb': 5,
    'Te': 5,
    'I': 5,
    'Xe': 5,
    'Cs': 6,
    'Ba': 6,
    'La': 6,
    'Ce': 6,
    'Pr': 6,
    'Nd': 6,
    'Pm': 6,
    'Sm': 6,
    'Eu': 6,
    'Gd': 6,
    'Tb': 6,
    'Dy': 6,
    'Ho': 6,
    'Er': 6,
    'Tm': 6,
    'Yb': 6,
    'Lu': 6,
    'Hf': 6,
    'Ta': 6,
    'W': 6,
    'Re': 6,
    'Os': 6,
    'Ir': 6,
    'Pt': 6,
    'Au': 6,
    'Hg': 6,
    'Tl': 6,
    'Pb': 6,
    'Bi': 6,
    'Po': 6,
    'At': 6,
    'Rn': 6,
    'Fr': 7,
    'Ra': 7,
    'Ac': 7,
    'Th': 7,
    'Pa': 7,
    'U': 7,
    'Np': 7,
    'Pu': 7,
    'Am': 7,
    'Cm': 7,
    'Bk': 7,
    'Cf': 7,
    'Es': 7,
    'Fm': 7,
    'Md': 7,
    'No': 7,
    'Lr': 7,
    'Rf': 7,
    'Db': 7,
    'Sg': 7,
    'Bh': 7,
    'Hs': 7,
    'Mt': 7,
    'Ds': 7,
    'Rg': 7,
    'Cn': 7,
    'Nh': 7,
    'Fl': 7,
    'Mc': 7,
    'Lv': 7,
    'Ts': 7,
    'Og': 7,
}

# Hashmap for periodic table column (group) (manually entered)
element_col = {
    'H': 1,
    'He': 18,
    'Li': 1,
    'Be': 2,
    'B': 13,
    'C': 14,
    'N': 15,
    'O': 16,
    'F': 17,
    'Ne': 18,
    'Na': 1,
    'Mg': 2,
    'Al': 13,
    'Si': 14,
    'P': 15,
    'S': 16,
    'Cl': 17,
    'Ar': 18,
    'K': 1,
    'Ca': 2,
    'Sc': 3,
    'Ti': 4,
    'V': 5,
    'Cr': 6,
    'Mn': 7,
    'Fe': 8,
    'Co': 9,
    'Ni': 10,
    'Cu': 11,
    'Zn': 12,
    'Ga': 13,
    'Ge': 14,
    'As': 15,
    'Se': 16,
    'Br': 17,
    'Kr': 18,
    'Rb': 1,
    'Sr': 2,
    'Y': 3,
    'Zr': 4,
    'Nb': 5,
    'Mo': 6,
    'Tc': 7,
    'Ru': 8,
    'Rh': 9,
    'Pd': 10,
    'Ag': 11,
    'Cd': 12,
    'In': 13,
    'Sn': 14,
    'Sb': 15,
    'Te': 16,
    'I': 17,
    'Xe': 18,
    'Cs': 1,
    'Ba': 2,
    'La': 0,
    'Ce': 0,
    'Pr': 0,
    'Nd': 0,
    'Pm': 0,
    'Sm': 0,
    'Eu': 0,
    'Gd': 0,
    'Tb': 0,
    'Dy': 0,
    'Ho': 0,
    'Er': 0,
    'Tm': 0,
    'Yb': 0,
    'Lu': 3,
    'Hf': 4,
    'Ta': 5,
    'W': 6,
    'Re': 7,
    'Os': 8,
    'Ir': 9,
    'Pt': 10,
    'Au': 11,
    'Hg': 12,
    'Tl': 13,
    'Pb': 14,
    'Bi': 15,
    'Po': 16,
    'At': 17,
    'Rn': 18,
    'Fr': 1,
    'Ra': 2,
    'Ac': 0,
    'Th': 0,
    'Pa': 0,
    'U': 0,
    'Np': 0,
    'Pu': 0,
    'Am': 0,
    'Cm': 0,
    'Bk': 0,
    'Cf': 0,
    'Es': 0,
    'Fm': 0,
    'Md': 0,
    'No': 0,
    'Lr': 3,
    'Rf': 4,
    'Db': 5,
    'Sg': 6,
    'Bh': 7,
    'Hs': 8,
    'Mt': 9,
    'Ds': 10,
    'Rg': 11,
    'Cn': 12,
    'Nh': 13,
    'Fl': 14,
    'Mc': 15,
    'Lv': 16,
    'Ts': 17,
    'Og': 18,
}

## Protein encoder classes

In [6]:
class AdvancedGNNProteinEncoder(nn.Module):
    """
    Advanced Graph Neural Network-based Protein Encoder that incorporates:
    - Rich amino acid feature representation
    - Flexible graph structures (sequential, predicted contacts)
    - Attention-based message passing
    - Multiple readout functions
    """
    def __init__(self, 
                 embedding_dim: int = 256, 
                 hidden_channels: int = 128,
                 num_layers: int = 3,
                 dropout: float = 0.2,
                 use_attention: bool = True,
                 readout_mode: str = 'mean'):
        """
        Initialize the advanced GNN protein encoder.
        
        Args:
            embedding_dim: Final embedding dimension
            hidden_channels: Size of hidden layers in GNN
            num_layers: Number of GNN layers
            dropout: Dropout probability
            use_attention: Whether to use attention-based message passing
            readout_mode: Method for graph-level pooling ('mean', 'sum', 'max', 'mean+max')
        """
        super().__init__()
        self._embedding_dim = embedding_dim
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.dropout = dropout
        self.use_attention = use_attention
        self.readout_mode = readout_mode
        
        # Feature dimensions
        self.aa_embedding_dim = 20  # One-hot encoding of amino acids
        self.physicochemical_dim = 12  # Various amino acid properties
        self.position_embedding_dim = 16  # Positional encoding
        
        # Total node feature dimension
        node_feature_dim = self.aa_embedding_dim + self.physicochemical_dim + self.position_embedding_dim
        
        # Amino acid mappings
        self.aa_to_idx = {aa: i for i, aa in enumerate("ACDEFGHIKLMNPQRSTVWY")}
        self.default_idx = len(self.aa_to_idx)  # For unknown amino acids
        
        # Feature initialization layers
        self.position_embedding = nn.Embedding(1000, self.position_embedding_dim)  # Max sequence length of 1000
        
        # Physicochemical property mappings (pre-computed)
        self.aa_properties = self._initialize_aa_properties()
        
        # GNN layers
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        
        # First layer takes the combined node features
        if use_attention:
            self.convs.append(GATConv(node_feature_dim, hidden_channels, heads=4, concat=False))
        else:
            self.convs.append(GCNConv(node_feature_dim, hidden_channels))
        self.batch_norms.append(nn.BatchNorm1d(hidden_channels))
        
        # Additional layers
        for _ in range(num_layers - 1):
            if use_attention:
                self.convs.append(GATConv(hidden_channels, hidden_channels, heads=4, concat=False))
            else:
                self.convs.append(GCNConv(hidden_channels, hidden_channels))
            self.batch_norms.append(nn.BatchNorm1d(hidden_channels))
        
        # Output projections depend on readout mode
        output_dim = hidden_channels if 'mean+max' not in readout_mode else hidden_channels * 2
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_channels),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels, embedding_dim)
        )
        
    def _initialize_aa_properties(self) -> Dict[str, torch.Tensor]:
        """Initialize physicochemical properties for each amino acid"""
        properties = {}
        
        # These values are based on common AA properties: 
        # hydrophobicity, charge, size, polarity, etc.
        
        # Define key properties for each amino acid (normalized)
        # Format: [hydrophobicity, charge, size, polarity, aromaticity, 
        #          h-bond donor, h-bond acceptor, pKa, pI, flexibility,
        #          reactivity, glycosylation_site]
        
        properties['A'] = torch.tensor([0.7, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.3, 0.1, 0.0])
        properties['C'] = torch.tensor([0.8, 0.0, 0.2, 0.1, 0.0, 0.5, 0.0, 0.9, 0.4, 0.2, 0.9, 0.0])
        properties['D'] = torch.tensor([0.3, -1.0, 0.3, 0.9, 0.0, 0.0, 1.0, 0.1, 0.3, 0.5, 0.4, 0.0])
        properties['E'] = torch.tensor([0.4, -1.0, 0.4, 0.8, 0.0, 0.0, 1.0, 0.2, 0.3, 0.5, 0.3, 0.0])
        properties['F'] = torch.tensor([0.9, 0.0, 0.6, 0.0, 1.0, 0.0, 0.0, 0.0, 0.5, 0.2, 0.2, 0.0])
        properties['G'] = torch.tensor([0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 1.0, 0.2, 0.0])
        properties['H'] = torch.tensor([0.5, 0.5, 0.5, 0.7, 0.5, 0.5, 0.5, 0.6, 0.7, 0.3, 0.6, 0.0])
        properties['I'] = torch.tensor([1.0, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.1, 0.1, 0.0])
        properties['K'] = torch.tensor([0.3, 1.0, 0.6, 0.8, 0.0, 0.5, 0.0, 1.0, 0.9, 0.5, 0.3, 0.0])
        properties['L'] = torch.tensor([0.9, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.2, 0.1, 0.0])
        properties['M'] = torch.tensor([0.7, 0.0, 0.5, 0.1, 0.0, 0.0, 0.0, 0.0, 0.5, 0.3, 0.2, 0.0])
        properties['N'] = torch.tensor([0.3, 0.0, 0.3, 0.8, 0.0, 0.5, 0.5, 0.0, 0.5, 0.5, 0.3, 1.0])
        properties['P'] = torch.tensor([0.5, 0.0, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.2, 0.0])
        properties['Q'] = torch.tensor([0.4, 0.0, 0.4, 0.7, 0.0, 0.5, 0.5, 0.0, 0.5, 0.4, 0.2, 0.0])
        properties['R'] = torch.tensor([0.2, 1.0, 0.7, 0.9, 0.0, 0.5, 0.0, 0.5, 1.0, 0.4, 0.3, 0.0])
        properties['S'] = torch.tensor([0.4, 0.0, 0.2, 0.6, 0.0, 0.5, 0.5, 0.0, 0.5, 0.6, 0.2, 0.5])
        properties['T'] = torch.tensor([0.5, 0.0, 0.3, 0.5, 0.0, 0.5, 0.5, 0.0, 0.5, 0.4, 0.2, 0.5])
        properties['V'] = torch.tensor([0.8, 0.0, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.2, 0.1, 0.0])
        properties['W'] = torch.tensor([0.6, 0.0, 0.8, 0.1, 1.0, 0.5, 0.0, 0.0, 0.5, 0.1, 0.2, 0.0])
        properties['Y'] = torch.tensor([0.7, 0.0, 0.7, 0.4, 0.8, 0.5, 0.5, 0.3, 0.5, 0.2, 0.3, 0.0])
        
        # Default for unknown amino acids (average values)
        properties['X'] = torch.mean(torch.stack([prop for prop in properties.values()]), dim=0)
        
        return properties
        
    def _one_hot_encode_aa(self, aa: str) -> torch.Tensor:
        """One-hot encode an amino acid"""
        idx = self.aa_to_idx.get(aa, self.default_idx)
        one_hot = torch.zeros(self.aa_embedding_dim)
        if idx < self.aa_embedding_dim:
            one_hot[idx] = 1.0
        return one_hot
    
    def _get_aa_properties(self, aa: str) -> torch.Tensor:
        """Get physicochemical properties for an amino acid"""
        return self.aa_properties.get(aa, self.aa_properties['X'])
    
    def _sequence_to_graph(self, 
                          sequence: str, 
                          contact_map: Optional[Union[torch.Tensor, List, np.ndarray, None]] = None,
                          distance_threshold: float = 8.0) -> Data:
        """
        Convert a protein sequence to a graph representation.
        
        Args:
            sequence: Amino acid sequence
            contact_map: Optional tensor of pairwise distances/contacts
            distance_threshold: Threshold for considering residues in contact
            
        Returns:
            PyTorch Geometric Data object
        """
        # Node features: combine one-hot encoding, properties, and position
        x = []
        for i, aa in enumerate(sequence):
            if aa not in self.aa_to_idx and aa != 'X':
                aa = 'X'  # Use default for unknown amino acids
                
            # Combine features
            one_hot = self._one_hot_encode_aa(aa)
            properties = self._get_aa_properties(aa)
            position = self.position_embedding(torch.tensor([min(i, 999)]))
            
            # Concatenate all features
            features = torch.cat([one_hot, properties, position.squeeze(0)])
            x.append(features)
            
        # Create node features tensor
        x = torch.stack(x)
        
        # Create edge index
        edge_index = []
        
        # Add sequential connections (each AA connected to neighbors within window)
        window_size = 3  # Connect each AA to this many neighbors in each direction
        for i in range(len(sequence)):
            # Connect to previous AAs within window
            for w in range(1, window_size + 1):
                if i - w >= 0:
                    edge_index.append([i-w, i])
                    edge_index.append([i, i-w])  # Bidirectional
            
            # Connect to next AAs within window
            for w in range(1, window_size + 1):
                if i + w < len(sequence):
                    edge_index.append([i, i+w])
                    edge_index.append([i+w, i])  # Bidirectional
        
        # Add contacts from contact map if provided
        if contact_map is not None:
            try:
                # Convert to tensor if not already
                if not isinstance(contact_map, torch.Tensor):
                    if isinstance(contact_map, np.ndarray):
                        contact_map = torch.tensor(contact_map)
                    elif isinstance(contact_map, list):
                        contact_map = torch.tensor(contact_map)
                
                # Only use contact map if it's now a tensor with the right shape
                if isinstance(contact_map, torch.Tensor) and contact_map.dim() == 2:
                    for i in range(len(sequence)):
                        for j in range(i + window_size + 1, min(len(sequence), contact_map.shape[0])):
                            # Check dimensions to avoid index errors
                            if i < contact_map.shape[0] and j < contact_map.shape[1]:
                                if contact_map[i, j] <= distance_threshold:
                                    edge_index.append([i, j])
                                    edge_index.append([j, i])  # Bidirectional
            except Exception as e:
                # If we encounter any error with the contact map, just ignore it
                print(f"Warning: Could not use contact map: {e}")
        
        # Create edge index tensor
        if edge_index:
            edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        else:
            # Handle case with no edges (very short sequence)
            edge_index = torch.zeros((2, 0), dtype=torch.long)
        
        # Create PyG Data object
        data = Data(x=x, edge_index=edge_index)
        return data
    
    def forward(self, data: Data) -> torch.Tensor:
        """
        Process protein graph through the GNN.
        
        Args:
            data: PyTorch Geometric Data object
            
        Returns:
            Protein embedding tensor
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # Apply GNN layers with residual connections
        for i in range(self.num_layers):
            identity = x
            x = self.convs[i](x, edge_index)
            x = self.batch_norms[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout)#, training=self.training)
            
            # Add residual connection if dimensions match
            if i > 0 and x.size(-1) == identity.size(-1):
                x = x + identity
        
        # Different pooling strategies
        if self.readout_mode == 'mean':
            x = global_mean_pool(x, batch)
        elif self.readout_mode == 'sum':
            x = global_add_pool(x, batch)
        elif self.readout_mode == 'max':
            # Manual implementation of max pooling
            x_max, _ = global_max_pool(x, batch, dim=0)
            x = x_max
        elif self.readout_mode == 'mean+max':
            x_mean = global_mean_pool(x, batch)
            # Manual implementation of max pooling
            x_max, _ = global_max_pool(x, batch, dim=0)
            x = torch.cat([x_mean, x_max], dim=1)
        
        # Final projection
        x = self.projection(x)
        
        return x
    
    def encode_sequence(self, 
                         sequence: str, 
                         device: Optional[torch.device] = None,
                         contact_map: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Encode a single protein sequence.
        
        Args:
            sequence: Amino acid sequence
            contact_map: Optional contact map for the protein
            
        Returns:
            Embedding tensor
        """
        # Convert sequence to graph
        data = self._sequence_to_graph(sequence, contact_map)
        
        # Add batch dimension for single sequence
        data.batch = torch.zeros(len(sequence), dtype=torch.long)
        
        # Move to device if specified
        if device is not None:
            data = data.to(device)
        
        # Forward pass
        with torch.no_grad():
            embedding = self.forward(data)
            
        return embedding
    
    def encode_batch(self, 
                     batch_data: List[str],
                     device: torch.device = None,
                     contact_maps: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
        """
        Encode a batch of protein sequences.
        
        Args:
            batch_data: List of amino acid sequences
            device: Device to place tensors on
            contact_maps: Optional list of contact maps for each protein
            
        Returns:
            Batch of embedding tensors
        """
        # Create a list of Data objects
        data_list = []
        count = 0
        for sequence in batch_data:
            count += 1
            print(f'protein encoder progress: {count}/{len(batch_data)}')
            # Don't use contact maps for now to avoid the error
            data = self._sequence_to_graph(sequence, None)
            data_list.append(data)
            
        # Create a batch from the list
        batch = Batch.from_data_list(data_list)
        
        # Move to device if specified
        if device is not None:
            batch = batch.to(device)
        
        # Forward pass
        with torch.no_grad():
            embeddings = self.forward(batch)
            
        return embeddings
    
    def predict_secondary_structure(self, sequence: str) -> Dict[str, torch.Tensor]:
        """
        Predict secondary structure probabilities (helix, sheet, coil)
        
        Args:
            sequence: Amino acid sequence
            
        Returns:
            Dictionary of secondary structure probabilities
        """
        # This would require a separate prediction head
        # Here we use Biopython as a placeholder
        try:
            analysis = ProteinAnalysis(sequence)
            helix, turn, sheet = analysis.secondary_structure_fraction()
            
            # Convert to tensor format that could come from a model
            ss_pred = {
                'helix': torch.tensor([helix] * len(sequence)),
                'sheet': torch.tensor([sheet] * len(sequence)),
                'coil': torch.tensor([turn] * len(sequence))
            }
            return ss_pred
        except:
            # Default values if analysis fails
            return {
                'helix': torch.zeros(len(sequence)),
                'sheet': torch.zeros(len(sequence)),
                'coil': torch.ones(len(sequence))
            }
    
    def estimate_contact_map(self, sequence: str) -> torch.Tensor:
        """
        Estimate a contact map based on amino acid properties and sequential distance.
        This is a placeholder - ideally a dedicated contact prediction model would be used.
        
        Args:
            sequence: Amino acid sequence
            
        Returns:
            Estimated contact map (distances between residues)
        """
        seq_len = len(sequence)
        contact_map = torch.ones(seq_len, seq_len) * 100  # Initialize with large distances
        
        # Set sequential distances
        for i in range(seq_len):
            for j in range(seq_len):
                # Sequential distance penalty
                contact_map[i, j] = min(contact_map[i, j], abs(i - j) * 3.8)
                
                # Reduce distance for hydrophobic interactions
                aa_i = sequence[i] if sequence[i] in self.aa_to_idx else 'X'
                aa_j = sequence[j] if sequence[j] in self.aa_to_idx else 'X'
                hydrophobicity_i = self.aa_properties[aa_i][0]
                hydrophobicity_j = self.aa_properties[aa_j][0]
                
                # Hydrophobic residues tend to cluster
                if hydrophobicity_i > 0.7 and hydrophobicity_j > 0.7:
                    contact_map[i, j] = min(contact_map[i, j], 8.0 + abs(i - j) * 0.5)
                
                # Ionic interactions between charged residues
                charge_i = self.aa_properties[aa_i][1]
                charge_j = self.aa_properties[aa_j][1]
                if abs(i - j) > 4 and charge_i * charge_j < 0:  # Opposite charges attract
                    contact_map[i, j] = min(contact_map[i, j], 10.0)
                    
        return contact_map
    
    @property
    def embedding_dim(self) -> int:
        return self._embedding_dim

In [7]:
class GlycoProteinDataset(Dataset):
    def __init__(self, fractions_df, glycan_encodings, protein_encodings, glycan_mapping, protein_mapping, task_id):
        """
        Args:
            fractions_df: DataFrame with fraction data
            glycan_encodings: Tensor of shape [n_glycans, embedding_dim]
            protein_encodings: Tensor of shape [n_proteins, embedding_dim]
            glycan_mapping: Dict mapping glycan IDs to indices in glycan_encodings
            protein_mapping: Dict mapping protein IDs to indices in protein_encodings
        """
        self.fractions_df = fractions_df
        self.glycan_encodings = glycan_encodings
        self.protein_encodings = protein_encodings
        self.glycan_mapping = glycan_mapping
        self.protein_mapping = protein_mapping
        self.task_id = task_id
        
    def __len__(self):
        return len(self.fractions_df)
    
    def __getitem__(self, idx):
        row = self.fractions_df.iloc[idx]
        
        # Get the corresponding encodings using the mappings
        glycan_idx = self.glycan_mapping[row['GlycanID']]
        protein_idx = self.protein_mapping[row['ProteinGroup']]
        
        return {
            'glycan_encoding': self.glycan_encodings[glycan_idx],
            'protein_encoding': self.protein_encodings[protein_idx],
            'concentration': torch.tensor([row['Concentration']], dtype=torch.float32),
            'target': torch.tensor([row['f']], dtype=torch.float32),
            'task_id': self.task_id
        }

In [8]:
class AdvancedGNNProteinEncoder_GPU(nn.Module):
    """
    Advanced Graph Neural Network-based Protein Encoder that incorporates:
    - Rich amino acid feature representation
    - Flexible graph structures (sequential, predicted contacts)
    - Attention-based message passing
    - Multiple readout functions
    """
    def __init__(self, 
                 embedding_dim: int = 256, 
                 hidden_channels: int = 128,
                 num_layers: int = 3,
                 dropout: float = 0.2,
                 use_attention: bool = True,
                 readout_mode: str = 'mean'):
        """
        Initialize the advanced GNN protein encoder.
        
        Args:
            embedding_dim: Final embedding dimension
            hidden_channels: Size of hidden layers in GNN
            num_layers: Number of GNN layers
            dropout: Dropout probability
            use_attention: Whether to use attention-based message passing
            readout_mode: Method for graph-level pooling ('mean', 'sum', 'max', 'mean+max')
        """
        super().__init__()
        self._embedding_dim = embedding_dim
        self.hidden_channels = hidden_channels
        self.num_layers = num_layers
        self.dropout = dropout
        self.use_attention = use_attention
        self.readout_mode = readout_mode
        
        # Feature dimensions
        self.aa_embedding_dim = 20  # One-hot encoding of amino acids
        self.physicochemical_dim = 12  # Various amino acid properties
        self.position_embedding_dim = 16  # Positional encoding
        
        # Total node feature dimension
        node_feature_dim = self.aa_embedding_dim + self.physicochemical_dim + self.position_embedding_dim
        
        # Amino acid mappings
        self.aa_to_idx = {aa: i for i, aa in enumerate("ACDEFGHIKLMNPQRSTVWY")}
        self.default_idx = len(self.aa_to_idx)  # For unknown amino acids
        
        # Feature initialization layers
        self.position_embedding = nn.Embedding(1000, self.position_embedding_dim).cuda()  # Max sequence length of 1000
        
        # Physicochemical property mappings (pre-computed)
        self.aa_properties = self._initialize_aa_properties()
        
        # GNN layers
        self.convs = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        
        # First layer takes the combined node features
        if use_attention:
            self.convs.append(GATConv(node_feature_dim, hidden_channels, heads=4, concat=False))
        else:
            self.convs.append(GCNConv(node_feature_dim, hidden_channels))
        self.batch_norms.append(nn.BatchNorm1d(hidden_channels))
        
        # Additional layers
        for _ in range(num_layers - 1):
            if use_attention:
                self.convs.append(GATConv(hidden_channels, hidden_channels, heads=4, concat=False))
            else:
                self.convs.append(GCNConv(hidden_channels, hidden_channels))
            self.batch_norms.append(nn.BatchNorm1d(hidden_channels))
        
        # Output projections depend on readout mode
        output_dim = hidden_channels if 'mean+max' not in readout_mode else hidden_channels * 2
        self.projection = nn.Sequential(
            nn.Linear(output_dim, hidden_channels),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels, embedding_dim)
        )
        
    def _initialize_aa_properties(self) -> Dict[str, torch.Tensor]:
        """Initialize physicochemical properties for each amino acid"""
        properties = {}
        
        # These values are based on common AA properties: 
        # hydrophobicity, charge, size, polarity, etc.
        
        # Define key properties for each amino acid (normalized)
        # Format: [hydrophobicity, charge, size, polarity, aromaticity, 
        #          h-bond donor, h-bond acceptor, pKa, pI, flexibility,
        #          reactivity, glycosylation_site]
        
        properties['A'] = torch.tensor([0.7, 0.0, 0.1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.3, 0.1, 0.0])
        properties['C'] = torch.tensor([0.8, 0.0, 0.2, 0.1, 0.0, 0.5, 0.0, 0.9, 0.4, 0.2, 0.9, 0.0])
        properties['D'] = torch.tensor([0.3, -1.0, 0.3, 0.9, 0.0, 0.0, 1.0, 0.1, 0.3, 0.5, 0.4, 0.0])
        properties['E'] = torch.tensor([0.4, -1.0, 0.4, 0.8, 0.0, 0.0, 1.0, 0.2, 0.3, 0.5, 0.3, 0.0])
        properties['F'] = torch.tensor([0.9, 0.0, 0.6, 0.0, 1.0, 0.0, 0.0, 0.0, 0.5, 0.2, 0.2, 0.0])
        properties['G'] = torch.tensor([0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 1.0, 0.2, 0.0])
        properties['H'] = torch.tensor([0.5, 0.5, 0.5, 0.7, 0.5, 0.5, 0.5, 0.6, 0.7, 0.3, 0.6, 0.0])
        properties['I'] = torch.tensor([1.0, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.1, 0.1, 0.0])
        properties['K'] = torch.tensor([0.3, 1.0, 0.6, 0.8, 0.0, 0.5, 0.0, 1.0, 0.9, 0.5, 0.3, 0.0])
        properties['L'] = torch.tensor([0.9, 0.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.2, 0.1, 0.0])
        properties['M'] = torch.tensor([0.7, 0.0, 0.5, 0.1, 0.0, 0.0, 0.0, 0.0, 0.5, 0.3, 0.2, 0.0])
        properties['N'] = torch.tensor([0.3, 0.0, 0.3, 0.8, 0.0, 0.5, 0.5, 0.0, 0.5, 0.5, 0.3, 1.0])
        properties['P'] = torch.tensor([0.5, 0.0, 0.3, 0.3, 0.0, 0.0, 0.0, 0.0, 0.5, 0.0, 0.2, 0.0])
        properties['Q'] = torch.tensor([0.4, 0.0, 0.4, 0.7, 0.0, 0.5, 0.5, 0.0, 0.5, 0.4, 0.2, 0.0])
        properties['R'] = torch.tensor([0.2, 1.0, 0.7, 0.9, 0.0, 0.5, 0.0, 0.5, 1.0, 0.4, 0.3, 0.0])
        properties['S'] = torch.tensor([0.4, 0.0, 0.2, 0.6, 0.0, 0.5, 0.5, 0.0, 0.5, 0.6, 0.2, 0.5])
        properties['T'] = torch.tensor([0.5, 0.0, 0.3, 0.5, 0.0, 0.5, 0.5, 0.0, 0.5, 0.4, 0.2, 0.5])
        properties['V'] = torch.tensor([0.8, 0.0, 0.4, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5, 0.2, 0.1, 0.0])
        properties['W'] = torch.tensor([0.6, 0.0, 0.8, 0.1, 1.0, 0.5, 0.0, 0.0, 0.5, 0.1, 0.2, 0.0])
        properties['Y'] = torch.tensor([0.7, 0.0, 0.7, 0.4, 0.8, 0.5, 0.5, 0.3, 0.5, 0.2, 0.3, 0.0])
        
        # Default for unknown amino acids (average values)
        properties['X'] = torch.mean(torch.stack([prop for prop in properties.values()]), dim=0)
        
        return properties
        
    def _one_hot_encode_aa(self, aa: str) -> torch.Tensor:
        """One-hot encode an amino acid"""
        idx = self.aa_to_idx.get(aa, self.default_idx)
        one_hot = torch.zeros(self.aa_embedding_dim).cuda()
        if idx < self.aa_embedding_dim:
            one_hot[idx] = 1.0
        return one_hot.cuda()
    
    def _get_aa_properties(self, aa: str) -> torch.Tensor:
        """Get physicochemical properties for an amino acid"""
        return self.aa_properties.get(aa, self.aa_properties['X']).cuda()
    
    def _sequence_to_graph(self, 
                          sequence: str, 
                          contact_map: Optional[Union[torch.Tensor, List, np.ndarray, None]] = None,
                          distance_threshold: float = 8.0) -> Data:
        """
        Convert a protein sequence to a graph representation.
        
        Args:
            sequence: Amino acid sequence
            contact_map: Optional tensor of pairwise distances/contacts
            distance_threshold: Threshold for considering residues in contact
            
        Returns:
            PyTorch Geometric Data object
        """
        # Node features: combine one-hot encoding, properties, and position
        x = []
        for i, aa in enumerate(sequence):
            if aa not in self.aa_to_idx and aa != 'X':
                aa = 'X'  # Use default for unknown amino acids
                
            # Combine features
            #aa = aa.cuda()
            one_hot = self._one_hot_encode_aa(aa).cuda()
            properties = self._get_aa_properties(aa).cuda()
            position = self.position_embedding(torch.tensor([min(i, 999)]).cuda()).cuda()
            
            # Concatenate all features
            features = torch.cat([one_hot, properties, position.squeeze(0)]).cuda()
            x.append(features)
            
        # Create node features tensor
        x = torch.stack(x).cuda()
        
        # Create edge index
        edge_index = []
        
        # Add sequential connections (each AA connected to neighbors within window)
        window_size = 3  # Connect each AA to this many neighbors in each direction
        for i in range(len(sequence)):
            # Connect to previous AAs within window
            for w in range(1, window_size + 1):
                if i - w >= 0:
                    edge_index.append([i-w, i])
                    edge_index.append([i, i-w])  # Bidirectional
            
            # Connect to next AAs within window
            for w in range(1, window_size + 1):
                if i + w < len(sequence):
                    edge_index.append([i, i+w])
                    edge_index.append([i+w, i])  # Bidirectional
        
        # Add contacts from contact map if provided
        if contact_map is not None:
            try:
                # Convert to tensor if not already
                if not isinstance(contact_map, torch.Tensor):
                    if isinstance(contact_map, np.ndarray):
                        contact_map = torch.tensor(contact_map).cuda()
                    elif isinstance(contact_map, list):
                        contact_map = torch.tensor(contact_map).cuda()
                
                # Only use contact map if it's now a tensor with the right shape
                if isinstance(contact_map, torch.Tensor) and contact_map.dim() == 2:
                    for i in range(len(sequence)):
                        for j in range(i + window_size + 1, min(len(sequence), contact_map.shape[0])):
                            # Check dimensions to avoid index errors
                            if i < contact_map.shape[0] and j < contact_map.shape[1]:
                                if contact_map[i, j] <= distance_threshold:
                                    edge_index.append([i, j])
                                    edge_index.append([j, i])  # Bidirectional
            except Exception as e:
                # If we encounter any error with the contact map, just ignore it
                print(f"Warning: Could not use contact map: {e}")
        
        # Create edge index tensor
        if edge_index:
            edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        else:
            # Handle case with no edges (very short sequence)
            edge_index = torch.zeros((2, 0), dtype=torch.long)
        edge_index = edge_index.cuda()
        # Create PyG Data object
        data = Data(x=x, edge_index=edge_index)
        return data.cuda()
    
    def forward(self, data: Data) -> torch.Tensor:
        """
        Process protein graph through the GNN.
        
        Args:
            data: PyTorch Geometric Data object
            
        Returns:
            Protein embedding tensor
        """
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # Apply GNN layers with residual connections
        for i in range(self.num_layers):
            identity = x.cuda()
            x = self.convs[i](x, edge_index.cuda())
            x = self.batch_norms[i](x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout)#, training=self.training)
            
            # Add residual connection if dimensions match
            if i > 0 and x.size(-1) == identity.size(-1):
                x = x + identity
        
        # Different pooling strategies
        if self.readout_mode == 'mean':
            x = global_mean_pool(x, batch)
        elif self.readout_mode == 'sum':
            x = global_add_pool(x, batch)
        elif self.readout_mode == 'max':
            # Manual implementation of max pooling
            x_max, _ = global_max_pool(x, batch, dim=0)
            x = x_max
        elif self.readout_mode == 'mean+max':
            x_mean = global_mean_pool(x, batch)
            # Manual implementation of max pooling
            x_max, _ = global_max_pool(x, batch, dim=0)
            x = torch.cat([x_mean, x_max], dim=1)
        
        # Final projection
        x = self.projection(x).cuda()
        
        return x
    
    def encode_sequence(self, 
                         sequence: str, 
                         device: Optional[torch.device] = None,
                         contact_map: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Encode a single protein sequence.
        
        Args:
            sequence: Amino acid sequence
            contact_map: Optional contact map for the protein
            
        Returns:
            Embedding tensor
        """
        # Convert sequence to graph
        data = self._sequence_to_graph(sequence, contact_map).to(device)
        
        # Add batch dimension for single sequence
        data.batch = torch.zeros(len(sequence), dtype=torch.long)
        
        # Move to device if specified
        if device is not None:
            data = data.to(device)
        
        # Forward pass
        with torch.no_grad():
            embedding = self.forward(data)
            
        return embedding
    
    def encode_batch(self, 
                     batch_data: List[str],
                     device: torch.device = None,
                     contact_maps: Optional[List[torch.Tensor]] = None) -> torch.Tensor:
        """
        Encode a batch of protein sequences.
        
        Args:
            batch_data: List of amino acid sequences
            device: Device to place tensors on
            contact_maps: Optional list of contact maps for each protein
            
        Returns:
            Batch of embedding tensors
        """
        print('deviceee,', device)
        #batch_data.to(device)
        # Create a list of Data objects
        data_list = []
        count = 0
        for sequence in batch_data:
            count += 1
            print(f'encode batch progress: {count}/{len(batch_data)}')
            # Don't use contact maps for now to avoid the error
            #sequence = sequence.cuda()
            data = self._sequence_to_graph(sequence, None).cuda()

            data_list.append(data)
            
        # Create a batch from the list
        batch = Batch.from_data_list(data_list)
        
        # Move to device if specified
        #if device is not None:
        batch = batch.cuda()
        
        # Forward pass
        with torch.no_grad():
            embeddings = self.forward(batch)
            
        return embeddings
    
    def predict_secondary_structure(self, sequence: str) -> Dict[str, torch.Tensor]:
        """
        Predict secondary structure probabilities (helix, sheet, coil)
        
        Args:
            sequence: Amino acid sequence
            
        Returns:
            Dictionary of secondary structure probabilities
        """
        # This would require a separate prediction head
        # Here we use Biopython as a placeholder
        try:
            analysis = ProteinAnalysis(sequence)
            helix, turn, sheet = analysis.secondary_structure_fraction()
            
            # Convert to tensor format that could come from a model
            ss_pred = {
                'helix': torch.tensor([helix] * len(sequence)),
                'sheet': torch.tensor([sheet] * len(sequence)),
                'coil': torch.tensor([turn] * len(sequence))
            }
            return ss_pred
        except:
            # Default values if analysis fails
            return {
                'helix': torch.zeros(len(sequence)),
                'sheet': torch.zeros(len(sequence)),
                'coil': torch.ones(len(sequence))
            }
    
    def estimate_contact_map(self, sequence: str) -> torch.Tensor:
        """
        Estimate a contact map based on amino acid properties and sequential distance.
        This is a placeholder - ideally a dedicated contact prediction model would be used.
        
        Args:
            sequence: Amino acid sequence
            
        Returns:
            Estimated contact map (distances between residues)
        """
        seq_len = len(sequence)
        contact_map = torch.ones(seq_len, seq_len) * 100  # Initialize with large distances
        
        # Set sequential distances
        for i in range(seq_len):
            for j in range(seq_len):
                # Sequential distance penalty
                contact_map[i, j] = min(contact_map[i, j], abs(i - j) * 3.8)
                
                # Reduce distance for hydrophobic interactions
                aa_i = sequence[i] if sequence[i] in self.aa_to_idx else 'X'
                aa_j = sequence[j] if sequence[j] in self.aa_to_idx else 'X'
                hydrophobicity_i = self.aa_properties[aa_i][0]
                hydrophobicity_j = self.aa_properties[aa_j][0]
                
                # Hydrophobic residues tend to cluster
                if hydrophobicity_i > 0.7 and hydrophobicity_j > 0.7:
                    contact_map[i, j] = min(contact_map[i, j], 8.0 + abs(i - j) * 0.5)
                
                # Ionic interactions between charged residues
                charge_i = self.aa_properties[aa_i][1]
                charge_j = self.aa_properties[aa_j][1]
                if abs(i - j) > 4 and charge_i * charge_j < 0:  # Opposite charges attract
                    contact_map[i, j] = min(contact_map[i, j], 10.0)
                    
        return contact_map
    
    @property
    def embedding_dim(self) -> int:
        return self._embedding_dim

In [9]:


fractions_df = pd.read_csv('../pipeline/data/Train_Fractions.csv', sep='\t')
glycans_df = pd.read_csv('../pipeline/data/Glycan-Structures-CFG611.txt', sep='\t')
proteins_df = pd.read_csv('../pipeline/data/Protein-Sequence-Table.txt', sep='\t')

glycan_encoder = MPNNGlycanEncoder().to(device) #GNNGlycanEncoder().to(device)
protein_encoder = AdvancedGNNProteinEncoder().to(device) #AdvancedGNNProteinEncoder().to(device)

glycan_type = 'SMILES'

random_state = 42

split_mode = 'AND'
use_kfolds = False
k_folds = 0
val_split = 0.5
#device = 'cpu'

full_indicies, glycan_encodings, protein_encodings = prepare_train_val_datasets(fractions_df, glycans_df, proteins_df, glycan_encoder, protein_encoder, glycan_type, random_state, split_mode, use_kfolds, k_folds, val_split, device)

glycan encoding progress: 1/611
glycan encoding progress: 2/611
glycan encoding progress: 3/611
glycan encoding progress: 4/611
glycan encoding progress: 5/611
glycan encoding progress: 6/611
glycan encoding progress: 7/611
glycan encoding progress: 8/611
glycan encoding progress: 9/611
glycan encoding progress: 10/611
glycan encoding progress: 11/611
glycan encoding progress: 12/611
glycan encoding progress: 13/611
glycan encoding progress: 14/611
glycan encoding progress: 15/611
glycan encoding progress: 16/611
glycan encoding progress: 17/611
glycan encoding progress: 18/611
glycan encoding progress: 19/611
glycan encoding progress: 20/611
glycan encoding progress: 21/611
glycan encoding progress: 22/611
glycan encoding progress: 23/611
glycan encoding progress: 24/611
glycan encoding progress: 25/611
glycan encoding progress: 26/611
glycan encoding progress: 27/611
glycan encoding progress: 28/611
glycan encoding progress: 29/611
glycan encoding progress: 30/611
glycan encoding pro



In [None]:
## first task OUR task
TASK_ID = 0

glycan_mapping = {name: idx for idx, name in enumerate(glycans_df['Name'])}
protein_mapping = {name: idx for idx, name in enumerate(proteins_df['ProteinGroup'])}

train_idx, val_idx = full_indicies[0]

train_data = fractions_df.loc[train_idx]
val_data = fractions_df.loc[val_idx]

train_pytorch_dataset = GlycoProteinDataset(
    train_data, glycan_encodings, protein_encodings, glycan_mapping, protein_mapping, TASK_ID
)
val_pytorch_dataset = GlycoProteinDataset(
    val_data, glycan_encodings, protein_encodings, glycan_mapping, protein_mapping, TASK_ID
)

batch_size = 32

task469_train_loader = DataLoader(
    train_pytorch_dataset,
    batch_size=batch_size,
    shuffle=True,
)
task469_val_loader = DataLoader(
    val_pytorch_dataset,
    batch_size=batch_size,
    shuffle=True,
)

# GlycanML Data

In [13]:
# glycanML dataset

glyML_TASK_ID = 1

glyML_fractions_train_df = pd.read_csv('../pipeline/data/GlycanML/train_fractions.tsv', sep='\t')
glyML_fractions_test_df = pd.read_csv('../pipeline/data/GlycanML/test_fractions.tsv', sep='\t')
glyML_glycans_df = pd.read_csv('../pipeline/data/GlycanML/glycans.tsv', sep='\t')
glyML_proteins_df = pd.read_csv('../pipeline/data/GlycanML/proteins.tsv', nrows=600, sep='\t')

glyML_glycan_mapping = {name: idx for idx, name in enumerate(glyML_glycans_df['Name'])}
glyML_protein_mapping = {name: idx for idx, name in enumerate(glyML_proteins_df['ProteinGroup'])}

# dont need this I dont think as not doing train-val split
#train_data = fractions_df.loc[train_idx]
#val_data = fractions_df.loc[val_idx]
##@ filtyer for GPU
glyML_fractions_train_df = glyML_fractions_train_df[
    (glyML_fractions_train_df['GlycanID'].isin(glyML_glycans_df['Name'])) &
    (glyML_fractions_train_df['ProteinGroup'].isin(glyML_proteins_df['ProteinGroup']))
]

glyML_glycan_type = 'SMILES'

# use the same glycan and protein encoders as for MultiTask we share features and then have diff classif heads
glyML_glycan_encodings = glycan_encoder.encode_batch(glyML_glycans_df[glyML_glycan_type].tolist(), device)
glyML_protein_encodings = protein_encoder.encode_batch(glyML_proteins_df['Amino Acid Sequence'].tolist(), device)

glyML_train_pytorch_dataset = GlycoProteinDataset(
    glyML_fractions_train_df, glyML_glycan_encodings, glyML_protein_encodings, glyML_glycan_mapping, glyML_protein_mapping, glyML_TASK_ID
)
glyML_val_pytorch_dataset = GlycoProteinDataset(
    glyML_fractions_test_df, glyML_glycan_encodings, glyML_protein_encodings, glyML_glycan_mapping, glyML_protein_mapping, glyML_TASK_ID
)

glyML_batch_size = 32

glyML_train_loader = DataLoader(
    glyML_train_pytorch_dataset,
    batch_size=glyML_batch_size,
    shuffle=True,
)
glyML_val_loader = DataLoader(
    glyML_val_pytorch_dataset,
    batch_size=glyML_batch_size,
    shuffle=True,
)

glycan encoding progress: 1/417
glycan encoding progress: 2/417
glycan encoding progress: 3/417
glycan encoding progress: 4/417
glycan encoding progress: 5/417
glycan encoding progress: 6/417
glycan encoding progress: 7/417
glycan encoding progress: 8/417
glycan encoding progress: 9/417
glycan encoding progress: 10/417
glycan encoding progress: 11/417
glycan encoding progress: 12/417
glycan encoding progress: 13/417
glycan encoding progress: 14/417
glycan encoding progress: 15/417
glycan encoding progress: 16/417
glycan encoding progress: 17/417
glycan encoding progress: 18/417
glycan encoding progress: 19/417
glycan encoding progress: 20/417
glycan encoding progress: 21/417
glycan encoding progress: 22/417
glycan encoding progress: 23/417
glycan encoding progress: 24/417
glycan encoding progress: 25/417
glycan encoding progress: 26/417
glycan encoding progress: 27/417
glycan encoding progress: 28/417
glycan encoding progress: 29/417
glycan encoding progress: 30/417
glycan encoding pro

In [9]:
import pickle

In [11]:
# save so we dont have to spend 10 mins each time recalculating this
with open('glyML_train_dataset.pkl', 'wb') as f:
    pickle.dump(glyML_train_pytorch_dataset, f)

with open('glyML_val_dataset.pkl', 'wb') as f:
    pickle.dump(glyML_val_pytorch_dataset, f)

In [17]:
# run this to reload them
with open('glyML_train_dataset.pkl', 'rb') as f:
    glyML_train_pytorch_dataset = pickle.load(f)

with open('glyML_val_dataset.pkl', 'rb') as f:
    glyML_val_pytorch_dataset = pickle.load(f)

glyML_batch_size = 32

glyML_train_loader = DataLoader(
    glyML_train_pytorch_dataset,
    batch_size=glyML_batch_size,
    shuffle=True,
)

## Binding DB interaction prediction task

In [None]:
#def pytorch_dataset_n_dataloader(TASK_ID, train_df, glycans_df, proteins_df)

In [14]:
BDB_TASK_ID = 2

BDB_fractions_train_df = pd.read_csv('../pipeline/data/BindingDB/BDB_Train_Fractions.tsv', sep='\t')
#BDB_fractions_test_df = pd.read_csv('../pipeline/data/GlycanML/test_fractions.tsv', sep='\t')
BDB_glycans_df = pd.read_csv('../pipeline/data/BindingDB/BDB_Glycan-Structures-CFG611.txt', nrows=10_000, sep='\t')
BDB_proteins_df = pd.read_csv('../pipeline/data/BindingDB/BDB_Protein-Sequence-Table.txt', nrows=10_000, sep='\t')

BDB_glycan_mapping = {name: idx for idx, name in enumerate(BDB_glycans_df['Name'])}
BDB_protein_mapping = {name: idx for idx, name in enumerate(BDB_proteins_df['ProteinGroup'])}

# dont need this I dont think as not doing train-val split
#train_data = fractions_df.loc[train_idx]
#val_data = fractions_df.loc[val_idx]

BDB_fractions_train_df = BDB_fractions_train_df[
    (BDB_fractions_train_df['GlycanID'].isin(BDB_glycans_df['Name'])) &
    (BDB_fractions_train_df['ProteinGroup'].isin(BDB_proteins_df['ProteinGroup']))
]

BDB_glycan_type = 'SMILES'

# use the same glycan and protein encoders as for MultiTask we share features and then have diff classif heads
BDB_glycan_encodings = glycan_encoder.encode_batch(BDB_glycans_df[BDB_glycan_type].tolist(), device)
BDB_protein_encodings = protein_encoder.encode_batch(BDB_proteins_df['Amino Acid Sequence'].tolist(), device)

BDB_train_pytorch_dataset = GlycoProteinDataset(
    BDB_fractions_train_df, BDB_glycan_encodings, BDB_protein_encodings, BDB_glycan_mapping, BDB_protein_mapping, BDB_TASK_ID
)
#BDB_val_pytorch_dataset = GlycoProteinDataset(
    #BDB_fractions_test_df, BDB_glycan_encodings, BDB_protein_encodings, BDB_glycan_mapping, BDB_protein_mapping, BDB_TASK_ID
#)

BDB_batch_size = 32

BDB_train_loader = DataLoader(
    BDB_train_pytorch_dataset,
    batch_size=BDB_batch_size,
    shuffle=True,
)
#glyML_val_loader = DataLoader(
    #$glyML_val_pytorch_dataset,
    ##batch_size=glyML_batch_size,
    #shuffle=True,
#)

glycan encoding progress: 1/10000
glycan encoding progress: 2/10000
glycan encoding progress: 3/10000
glycan encoding progress: 4/10000
glycan encoding progress: 5/10000
glycan encoding progress: 6/10000
glycan encoding progress: 7/10000
glycan encoding progress: 8/10000
glycan encoding progress: 9/10000
glycan encoding progress: 10/10000
glycan encoding progress: 11/10000
glycan encoding progress: 12/10000
glycan encoding progress: 13/10000
glycan encoding progress: 14/10000
glycan encoding progress: 15/10000
glycan encoding progress: 16/10000
glycan encoding progress: 17/10000
glycan encoding progress: 18/10000
glycan encoding progress: 19/10000
glycan encoding progress: 20/10000
glycan encoding progress: 21/10000
glycan encoding progress: 22/10000
glycan encoding progress: 23/10000
glycan encoding progress: 24/10000
glycan encoding progress: 25/10000
glycan encoding progress: 26/10000
glycan encoding progress: 27/10000
glycan encoding progress: 28/10000
glycan encoding progress: 29/

: 

In [None]:
# save so we dont have to spend 10 mins each time recalculating this
with open('BDB_train_pytorch_dataset.pkl', 'wb') as f:
    pickle.dump(BDB_train_pytorch_dataset, f)


In [None]:
# run this to reload them
with open('BDB_train_pytorch_dataset.pkl', 'rb') as f:
    BDB_train_pytorch_dataset = pickle.load(f)


BDB_batch_size = 32

BDB_train_loader = DataLoader(
    BDB_train_pytorch_dataset,
    batch_size=BDB_batch_size,
    shuffle=True,
)

## MultiTask Classifier network

In [14]:


class MultiTask_Network(nn.Module):
    def __init__(self, 
                 input_dim: int,
                 task_output_dims: Dict[int, int],  # Key: task_id, Value: output_dim
                 hidden_dims: List[int] = [256, 128, 64]):  # DNN hidden layer sizes tried: #[256, 128, 128, 64, 32]
        super(MultiTask_Network, self).__init__()

        self.input_dim = input_dim
        self.task_output_dims = task_output_dims
        self.hidden_dims = hidden_dims

        # Shared DNN Layers (based on your DNNBindingPredictor)
        dnn_layers = []
        for i, hidden_dim in enumerate(hidden_dims):
            dnn_layers.append(nn.Linear(input_dim if i == 0 else hidden_dims[i - 1], hidden_dim))
            dnn_layers.append(nn.ReLU())
            dnn_layers.append(nn.BatchNorm1d(hidden_dim))
            dnn_layers.append(nn.Dropout(0.4))
        self.dnn = nn.Sequential(*dnn_layers)  # Store it as self.dnn

        # Task-Specific Output Layers (store in a dictionary)
        self.final_layers = nn.ModuleDict({
            str(task_id): nn.Linear(hidden_dims[-1], output_dim) # Use hidden_dims[-1] as input
            for task_id, output_dim in task_output_dims.items()
        })


    def forward(self, x: torch.Tensor, task_id: int):
        """
        Forward pass.
        Args:
            x: Input tensor (concatenated glycan/protein embeddings)
            task_id: Integer identifying the task.
        Returns:
            Output tensor for the specified task.
        """

        # Pass through shared DNN layers
        x = self.dnn(x) #  Pass input through the DNN

        # Task-specific output layer
        task_id_str = str(task_id) # crucial to make it a string
        if task_id_str in self.final_layers:
            x = self.final_layers[task_id_str](x) # Apply the output layer
        else:
            raise ValueError(f"Invalid task_id: {task_id}.  Available task_ids are {self.final_layers.keys()}")

        return x

In [28]:
def calculate_metrics(predictions: torch.Tensor, targets: torch.Tensor) -> Dict[str, float]:
    """
    Calculate training/validation metrics
    
    Args:
        predictions (torch.Tensor): Model predictions
        targets (torch.Tensor): True values
        
    Returns:
        Dict[str, float]: Dictionary of metric names and values
    """
    # convert values to numpy arrays
    preds_np = predictions.detach().cpu().numpy()
    targets_np = targets.detach().cpu().numpy()
    
    mse = np.mean((preds_np - targets_np) ** 2)
    pearson_corr, _ = pearsonr(preds_np.flatten(), targets_np.flatten())
    
    return {
        'mse': float(mse),
        'pearson': float(pearson_corr)
    }

## Setup training params and data

In [None]:


input_dim = glycan_encoder.embedding_dim + protein_encoder.embedding_dim + 1  # Input dimension (glycan + protein + concentration)
print('input dim:', input_dim)
# Define output dimensions for each task
task_output_dims = {0: 1, 1: 1, 2: 1}  # Task 0 (469): 1 class regression, Task 1 (glycanML): 1 class regression, Task 2: binding BD

# Create the multi-task network
model = MultiTask_Network(input_dim=input_dim,
                          task_output_dims=task_output_dims,
                          hidden_dims=[256, 128, 64]).to(device)

# Example forward pass for task 0
#input_tensor = torch.randn(32, input_dim)  # Example input
#task_id = 1
#output = model(input_tensor, task_id) # Pass the input and the task id to the model
#print(f"Output for task {task_id}: {output.shape}")

# Loss and activation functions for each task
loss_functions = {0: nn.MSELoss(), 1: nn.MSELoss(), 2: nn.MSELoss()}   

#nn.Sigmoid()
activation_functions = {0: None, 1: None, 2:None} # Task 1 & 2: No activation (regression)

loss_weights = {0: 1.0, 1: 1.0, 2: 1.0}

learning_rate = 0.001

optimizer = optim.Adam(
    list(glycan_encoder.parameters()) +
    list(protein_encoder.parameters()) +
    list(model.parameters()),
    lr=learning_rate
)


# handles batch interleaving to make sure we dont commit CATASTROPHIC FORGETTING
train_loader = CombinedLoader(
    {"glyML": glyML_train_loader, "task469": task469_train_loader, "BindingDB": BDB_train_loader},
    mode="max_size_cycle"  # Oversample the smaller dataset  # "min_size"
)

# Dont need batch interleaving for validation here as not training
#val_loader = CombinedLoader(
    #{"glyML": glyML_val_loader, "task469": task469_val_loader},
    #mode="max_size_cycle"  # Also use max_size_cycle for validation
#)
    

input dim: 513


## Define train and validation functions

In [24]:
def train_epoch(train_loader, model, glycan_encoder, protein_encoder, optimizer, 
                loss_functions, loss_weights, device):
    # Set to training mode
    glycan_encoder.train()
    protein_encoder.train()
    model.train()
    
    # Setup metrics tracking
    task_metrics = {
        "glyML": {"total_loss": 0, "predictions": [], "targets": []},
        "task469": {"total_loss": 0, "predictions": [], "targets": []},
        "BindingDB": {"total_loss": 0, "predictions": [], "targets": []}
    }
    
    
    
    # Initialize the iterator for the combined loader
    train_iterator = iter(train_loader)
    
    # Get the total number of batches (now that iterator is initialized)
    total_batches = len(train_loader)
    
    # Progress bar for the combined loader
    pbar = tqdm(range(total_batches), desc='Training')
    
    # Iterate through batches from the combined loader
    for _ in pbar:

        batch_dict, batch_idx, dataloader_idx = next(train_iterator)
        all_tasks_loss = 0
        batch_losses = {}
        
        # Process each task's batch
        for task_name, batch in batch_dict.items():
            # Get task ID (0 for task469, 1 for glyML)
            task_id = batch["task_id"]
            # convert that baby to int
            task_id = task_id[0].item()
            
            # Move data to device
            glycan_encoding = batch["glycan_encoding"].to(device)
            protein_encoding = batch["protein_encoding"].to(device)
            concentration = batch["concentration"].to(device)
            targets = batch["target"].to(device)
            
            ## Apply log transform if configured
            #if config.log_predict:
                #targets = torch.log(targets + 1e-6)
            
            
            # Concatenate inputs for the multi-task model
            combined_input = torch.cat([glycan_encoding, protein_encoding, concentration], dim=-1)
            
            # Forward pass through the multi-task model
            predictions = model(combined_input, task_id)
            
            # apply activation func if not None
            if activation_functions[task_id] is not None:
                predictions = activation_functions[task_id](predictions)
                
            # Calculate loss
            loss_fn = loss_functions[task_id]
            loss = loss_fn(predictions, targets) * loss_weights[task_id]
            
            
        
            all_tasks_loss += loss
            batch_losses[task_name] = loss.item()
            
            # Revert log transform for metrics calculation
            #if config.log_predict:
                #predictions = torch.exp(predictions) - 1e-6
                #targets = torch.exp(targets) - 1e-6
            
            
            # Store predictions and targets for metrics calculation
            task_metrics[task_name]["predictions"].append(predictions.detach())
            task_metrics[task_name]["targets"].append(targets.detach())
            task_metrics[task_name]["total_loss"] += loss.item()
        
        # Optimization step for all tasks together
        optimizer.zero_grad()
        all_tasks_loss.backward()
        optimizer.step()
        
        # Update progress bar with current losses
        pbar.set_postfix({
            f"{task}_loss": f"{loss:.4f}" for task, loss in batch_losses.items()
        })
    
    # Calculate final metrics for each task
    final_metrics = {}
    for task_name, data in task_metrics.items():
        if data["predictions"]:  # Check if we have any predictions
            epoch_predictions = torch.cat(data["predictions"])
            epoch_targets = torch.cat(data["targets"])
            task_result = calculate_metrics(epoch_predictions, epoch_targets)
            task_result["loss"] = data["total_loss"]
            final_metrics[task_name] = task_result
    
    return final_metrics

# Define validation epoch function
def validate_epoch(val_loader, model, glycan_encoder, protein_encoder, 
                  loss_functions, loss_weights, device):
    # Set to evaluation mode
    glycan_encoder.eval()
    protein_encoder.eval()
    model.eval()
    
    # Setup metrics tracking
    task_metrics = {
        "glyML": {"total_loss": 0, "predictions": [], "targets": []},
        "task469": {"total_loss": 0, "predictions": [], "targets": []}
    }
    
    # Initialize the iterator for the combined loader
    val_iterator = iter(val_loader)
    
    # Get the total number of batches (now that iterator is initialized)
    total_batches = len(val_loader)
    
    # Progress bar for the combined loader
    pbar = tqdm(range(total_batches), desc='Validating')
    
    
    #with torch.no_grad():
    # Iterate through batches from the combined loader
    for _ in pbar:
        batch_dict, batch_idx, dataloader_idx = next(val_iterator)
        batch_losses = {}
        
        # Process each task's batch
        for task_name, batch in batch_dict.items():
            # Get task ID
            task_id = batch["task_id"]
            
            task_id = task_id[0].item()
            
            # Move data to device
            glycan_encoding = batch["glycan_encoding"].to(device)
            protein_encoding = batch["protein_encoding"].to(device)
            concentration = batch["concentration"].to(device)
            targets = batch["target"].to(device)
            
            # Apply log transform if configured
            #if config.log_predict:
                #targets = torch.log(targets + 1e-6)
            

            
            # Concatenate inputs for the multi-task model
            combined_input = torch.cat([glycan_encoding, protein_encoding, concentration], dim=-1)
            
            # Forward pass through the multi-task model
            predictions = model(combined_input, task_id)
            
            # apply activation func if not None
            if activation_functions[task_id] is not None:
                predictions = activation_functions[task_id](predictions)
                
            # Calculate loss
            loss_fn = loss_functions[task_id]
            loss = loss_fn(predictions, targets) * loss_weights[task_id]
            
            # Revert log transform for metrics calculation
            #if config.log_predict:
                #predictions = torch.exp(predictions) - 1e-6
                #targets = torch.exp(targets) - 1e-6

            # Store predictions and targets for metrics calculation
            task_metrics[task_name]["predictions"].append(predictions.detach())
            task_metrics[task_name]["targets"].append(targets.detach())
            task_metrics[task_name]["total_loss"] += loss.item()
        
        # Update progress bar with current losses
        pbar.set_postfix({
            f"{task}_loss": f"{loss:.4f}" for task, loss in batch_losses.items()
        })
    
    # Calculate final metrics for each task
    final_metrics = {}
    for task_name, data in task_metrics.items():
        if data["predictions"]:  # Check if we have any predictions
            epoch_predictions = torch.cat(data["predictions"])
            epoch_targets = torch.cat(data["targets"])
            task_result = calculate_metrics(epoch_predictions, epoch_targets)
            task_result["loss"] = data["total_loss"]
            final_metrics[task_name] = task_result
    
    return final_metrics

In [20]:
def _validate(val_loader: DataLoader, task_id: int, device) -> Dict[str, float]:
        # Set to evaluation mode
    glycan_encoder.eval()
    protein_encoder.eval()
    model.eval()
    
    total_loss = 0
    all_predictions = []
    all_targets = []
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validating')
        for batch in pbar:
            glycan_encoding = batch['glycan_encoding'].to(device)
            protein_encoding = batch['protein_encoding'].to(device)
            concentration = batch['concentration'].to(device)
            targets = batch['target'].to(device)
            
            #if self.config.log_predict:
                #targets = torch.log(targets + 1e-6) #torch.log1p(targets)
            
            # Concatenate inputs for the multi-task model
            combined_input = torch.cat([glycan_encoding, protein_encoding, concentration], dim=-1)
            
            # Forward pass through the multi-task model
            predictions = model(combined_input, task_id)
            
            # apply activation func if not None
            if activation_functions[task_id] is not None:
                predictions = activation_functions[task_id](predictions)
                
            # Calculate loss
            loss_fn = loss_functions[task_id]
            loss = loss_fn(predictions, targets) * loss_weights[task_id]
            
            # track totals
            total_loss += loss.item()
            all_predictions.append(predictions)
            all_targets.append(targets)
            
            # Update progress bar with current loss
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    # save metrics
    val_predictions = torch.cat(all_predictions)
    val_targets = torch.cat(all_targets)
    metrics = calculate_metrics(val_predictions, val_targets)
    metrics['loss'] = total_loss #/ len(val_loader)
    
    return metrics

## TRAIN

In [29]:
epochs = 3
for epoch in range(epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    
    # Training phase
    train_metrics = train_epoch(
        train_loader, model, glycan_encoder, protein_encoder, optimizer, 
        loss_functions, loss_weights, device
    )
    

    
    
    # Validation phase
    #val_metrics = validate_epoch(
        #val_loader, model, glycan_encoder, protein_encoder, 
        #loss_functions, loss_weights, device
    #)
    
    val_metr = _validate(task469_val_loader, task_id=0, device=device)
    
    
    print(f"Training Results:")
    for task_name, metrics in train_metrics.items():
        metrics_str = ", ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
        print(f"  {task_name}: {metrics_str}")
    print('\n')
    
    
    
    print(val_metr)
    
    
    #print(f"Validation Results:")
    #for task_name, metrics in val_metrics.items():
        #metrics_str = ", ".join([f"{k}: {v:.4f}" for k, v in metrics.items()])
        #print(f"  {task_name}: {metrics_str}")



Epoch 1/3


Training: 100%|██████████| 125/125 [00:03<00:00, 41.48it/s, glyML_loss=0.1116, task469_loss=0.0154, BindingDB_loss=0.2256]
Validating: 100%|██████████| 252/252 [00:00<00:00, 460.26it/s, loss=0.0083]


Training Results:
  glyML: mse: 0.9038, pearson: 0.0122, loss: 112.9788
  task469: mse: 0.0448, pearson: -0.0087, loss: 5.5969
  BindingDB: mse: 0.2516, pearson: 0.2931, loss: 31.4534


{'mse': 0.006934016942977905, 'pearson': 0.02703520655632019, 'loss': 1.7478027993347496}
Epoch 2/3


Training: 100%|██████████| 125/125 [00:02<00:00, 60.43it/s, glyML_loss=0.5667, task469_loss=0.0038, BindingDB_loss=0.1371]
Validating: 100%|██████████| 252/252 [00:01<00:00, 193.26it/s, loss=0.0002]


Training Results:
  glyML: mse: 0.8235, pearson: 0.0076, loss: 102.9425
  task469: mse: 0.0131, pearson: 0.0532, loss: 1.6408
  BindingDB: mse: 0.1823, pearson: 0.3373, loss: 22.7839


{'mse': 0.004997290670871735, 'pearson': 0.08880670368671417, 'loss': 1.2578277978755068}
Epoch 3/3


Training: 100%|██████████| 125/125 [00:02<00:00, 43.62it/s, glyML_loss=0.3614, task469_loss=0.0009, BindingDB_loss=0.1893]
Validating: 100%|██████████| 252/252 [00:00<00:00, 421.43it/s, loss=0.0100]


Training Results:
  glyML: mse: 0.6901, pearson: 0.0031, loss: 86.2608
  task469: mse: 0.0071, pearson: 0.0255, loss: 0.8864
  BindingDB: mse: 0.1577, pearson: 0.3626, loss: 19.7086


{'mse': 0.00487586110830307, 'pearson': 0.051806166768074036, 'loss': 1.2303182570758509}
