In [1]:
import pandas as pd
import yaml

In [2]:
with open('config.yaml', 'r') as file:
    config = yaml.safe_load(file)

In [3]:
dti_pd = pd.read_parquet(config["dataPath"]["dti_data"])
adr_pd = pd.read_parquet(config["dataPath"]["adr_data"])
print("DTI Data info: ")
print(dti_pd.info(),"\n")
print("ADR Data info: ")
print(adr_pd.info())

DTI Data info: 
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 34741 entries, 0 to 34740
Data columns (total 7 columns):
 #   Column             Non-Null Count  Dtype 
---  ------             --------------  ----- 
 0   drug_chembl_id     34741 non-null  object
 1   target_uniprot_id  34741 non-null  object
 2   label              34741 non-null  int64 
 3   smiles             34741 non-null  object
 4   sequence           34741 non-null  object
 5   molfile_3d         34741 non-null  object
 6   rxcui              34741 non-null  object
dtypes: int64(1), object(6)
memory usage: 1.9+ MB
None 

ADR Data info: 
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 69474 entries, 0 to 69473
Data columns (total 3 columns):
 #   Column                Non-Null Count  Dtype 
---  ------                --------------  ----- 
 0   rxnorm_ingredient_id  69474 non-null  object
 1   meddra_id             69474 non-null  int64 
 2   meddra_name           69474 non-null  object
dtypes: int64(1), obje

In [4]:
import os
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import Data, Dataset
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from rdkit import Chem
from rdkit.Chem import AllChem
import warnings
warnings.filterwarnings('ignore')


In [5]:
# Create cache directories if they don't exist
os.makedirs(config["cachePath"]["prot_model"], exist_ok=True)
os.makedirs(config["cachePath"]["drug_model"], exist_ok=True)

# Basic configuration - UPDATED
class Config:
    def __init__(self, config_dict):
        self.data_paths = config_dict["dataPath"]
        self.cache_paths = config_dict["cachePath"]
        
        # Model parameters - WILL BE UPDATED AFTER DATA ANALYSIS
        self.protein_feat_dim = None  # Will be set after building first protein graph
        self.drug_scalar_dim = None   # Will be set after building first drug graph
        self.drug_vector_neighbor_dim = None  # Will be set after building first drug graph  
        self.drug_vector_coord_dim = None     # Will be set after building first drug graph
        
        # Fixed architecture parameters
        self.hidden_dim = 128
        
        # Training parameters
        self.batch_size = 32
        self.learning_rate = 0.001
        
    def update_from_data(self, protein_feat_dim, drug_scalar_dim, drug_vector_neighbor_dim, drug_vector_coord_dim):
        """Update dimensions after analyzing actual data"""
        self.protein_feat_dim = protein_feat_dim
        self.drug_scalar_dim = drug_scalar_dim
        self.drug_vector_neighbor_dim = drug_vector_neighbor_dim
        self.drug_vector_coord_dim = drug_vector_coord_dim
        
        print(f"Configuration updated from data:")
        print(f"  protein_feat_dim: {self.protein_feat_dim}")
        print(f"  drug_scalar_dim: {self.drug_scalar_dim}")
        print(f"  drug_vector_neighbor_dim: {self.drug_vector_neighbor_dim}")
        print(f"  drug_vector_coord_dim: {self.drug_vector_coord_dim}")

# Initialize config
cfg = Config(config)



In [6]:
class ProteinGraphBuilder:
    def __init__(self, prot_3d_dir):
        self.prot_3d_dir = prot_3d_dir
        self.amino_acid_dict = self._create_aa_embedding_dict()
    
    def _create_aa_embedding_dict(self):
        """Create amino acid feature dictionary"""
        # Simple one-hot encoding for starters
        aa_list = 'ACDEFGHIKLMNPQRSTVWY'
        aa_dict = {}
        for i, aa in enumerate(aa_list):
            # One-hot + some basic properties
            features = np.zeros(20)
            features[i] = 1.0
            aa_dict[aa] = features
        return aa_dict
    
    def get_alphafold_path(self, uniprot_id):
        """Get AlphaFold PDB file path for a given UniProt ID"""
        pdb_path = os.path.join(self.prot_3d_dir, f"{uniprot_id}.pdb")
        if os.path.exists(pdb_path):
            return pdb_path
        else:
            print(f"Warning: AlphaFold file not found for {uniprot_id}")
            return None
    
    def parse_alphafold_pdb(self, pdb_path):
        """Parse AlphaFold PDB file to extract coordinates and pLDDT scores"""
        try:
            coords = []
            plddt_scores = []
            
            with open(pdb_path, 'r') as f:
                for line in f:
                    if line.startswith('ATOM') and 'CA' in line:
                        # Extract coordinates
                        x = float(line[30:38])
                        y = float(line[38:46])
                        z = float(line[46:54])
                        coords.append([x, y, z])
                        
                        # Extract pLDDT score (B-factor column)
                        plddt = float(line[60:66])
                        plddt_scores.append(plddt)
            
            return np.array(coords), np.array(plddt_scores)
        except Exception as e:
            print(f"Error parsing PDB file {pdb_path}: {e}")
            return None, None
    
    def build_spatial_edges(self, coords, cutoff=8.0):
        """Build edges based on spatial proximity"""
        n_nodes = len(coords)
        edges = []
        
        for i in range(n_nodes):
            for j in range(i + 1, n_nodes):
                distance = np.linalg.norm(coords[i] - coords[j])
                if distance < cutoff:
                    edges.append([i, j])
                    edges.append([j, i])  # Undirected graph
        
        if len(edges) == 0:
            # Fallback: connect sequential residues
            edges = [[i, i+1] for i in range(n_nodes-1)]
            edges += [[i+1, i] for i in range(n_nodes-1)]
        
        return torch.tensor(edges, dtype=torch.long).t().contiguous()
    
    def build_protein_graph(self, uniprot_id, sequence):
        """Build protein graph from sequence and AlphaFold data"""
        # Try to get AlphaFold structure
        pdb_path = self.get_alphafold_path(uniprot_id)
        coords, plddt_scores = None, None
        
        if pdb_path:
            coords, plddt_scores = self.parse_alphafold_pdb(pdb_path)
        
        # Build node features
        node_features = []
        for i, aa in enumerate(sequence):
            if aa in self.amino_acid_dict:
                aa_features = self.amino_acid_dict[aa]
            else:
                # Unknown amino acid - use average
                aa_features = np.zeros(20)
            
            # Add structural features if available
            structural_features = []
            if plddt_scores is not None and i < len(plddt_scores):
                structural_features.append(plddt_scores[i] / 100.0)  # Normalize pLDDT
            else:
                structural_features.append(0.5)  # Default confidence
            
            if coords is not None and i < len(coords):
                # Add some basic structural features
                if i > 0 and i < len(coords) - 1:
                    # Simple pseudo-dihedral (simplified)
                    vec1 = coords[i] - coords[i-1]
                    vec2 = coords[i+1] - coords[i]
                    if np.linalg.norm(vec1) > 0 and np.linalg.norm(vec2) > 0:
                        cos_angle = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
                        structural_features.append(cos_angle)
                    else:
                        structural_features.append(0.0)
                else:
                    structural_features.append(0.0)
            else:
                structural_features.append(0.0)
            
            # Combine all features
            features = np.concatenate([aa_features, structural_features])
            node_features.append(features)
        
        node_features = np.array(node_features)
        
        # Build edges
        if coords is not None:
            edge_index = self.build_spatial_edges(coords, cutoff=8.0)
        else:
            # Fallback: sequence-based edges
            n_residues = len(sequence)
            edges = []
            for i in range(n_residues - 1):
                edges.append([i, i + 1])
                edges.append([i + 1, i])
            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        
        # Update feature dimension in config
        cfg.protein_feat_dim = node_features.shape[1]
        
        return Data(
            x=torch.tensor(node_features, dtype=torch.float),
            edge_index=edge_index,
            uniprot_id=uniprot_id,
            sequence=sequence
        )

# Initialize protein graph builder
protein_builder = ProteinGraphBuilder(cfg.data_paths["prot_3d_dir"])

In [7]:
class DrugGVPBuilder:
    def __init__(self):
        self.atom_features_dict = self._create_atom_feature_dict()
    
    def _create_atom_feature_dict(self):
        """Create atom feature dictionary"""
        # Basic atom types
        atom_types = ['C', 'N', 'O', 'S', 'F', 'Cl', 'Br', 'I', 'P']
        feature_dict = {}
        
        for atom in atom_types:
            # One-hot for common atoms + zeros for others
            features = [1.0 if a == atom else 0.0 for a in atom_types]
            feature_dict[atom] = features
        
        # Default for other atoms
        feature_dict['OTHER'] = [0.0] * len(atom_types)
        
        return feature_dict
    
    def get_atom_features(self, atom):
        """Get features for an atom"""
        symbol = atom.GetSymbol()
        if symbol in self.atom_features_dict:
            base_features = self.atom_features_dict[symbol]
        else:
            base_features = self.atom_features_dict['OTHER']
        
        # Additional features
        additional_features = [
            atom.GetDegree() / 4.0,  # Normalized degree
            atom.GetFormalCharge() / 2.0,  # Normalized charge
            float(atom.GetIsAromatic()),
            atom.GetNumImplicitHs() / 4.0,  # Normalized H count
            atom.GetMass() / 100.0  # Normalized mass
        ]
        
        return base_features + additional_features
    
    def get_neighbor_vectors(self, mol, atom_idx, conformer):
        """Get direction vectors to neighbors for geometric features"""
        atom = mol.GetAtomWithIdx(atom_idx)
        neighbors = atom.GetNeighbors()
        
        vectors = []
        for neighbor in neighbors:
            nbr_idx = neighbor.GetIdx()
            if nbr_idx != atom_idx:
                atom_pos = conformer.GetAtomPosition(atom_idx)
                nbr_pos = conformer.GetAtomPosition(nbr_idx)
                vec = [nbr_pos.x - atom_pos.x, nbr_pos.y - atom_pos.y, nbr_pos.z - atom_pos.z]
                vectors.append(vec)
        
        # Pad to fixed dimension
        max_neighbors = 4
        while len(vectors) < max_neighbors:
            vectors.append([0.0, 0.0, 0.0])
        
        return vectors[:max_neighbors]
    
    def smiles_to_gvp_data(self, smiles, drug_id):
        """Convert SMILES to GVP-ready data"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                raise ValueError(f"Invalid SMILES: {smiles}")
            
            # Add hydrogens and generate 3D coordinates
            mol = Chem.AddHs(mol)
            AllChem.EmbedMolecule(mol, AllChem.ETKDG())
            
            conformer = mol.GetConformers()[0]
            
            scalar_feats = []
            vector_feats = []
            positions = []
            
            for atom in mol.GetAtoms():
                idx = atom.GetIdx()
                
                # Scalar features
                scalar_feat = self.get_atom_features(atom)
                scalar_feats.append(scalar_feat)
                
                # Vector features (neighbor directions)
                vector_feat = self.get_neighbor_vectors(mol, idx, conformer)
                vector_feats.append(vector_feat)
                
                # Atom positions
                pos = conformer.GetAtomPosition(idx)
                positions.append([pos.x, pos.y, pos.z])
            
            scalar_feats = np.array(scalar_feats)
            vector_feats = np.array(vector_feats)
            positions = np.array(positions)
            
            # Update drug feature dimensions in config
            cfg.drug_scalar_dim = scalar_feats.shape[1]
            cfg.drug_vector_dim = vector_feats.shape[2]  # Should be 3
            
            return {
                'scalar_feats': torch.tensor(scalar_feats, dtype=torch.float),
                'vector_feats': torch.tensor(vector_feats, dtype=torch.float),
                'positions': torch.tensor(positions, dtype=torch.float),
                'drug_id': drug_id,
                'smiles': smiles
            }
            
        except Exception as e:
            print(f"Error processing drug {drug_id}: {e}")
            return None

# Initialize drug graph builder
drug_builder = DrugGVPBuilder()

In [8]:
# Test with a small sample
sample_protein = dti_pd.iloc[0]
sample_drug = dti_pd.iloc[0]

print("Testing protein graph builder...")
protein_graph = protein_builder.build_protein_graph(
    sample_protein['target_uniprot_id'], 
    sample_protein['sequence']
)
print(f"Protein graph: {protein_graph}")
print(f"Protein node features shape: {protein_graph.x.shape}")
print(f"Protein edges: {protein_graph.edge_index.shape}")

print("\nTesting drug graph builder...")
drug_data = drug_builder.smiles_to_gvp_data(
    sample_drug['smiles'],
    sample_drug['drug_chembl_id']
)
if drug_data:
    print(f"Drug scalar features shape: {drug_data['scalar_feats'].shape}")
    print(f"Drug vector features shape: {drug_data['vector_feats'].shape}")

print(f"\nFinal feature dimensions:")
print(f"Protein feature dim: {cfg.protein_feat_dim}")
print(f"Drug scalar dim: {cfg.drug_scalar_dim}")
print(f"Drug vector dim: {cfg.drug_vector_dim}")

Testing protein graph builder...
Protein graph: Data(x=[554, 22], edge_index=[2, 5418], uniprot_id='O15245', sequence='MPTVDDILEQVGESGWFQKQAFLILCLLSAAFAPICVGIVFLGFTPDHHCQSPGVAELSQRCGWSPAEELNYTVPGLGPAGEAFLGQCRRYEVDWNQSALSCVDPLASLATNRSHLPLGPCQDGWVYDTPGSSIVTEFNLVCADSWKLDLFQSCLNAGFLFGSLGVGYFADRFGRKLCLLGTVLVNAVSGVLMAFSPNYMSMLLFRLLQGLVSKGNWMAGYTLITEFVGSGSRRTVAIMYQMAFTVGLVALTGLAYALPHWRWLQLAVSLPTFLFLLYYWCVPESPRWLLSQKRNTEAIKIMDHIAQKNGKLPPADLKMLSLEEDVTEKLSPSFADLFRTPRLRKRTFILMYLWFTDSVLYQGLILHMGATSGNLYLDFLYSALVEIPGAFIALITIDRVGRIYPMAMSNLLAGAACLVMIFISPDLHWLNIIIMCVGRMGITIAIQMICLVNAELYPTFVRNLGVMVCSSLCDIGGIITPFIVFRLREVWQALPLILFAVLGLLAAGVTLLLPETKGVALPETMKDAENLGRKAKPKENTIYLKVQTSEPSGT')
Protein node features shape: torch.Size([554, 22])
Protein edges: torch.Size([2, 5418])

Testing drug graph builder...
Drug scalar features shape: torch.Size([52, 14])
Drug vector features shape: torch.Size([52, 4, 3])

Final feature dimensions:
Protein feature dim: 22
Drug scalar dim: 14
Drug vector dim: 3


In [10]:
# =============================================================================
# CACHED DTI DATASET WITH PROGRESS BARS FOR GRAPH BUILDING
# =============================================================================
import pickle
import hashlib
import concurrent.futures
import threading
from tqdm import tqdm

class ThreadedCachedDTIDataset(Dataset):
    def __init__(self, dti_df, protein_builder, drug_builder, adr_df=None, 
                 cache_dir="./Cache", force_rebuild=False, show_progress=True,
                 max_workers=None):
        self.dti_df = dti_df.reset_index(drop=True)
        self.protein_builder = protein_builder
        self.drug_builder = drug_builder
        self.adr_df = adr_df
        self.cache_dir = cache_dir
        self.force_rebuild = force_rebuild
        self.show_progress = show_progress
        self.max_workers = max_workers or min(32, (os.cpu_count() or 1) + 4)
        
        # Create cache directories
        self.protein_cache_dir = os.path.join(cache_dir, "proteins")
        self.drug_cache_dir = os.path.join(cache_dir, "drugs")
        os.makedirs(self.protein_cache_dir, exist_ok=True)
        os.makedirs(self.drug_cache_dir, exist_ok=True)
        
        # Thread-safe data structures
        self.cache_stats = {
            'protein_hits': 0,
            'protein_misses': 0,
            'drug_hits': 0,
            'drug_misses': 0
        }
        self._stats_lock = threading.Lock()
        
        # Build side effect mapping
        self.side_effect_mapping = self._build_side_effect_mapping()
        
        print(f"üß¨ Threaded Cached Dataset Initialized:")
        print(f"   üìä Total samples: {len(self.dti_df):,}")
        print(f"   üß™ Unique proteins: {self.dti_df['target_uniprot_id'].nunique():,}")
        print(f"   üíä Unique drugs: {self.dti_df['drug_chembl_id'].nunique():,}")
        print(f"   üßµ Max workers: {self.max_workers}")
        print(f"   üìÅ Protein cache: {self.protein_cache_dir}")
        print(f"   üìÅ Drug cache: {self.drug_cache_dir}")
        print(f"   ü©∫ Number of side effects: {self.side_effect_mapping['num_side_effects']:,}")
        
        # Precompute all missing graphs with threading
        self.precompute_all_graphs()
    
    def _build_side_effect_mapping(self):
        """Build mapping from drug to side effects"""
        if self.adr_df is None:
            return {'drug_to_se': {}, 'se_id_to_idx': {}, 'num_side_effects': 0}
        
        side_effect_dict = {}
        for _, row in self.adr_df.iterrows():
            drug_id = str(row['rxnorm_ingredient_id']).strip()
            se_id = row['meddra_id']
            
            if drug_id not in side_effect_dict:
                side_effect_dict[drug_id] = set()
            side_effect_dict[drug_id].add(se_id)
        
        all_se_ids = sorted(set(self.adr_df['meddra_id']))
        se_id_to_idx = {se_id: idx for idx, se_id in enumerate(all_se_ids)}
        
        return {
            'drug_to_se': side_effect_dict,
            'se_id_to_idx': se_id_to_idx,
            'num_side_effects': len(all_se_ids)
        }
    
    def _get_protein_cache_path(self, uniprot_id, sequence):
        """Generate cache file path for protein"""
        content = f"{uniprot_id}_{sequence}"
        hash_obj = hashlib.md5(content.encode())
        return os.path.join(self.protein_cache_dir, f"{hash_obj.hexdigest()}.pkl")
    
    def _get_drug_cache_path(self, drug_id, smiles):
        """Generate cache file path for drug"""
        content = f"{drug_id}_{smiles}"
        hash_obj = hashlib.md5(content.encode())
        return os.path.join(self.drug_cache_dir, f"{hash_obj.hexdigest()}.pkl")
    
    def _process_protein(self, uniprot_id, sequence):
        """Process a single protein (thread-safe)"""
        cache_path = self._get_protein_cache_path(uniprot_id, sequence)
        
        # Try to load from cache
        if not self.force_rebuild and os.path.exists(cache_path):
            try:
                with open(cache_path, 'rb') as f:
                    with self._stats_lock:
                        self.cache_stats['protein_hits'] += 1
                    return pickle.load(f)
            except Exception as e:
                if self.show_progress:
                    print(f"Error loading cached protein {uniprot_id}: {e}")
        
        # Build and cache
        with self._stats_lock:
            self.cache_stats['protein_misses'] += 1
        
        protein_graph = self.protein_builder.build_protein_graph(uniprot_id, sequence)
        
        try:
            with open(cache_path, 'wb') as f:
                pickle.dump(protein_graph, f)
        except Exception as e:
            if self.show_progress:
                print(f"Error caching protein {uniprot_id}: {e}")
        
        return protein_graph
    
    def _process_drug(self, drug_id, smiles):
        """Process a single drug (thread-safe)"""
        cache_path = self._get_drug_cache_path(drug_id, smiles)
        
        # Try to load from cache
        if not self.force_rebuild and os.path.exists(cache_path):
            try:
                with open(cache_path, 'rb') as f:
                    with self._stats_lock:
                        self.cache_stats['drug_hits'] += 1
                    return pickle.load(f)
            except Exception as e:
                if self.show_progress:
                    print(f"Error loading cached drug {drug_id}: {e}")
        
        # Build and cache
        with self._stats_lock:
            self.cache_stats['drug_misses'] += 1
        
        drug_data = self.drug_builder.smiles_to_gvp_data(smiles, drug_id)
        
        if drug_data is not None:
            try:
                with open(cache_path, 'wb') as f:
                    pickle.dump(drug_data, f)
            except Exception as e:
                if self.show_progress:
                    print(f"Error caching drug {drug_id}: {e}")
        
        return drug_data
    
    def precompute_all_graphs(self):
        """Precompute all missing graphs with threading"""
        print("üî® Precomputing missing graphs with threading...")
        
        # Get unique proteins and drugs
        unique_proteins = self.dti_df[['target_uniprot_id', 'sequence']].drop_duplicates()
        unique_drugs = self.dti_df[['drug_chembl_id', 'smiles']].drop_duplicates()
        
        print(f"üìã Found {len(unique_proteins):,} unique proteins and {len(unique_drugs):,} unique drugs")
        
        # Process proteins with threading
        protein_tasks = []
        for _, row in unique_proteins.iterrows():
            uniprot_id, sequence = row['target_uniprot_id'], row['sequence']
            cache_path = self._get_protein_cache_path(uniprot_id, sequence)
            if self.force_rebuild or not os.path.exists(cache_path):
                protein_tasks.append((uniprot_id, sequence))
        
        if protein_tasks:
            print(f"üß™ Processing {len(protein_tasks):,} proteins with {self.max_workers} workers...")
            with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
                # Submit all protein tasks
                future_to_protein = {
                    executor.submit(self._process_protein, uniprot_id, sequence): (uniprot_id, sequence)
                    for uniprot_id, sequence in protein_tasks
                }
                
                # Track progress with tqdm
                if self.show_progress:
                    with tqdm(total=len(future_to_protein), desc="üß™ Building proteins") as pbar:
                        for future in concurrent.futures.as_completed(future_to_protein):
                            uniprot_id, sequence = future_to_protein[future]
                            try:
                                future.result()
                                pbar.update(1)
                                pbar.set_postfix({'current': uniprot_id[:10] + "..."})
                            except Exception as e:
                                print(f"‚ùå Error processing protein {uniprot_id}: {e}")
                                pbar.update(1)
                else:
                    # Without progress bar, just wait for completion
                    for future in concurrent.futures.as_completed(future_to_protein):
                        uniprot_id, sequence = future_to_protein[future]
                        try:
                            future.result()
                        except Exception as e:
                            print(f"‚ùå Error processing protein {uniprot_id}: {e}")
        
        # Process drugs with threading (this is where threading helps most!)
        drug_tasks = []
        for _, row in unique_drugs.iterrows():
            drug_id, smiles = row['drug_chembl_id'], row['smiles']
            cache_path = self._get_drug_cache_path(drug_id, smiles)
            if self.force_rebuild or not os.path.exists(cache_path):
                drug_tasks.append((drug_id, smiles))
        
        if drug_tasks:
            print(f"üíä Processing {len(drug_tasks):,} drugs with {self.max_workers} workers...")
            successful_drugs = 0
            with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
                # Submit all drug tasks
                future_to_drug = {
                    executor.submit(self._process_drug, drug_id, smiles): (drug_id, smiles)
                    for drug_id, smiles in drug_tasks
                }
                
                # Track progress with tqdm
                if self.show_progress:
                    with tqdm(total=len(future_to_drug), desc="üíä Building drugs") as pbar:
                        for future in concurrent.futures.as_completed(future_to_drug):
                            drug_id, smiles = future_to_drug[future]
                            try:
                                result = future.result()
                                if result is not None:
                                    successful_drugs += 1
                                pbar.update(1)
                                pbar.set_postfix({
                                    'success': successful_drugs,
                                    'current': drug_id[:10] + "..."
                                })
                            except Exception as e:
                                print(f"‚ùå Error processing drug {drug_id}: {e}")
                                pbar.update(1)
                else:
                    # Without progress bar
                    for future in concurrent.futures.as_completed(future_to_drug):
                        drug_id, smiles = future_to_drug[future]
                        try:
                            result = future.result()
                            if result is not None:
                                successful_drugs += 1
                        except Exception as e:
                            print(f"‚ùå Error processing drug {drug_id}: {e}")
        
        print(f"‚úÖ Precomputation complete!")
        self._show_cache_sizes()
    
    def _show_cache_sizes(self):
        """Show the size of cache directories"""
        def get_dir_size(path):
            total = 0
            for dirpath, dirnames, filenames in os.walk(path):
                for f in filenames:
                    fp = os.path.join(dirpath, f)
                    total += os.path.getsize(fp)
            return total
        
        protein_size = get_dir_size(self.protein_cache_dir) / (1024 * 1024)  # MB
        drug_size = get_dir_size(self.drug_cache_dir) / (1024 * 1024)  # MB
        
        print(f"üíæ Cache sizes:")
        print(f"   üß™ Protein cache: {protein_size:.1f} MB")
        print(f"   üíä Drug cache: {drug_size:.1f} MB")
        print(f"   üíΩ Total cache: {protein_size + drug_size:.1f} MB")
    
    def get_protein_graph(self, uniprot_id, sequence):
        """Get protein graph (single-threaded for training)"""
        return self._process_protein(uniprot_id, sequence)
    
    def get_drug_data(self, drug_id, smiles):
        """Get drug data (single-threaded for training)"""
        return self._process_drug(drug_id, smiles)
    
    def get_side_effects(self, drug_id):
        """Get side effect vector for a drug"""
        if not self.side_effect_mapping or self.side_effect_mapping['num_side_effects'] == 0:
            return torch.tensor([], dtype=torch.float)
        
        se_dict = self.side_effect_mapping['drug_to_se']
        se_id_to_idx = self.side_effect_mapping['se_id_to_idx']
        num_se = self.side_effect_mapping['num_side_effects']
        
        se_vector = torch.zeros(num_se, dtype=torch.float)
        
        drug_id_str = str(drug_id).strip()
        if drug_id_str in se_dict:
            for se_id in se_dict[drug_id_str]:
                if se_id in se_id_to_idx:
                    se_vector[se_id_to_idx[se_id]] = 1.0
        
        return se_vector
    
    def get_cache_stats(self):
        """Get cache hit/miss statistics"""
        with self._stats_lock:
            total_protein = self.cache_stats['protein_hits'] + self.cache_stats['protein_misses']
            total_drug = self.cache_stats['drug_hits'] + self.cache_stats['drug_misses']
            
            stats = {
                'protein_hit_rate': self.cache_stats['protein_hits'] / total_protein if total_protein > 0 else 0,
                'drug_hit_rate': self.cache_stats['drug_hits'] / total_drug if total_drug > 0 else 0,
                'total_requests': total_protein + total_drug,
                'protein_hits': self.cache_stats['protein_hits'],
                'protein_misses': self.cache_stats['protein_misses'],
                'drug_hits': self.cache_stats['drug_hits'],
                'drug_misses': self.cache_stats['drug_misses']
            }
        return stats
    
    def __len__(self):
        return len(self.dti_df)
    
    def __getitem__(self, idx):
        """Fast data loading using persistent cache"""
        row = self.dti_df.iloc[idx]
        
        # These are now fast cache lookups
        protein_graph = self.get_protein_graph(row['target_uniprot_id'], row['sequence'])
        drug_data = self.get_drug_data(row['drug_chembl_id'], row['smiles'])
        
        if drug_data is None:
            # Skip invalid drugs
            return self.__getitem__((idx + 1) % len(self))
        
        binding_label = torch.tensor(row['label'], dtype=torch.float)
        side_effects = self.get_side_effects(row['rxcui'])
        
        return {
            'protein_data': protein_graph,
            'drug_data': drug_data,
            'binding_label': binding_label,
            'side_effects': side_effects,
            'drug_id': row['drug_chembl_id'],
            'uniprot_id': row['target_uniprot_id'],
            'idx': idx
        }

# =============================================================================
# CREATE THE THREADED CACHED DATASET
# =============================================================================

print("üöÄ Creating threaded cached dataset...")
dataset = ThreadedCachedDTIDataset(
    dti_df=dti_pd,
    protein_builder=protein_builder,
    drug_builder=drug_builder,
    adr_df=adr_pd,
    cache_dir="./Cache",
    force_rebuild=False,
    show_progress=True,
    max_workers=2  # Adjust based on your CPU. 4-8 is usually good.
)

# Test the dataset
print("\nüß™ Testing threaded cached dataset...")
sample = dataset[0]
print(f"‚úÖ Sample loaded successfully!")

# Show detailed cache statistics
cache_stats = dataset.get_cache_stats()
print(f"\nüìä Cache Statistics:")
print(f"   üß™ Protein cache: {cache_stats['protein_hits']} hits, {cache_stats['protein_misses']} misses ({cache_stats['protein_hit_rate']:.2%})")
print(f"   üíä Drug cache: {cache_stats['drug_hits']} hits, {cache_stats['drug_misses']} misses ({cache_stats['drug_hit_rate']:.2%})")
print(f"   üìà Total requests: {cache_stats['total_requests']}")

print("\nüéâ Threaded cached dataset ready for training!")

üöÄ Creating threaded cached dataset...
üß¨ Threaded Cached Dataset Initialized:
   üìä Total samples: 34,741
   üß™ Unique proteins: 2,385
   üíä Unique drugs: 1,028
   üßµ Max workers: 2
   üìÅ Protein cache: ./Cache\proteins
   üìÅ Drug cache: ./Cache\drugs
   ü©∫ Number of side effects: 4,817
üî® Precomputing missing graphs with threading...
üìã Found 2,385 unique proteins and 1,028 unique drugs
üíä Processing 1 drugs with 2 workers...


üíä Building drugs: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:00<00:00,  1.05it/s, success=1, current=CHEMBL2110...]


‚úÖ Precomputation complete!
üíæ Cache sizes:
   üß™ Protein cache: 310.5 MB
   üíä Drug cache: 7.0 MB
   üíΩ Total cache: 317.6 MB

üß™ Testing threaded cached dataset...
‚úÖ Sample loaded successfully!

üìä Cache Statistics:
   üß™ Protein cache: 1 hits, 0 misses (100.00%)
   üíä Drug cache: 1 hits, 1 misses (50.00%)
   üìà Total requests: 3

üéâ Threaded cached dataset ready for training!


In [11]:
# Test again
print("\nTesting fixed dataset...")
sample = dataset[3400]
print(f"Sample keys: {sample.keys()}")
print(f"Binding label: {sample['binding_label']}")
print(f"Side effects shape: {sample['side_effects'].shape}")
print(f"Number of side effects in mapping: {dataset.side_effect_mapping['num_side_effects']}")
print(f"Actual side effects sum: {sample['side_effects'].sum().item()}")
print(f"Side effects vector: {sample['side_effects'][:10]}")  # Show first 10 elements


Testing fixed dataset...
Sample keys: dict_keys(['protein_data', 'drug_data', 'binding_label', 'side_effects', 'drug_id', 'uniprot_id', 'idx'])
Binding label: 1.0
Side effects shape: torch.Size([4817])
Number of side effects in mapping: 4817
Actual side effects sum: 77.0
Side effects vector: tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])


In [13]:

protein_feat_dim = sample['protein_data'].x.shape[1]  # Should be 22
drug_scalar_dim = sample['drug_data']['scalar_feats'].shape[1]  # Should be 14
drug_vector_neighbor_dim = sample['drug_data']['vector_feats'].shape[1]  # Should be 4

print(f"üìê Feature dimensions from data:")
print(f"   üß™ Protein feature dim: {protein_feat_dim}")
print(f"   üíä Drug scalar dim: {drug_scalar_dim}")
print(f"   üíä Drug vector neighbor dim: {drug_vector_neighbor_dim}")

üìê Feature dimensions from data:
   üß™ Protein feature dim: 22
   üíä Drug scalar dim: 14
   üíä Drug vector neighbor dim: 4


In [14]:
# =============================================================================
# STEP 2: SIMPLIFIED MODEL THAT USES PRE-BUILT GRAPHS DIRECTLY
# =============================================================================

class SimpleDTIModel(nn.Module):
    def __init__(self, protein_feat_dim, drug_scalar_dim, drug_vector_neighbor_dim, 
                 hidden_dim=128, num_side_effects=4817):
        super().__init__()
        
        # Simple encoders (no heavy GVP/GCN - just process the features we already have)
        self.protein_encoder = nn.Sequential(
            nn.Linear(protein_feat_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        # Drug encoder uses the precomputed features directly
        total_drug_features = drug_scalar_dim + (drug_vector_neighbor_dim * 3)  # scalar + flattened vector
        self.drug_encoder = nn.Sequential(
            nn.Linear(total_drug_features, hidden_dim),
            nn.ReLU(), 
            nn.Dropout(0.2)
        )
        
        self.num_side_effects = num_side_effects
        
        # Interaction and prediction heads (same as before)
        self.interaction_net = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        self.binding_head = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        
        self.side_effect_head = nn.Sequential(
            nn.Linear(hidden_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, self.num_side_effects),
            nn.Sigmoid()
        )
        
    def forward(self, protein_data, drug_data):
        # Simple protein encoding: just use mean of node features
        protein_embedding = self.protein_encoder(protein_data.x)
        protein_embedding = global_mean_pool(protein_embedding, protein_data.batch)
        
        # Simple drug encoding: flatten scalar + vector features
        drug_scalar = drug_data['scalar_feats']
        drug_vector_flat = drug_data['vector_feats'].reshape(drug_data['vector_feats'].size(0), -1)
        drug_combined = torch.cat([drug_scalar, drug_vector_flat], dim=1)
        drug_embedding = self.drug_encoder(drug_combined)
        drug_embedding = global_mean_pool(drug_embedding, drug_data['batch'])
        
        # Interaction and prediction (same as before)
        interaction_features = torch.cat([
            protein_embedding, drug_embedding, protein_embedding * drug_embedding
        ], dim=1)
        
        combined_rep = self.interaction_net(interaction_features)
        binding_score = self.binding_head(combined_rep).squeeze()
        side_effects = self.side_effect_head(combined_rep)
        
        return binding_score, side_effects

In [15]:
def collate_fn(batch):
    """Custom collate function for our heterogeneous data"""
    protein_data_list = []
    drug_data_list = []
    binding_labels = []
    side_effects_list = []
    
    protein_batch = []
    drug_batch = []
    
    for i, sample in enumerate(batch):
        # Protein data
        protein_data = sample['protein_data']
        protein_data.batch = torch.full((protein_data.x.size(0),), i, dtype=torch.long)
        protein_data_list.append(protein_data)
        
        # Drug data  
        drug_data = sample['drug_data'].copy()
        drug_data['batch'] = torch.full((drug_data['scalar_feats'].size(0),), i, dtype=torch.long)
        drug_data_list.append(drug_data)
        
        # Labels
        binding_labels.append(sample['binding_label'])
        side_effects_list.append(sample['side_effects'])
    
    # Batch protein data
    from torch_geometric.data import Batch as PyGBatch
    batched_protein = PyGBatch.from_data_list(protein_data_list)
    
    # Batch drug data manually
    batched_drug = {
        'scalar_feats': torch.cat([d['scalar_feats'] for d in drug_data_list], dim=0),
        'vector_feats': torch.cat([d['vector_feats'] for d in drug_data_list], dim=0),
        'batch': torch.cat([d['batch'] for d in drug_data_list], dim=0)
    }
    
    batched_binding = torch.stack(binding_labels)
    batched_side_effects = torch.stack(side_effects_list)
    
    return batched_protein, batched_drug, batched_binding, batched_side_effects

# Split dataset
from torch.utils.data import random_split

train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size]
)

print(f"üìä Dataset splits:")
print(f"   üöÇ Training: {len(train_dataset):,} samples")
print(f"   üìã Validation: {len(val_dataset):,} samples") 
print(f"   üß™ Test: {len(test_dataset):,} samples")

# Create data loaders
from torch.utils.data import DataLoader

train_loader = DataLoader(
    train_dataset, 
    batch_size=32, 
    shuffle=True, 
    collate_fn=collate_fn,
    num_workers=0
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=32, 
    shuffle=False, 
    collate_fn=collate_fn,
    num_workers=0
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=32, 
    shuffle=False, 
    collate_fn=collate_fn,
    num_workers=0
)

# Test the data loader
print("\nüß™ Testing data loader...")
for batch_idx, (protein_batch, drug_batch, binding_batch, se_batch) in enumerate(train_loader):
    print(f"‚úÖ Batch {batch_idx} loaded successfully!")
    print(f"   üß™ Protein batch: {protein_batch}")
    print(f"   üíä Drug scalar: {drug_batch['scalar_feats'].shape}")
    print(f"   üíä Drug vector: {drug_batch['vector_feats'].shape}")
    print(f"   üéØ Binding labels: {binding_batch.shape}")
    print(f"   ü©∫ Side effects: {se_batch.shape}")
    break

üìä Dataset splits:
   üöÇ Training: 27,792 samples
   üìã Validation: 3,474 samples
   üß™ Test: 3,475 samples

üß™ Testing data loader...
‚úÖ Batch 0 loaded successfully!
   üß™ Protein batch: DataBatch(x=[22679, 22], edge_index=[2, 195932], uniprot_id=[32], sequence=[32], batch=[22679], ptr=[33])
   üíä Drug scalar: torch.Size([1734, 14])
   üíä Drug vector: torch.Size([1734, 4, 3])
   üéØ Binding labels: torch.Size([32])
   ü©∫ Side effects: torch.Size([32, 4817])


In [16]:
# Initialize the simplified model
model = SimpleDTIModel(
    protein_feat_dim=protein_feat_dim,
    drug_scalar_dim=drug_scalar_dim,
    drug_vector_neighbor_dim=drug_vector_neighbor_dim,
    hidden_dim=128,
    num_side_effects=dataset.side_effect_mapping['num_side_effects']
)

print(f"‚úÖ Simplified model created successfully!")
print(f"üìä Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Test model with a batch
print("\nüß™ Testing model forward pass...")
with torch.no_grad():
    binding_pred, side_effect_pred = model(protein_batch, drug_batch)

print(f"‚úÖ Model forward pass successful!")
print(f"   üéØ Binding prediction: {binding_pred.shape}")
print(f"   ü©∫ Side effect prediction: {side_effect_pred.shape}")

‚úÖ Simplified model created successfully!
üìä Model parameters: 2,683,346

üß™ Testing model forward pass...
‚úÖ Model forward pass successful!
   üéØ Binding prediction: torch.Size([32])
   ü©∫ Side effect prediction: torch.Size([32, 4817])


In [17]:
# =============================================================================
# STEP 4: LOSS FUNCTIONS AND METRICS
# =============================================================================

class MultiTaskLoss(nn.Module):
    def __init__(self, binding_weight=0.7, side_effect_weight=0.3):
        super().__init__()
        self.binding_weight = binding_weight
        self.side_effect_weight = side_effect_weight
        self.binding_loss = nn.BCELoss()
        self.side_effect_loss = nn.BCELoss()
    
    def forward(self, binding_pred, side_effect_pred, binding_true, side_effect_true):
        binding_loss = self.binding_loss(binding_pred, binding_true)
        
        # Only compute side effect loss if we have side effect labels
        if side_effect_pred.numel() > 0 and side_effect_true.numel() > 0:
            side_effect_loss = self.side_effect_loss(side_effect_pred, side_effect_true)
        else:
            side_effect_loss = torch.tensor(0.0, device=binding_pred.device)
        
        total_loss = (self.binding_weight * binding_loss + 
                     self.side_effect_weight * side_effect_loss)
        
        return total_loss, binding_loss, side_effect_loss

def calculate_metrics(binding_pred, binding_true, side_effect_pred=None, side_effect_true=None):
    """Calculate evaluation metrics"""
    import numpy as np
    from sklearn.metrics import roc_auc_score, average_precision_score, f1_score
    
    binding_pred_np = binding_pred.detach().cpu().numpy()
    binding_true_np = binding_true.detach().cpu().numpy()
    
    metrics = {}
    
    # Binding metrics
    try:
        metrics['binding_auc'] = roc_auc_score(binding_true_np, binding_pred_np)
        metrics['binding_ap'] = average_precision_score(binding_true_np, binding_pred_np)
        metrics['binding_f1'] = f1_score(binding_true_np, binding_pred_np > 0.5)
    except Exception as e:
        metrics['binding_auc'] = 0.0
        metrics['binding_ap'] = 0.0
        metrics['binding_f1'] = 0.0
    
    # Side effect metrics (if available)
    if side_effect_pred is not None and side_effect_true is not None:
        se_pred_np = side_effect_pred.detach().cpu().numpy()
        se_true_np = side_effect_true.detach().cpu().numpy()
        
        try:
            # Micro-averaged metrics for multi-label classification
            metrics['side_effect_auc'] = roc_auc_score(se_true_np.ravel(), se_pred_np.ravel())
            metrics['side_effect_ap'] = average_precision_score(se_true_np.ravel(), se_pred_np.ravel())
            metrics['side_effect_f1'] = f1_score(se_true_np.ravel(), se_pred_np.ravel() > 0.5)
        except:
            metrics['side_effect_auc'] = 0.0
            metrics['side_effect_ap'] = 0.0
            metrics['side_effect_f1'] = 0.0
    
    return metrics

In [18]:
# =============================================================================
# STEP 5: TRAINING LOOP
# =============================================================================

import torch.optim as optim
from tqdm import tqdm

class Trainer:
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, device):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.model.to(device)
        
        # Training history
        self.history = {
            'train_loss': [], 'val_loss': [],
            'train_binding_auc': [], 'val_binding_auc': [],
            'train_side_effect_auc': [], 'val_side_effect_auc': []
        }
    
    def train_epoch(self):
        self.model.train()
        total_loss = 0
        all_binding_preds = []
        all_binding_targets = []
        all_side_effect_preds = []
        all_side_effect_targets = []
        
        pbar = tqdm(self.train_loader, desc="üöÇ Training")
        for batch_idx, (protein_data, drug_data, binding_labels, se_labels) in enumerate(pbar):
            # Move data to device
            protein_data = protein_data.to(self.device)
            drug_data = {
                'scalar_feats': drug_data['scalar_feats'].to(self.device),
                'vector_feats': drug_data['vector_feats'].to(self.device),
                'batch': drug_data['batch'].to(self.device)
            }
            binding_labels = binding_labels.to(self.device)
            se_labels = se_labels.to(self.device)
            
            # Forward pass
            self.optimizer.zero_grad()
            binding_pred, se_pred = self.model(protein_data, drug_data)
            
            # Calculate loss
            loss, binding_loss, se_loss = self.criterion(binding_pred, se_pred, binding_labels, se_labels)
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()
            
            total_loss += loss.item()
            
            # Collect predictions for metrics
            all_binding_preds.append(binding_pred.detach())
            all_binding_targets.append(binding_labels.detach())
            all_side_effect_preds.append(se_pred.detach())
            all_side_effect_targets.append(se_labels.detach())
            
            # Update progress bar
            pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'B_Loss': f'{binding_loss.item():.4f}',
                'SE_Loss': f'{se_loss.item():.4f}'
            })
        
        # Calculate epoch metrics
        binding_preds = torch.cat(all_binding_preds)
        binding_targets = torch.cat(all_binding_targets)
        side_effect_preds = torch.cat(all_side_effect_preds)
        side_effect_targets = torch.cat(all_side_effect_targets)
        
        metrics = calculate_metrics(binding_preds, binding_targets, side_effect_preds, side_effect_targets)
        
        return total_loss / len(self.train_loader), metrics
    
    def validate(self):
        self.model.eval()
        total_loss = 0
        all_binding_preds = []
        all_binding_targets = []
        all_side_effect_preds = []
        all_side_effect_targets = []
        
        with torch.no_grad():
            for protein_data, drug_data, binding_labels, se_labels in tqdm(self.val_loader, desc="üìã Validating"):
                # Move data to device
                protein_data = protein_data.to(self.device)
                drug_data = {
                    'scalar_feats': drug_data['scalar_feats'].to(self.device),
                    'vector_feats': drug_data['vector_feats'].to(self.device),
                    'batch': drug_data['batch'].to(self.device)
                }
                binding_labels = binding_labels.to(self.device)
                se_labels = se_labels.to(self.device)
                
                # Forward pass
                binding_pred, se_pred = self.model(protein_data, drug_data)
                
                # Calculate loss
                loss, binding_loss, se_loss = self.criterion(binding_pred, se_pred, binding_labels, se_labels)
                total_loss += loss.item()
                
                # Collect predictions
                all_binding_preds.append(binding_pred)
                all_binding_targets.append(binding_labels)
                all_side_effect_preds.append(se_pred)
                all_side_effect_targets.append(se_labels)
        
        # Calculate metrics
        binding_preds = torch.cat(all_binding_preds)
        binding_targets = torch.cat(all_binding_targets)
        side_effect_preds = torch.cat(all_side_effect_preds)
        side_effect_targets = torch.cat(all_side_effect_targets)
        
        metrics = calculate_metrics(binding_preds, binding_targets, side_effect_preds, side_effect_targets)
        
        return total_loss / len(self.val_loader), metrics
    
    def train(self, num_epochs=50, save_path="best_model.pth"):
        print(f"üéØ Starting training for {num_epochs} epochs...")
        
        best_val_auc = 0
        patience = 10
        patience_counter = 0
        
        for epoch in range(num_epochs):
            print(f"\n‚è∞ Epoch {epoch+1}/{num_epochs}")
            print("-" * 50)
            
            # Train
            train_loss, train_metrics = self.train_epoch()
            
            # Validate
            val_loss, val_metrics = self.validate()
            
            # Update history
            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            self.history['train_binding_auc'].append(train_metrics['binding_auc'])
            self.history['val_binding_auc'].append(val_metrics['binding_auc'])
            self.history['train_side_effect_auc'].append(train_metrics.get('side_effect_auc', 0))
            self.history['val_side_effect_auc'].append(val_metrics.get('side_effect_auc', 0))
            
            # Print results
            print(f"üìä Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
            print(f"üéØ Train Binding AUC: {train_metrics['binding_auc']:.4f} | Val Binding AUC: {val_metrics['binding_auc']:.4f}")
            if 'side_effect_auc' in val_metrics:
                print(f"ü©∫ Train SE AUC: {train_metrics['side_effect_auc']:.4f} | Val SE AUC: {val_metrics['side_effect_auc']:.4f}")
            
            # Early stopping
            if val_metrics['binding_auc'] > best_val_auc:
                best_val_auc = val_metrics['binding_auc']
                patience_counter = 0
                torch.save(self.model.state_dict(), save_path)
                print(f"üíæ New best model saved with Val AUC: {best_val_auc:.4f}")
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"üõë Early stopping after {epoch+1} epochs")
                    break
        
        print(f"\n‚úÖ Training completed! Best Val AUC: {best_val_auc:.4f}")
        return self.history

In [19]:
# =============================================================================
# STEP 6: START TRAINING
# =============================================================================

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è Using device: {device}")

# Move model to device
model = model.to(device)

# Optimizer and loss
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
criterion = MultiTaskLoss(binding_weight=0.7, side_effect_weight=0.3)

# Create trainer
trainer = Trainer(model, train_loader, val_loader, criterion, optimizer, device)

# Start training!
history = trainer.train(num_epochs=5, save_path="best_simple_model.pth")

print("üéâ Training pipeline complete!")

üñ•Ô∏è Using device: cuda
üéØ Starting training for 5 epochs...

‚è∞ Epoch 1/5
--------------------------------------------------


üöÇ Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 869/869 [01:43<00:00,  8.43it/s, Loss=0.4384, B_Loss=0.6017, SE_Loss=0.0575]
üìã Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 109/109 [00:09<00:00, 11.15it/s]


üìä Train Loss: 0.4347 | Val Loss: 0.3858
üéØ Train Binding AUC: 0.7062 | Val Binding AUC: 0.7845
ü©∫ Train SE AUC: 0.8776 | Val SE AUC: 0.9144
üíæ New best model saved with Val AUC: 0.7845

‚è∞ Epoch 2/5
--------------------------------------------------


üöÇ Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 869/869 [01:20<00:00, 10.86it/s, Loss=0.3896, B_Loss=0.5322, SE_Loss=0.0570]
üìã Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 109/109 [00:08<00:00, 13.62it/s]


üìä Train Loss: 0.3851 | Val Loss: 0.3740
üéØ Train Binding AUC: 0.7853 | Val Binding AUC: 0.8007
ü©∫ Train SE AUC: 0.9000 | Val SE AUC: 0.9194
üíæ New best model saved with Val AUC: 0.8007

‚è∞ Epoch 3/5
--------------------------------------------------


üöÇ Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 869/869 [01:20<00:00, 10.82it/s, Loss=0.2534, B_Loss=0.3387, SE_Loss=0.0545]
üìã Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 109/109 [00:07<00:00, 13.67it/s]


üìä Train Loss: 0.3759 | Val Loss: 0.3647
üéØ Train Binding AUC: 0.7984 | Val Binding AUC: 0.8126
ü©∫ Train SE AUC: 0.9004 | Val SE AUC: 0.9204
üíæ New best model saved with Val AUC: 0.8126

‚è∞ Epoch 4/5
--------------------------------------------------


üöÇ Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 869/869 [01:16<00:00, 11.38it/s, Loss=0.3526, B_Loss=0.4779, SE_Loss=0.0603]
üìã Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 109/109 [00:08<00:00, 13.15it/s]


üìä Train Loss: 0.3663 | Val Loss: 0.3756
üéØ Train Binding AUC: 0.8125 | Val Binding AUC: 0.8215
ü©∫ Train SE AUC: 0.9009 | Val SE AUC: 0.9195
üíæ New best model saved with Val AUC: 0.8215

‚è∞ Epoch 5/5
--------------------------------------------------


üöÇ Training: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 869/869 [01:16<00:00, 11.43it/s, Loss=0.3790, B_Loss=0.5147, SE_Loss=0.0624]
üìã Validating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 109/109 [00:07<00:00, 14.18it/s]


üìä Train Loss: 0.3632 | Val Loss: 0.3722
üéØ Train Binding AUC: 0.8156 | Val Binding AUC: 0.8256
ü©∫ Train SE AUC: 0.9020 | Val SE AUC: 0.9203
üíæ New best model saved with Val AUC: 0.8256

‚úÖ Training completed! Best Val AUC: 0.8256
üéâ Training pipeline complete!


In [24]:
# =============================================================================
# MANUAL PREDICTION PIPELINE
# =============================================================================

def predict_manual_input(protein_sequence, drug_smiles, model, protein_builder, drug_builder, 
                        side_effect_mapping, uniprot_id, device='cuda', threshold=0.5):
    """
    Make predictions for manual input of protein sequence and drug SMILES
    """
    print(f"üß™ Predicting for:")
    print(f"   üß¨ Protein: {protein_sequence[:50]}...")
    print(f"   üíä Drug: {drug_smiles}")
    
    # Build graphs for the new inputs
    with torch.no_grad():
        # Create protein graph
        protein_graph = protein_builder.build_protein_graph(uniprot_id, protein_sequence)
        protein_graph.batch = torch.zeros(protein_graph.x.size(0), dtype=torch.long)
        
        # Create drug data
        drug_data = drug_builder.smiles_to_gvp_data(drug_smiles, "manual_drug")
        if drug_data is None:
            print("‚ùå Failed to process drug SMILES")
            return None
        
        drug_data['batch'] = torch.zeros(drug_data['scalar_feats'].size(0), dtype=torch.long)
        
        # Move to device
        protein_graph = protein_graph.to(device)
        drug_data = {
            'scalar_feats': drug_data['scalar_feats'].to(device),
            'vector_feats': drug_data['vector_feats'].to(device),
            'batch': drug_data['batch'].to(device)
        }
        
        # Make prediction
        binding_score, side_effect_scores = model(protein_graph, drug_data)
        
        # Convert to probabilities
        binding_prob = binding_score.item()
        side_effect_probs = side_effect_scores.cpu().numpy().flatten()
        
        return binding_prob, side_effect_probs

def interpret_predictions(binding_prob, side_effect_probs, side_effect_mapping, threshold=0.5):
    """
    Convert model outputs to interpretable results
    """
    print("\n" + "="*60)
    print("üéØ PREDICTION RESULTS")
    print("="*60)
    
    # Binding prediction interpretation
    print(f"\nüíä DRUG-TARGET BINDING PREDICTION:")
    print(f"   Probability: {binding_prob:.4f}")
    
    if binding_prob >= threshold:
        print(f"   üü¢ PREDICTION: LIKELY BINDS (confidence: {binding_prob:.1%})")
        binding_strength = "Strong" if binding_prob > 0.8 else "Moderate" if binding_prob > 0.6 else "Weak"
        print(f"   üìä Binding Strength: {binding_strength}")
    else:
        print(f"   üî¥ PREDICTION: UNLIKELY TO BIND (confidence: {(1-binding_prob):.1%})")
    
    # Side effect predictions
    print(f"\nü©∫ PREDICTED SIDE EFFECTS (top 10):")
    
    # Get side effect IDs and names
    se_id_to_idx = side_effect_mapping['se_id_to_idx']
    idx_to_se_id = {v: k for k, v in se_id_to_idx.items()}
    
    # Get ADR data for side effect names
    adr_names = {}
    for _, row in adr_pd.iterrows():
        adr_names[row['meddra_id']] = row['meddra_name']
    
    # Get top predicted side effects
    top_indices = np.argsort(side_effect_probs)[-10:][::-1]  # Top 10 highest probability
    predicted_side_effects = []
    
    for idx in top_indices:
        prob = side_effect_probs[idx]
        if prob >= threshold:
            se_id = idx_to_se_id.get(idx)
            se_name = adr_names.get(se_id, f"Side Effect {se_id}")
            predicted_side_effects.append((se_name, prob))
    
    if predicted_side_effects:
        for i, (se_name, prob) in enumerate(predicted_side_effects, 1):
            confidence = "High" if prob > 0.8 else "Medium" if prob > 0.6 else "Low"
            print(f"   {i:2d}. {se_name:<40} {prob:.4f} ({confidence} confidence)")
    else:
        print("   No significant side effects predicted above threshold")
    
    # Summary statistics
    num_predicted_se = np.sum(side_effect_probs >= threshold)
    avg_se_prob = np.mean(side_effect_probs[side_effect_probs >= threshold]) if num_predicted_se > 0 else 0
    
    print(f"\nüìä SIDE EFFECT SUMMARY:")
    print(f"   Total predicted side effects: {num_predicted_se}")
    print(f"   Average confidence: {avg_se_prob:.4f}")
    
    return binding_prob, predicted_side_effects

# Load the best model for predictions
print("üîß Loading best model for predictions...")
model.load_state_dict(torch.load("best_simple_model.pth"))
model.eval()

# Example usage with some test data
def test_with_examples():
    """Test the prediction pipeline with some examples"""
    
    # Example 1: Get a real example from our dataset
    print("1. Testing with dataset example:")
    sample = dataset[0]
    protein_seq = sample['protein_data'].sequence
    drug_smiles = sample['drug_data']['smiles']
    
    results = predict_manual_input(protein_seq, drug_smiles, model, protein_builder, 
                                 drug_builder, dataset.side_effect_mapping, device)
    if results:
        binding_prob, side_effect_probs = results
        interpret_predictions(binding_prob, side_effect_probs, dataset.side_effect_mapping)
    
    print("\n" + "="*80)
    
    # Example 2: Let user input their own
    print("2. Now test with your own inputs:")

# Run the examples
# test_with_examples()

# Function for continuous manual input
def interactive_prediction():
    """Interactive mode for manual predictions"""
    print("\nüéÆ INTERACTIVE PREDICTION MODE")
    print("Enter 'quit' to exit")
    
    protein_seq = "MSLSFCGNNISSYNINDGVLQNSCFVDALNLVPHVFLLFITFPILFIGWGSQSSKVQIHHNTWLHFPGHNLRWILTFALLFVHVCEIAEGIVSDSRRESRHLHLFMPAVMGFVATTTSIVYYHNIETSNFPKLLLALFLYWVMAFITKTIKLVKYCQSGLDISNLRFCITGMMVILNGLLMAVEINVIRVRRYVFFMNPQKVKPPEDLQDLGVRFLQPFVNLLSKATYWWMNTLIISAHKKPIDLKAIGKLPIAMRAVTNYVCLKDAYEEQKKKVADHPNRTPSIWLAMYRAFGRPILLSSTFRYLADLLGFAGPLCISGIVQRVNETQNGTNNTTGISETLSSKEFLENAYVLAVLLFLALILQRTFLQASYYVTIETGINLRGALLAMIYNKILRLSTSNLSMGEMTLGQINNLVAIETNQLMWFLFLCPNLWAMPVQIIMGVILLYNLLGSSALVGAAVIVLLAPIQYFIATKLAEAQKSTLDYSTERLKKTNEILKGIKLLKLYAWEHIFCKSVEETRMKELSSLKTFALYTSLSIFMNAAIPIAAVLATFVTHAYASGNNLKPAEAFASLSLFHILVTPLFLLSTVVRFAVKAIISVQKLNEFLLSDEIGDDSWRTGESSLPFESCKKHTGVQPKTINRKQPGRYHLDSYEQSTRRLRPAETEDIAIKVTNGYFSWGSGLATLSNIDIRIPTGQLTMIVGQVGCGKSSLLLAILGEMQTLEGKVHWSNVNESEPSFEATRSRNRYSVAYAAQKPWLLNATVEENITFGSPFNKQRYKAVTDACSLQPDIDLLPFGDQTEIGERGINLSGGQRQRICVARALYQNTNIVFLDDPFSALDIHLSDHLMQEGILKFLQDDKRTLVLVTHKLQYLTHADWIIAMKDGSVLREGTLKDIQTKDVELYEHWKTLMNRQDQELEKDMEADQTTLERKTLRRAMYSREAKAQMEDEDEEEEEEEDEDDNMSTVMRLRTKMPWKTCWRYLTSGGFFLLILMIFSKLLKHSVIVAIDYWLATWTSEYSINNTGKADQTYYVAGFSILCGAGIFLCLVTSLTVEWMGLTAAKNLHHNLLNKIILGPIRFFDTTPLGLILNRFSADTNIIDQHIPPTLESLTRSTLLCLSAIGMISYATPVFLVALLPLGVAFYFIQKYFRVASKDLQELDDSTQLPLLCHFSETAEGLTTIRAFRHETRFKQRMLELTDTNNIAYLFLSAANRWLEVRTDYLGACIVLTASIASISGSSNSGLVGLGLLYALTITNYLNWVVRNLADLEVQMGAVKKVNSFLTMESENYEGTMDPSQVPEHWPQEGEIKIHDLCVRYENNLKPVLKHVKAYIKPGQKVGICGRTGSGKSSLSLAFFRMVDIFDGKIVIDGIDISKLPLHTLRSRLSIILQDPILFSGSIRFNLDPECKCTDDRLWEALEIAQLKNMVKSLPGGLDAVVTEGGENFSVGQRQLFCLARAFVRKSSILIMDEATASIDMATENILQKVVMTAFADRTVVTIAHRVSSIMDAGLVLVFSEGILVECDTVPNLLAHKNGLFSTLVMTNK"    
    drug_smiles = "CC1(C)Oc2ccc(C#N)cc2[C@@H](N2CCCC2=O)[C@@H]1O"
    uniprot_id = "O60706"
    try:
        results = predict_manual_input(protein_seq, drug_smiles, model, protein_builder, 
                                        drug_builder, dataset.side_effect_mapping, uniprot_id, device)
        if results:
            binding_prob, side_effect_probs = results
            interpret_predictions(binding_prob, side_effect_probs, dataset.side_effect_mapping)
    except Exception as e:
        print(f"‚ùå Error during prediction: {e}")

# Start interactive mode
interactive_prediction()

üîß Loading best model for predictions...

üéÆ INTERACTIVE PREDICTION MODE
Enter 'quit' to exit
üß™ Predicting for:
   üß¨ Protein: MSLSFCGNNISSYNINDGVLQNSCFVDALNLVPHVFLLFITFPILFIGWG...
   üíä Drug: CC1(C)Oc2ccc(C#N)cc2[C@@H](N2CCCC2=O)[C@@H]1O

üéØ PREDICTION RESULTS

üíä DRUG-TARGET BINDING PREDICTION:
   Probability: 0.5851
   üü¢ PREDICTION: LIKELY BINDS (confidence: 58.5%)
   üìä Binding Strength: Weak

ü©∫ PREDICTED SIDE EFFECTS (top 10):
    1. UTI                                      0.8925 (High confidence)
    2. Gas                                      0.8771 (High confidence)
    3. Nervous                                  0.8348 (High confidence)
    4. Nausea                                   0.8026 (High confidence)
    5. Gastrointestinal disorder                0.7564 (Medium confidence)
    6. Vomiting                                 0.7101 (Medium confidence)
    7. Rash                                     0.6810 (Medium confidence)
    8. Headache         

# This part is different

In [90]:
print(dataset[0]['protein_data'])

Data(x=[554, 22], edge_index=[2, 5418], uniprot_id='O15245', sequence='MPTVDDILEQVGESGWFQKQAFLILCLLSAAFAPICVGIVFLGFTPDHHCQSPGVAELSQRCGWSPAEELNYTVPGLGPAGEAFLGQCRRYEVDWNQSALSCVDPLASLATNRSHLPLGPCQDGWVYDTPGSSIVTEFNLVCADSWKLDLFQSCLNAGFLFGSLGVGYFADRFGRKLCLLGTVLVNAVSGVLMAFSPNYMSMLLFRLLQGLVSKGNWMAGYTLITEFVGSGSRRTVAIMYQMAFTVGLVALTGLAYALPHWRWLQLAVSLPTFLFLLYYWCVPESPRWLLSQKRNTEAIKIMDHIAQKNGKLPPADLKMLSLEEDVTEKLSPSFADLFRTPRLRKRTFILMYLWFTDSVLYQGLILHMGATSGNLYLDFLYSALVEIPGAFIALITIDRVGRIYPMAMSNLLAGAACLVMIFISPDLHWLNIIIMCVGRMGITIAIQMICLVNAELYPTFVRNLGVMVCSSLCDIGGIITPFIVFRLREVWQALPLILFAVLGLLAAGVTLLLPETKGVALPETMKDAENLGRKAKPKENTIYLKVQTSEPSGT', batch=[554])


In [70]:
class GVP(nn.Module):
    """Geometric Vector Perceptron - FIXED VERSION"""
    def __init__(self, in_scalar, in_vector, out_scalar, out_vector):
        super().__init__()
        self.in_scalar = in_scalar
        self.in_vector = in_vector
        self.out_scalar = out_scalar
        self.out_vector = out_vector
        
        # Scalar pathway (includes vector norms)
        self.W_h = nn.Linear(in_scalar + in_vector, out_scalar)
        
        # Vector pathway - FIXED: input should be [..., 3, in_vector]
        self.W_v = nn.Linear(in_vector, out_vector, bias=False)
        
    def forward(self, h, v):
        # h: [num_nodes, in_scalar] 
        # v: [num_nodes, 3, in_vector]
        
        # Compute vector norms and incorporate into scalar features
        v_norm = torch.norm(v, dim=1)  # [num_nodes, in_vector]
        
        # Enhanced scalar features
        h_input = torch.cat([h, v_norm], dim=-1)
        h_out = F.relu(self.W_h(h_input))
        
        # Transform vector features - FIXED
        # v shape: [num_nodes, 3, in_vector] -> we want to transform the vector dimension
        v_reshaped = v.reshape(-1, self.in_vector)  # [num_nodes * 3, in_vector]
        v_transformed = self.W_v(v_reshaped)  # [num_nodes * 3, out_vector]
        v_out = v_transformed.reshape(-1, 3, self.out_vector)  # [num_nodes, 3, out_vector]
        
        return h_out, v_out

class GVPBlock(nn.Module):
    def __init__(self, scalar_dim, vector_dim):
        super().__init__()
        self.gvp = GVP(scalar_dim, vector_dim, scalar_dim, vector_dim)
        self.ln_scalar = nn.LayerNorm(scalar_dim)
        self.ln_vector = nn.LayerNorm(vector_dim)
        
    def forward(self, h, v):
        h_res, v_res = h, v
        h, v = self.gvp(h, v)
        h = self.ln_scalar(h + h_res)
        v = self.ln_vector(v + v_res)
        return h, v

class DrugGVPEncoder(nn.Module):
    def __init__(self, scalar_dim, vector_neighbor_dim, hidden_dim=128, num_layers=3):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.vector_neighbor_dim = vector_neighbor_dim
        
        # Scalar feature projection
        self.scalar_proj = nn.Linear(scalar_dim, hidden_dim)
        
        # Vector feature projection - FIXED: input should be neighbors * coordinates = 4 * 3 = 12
        total_vector_features = vector_neighbor_dim * 3  # 4 neighbors * 3 coordinates = 12
        self.vector_proj = nn.Linear(total_vector_features, hidden_dim)
        
        # GVP blocks
        self.gvp_blocks = nn.ModuleList([
            GVPBlock(hidden_dim, hidden_dim) for _ in range(num_layers)
        ])
        
        # Readout layer
        self.readout = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        print(f"DrugGVPEncoder initialized:")
        print(f"  scalar_dim: {scalar_dim}")
        print(f"  vector_neighbor_dim: {vector_neighbor_dim}")
        print(f"  total_vector_features: {total_vector_features}")
        print(f"  vector_proj: {self.vector_proj}")
        
    def forward(self, scalar_feats, vector_feats, batch_mask=None):
        # scalar_feats: [58, 14]
        # vector_feats: [58, 4, 3]
        
        # Project scalar features
        h = self.scalar_proj(scalar_feats)  # [58, 128]
        
        # Project vector features - CORRECT
        v_flat = vector_feats.reshape(vector_feats.size(0), -1)  # [58, 4*3] = [58, 12]
        v_projected = self.vector_proj(v_flat)  # [58, 128]
        
        # Create vector representation for GVP
        v = v_projected.unsqueeze(1).repeat(1, 3, 1)  # [58, 3, 128]
        
        # Apply GVP blocks
        for gvp_block in self.gvp_blocks:
            h, v = gvp_block(h, v)
        
        # Readout
        v_norm = torch.norm(v, dim=1)  # [58, 128]
        combined = torch.cat([h, v_norm], dim=-1)  # [58, 256]
        
        # Global pooling
        if batch_mask is not None:
            drug_embedding = torch.zeros(batch_mask.max() + 1, combined.size(-1), device=combined.device)
            for i in range(batch_mask.max() + 1):
                mask = (batch_mask == i)
                if mask.any():
                    drug_embedding[i] = combined[mask].mean(dim=0)
        else:
            drug_embedding = combined.mean(dim=0, keepdim=True)  # [1, 256]
        
        drug_embedding = self.readout(drug_embedding)  # [1, 128]
        return drug_embedding

In [68]:
cfg.update_from_data(
    protein_feat_dim=22,           # From protein graph: [931, 22]
    drug_scalar_dim=14,            # From drug scalar: [58, 14]  
    drug_vector_neighbor_dim=4,    # From drug vector: [58, 4, 3]
    drug_vector_coord_dim=3        # From drug vector: [58, 4, 3]
)

Configuration updated from data:
  protein_feat_dim: 22
  drug_scalar_dim: 14
  drug_vector_neighbor_dim: 4
  drug_vector_coord_dim: 3


In [71]:
class GCN_GVP_Model(nn.Module):
    def __init__(self, protein_feat_dim, drug_scalar_dim, drug_vector_dim, 
                 hidden_dim=128, num_side_effects=4817):
        super().__init__()
        
        # Encoders
        self.protein_encoder = ProteinGCNEncoder(protein_feat_dim, hidden_dim)
        self.drug_encoder = DrugGVPEncoder(drug_scalar_dim, drug_vector_dim, hidden_dim)
        
        self.num_side_effects = num_side_effects
        
        # Interaction module
        self.interaction_net = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        # Prediction heads
        self.binding_head = nn.Sequential(
            nn.Linear(hidden_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        
        # Simpler side effect head to avoid overfitting
        self.side_effect_head = nn.Sequential(
            nn.Linear(hidden_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, self.num_side_effects),
            nn.Sigmoid()
        )
        
        print(f"Model initialized with:")
        print(f"  - Protein feature dim: {protein_feat_dim}")
        print(f"  - Drug scalar dim: {drug_scalar_dim}")
        print(f"  - Drug vector dim: {drug_vector_dim}")
        print(f"  - Hidden dim: {hidden_dim}")
        print(f"  - Number of side effects: {self.num_side_effects}")
        
    def forward(self, protein_data, drug_data):
        # Encode protein
        protein_embedding = self.protein_encoder(
            protein_data.x, 
            protein_data.edge_index,
            protein_data.batch if hasattr(protein_data, 'batch') else None
        )
        
        # Encode drug
        drug_embedding = self.drug_encoder(
            drug_data['scalar_feats'],
            drug_data['vector_feats'],
            drug_data['batch'] if 'batch' in drug_data else None
        )
        
        # Make sure embeddings have the same batch size
        if protein_embedding.dim() == 1:
            protein_embedding = protein_embedding.unsqueeze(0)
        if drug_embedding.dim() == 1:
            drug_embedding = drug_embedding.unsqueeze(0)
            
        # Interaction features
        interaction_features = torch.cat([
            protein_embedding,
            drug_embedding,
            protein_embedding * drug_embedding  # Element-wise interaction
        ], dim=1)
        
        # Combined representation
        combined_rep = self.interaction_net(interaction_features)
        
        # Predictions
        binding_score = self.binding_head(combined_rep).squeeze()
        side_effects = self.side_effect_head(combined_rep)
        
        return binding_score, side_effects

# Reinitialize the model with fixed GVP
print("Reinitializing model with fixed GVP...")
model = GCN_GVP_Model(
    protein_feat_dim=cfg.protein_feat_dim,
    drug_scalar_dim=cfg.drug_scalar_dim, 
    drug_vector_dim=cfg.drug_vector_neighbor_dim,  # This is actually the neighbor count (4)
    hidden_dim=cfg.hidden_dim,
    num_side_effects=4817
)

# Test the model again
print("\nTesting fixed model forward pass...")
sample_protein = sample['protein_data']
sample_drug = sample['drug_data']

# Add batch dimension
sample_protein.batch = torch.zeros(sample_protein.x.size(0), dtype=torch.long)
sample_drug['batch'] = torch.zeros(sample_drug['scalar_feats'].size(0), dtype=torch.long)

print(f"Input shapes:")
print(f"  Protein features: {sample_protein.x.shape}")
print(f"  Drug scalar features: {sample_drug['scalar_feats'].shape}")
print(f"  Drug vector features: {sample_drug['vector_feats'].shape}")

# Forward pass
with torch.no_grad():
    binding_pred, side_effect_pred = model(sample_protein, sample_drug)

print(f"\nOutput shapes:")
print(f"Binding prediction: {binding_pred.shape} - {binding_pred.item():.4f}")
print(f"Side effect prediction: {side_effect_pred.shape}")
print(f"Actual binding label: {sample['binding_label'].item()}")
print(f"Actual side effects sum: {sample['side_effects'].sum().item()}")

Reinitializing model with fixed GVP...
DrugGVPEncoder initialized:
  scalar_dim: 14
  vector_neighbor_dim: 4
  total_vector_features: 12
  vector_proj: Linear(in_features=12, out_features=128, bias=True)
Model initialized with:
  - Protein feature dim: 22
  - Drug scalar dim: 14
  - Drug vector dim: 4
  - Hidden dim: 128
  - Number of side effects: 4817

Testing fixed model forward pass...
Input shapes:
  Protein features: torch.Size([931, 22])
  Drug scalar features: torch.Size([58, 14])
  Drug vector features: torch.Size([58, 4, 3])

Output shapes:
Binding prediction: torch.Size([]) - 0.5125
Side effect prediction: torch.Size([1, 4817])
Actual binding label: 1.0
Actual side effects sum: 77.0


In [72]:
from torch_geometric.data import Batch as PyGBatch

def collate_fn(batch):
    """Custom collate function for our heterogeneous data"""
    protein_data_list = []
    drug_data_list = []
    binding_labels = []
    side_effects_list = []
    
    protein_batch = []
    drug_batch = []
    
    for i, sample in enumerate(batch):
        # Protein data
        protein_data = sample['protein_data']
        protein_data.batch = torch.full((protein_data.x.size(0),), i, dtype=torch.long)
        protein_data_list.append(protein_data)
        
        # Drug data  
        drug_data = sample['drug_data'].copy()
        drug_data['batch'] = torch.full((drug_data['scalar_feats'].size(0),), i, dtype=torch.long)
        drug_data_list.append(drug_data)
        
        # Labels
        binding_labels.append(sample['binding_label'])
        side_effects_list.append(sample['side_effects'])
    
    # Batch protein data using PyG's Batch
    batched_protein = PyGBatch.from_data_list(protein_data_list)
    
    # Batch drug data manually
    batched_drug = {
        'scalar_feats': torch.cat([d['scalar_feats'] for d in drug_data_list], dim=0),
        'vector_feats': torch.cat([d['vector_feats'] for d in drug_data_list], dim=0),
        'batch': torch.cat([d['batch'] for d in drug_data_list], dim=0)
    }
    
    batched_binding = torch.stack(binding_labels)
    batched_side_effects = torch.stack(side_effects_list)
    
    return batched_protein, batched_drug, batched_binding, batched_side_effects

In [74]:
# Create data loaders with our working collate function
from torch.utils.data import DataLoader, random_split

# Split dataset
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size]
)

print(f"Dataset splits: Train {len(train_dataset)}, Val {len(val_dataset)}, Test {len(test_dataset)}")

# Create data loaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=cfg.batch_size, 
    shuffle=True, 
    collate_fn=collate_fn,
    num_workers=0
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=cfg.batch_size, 
    shuffle=False, 
    collate_fn=collate_fn,
    num_workers=0
)

test_loader = DataLoader(
    test_dataset, 
    batch_size=cfg.batch_size, 
    shuffle=False, 
    collate_fn=collate_fn,
    num_workers=0
)

# Test the data loader with a batch
print("\nTesting data loader with batch...")
for batch_idx, (protein_batch, drug_batch, binding_batch, se_batch) in enumerate(train_loader):
    print(f"Batch {batch_idx}:")
    print(f"  Protein batch: {protein_batch}")
    print(f"  Drug scalar batch: {drug_batch['scalar_feats'].shape}")
    print(f"  Drug vector batch: {drug_batch['vector_feats'].shape}")
    print(f"  Binding labels: {binding_batch.shape}")
    print(f"  Side effects: {se_batch.shape}")
    
    # Test model with batch
    with torch.no_grad():
        batch_binding_pred, batch_se_pred = model(protein_batch, drug_batch)
    
    print(f"  Model outputs - Binding: {batch_binding_pred.shape}, Side effects: {batch_se_pred.shape}")
    
    if batch_idx == 0:  # Just test first batch
        break

Dataset splits: Train 27792, Val 3474, Test 3475

Testing data loader with batch...
Batch 0:
  Protein batch: DataBatch(x=[22348, 22], edge_index=[2, 190314], uniprot_id=[32], sequence=[32], batch=[22348], ptr=[33])
  Drug scalar batch: torch.Size([1851, 14])
  Drug vector batch: torch.Size([1851, 4, 3])
  Binding labels: torch.Size([32])
  Side effects: torch.Size([32, 4817])
  Model outputs - Binding: torch.Size([32]), Side effects: torch.Size([32, 4817])


In [73]:
class MultiTaskLoss(nn.Module):
    def __init__(self, binding_weight=0.7, side_effect_weight=0.3, pos_weight=None):
        super().__init__()
        self.binding_weight = binding_weight
        self.side_effect_weight = side_effect_weight
        
        # For binding classification
        self.binding_bce = nn.BCELoss()
        
        # For multi-label side effects
        if pos_weight is not None:
            self.side_effect_bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
        else:
            self.side_effect_bce = nn.BCEWithLogitsLoss()
    
    def forward(self, binding_pred, side_effect_pred, binding_true, side_effect_true):
        binding_loss = self.binding_bce(binding_pred, binding_true)
        
        # Side effect loss - only compute if we have side effect labels
        if side_effect_pred.numel() > 0 and side_effect_true.numel() > 0:
            side_effect_loss = self.side_effect_bce(side_effect_pred, side_effect_true)
        else:
            side_effect_loss = torch.tensor(0.0, device=binding_pred.device)
        
        total_loss = (self.binding_weight * binding_loss + 
                     self.side_effect_weight * side_effect_loss)
        
        return total_loss, binding_loss, side_effect_loss

In [75]:
import torch.optim as optim
from sklearn.metrics import roc_auc_score, average_precision_score
import numpy as np

class Trainer:
    def __init__(self, model, train_loader, val_loader, criterion, optimizer, device):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.criterion = criterion
        self.optimizer = optimizer
        self.device = device
        self.model.to(device)
        
    def train_epoch(self):
        self.model.train()
        total_loss = 0
        binding_preds = []
        binding_targets = []
        
        for batch_idx, (protein_data, drug_data, binding_labels, se_labels) in enumerate(self.train_loader):
            # Move data to device
            protein_data = protein_data.to(self.device)
            drug_data = {
                'scalar_feats': drug_data['scalar_feats'].to(self.device),
                'vector_feats': drug_data['vector_feats'].to(self.device),
                'batch': drug_data['batch'].to(self.device)
            }
            binding_labels = binding_labels.to(self.device)
            se_labels = se_labels.to(self.device)
            
            # Forward pass
            self.optimizer.zero_grad()
            binding_pred, se_pred = self.model(protein_data, drug_data)
            
            # Calculate loss
            loss, binding_loss, se_loss = self.criterion(binding_pred, se_pred, binding_labels, se_labels)
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()
            
            total_loss += loss.item()
            binding_preds.extend(binding_pred.detach().cpu().numpy())
            binding_targets.extend(binding_labels.detach().cpu().numpy())
            
            if batch_idx % 100 == 0:
                print(f'  Batch {batch_idx}, Loss: {loss.item():.4f}')
        
        # Calculate metrics
        binding_auc = roc_auc_score(binding_targets, binding_preds)
        binding_ap = average_precision_score(binding_targets, binding_preds)
        
        return total_loss / len(self.train_loader), binding_auc, binding_ap
    
    def validate(self):
        self.model.eval()
        total_loss = 0
        binding_preds = []
        binding_targets = []
        se_preds = []
        se_targets = []
        
        with torch.no_grad():
            for protein_data, drug_data, binding_labels, se_labels in self.val_loader:
                # Move data to device
                protein_data = protein_data.to(self.device)
                drug_data = {
                    'scalar_feats': drug_data['scalar_feats'].to(self.device),
                    'vector_feats': drug_data['vector_feats'].to(self.device),
                    'batch': drug_data['batch'].to(self.device)
                }
                binding_labels = binding_labels.to(self.device)
                se_labels = se_labels.to(self.device)
                
                # Forward pass
                binding_pred, se_pred = self.model(protein_data, drug_data)
                
                # Calculate loss
                loss, binding_loss, se_loss = self.criterion(binding_pred, se_pred, binding_labels, se_labels)
                
                total_loss += loss.item()
                binding_preds.extend(binding_pred.detach().cpu().numpy())
                binding_targets.extend(binding_labels.detach().cpu().numpy())
                se_preds.extend(se_pred.detach().cpu().numpy())
                se_targets.extend(se_labels.detach().cpu().numpy())
        
        # Calculate metrics
        binding_auc = roc_auc_score(binding_targets, binding_preds)
        binding_ap = average_precision_score(binding_targets, binding_preds)
        
        # Side effect metrics (simplified)
        se_preds = np.array(se_preds)
        se_targets = np.array(se_targets)
        se_auc = roc_auc_score(se_targets.ravel(), se_preds.ravel()) if len(se_targets) > 0 else 0.0
        
        return total_loss / len(self.val_loader), binding_auc, binding_ap, se_auc

# Initialize training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Optimizer and loss
optimizer = optim.Adam(model.parameters(), lr=cfg.learning_rate, weight_decay=1e-4)
criterion = MultiTaskLoss(binding_weight=0.7, side_effect_weight=0.3)

# Trainer
trainer = Trainer(model, train_loader, val_loader, criterion, optimizer, device)

print("Training setup complete!")

Using device: cuda
Training setup complete!


In [76]:
# Training loop
def train_model(num_epochs=50):
    best_val_auc = 0
    patience = 10
    patience_counter = 0
    
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        
        # Train
        train_loss, train_binding_auc, train_binding_ap = trainer.train_epoch()
        
        # Validate
        val_loss, val_binding_auc, val_binding_ap, val_se_auc = trainer.validate()
        
        print(f'Train Loss: {train_loss:.4f}, Train AUC: {train_binding_auc:.4f}')
        print(f'Val Loss: {val_loss:.4f}, Val AUC: {val_binding_auc:.4f}, Val SE AUC: {val_se_auc:.4f}')
        
        # Early stopping
        if val_binding_auc > best_val_auc:
            best_val_auc = val_binding_auc
            patience_counter = 0
            # Save best model
            torch.save(model.state_dict(), 'best_model.pth')
            print(f'New best model saved with AUC: {best_val_auc:.4f}')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f'Early stopping after {epoch+1} epochs')
                break

# Start training
print("Starting training...")
train_model(num_epochs=1)

Starting training...

Epoch 1/1
  Batch 0, Loss: 0.7785


[16:31:51] UFFTYPER: Unrecognized charge state for atom: 14
[16:33:37] UFFTYPER: Unrecognized charge state for atom: 12
[16:33:51] UFFTYPER: Unrecognized charge state for atom: 0
[16:33:51] UFFTYPER: Unrecognized atom type: Zn+2 (0)
[16:34:32] UFFTYPER: Unrecognized atom type: Ca+2 (0)
[16:34:32] UFFTYPER: Unrecognized atom type: Ca+2 (0)
[16:34:32] UFFTYPER: Unrecognized atom type: Ca+2 (0)
[16:39:53] UFFTYPER: Unrecognized charge state for atom: 20
[16:40:48] UFFTYPER: Unrecognized charge state for atom: 4
[16:45:12] UFFTYPER: Unrecognized charge state for atom: 14


KeyboardInterrupt: 