<a href="https://colab.research.google.com/github/BtissamBalmane/FM_FunPDBe/blob/master/EnzymeGNN_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
# Install required packages
!pip install torch torchvision torchaudio
!pip install torch-geometric
!pip install matplotlib numpy pandas scikit-learn
!pip install biopython
!pip install networkx


Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [5]:
# Mount Google Drive to store data and results
from google.colab import drive
drive.mount('/content/drive')

# Create directories for the project
!mkdir -p /content/drive/MyDrive/EnzymeGNN/data
!mkdir -p /content/drive/MyDrive/EnzymeGNN/models
!mkdir -p /content/drive/MyDrive/EnzymeGNN/results
!mkdir -p /content/drive/MyDrive/EnzymeGNN/visualizations


Mounted at /content/drive


In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GeometricMessagePassing(nn.Module):
    """
    Geometric message passing layer that explicitly encodes distances and angles.
    """
    def __init__(self, in_channels, out_channels, edge_dim=None):
        super(GeometricMessagePassing, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        # Message function
        self.message_fn = nn.Sequential(
            nn.Linear(in_channels * 2 + (edge_dim or 0), out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )

        # Update function
        self.update_fn = nn.Sequential(
            nn.Linear(in_channels + out_channels, out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels)
        )

        # Distance encoding
        self.distance_encoder = nn.Sequential(
            nn.Linear(1, 16),
            nn.ReLU(),
            nn.Linear(16, 32)
        )

        # Angle encoding (if available)
        self.angle_encoder = nn.Sequential(
            nn.Linear(2, 16),  # sin and cos of angle
            nn.ReLU(),
            nn.Linear(16, 32)
        )

    def forward(self, x, edge_index, edge_attr=None, angles=None):
        """
        Forward pass of the geometric message passing layer.
        """
        # Extract source and target nodes
        src, dst = edge_index

        # Compute messages
        src_features = x[src]
        dst_features = x[dst]

        # Concatenate source and destination features
        if edge_attr is not None:
            message_inputs = torch.cat([src_features, dst_features, edge_attr], dim=1)
        else:
            message_inputs = torch.cat([src_features, dst_features], dim=1)

        # Apply message function
        messages = self.message_fn(message_inputs)

        # Aggregate messages (sum)
        aggr_messages = torch.zeros_like(x[:, :self.out_channels])
        aggr_messages.index_add_(0, dst, messages)

        # Update node features
        update_inputs = torch.cat([x, aggr_messages], dim=1)
        updated_features = self.update_fn(update_inputs)

        return updated_features


class ActiveSiteAttention(nn.Module):
    """
    Attention mechanism that focuses on functionally important regions (active sites).
    """
    def __init__(self, in_channels, hidden_dim=64):
        super(ActiveSiteAttention, self).__init__()

        # Active site score prediction
        self.active_site_predictor = nn.Sequential(
            nn.Linear(in_channels, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

        # Attention network
        self.attention_network = nn.Sequential(
            nn.Linear(in_channels * 2 + 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, edge_index):
        """
        Forward pass of the active site attention mechanism.
        """
        # Predict active site scores
        active_site_scores = self.active_site_predictor(x)

        # Extract source and target nodes
        src, dst = edge_index

        # Compute attention inputs
        src_features = x[src]
        dst_features = x[dst]
        src_scores = active_site_scores[src]
        dst_scores = active_site_scores[dst]

        # Concatenate features and scores
        attention_inputs = torch.cat([src_features, dst_features, src_scores, dst_scores], dim=1)

        # Compute attention weights
        attention_logits = self.attention_network(attention_inputs)

        # Apply softmax over destination nodes
        attention_weights = []
        for i in range(x.size(0)):
            mask = (dst == i)
            if mask.sum() > 0:
                node_logits = attention_logits[mask]
                node_weights = F.softmax(node_logits, dim=0)
                attention_weights.append(node_weights)

        return torch.cat(attention_weights, dim=0), active_site_scores


class SequenceStructureFusion(nn.Module):
    """
    Bidirectional fusion module for integrating sequence and structure information.
    """
    def __init__(self, seq_dim, struct_dim, fusion_dim):
        super(SequenceStructureFusion, self).__init__()

        # Sequence to structure attention
        self.seq_to_struct_attn = nn.MultiheadAttention(
            embed_dim=seq_dim,
            num_heads=4,
            batch_first=True
        )

        # Structure to sequence attention
        self.struct_to_seq_attn = nn.MultiheadAttention(
            embed_dim=struct_dim,
            num_heads=4,
            batch_first=True
        )

        # Gating network
        self.gate = nn.Sequential(
            nn.Linear(seq_dim + struct_dim, fusion_dim),
            nn.Sigmoid()
        )

        # Output projection
        self.output_proj = nn.Linear(fusion_dim, fusion_dim)

    def forward(self, seq_features, struct_features):
        """
        Forward pass of the sequence-structure fusion module.
        """
        # Sequence to structure attention
        seq_to_struct, _ = self.seq_to_struct_attn(
            query=seq_features,
            key=struct_features,
            value=struct_features
        )

        # Structure to sequence attention
        struct_to_seq, _ = self.struct_to_seq_attn(
            query=struct_features,
            key=seq_features,
            value=seq_features
        )

        # Global pooling
        seq_to_struct_global = seq_to_struct.mean(dim=1)
        struct_to_seq_global = struct_to_seq.mean(dim=1)

        # Compute gating values
        gate_input = torch.cat([seq_to_struct_global, struct_to_seq_global], dim=1)
        gate_values = self.gate(gate_input)

        # Weighted combination
        fused_features = gate_values * seq_to_struct_global + (1 - gate_values) * struct_to_seq_global

        # Final projection
        output = self.output_proj(fused_features)

        return output


class HierarchicalPrediction(nn.Module):
    """
    Hierarchical prediction framework for EC number classification.
    """
    def __init__(self, input_dim, hidden_dim=128, num_ec_level1=6, num_ec_level2=65,
                 num_ec_level3=50, num_ec_level4=207):
        super(HierarchicalPrediction, self).__init__()

        # Level 1 prediction
        self.level1_predictor = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, num_ec_level1)
        )

        # Level 2 prediction (conditioned on level 1)
        self.level2_predictor = nn.Sequential(
            nn.Linear(input_dim + num_ec_level1, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, num_ec_level2)
        )

        # Level 3 prediction (conditioned on levels 1 and 2)
        self.level3_predictor = nn.Sequential(
            nn.Linear(input_dim + num_ec_level1 + num_ec_level2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, num_ec_level3)
        )

        # Level 4 prediction (conditioned on levels 1, 2, and 3)
        self.level4_predictor = nn.Sequential(
            nn.Linear(input_dim + num_ec_level1 + num_ec_level2 + num_ec_level3, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, num_ec_level4)
        )

        # Thermostability prediction
        self.thermo_predictor = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x):
        """
        Forward pass of the hierarchical prediction framework.
        """
        # Level 1 prediction
        level1_logits = self.level1_predictor(x)
        level1_probs = F.softmax(level1_logits, dim=1)

        # Level 2 prediction (conditioned on level 1)
        level2_input = torch.cat([x, level1_probs], dim=1)
        level2_logits = self.level2_predictor(level2_input)
        level2_probs = F.softmax(level2_logits, dim=1)

        # Level 3 prediction (conditioned on levels 1 and 2)
        level3_input = torch.cat([x, level1_probs, level2_probs], dim=1)
        level3_logits = self.level3_predictor(level3_input)
        level3_probs = F.softmax(level3_logits, dim=1)

        # Level 4 prediction (conditioned on levels 1, 2, and 3)
        level4_input = torch.cat([x, level1_probs, level2_probs, level3_probs], dim=1)
        level4_logits = self.level4_predictor(level4_input)
        level4_probs = F.softmax(level4_logits, dim=1)

        # Thermostability prediction
        thermo_pred = self.thermo_predictor(x)

        return {
            'level1': level1_logits,
            'level2': level2_logits,
            'level3': level3_logits,
            'level4': level4_logits,
            'thermo': thermo_pred
        }

class EnzymeGNN(nn.Module):
    """
    EnzymeGNN: A multi-scale geometric graph neural network for enzyme function prediction.
    """
    def __init__(self, atom_feature_dim=74, residue_feature_dim=54, hidden_dim=128,
                 num_layers=6, fusion_dim=256):
        super(EnzymeGNN, self).__init__()

        # Atom-level embedding
        self.atom_embedding = nn.Linear(atom_feature_dim, hidden_dim)

        # Residue-level embedding
        self.residue_embedding = nn.Linear(residue_feature_dim, hidden_dim)

        # Atom-level message passing layers
        self.atom_layers = nn.ModuleList([
            GeometricMessagePassing(
                in_channels=hidden_dim,
                out_channels=hidden_dim,
                edge_dim=32  # Distance + angle features
            ) for _ in range(num_layers)
        ])

        # Residue-level message passing layers
        self.residue_layers = nn.ModuleList([
            GeometricMessagePassing(
                in_channels=hidden_dim,
                out_channels=hidden_dim,
                edge_dim=32  # Distance + angle features
            ) for _ in range(num_layers)
        ])

        # Active site attention mechanism
        self.active_site_attention = ActiveSiteAttention(
            in_channels=hidden_dim
        )

        # Atom to residue pooling
        self.atom_to_residue_pool = nn.Linear(hidden_dim, hidden_dim)

        # Residue to atom unpooling
        self.residue_to_atom_unpool = nn.Linear(hidden_dim, hidden_dim)

        # Sequence-structure fusion
        self.seq_struct_fusion = SequenceStructureFusion(
            seq_dim=hidden_dim,
            struct_dim=hidden_dim,
            fusion_dim=fusion_dim
        )

        # Hierarchical prediction
        self.hierarchical_prediction = HierarchicalPrediction(
            input_dim=fusion_dim
        )

        # Readout function (global pooling)
        self.readout = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, data):
        """
        Forward pass of EnzymeGNN.
        """
        # Extract input data
        atom_features = data['atom_features']
        atom_edge_index = data['atom_edge_index']
        atom_edge_attr = data.get('atom_edge_attr')

        residue_features = data['residue_features']
        residue_edge_index = data['residue_edge_index']
        residue_edge_attr = data.get('residue_edge_attr')

        atom_to_residue_mapping = data['atom_to_residue_mapping']

        # Handle sequence_features which might be a list
        sequence_features = data['sequence_features']
        if isinstance(sequence_features, list):
            # Convert list to tensor if needed
            if len(sequence_features) > 0 and isinstance(sequence_features[0], torch.Tensor):
                sequence_features = torch.stack(sequence_features)
            else:
                # Create a dummy tensor if sequence_features can't be converted
                sequence_features = torch.zeros((1, residue_features.size(0), 20), device=residue_features.device)

        # Debug prints
        print("Inside model - atom_features shape:", atom_features.shape)
        print("Inside model - residue_features shape:", residue_features.shape)
        print("Inside model - sequence_features shape:",
              sequence_features.shape if isinstance(sequence_features, torch.Tensor) else "not a tensor")

        # Initial embeddings
        atom_h = self.atom_embedding(atom_features)
        residue_h = self.residue_embedding(residue_features)

        # Atom-level message passing
        for layer in self.atom_layers:
            atom_h = layer(atom_h, atom_edge_index, atom_edge_attr)

        # Apply active site attention at atom level
        atom_attention_weights, atom_active_site_scores = self.active_site_attention(
            atom_h, atom_edge_index
        )

        # Pool atom features to residue level
        pooled_atom_h = torch.zeros_like(residue_h)
        for atom_idx, residue_idx in enumerate(atom_to_residue_mapping):
            pooled_atom_h[residue_idx] += atom_h[atom_idx]

        # Combine pooled atom features with residue features
        residue_h = residue_h + self.atom_to_residue_pool(pooled_atom_h)

        # Residue-level message passing
        for layer in self.residue_layers:
            residue_h = layer(residue_h, residue_edge_index, residue_edge_attr)

        # Apply active site attention at residue level
        residue_attention_weights, residue_active_site_scores = self.active_site_attention(
            residue_h, residue_edge_index
        )

        # Unpool residue features to atom level
        unpooled_residue_h = self.residue_to_atom_unpool(residue_h[atom_to_residue_mapping])

        # Combine unpooled residue features with atom features
        atom_h = atom_h + unpooled_residue_h

        # Global pooling for structure representation - simplify for debugging
        structure_features = torch.mean(residue_h, dim=0, keepdim=True).unsqueeze(0)  # [1, 1, hidden_dim]

        # For debugging, use a simplified approach
        # Skip the sequence-structure fusion and use only structure features
        fused_features = structure_features.squeeze(0)  # [1, hidden_dim]

        # Debug prints for fusion
        print("Structure features shape:", structure_features.shape)
        print("Fused features shape:", fused_features.shape)

        # Hierarchical prediction
        predictions = self.hierarchical_prediction(fused_features)

        # Add attention weights and active site scores to predictions
        predictions['atom_attention'] = atom_attention_weights
        predictions['residue_attention'] = residue_attention_weights
        predictions['atom_active_sites'] = atom_active_site_scores
        predictions['residue_active_sites'] = residue_active_site_scores

        return predictions



# Loss function for multi-task learning
class EnzymeGNNLoss(nn.Module):
    """
    Multi-task loss function for EnzymeGNN.
    """
    def __init__(self, ec_weights=[1.0, 1.0, 1.0, 1.0], thermo_weight=1.0, active_site_weight=0.5):
        super(EnzymeGNNLoss, self).__init__()
        self.ec_weights = ec_weights
        self.thermo_weight = thermo_weight
        self.active_site_weight = active_site_weight

        # Loss functions
        self.ec_loss_fn = nn.CrossEntropyLoss()
        self.thermo_loss_fn = nn.MSELoss()
        self.active_site_loss_fn = nn.BCELoss()

    def forward(self, predictions, targets):
        """
        Compute the multi-task loss.
        """
        # EC number classification losses
        ec_level1_loss = self.ec_loss_fn(predictions['level1'], targets['level1'])
        ec_level2_loss = self.ec_loss_fn(predictions['level2'], targets['level2'])
        ec_level3_loss = self.ec_loss_fn(predictions['level3'], targets['level3'])
        ec_level4_loss = self.ec_loss_fn(predictions['level4'], targets['level4'])

        # Thermostability regression loss
        thermo_loss = self.thermo_loss_fn(predictions['thermo'], targets['thermo'])

        # Active site prediction loss (if ground truth is available)
        active_site_loss = 0.0
        if 'active_sites' in targets:
            atom_active_site_loss = self.active_site_loss_fn(
                predictions['atom_active_sites'],
                targets['active_sites'][targets['atom_to_residue_mapping']]
            )
            residue_active_site_loss = self.active_site_loss_fn(
                predictions['residue_active_sites'],
                targets['active_sites']
            )
            active_site_loss = atom_active_site_loss + residue_active_site_loss

        # Weighted sum of losses
        total_loss = (
            self.ec_weights[0] * ec_level1_loss +
            self.ec_weights[1] * ec_level2_loss +
            self.ec_weights[2] * ec_level3_loss +
            self.ec_weights[3] * ec_level4_loss +
            self.thermo_weight * thermo_loss +
            self.active_site_weight * active_site_loss
        )

        # Return total loss and individual components
        return {
            'total': total_loss,
            'ec_level1': ec_level1_loss,
            'ec_level2': ec_level2_loss,
            'ec_level3': ec_level3_loss,
            'ec_level4': ec_level4_loss,
            'thermo': thermo_loss,
            'active_site': active_site_loss
        }


In [16]:
import torch
import numpy as np
from Bio.PDB import PDBParser, Selection
import os
import pandas as pd

class ProteinGraphDataset:
    """
    Dataset for processing protein structures into multi-scale graphs.
    """
    def __init__(self, pdb_dir, csv_file=None, transform=None):
        """
        Initialize the dataset.
        """
        self.pdb_dir = pdb_dir
        self.transform = transform

        # Load metadata if available
        if csv_file is not None and os.path.exists(csv_file):
            self.metadata = pd.read_csv(csv_file)
            self.pdb_files = [f"{id}.pdb" for id in self.metadata['protein_id']]
        else:
            self.metadata = None
            self.pdb_files = [f for f in os.listdir(pdb_dir) if f.endswith('.pdb')]

        # PDB parser
        self.parser = PDBParser(QUIET=True)

    def __len__(self):
        return len(self.pdb_files)

    def __getitem__(self, idx):
        # Get PDB file path
        pdb_file = os.path.join(self.pdb_dir, self.pdb_files[idx])

        # Get metadata if available
        if self.metadata is not None:
            protein_id = self.metadata.iloc[idx]['protein_id']
            ec_numbers = self.metadata.iloc[idx]['ec_numbers'].split('.')
            thermostability = self.metadata.iloc[idx]['thermostability']
            sequence = self.metadata.iloc[idx]['sequence']
        else:
            protein_id = self.pdb_files[idx].split('.')[0]
            ec_numbers = None
            thermostability = None
            sequence = None

        # Process PDB file
        data = self.process_pdb(pdb_file, protein_id, sequence)

        # Add targets if available
        if ec_numbers is not None:
            data['targets'] = {
                'level1': torch.tensor(int(ec_numbers[0])),
                'level2': torch.tensor(int(ec_numbers[1])),
                'level3': torch.tensor(int(ec_numbers[2])),
                'level4': torch.tensor(int(ec_numbers[3])),
                'thermo': torch.tensor(thermostability, dtype=torch.float)
            }

        # Apply transform if available
        if self.transform is not None:
            data = self.transform(data)

        return data

    def process_pdb(self, pdb_file, protein_id, sequence=None):
        """
        Process a PDB file into a multi-scale graph.
        """
        # Parse PDB file
        structure = self.parser.get_structure(protein_id, pdb_file)
        model = structure[0]  # First model

        # Extract atoms and residues
        atoms = list(model.get_atoms())
        residues = list(model.get_residues())

        # Compute atom features
        atom_features = []
        atom_positions = []
        atom_to_residue_mapping = []

        for atom_idx, atom in enumerate(atoms):
            # Get atom features
            atom_feat = self.get_atom_features(atom)
            atom_features.append(atom_feat)

            # Get atom position
            atom_pos = atom.get_coord()
            atom_positions.append(atom_pos)

            # Map atom to residue
            residue_idx = residues.index(atom.get_parent())
            atom_to_residue_mapping.append(residue_idx)

        # Convert to tensors
        atom_features = torch.tensor(np.array(atom_features), dtype=torch.float)
        atom_positions = torch.tensor(np.array(atom_positions), dtype=torch.float)
        atom_to_residue_mapping = torch.tensor(atom_to_residue_mapping, dtype=torch.long)

        # Compute residue features
        residue_features = []
        residue_positions = []

        for residue_idx, residue in enumerate(residues):
            # Get residue features
            residue_feat = self.get_residue_features(residue)
            residue_features.append(residue_feat)

            # Get residue position (CA atom)
            ca_atom = residue['CA'] if 'CA' in residue else next(residue.get_atoms())
            residue_pos = ca_atom.get_coord()
            residue_positions.append(residue_pos)

        # Convert to tensors
        residue_features = torch.tensor(np.array(residue_features), dtype=torch.float)
        residue_positions = torch.tensor(np.array(residue_positions), dtype=torch.float)

        # Compute atom-level graph connectivity
        atom_edge_index, atom_edge_attr = self.compute_graph_connectivity(
            atom_positions, cutoff=4.5
        )

        # Compute residue-level graph connectivity
        residue_edge_index, residue_edge_attr = self.compute_graph_connectivity(
            residue_positions, cutoff=10.0
        )

        # Get sequence features if available
        if sequence is not None:
            sequence_features = self.get_sequence_features(sequence)
        else:
            # Extract sequence from residues
            sequence = ''.join([self.get_residue_code(r) for r in residues])
            sequence_features = self.get_sequence_features(sequence)

        # Create data dictionary
        data = {
            'protein_id': protein_id,
            'atom_features': atom_features,
            'atom_positions': atom_positions,
            'atom_edge_index': atom_edge_index,
            'atom_edge_attr': atom_edge_attr,
            'residue_features': residue_features,
            'residue_positions': residue_positions,
            'residue_edge_index': residue_edge_index,
            'residue_edge_attr': residue_edge_attr,
            'atom_to_residue_mapping': atom_to_residue_mapping,
            'sequence': sequence,
            'sequence_features': sequence_features
        }

        return data

    def get_atom_features(self, atom):
        """
        Extract features for an atom.
        """
        # Atom type (one-hot encoding)
        atom_types = ['C', 'N', 'O', 'S', 'P', 'H', 'F', 'Cl', 'Br', 'I']
        atom_type_onehot = [1 if atom.element == t else 0 for t in atom_types]

        # Hybridization (estimated from atom name and residue)
        hybridization = self.estimate_hybridization(atom)

        # Partial charge (estimated)
        partial_charge = self.estimate_partial_charge(atom)

        # Is in aromatic ring
        is_aromatic = self.is_in_aromatic_ring(atom)

        # Is backbone atom
        is_backbone = atom.name in ['N', 'CA', 'C', 'O']

        # Combine features
        features = atom_type_onehot + [hybridization, partial_charge, is_aromatic, is_backbone]

        return features

    def estimate_hybridization(self, atom):
        """Estimate atom hybridization (simplified)"""
        if atom.element == 'C':
            # Count bonds (simplified)
            parent = atom.get_parent()
            neighbors = [a for a in parent.get_atoms() if a != atom]
            num_neighbors = len([n for n in neighbors if self.is_bonded(atom, n)])

            if num_neighbors == 4:
                return 0  # sp3
            elif num_neighbors == 3:
                return 1  # sp2
            elif num_neighbors == 2:
                return 2  # sp
            else:
                return 0  # default to sp3
        elif atom.element == 'N':
            return 1  # default to sp2 for nitrogen
        elif atom.element == 'O':
            return 0  # default to sp3 for oxygen
        else:
            return 0  # default

    def estimate_partial_charge(self, atom):
        """Estimate atom partial charge (simplified)"""
        if atom.element == 'O':
            return -0.5
        elif atom.element == 'N':
            return -0.3
        elif atom.element == 'S':
            return -0.2
        elif atom.element == 'C':
            if atom.name == 'C':  # Carbonyl carbon
                return 0.5
            else:
                return 0.1
        else:
            return 0.0

    def is_in_aromatic_ring(self, atom):
        """Check if atom is in an aromatic ring (simplified)"""
        residue = atom.get_parent()
        residue_name = residue.get_resname()

        # Check if in aromatic residue and is part of ring
        if residue_name in ['PHE', 'TYR', 'TRP', 'HIS']:
            if atom.name in ['CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'CZ2', 'CZ3', 'CH2']:
                return 1

        return 0

    def is_bonded(self, atom1, atom2):
        """Check if two atoms are bonded (simplified)"""
        # Distance-based criterion
        distance = np.linalg.norm(atom1.get_coord() - atom2.get_coord())
        return distance < 2.0  # Typical bond length threshold

    def get_residue_features(self, residue):
        """
        Extract features for a residue.
        """
        # Residue type (one-hot encoding)
        residue_types = [
            'ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE',
            'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL'
        ]
        residue_type_onehot = [1 if residue.get_resname() == t else 0 for t in residue_types]

        # Physicochemical properties
        properties = self.get_residue_properties(residue)

        # Secondary structure (simplified)
        ss_features = self.get_secondary_structure_features(residue)

        # Combine features
        features = residue_type_onehot + properties + ss_features

        return features

    def get_residue_properties(self, residue):
        """Get physicochemical properties of residue"""
        # Properties: hydrophobicity, charge, polarity, size
        residue_name = residue.get_resname()

        # Hydrophobicity scale (Kyte-Doolittle)
        hydrophobicity = {
            'ALA': 1.8, 'ARG': -4.5, 'ASN': -3.5, 'ASP': -3.5, 'CYS': 2.5,
            'GLN': -3.5, 'GLU': -3.5, 'GLY': -0.4, 'HIS': -3.2, 'ILE': 4.5,
            'LEU': 3.8, 'LYS': -3.9, 'MET': 1.9, 'PHE': 2.8, 'PRO': -1.6,
            'SER': -0.8, 'THR': -0.7, 'TRP': -0.9, 'TYR': -1.3, 'VAL': 4.2
        }

        # Charge at pH 7
        charge = {
            'ARG': 1, 'LYS': 1, 'ASP': -1, 'GLU': -1, 'HIS': 0.1,
            'ALA': 0, 'ASN': 0, 'CYS': 0, 'GLN': 0, 'GLY': 0,
            'ILE': 0, 'LEU': 0, 'MET': 0, 'PHE': 0, 'PRO': 0,
            'SER': 0, 'THR': 0, 'TRP': 0, 'TYR': 0, 'VAL': 0
        }

        # Polarity
        polarity = {
            'ARG': 1, 'ASN': 1, 'ASP': 1, 'GLN': 1, 'GLU': 1,
            'HIS': 1, 'LYS': 1, 'SER': 1, 'THR': 1, 'TYR': 1,
            'ALA': 0, 'CYS': 0, 'GLY': 0, 'ILE': 0, 'LEU': 0,
            'MET': 0, 'PHE': 0, 'PRO': 0, 'TRP': 0, 'VAL': 0
        }

        # Size (volume in Å³)
        size = {
            'ALA': 88.6, 'ARG': 173.4, 'ASN': 114.1, 'ASP': 111.1, 'CYS': 108.5,
            'GLN': 143.8, 'GLU': 138.4, 'GLY': 60.1, 'HIS': 153.2, 'ILE': 166.7,
            'LEU': 166.7, 'LYS': 168.6, 'MET': 162.9, 'PHE': 189.9, 'PRO': 112.7,
            'SER': 89.0, 'THR': 116.1, 'TRP': 227.8, 'TYR': 193.6, 'VAL': 140.0
        }

        # Normalize values
        hydrophobicity_norm = (hydrophobicity.get(residue_name, 0) + 4.5) / 9.0
        size_norm = size.get(residue_name, 0) / 227.8

        return [
            hydrophobicity_norm,
            charge.get(residue_name, 0),
            polarity.get(residue_name, 0),
            size_norm
        ]

    def get_secondary_structure_features(self, residue):
        """Get secondary structure features (simplified)"""
        # This is a simplified version; ideally, use DSSP
        # Here we use a heuristic based on residue type

        residue_name = residue.get_resname()

        # Propensities for different secondary structures
        helix_propensity = {
            'ALA': 1.0, 'ARG': 0.7, 'ASN': 0.5, 'ASP': 0.5, 'CYS': 0.5,
            'GLN': 0.8, 'GLU': 0.8, 'GLY': 0.0, 'HIS': 0.5, 'ILE': 0.8,
            'LEU': 0.9, 'LYS': 0.7, 'MET': 0.8, 'PHE': 0.6, 'PRO': 0.0,
            'SER': 0.4, 'THR': 0.4, 'TRP': 0.6, 'TYR': 0.6, 'VAL': 0.7
        }

        sheet_propensity = {
            'ALA': 0.5, 'ARG': 0.4, 'ASN': 0.3, 'ASP': 0.3, 'CYS': 0.6,
            'GLN': 0.4, 'GLU': 0.4, 'GLY': 0.0, 'HIS': 0.5, 'ILE': 1.0,
            'LEU': 0.7, 'LYS': 0.4, 'MET': 0.6, 'PHE': 0.8, 'PRO': 0.0,
            'SER': 0.3, 'THR': 0.5, 'TRP': 0.8, 'TYR': 0.8, 'VAL': 1.0
        }

        turn_propensity = {
            'ALA': 0.3, 'ARG': 0.6, 'ASN': 0.8, 'ASP': 0.8, 'CYS': 0.4,
            'GLN': 0.6, 'GLU': 0.6, 'GLY': 1.0, 'HIS': 0.5, 'ILE': 0.2,
            'LEU': 0.3, 'LYS': 0.6, 'MET': 0.3, 'PHE': 0.3, 'PRO': 1.0,
            'SER': 0.7, 'THR': 0.6, 'TRP': 0.3, 'TYR': 0.3, 'VAL': 0.2
        }

        return [
            helix_propensity.get(residue_name, 0.5),
            sheet_propensity.get(residue_name, 0.5),
            turn_propensity.get(residue_name, 0.5)
        ]

    def get_residue_code(self, residue):
        """Convert residue to one-letter code"""
        residue_name = residue.get_resname()
        code_map = {
            'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C',
            'GLN': 'Q', 'GLU': 'E', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I',
            'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', 'PRO': 'P',
            'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V'
        }
        return code_map.get(residue_name, 'X')

    def get_sequence_features(self, sequence):
        """
        Extract features from protein sequence.
        """
        # One-hot encoding of amino acids
        aa_types = 'ACDEFGHIKLMNPQRSTVWY'

        # Initialize features
        seq_features = []

        for aa in sequence:
            # One-hot encoding
            aa_onehot = [1 if aa == t else 0 for t in aa_types]
            seq_features.append(aa_onehot)

        # Convert to tensor
        seq_features = torch.tensor(np.array(seq_features), dtype=torch.float)

        return seq_features

    def compute_graph_connectivity(self, positions, cutoff):
        """
        Compute graph connectivity based on distance cutoff.
        """
        # Convert to numpy for efficient computation
        positions_np = positions.numpy()

        # Compute pairwise distances
        num_nodes = positions_np.shape[0]
        edges = []
        edge_features = []

        for i in range(num_nodes):
            for j in range(num_nodes):
                if i != j:
                    # Compute distance
                    distance = np.linalg.norm(positions_np[i] - positions_np[j])

                    # Add edge if within cutoff
                    if distance < cutoff:
                        edges.append([i, j])

                        # Compute edge features
                        direction = positions_np[j] - positions_np[i]
                        unit_vec = direction / distance

                        # Edge features: distance and unit vector
                        edge_feat = [distance] + unit_vec.tolist()
                        edge_features.append(edge_feat)

        # Convert to tensors
        edge_index = torch.tensor(edges, dtype=torch.long).t()
        edge_attr = torch.tensor(edge_features, dtype=torch.float)

        return edge_index, edge_attr


# Example usage
def create_dataloader(pdb_dir, csv_file=None, batch_size=32, shuffle=True):
    """
    Create a dataloader for protein graphs.
    """
    # Create dataset
    dataset = ProteinGraphDataset(pdb_dir, csv_file)

    # Create dataloader
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=collate_protein_graphs
    )

    return dataloader
def collate_protein_graphs(batch):
    """
    Collate function for protein graphs.
    """
    # Initialize batched data
    batched_data = {
        'protein_ids': [],
        'atom_features': [],
        'atom_positions': [],
        'atom_edge_index': [],
        'atom_edge_attr': [],
        'residue_features': [],
        'residue_positions': [],
        'residue_edge_index': [],
        'residue_edge_attr': [],
        'atom_to_residue_mapping': [],
        'sequence_features': [],
        'targets': {
            'level1': [],
            'level2': [],
            'level3': [],
            'level4': [],
            'thermo': []
        }
    }

    # Cumulative number of atoms and residues
    cum_atoms = 0
    cum_residues = 0

    # Process each sample
    for data in batch:
        # Add protein ID
        batched_data['protein_ids'].append(data['protein_id'])

        # Add atom features and positions
        batched_data['atom_features'].append(data['atom_features'])
        batched_data['atom_positions'].append(data['atom_positions'])

        # Add residue features and positions
        batched_data['residue_features'].append(data['residue_features'])
        batched_data['residue_positions'].append(data['residue_positions'])

        # Add sequence features - store as is, will handle in model
        if 'sequence_features' in data:
            batched_data['sequence_features'].append(data['sequence_features'])

        # Add atom edge index and attributes with offset
        if cum_atoms > 0:
            atom_edge_index = data['atom_edge_index'] + cum_atoms
        else:
            atom_edge_index = data['atom_edge_index']
        batched_data['atom_edge_index'].append(atom_edge_index)
        batched_data['atom_edge_attr'].append(data['atom_edge_attr'])

        # Add residue edge index and attributes with offset
        if cum_residues > 0:
            residue_edge_index = data['residue_edge_index'] + cum_residues
        else:
            residue_edge_index = data['residue_edge_index']
        batched_data['residue_edge_index'].append(residue_edge_index)
        batched_data['residue_edge_attr'].append(data['residue_edge_attr'])

        # Add atom to residue mapping with offset
        atom_to_residue_mapping = data['atom_to_residue_mapping'] + cum_residues
        batched_data['atom_to_residue_mapping'].append(atom_to_residue_mapping)

        # Add targets if available
        if 'targets' in data:
            for level in ['level1', 'level2', 'level3', 'level4', 'thermo']:
                batched_data['targets'][level].append(data['targets'][level])

        # Update cumulative counts
        cum_atoms += data['atom_features'].shape[0]
        cum_residues += data['residue_features'].shape[0]

    # Concatenate tensors
    batched_data['atom_features'] = torch.cat(batched_data['atom_features'], dim=0)
    batched_data['atom_positions'] = torch.cat(batched_data['atom_positions'], dim=0)
    batched_data['atom_edge_index'] = torch.cat(batched_data['atom_edge_index'], dim=1)
    batched_data['atom_edge_attr'] = torch.cat(batched_data['atom_edge_attr'], dim=0)

    batched_data['residue_features'] = torch.cat(batched_data['residue_features'], dim=0)
    batched_data['residue_positions'] = torch.cat(batched_data['residue_positions'], dim=0)
    batched_data['residue_edge_index'] = torch.cat(batched_data['residue_edge_index'], dim=1)
    batched_data['residue_edge_attr'] = torch.cat(batched_data['residue_edge_attr'], dim=0)

    batched_data['atom_to_residue_mapping'] = torch.cat(batched_data['atom_to_residue_mapping'], dim=0)

    # Keep sequence_features as a list - will handle in model

    # Stack targets if available
    if batched_data['targets']['level1']:
        for level in ['level1', 'level2', 'level3', 'level4']:
            batched_data['targets'][level] = torch.stack(batched_data['targets'][level])
        batched_data['targets']['thermo'] = torch.stack(batched_data['targets']['thermo'])
    else:
        batched_data.pop('targets')

    return batched_data



In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import os
import time
from sklearn.metrics import accuracy_score, f1_score, mean_absolute_error, mean_squared_error

def train_model(model, train_loader, val_loader, criterion, optimizer,
                num_epochs=50, device='cpu', checkpoint_dir='/content/drive/MyDrive/EnzymeGNN/models'):
    """
    Train the EnzymeGNN model.
    """
    # Create checkpoint directory if it doesn't exist
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Move model to device
    model = model.to(device)

    # Initialize training history
    history = {
        'train_loss': [],
        'val_loss': [],
        'ec_level1_acc': [],
        'ec_level4_acc': [],
        'thermo_mae': []
    }

    # Best validation loss
    best_val_loss = float('inf')

    # Training loop
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0

        for batch_idx, batch in enumerate(train_loader):
            # Move data to device
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            targets = {k: v.to(device) for k, v in batch['targets'].items()}

            # Zero gradients
            optimizer.zero_grad()

            # Forward pass
            predictions = model(batch)

            # Compute loss
            loss_dict = criterion(predictions, targets)
            loss = loss_dict['total']

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Update training loss
            train_loss += loss.item()

            # Print progress
            if (batch_idx + 1) % 10 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

        # Average training loss
        train_loss /= len(train_loader)
        history['train_loss'].append(train_loss)

        # Validation phase
        model.eval()
        val_loss = 0.0
        ec_level1_preds = []
        ec_level1_targets = []
        ec_level4_preds = []
        ec_level4_targets = []
        thermo_preds = []
        thermo_targets = []

        with torch.no_grad():
            for batch in val_loader:
                # Move data to device
                batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
                targets = {k: v.to(device) for k, v in batch['targets'].items()}

                # Forward pass
                predictions = model(batch)

                # Compute loss
                loss_dict = criterion(predictions, targets)
                val_loss += loss_dict['total'].item()

                # Collect predictions and targets for metrics
                ec_level1_preds.append(predictions['level1'].argmax(dim=1).cpu().numpy())
                ec_level1_targets.append(targets['level1'].cpu().numpy())

                ec_level4_preds.append(predictions['level4'].argmax(dim=1).cpu().numpy())
                ec_level4_targets.append(targets['level4'].cpu().numpy())

                thermo_preds.append(predictions['thermo'].cpu().numpy())
                thermo_targets.append(targets['thermo'].cpu().numpy())

        # Average validation loss
        val_loss /= len(val_loader)
        history['val_loss'].append(val_loss)

        # Compute metrics
        ec_level1_preds = np.concatenate(ec_level1_preds)
        ec_level1_targets = np.concatenate(ec_level1_targets)
        ec_level1_acc = accuracy_score(ec_level1_targets, ec_level1_preds)
        history['ec_level1_acc'].append(ec_level1_acc)

        ec_level4_preds = np.concatenate(ec_level4_preds)
        ec_level4_targets = np.concatenate(ec_level4_targets)
        ec_level4_acc = accuracy_score(ec_level4_targets, ec_level4_preds)
        history['ec_level4_acc'].append(ec_level4_acc)

        thermo_preds = np.concatenate(thermo_preds)
        thermo_targets = np.concatenate(thermo_targets)
        thermo_mae = mean_absolute_error(thermo_targets, thermo_preds)
        history['thermo_mae'].append(thermo_mae)

        # Print epoch results
        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
        print(f'EC Level 1 Acc: {ec_level1_acc:.4f}, EC Level 4 Acc: {ec_level4_acc:.4f}, Thermo MAE: {thermo_mae:.4f}')

        # Save checkpoint if validation loss improved
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            checkpoint_path = os.path.join(checkpoint_dir, f'enzyme_gnn_epoch_{epoch+1}.pt')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'ec_level1_acc': ec_level1_acc,
                'ec_level4_acc': ec_level4_acc,
                'thermo_mae': thermo_mae
            }, checkpoint_path)
            print(f'Checkpoint saved to {checkpoint_path}')

    return model, history


def plot_training_history(history, save_dir='/content/drive/MyDrive/EnzymeGNN/visualizations'):
    """
    Plot training history.
    """
    # Create save directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # Plot training and validation loss
    plt.figure(figsize=(10, 6))
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(save_dir, 'loss_curve.png'), dpi=300)
    plt.close()

    # Plot EC level 1 and 4 accuracy
    plt.figure(figsize=(10, 6))
    plt.plot(history['ec_level1_acc'], label='EC Level 1 Accuracy')
    plt.plot(history['ec_level4_acc'], label='EC Level 4 Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('EC Number Prediction Accuracy')
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(save_dir, 'ec_accuracy.png'), dpi=300)
    plt.close()

    # Plot thermostability MAE
    plt.figure(figsize=(10, 6))
    plt.plot(history['thermo_mae'])
    plt.xlabel('Epoch')
    plt.ylabel('MAE (°C)')
    plt.title('Thermostability Prediction MAE')
    plt.grid(True)
    plt.savefig(os.path.join(save_dir, 'thermo_mae.png'), dpi=300)
    plt.close()


In [9]:
from google.colab import files
files.upload()  # Upload your kaggle.json file

Saving kaggle (1).json to kaggle (1).json


{'kaggle (1).json': b'{"username":"btissam12","key":"3774017372de4e6b1b0f11569844a11b"}'}

In [10]:
!chmod 600 /root/.kaggle/kaggle.json

chmod: cannot access '/root/.kaggle/kaggle.json': No such file or directory


In [11]:
!pip install kaggle



In [12]:
!kaggle datasets download -d roberthatch/nesp-kvigly-test-mutation-pdbs


Traceback (most recent call last):
  File "/usr/local/bin/kaggle", line 10, in <module>
    sys.exit(main())
             ^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/kaggle/cli.py", line 68, in main
    out = args.func(**command_args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/kaggle/api/kaggle_api_extended.py", line 1734, in dataset_download_cli
    with self.build_kaggle_client() as kaggle:
         ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/kaggle/api/kaggle_api_extended.py", line 688, in build_kaggle_client
    username=self.config_values['username'],
             ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^
KeyError: 'username'


In [13]:
!unzip nesp-kvigly-test-mutation-pdbs.zip -d /content/drive/MyDrive/EnzymeGNN/data/

unzip:  cannot find or open nesp-kvigly-test-mutation-pdbs.zip, nesp-kvigly-test-mutation-pdbs.zip.zip or nesp-kvigly-test-mutation-pdbs.zip.ZIP.


In [14]:
# Alternatively, you can use a sample dataset for testing
# Create a sample dataset with a few PDB files
import os
import urllib.request
from Bio.PDB import PDBParser, PDBIO

# Create directories
os.makedirs('/content/drive/MyDrive/EnzymeGNN/data/sample', exist_ok=True)

# Download a few sample PDB files
sample_pdbs = [
    ('1a0j', 'https://files.rcsb.org/download/1A0J.pdb') ,  # Alcohol dehydrogenase (EC 1.1.1.1)
    ('1a27', 'https://files.rcsb.org/download/1A27.pdb') ,  # Serine protease (EC 3.4.21.-)
    ('1a3w', 'https://files.rcsb.org/download/1A3W.pdb')    # Transferase (EC 2.7.1.-)
]

# Download and save PDB files
for pdb_id, url in sample_pdbs:
    pdb_file = f'/content/drive/MyDrive/EnzymeGNN/data/sample/{pdb_id}.pdb'
    if not os.path.exists(pdb_file):
        print(f'Downloading {pdb_id}...')
        urllib.request.urlretrieve(url, pdb_file)
        print(f'Saved to {pdb_file}')

# Create a sample metadata CSV file
import pandas as pd

# Sample metadata
metadata = {
    'protein_id': ['1a0j', '1a27', '1a3w'],
    'ec_numbers': ['1.1.1.1', '3.4.21.0', '2.7.1.0'],
    'thermostability': [45.2, 52.8, 38.6],
    'sequence': [
        'MKGFAMLSIGKVGWIEKEKPAPGPFDAIVRPLAVAPCTSDIHTVFEGAIGERHNMILGHEAVGEVVEVGSEVKDFKPGDRVIVPCTTPDWRSLEVQAGFQQHSNGMLAGWKFSNFKDGVFGEYFHVNDADMNLAILPKDMPLENAVMITDMMTTGFHGAELADIQMGSSVVVIGIGAVGLMGIAGAKLRGAGRIIGVGSRPICVEAAKFYGATDILNYKNGHIVDQVMKLTNGKGVDRVIMAGGGSETLSQAVSMVKPGGIISNINYHGSGDALLIPRVEWGCGMAHKTIKGGLCPGGRLRAEMLRDMVVYNRVDLSKLVTHVYHGFDHIEEALLLMKDKPKDLIKAVVIL',
        'IVGGYTCGANTVPYQVSLNSGYHFCGGSLINSQWVVSAAHCYKSGIQVRLGEDNINVVEGNEQFISASKSIVHPSYNSNTLNNDIMLIKLKSAASLNSRVASISLPTSCASAGTQCLISGWGNTKSSGTSYPDVLKCLKAPILSDSSCKSAYPGQITSNMFCAGYLEGGKDSCQGDSGGPVVCSGKLQGIVSWGSGCAQKNKPGVYTKVCNYVSWIKQTIASN',
        'MKLKGLDVVVGYSTDYLAGCNHLPWTEKLKTILRDIGFHSSRWVTQVDDGIDGLAQYIFENQLSEGLDSLKLVSVIHKDVEIVSQETVGQTLPWPKIDKFSDTPFYQRWMNFYISDEDNYLIGSNVYIGTDVGNNELTIIHTDNQERTPVVYLKGDLVWEGGNLDSLEGKQVIEHSYLDGAFYRLNSPWCNDTLFSRQRYKVNLKFPGGEHVYIAKQFVEGGGLDVVKFQNPNENFVGAAVLAKGDWQKIIDNPNVVLVDTVSGFGKDYYKVNPNELRVWDYTDVNKTPVVYLKGDLVWDGGNLDTLEGKQVIEHSYLDGAFYRLNSPWCNDTLFSRQRYKVNLKFPGGEHVYIAKQFVEGGGLDVVKFQNPNENFVGAAVLAKGDWQKIIDNPNVVLVDTVSGFGKDYYKV'
    ]
}

# Create DataFrame and save to CSV
metadata_df = pd.DataFrame(metadata)
metadata_df.to_csv('/content/drive/MyDrive/EnzymeGNN/data/sample_metadata.csv', index=False)
print('Sample metadata saved to /content/drive/MyDrive/EnzymeGNN/data/sample_metadata.csv')

Sample metadata saved to /content/drive/MyDrive/EnzymeGNN/data/sample_metadata.csv


In [None]:
def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    # Create dataloaders
    train_loader = create_dataloader(
        pdb_dir='/content/drive/MyDrive/EnzymeGNN/data/sample',
        csv_file='/content/drive/MyDrive/EnzymeGNN/data/sample_metadata.csv',
        batch_size=2,
        shuffle=True
    )

    val_loader = create_dataloader(
        pdb_dir='/content/drive/MyDrive/EnzymeGNN/data/sample',
        csv_file='/content/drive/MyDrive/EnzymeGNN/data/sample_metadata.csv',
        batch_size=2,
        shuffle=False
    )

    # Get a sample batch to check dimensions
    for batch in train_loader:
        print("Atom features shape:", batch['atom_features'].shape)
        print("Residue features shape:", batch['residue_features'].shape)

        # Check if sequence_features is a list or tensor
        if isinstance(batch['sequence_features'], list):
            print("Sequence features is a list with length:", len(batch['sequence_features']))
            # Convert list to tensor if needed
            if len(batch['sequence_features']) > 0 and isinstance(batch['sequence_features'][0], torch.Tensor):
                # Stack tensors in the list
                sequence_tensor = torch.stack(batch['sequence_features'])
                print("Converted sequence features shape:", sequence_tensor.shape)
                # Update the batch with the tensor
                batch['sequence_features'] = sequence_tensor
            else:
                print("Sequence features list contains non-tensor elements")
        else:
            print("Sequence features shape:", batch['sequence_features'].shape)
        break

    # Create model with correct dimensions
    atom_feature_dim = batch['atom_features'].shape[1]  # Use actual dimension
    residue_feature_dim = batch['residue_features'].shape[1]  # Use actual dimension

    print(f"Using atom_feature_dim={atom_feature_dim}, residue_feature_dim={residue_feature_dim}")

    # Create model with the correct dimensions
    model = EnzymeGNN(
        atom_feature_dim=atom_feature_dim,
        residue_feature_dim=residue_feature_dim,
        hidden_dim=64,
        num_layers=3,
        fusion_dim=128
    )

    # Create loss function
    criterion = EnzymeGNNLoss(
        ec_weights=[1.0, 1.0, 1.0, 1.0],
        thermo_weight=1.0,
        active_site_weight=0.5
    )

    # Create optimizer
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Train model
    model, history = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        num_epochs=10,  # Reduced for sample
        device=device,
        checkpoint_dir='/content/drive/MyDrive/EnzymeGNN/models'
    )

    # Plot training history
    plot_training_history(history, save_dir='/content/drive/MyDrive/EnzymeGNN/visualizations')

    print('Training completed successfully!')

# Run the main function
if __name__ == '__main__':
    main()


Using device: cpu
