In [132]:
#!pip install graphein MDAnalysis torch_geometric torchmetrics wandb
#!apt-get install dssp

In [133]:
import os
import random
import glob
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pickle
import h5py
import collections
import wandb

import networkx as nx
import numpy as np
import pandas as pd

from tqdm import tqdm
from functools import partial
from enum import Enum

from torch.utils.data import DataLoader
from torch_geometric.data import Dataset, Batch, Data
from torch_geometric.nn import GATv2Conv, global_mean_pool
from torch_geometric.utils import to_undirected

import MDAnalysis as mda
from MDAnalysis.lib.distances import calc_dihedrals

from graphein.protein.config import ProteinGraphConfig, DSSPConfig
from graphein.protein.graphs import construct_graph
from graphein.protein.edges.distance import (
    add_aromatic_interactions, add_disulfide_interactions,
    add_hydrogen_bond_interactions, add_peptide_bonds,
    add_hydrophobic_interactions, add_ionic_interactions,
    add_k_nn_edges, add_distance_threshold
)
from graphein.protein.features.nodes.amino_acid import amino_acid_one_hot
from graphein.protein.features.nodes import asa, rsa
from graphein.protein.features.nodes.dssp import secondary_structure

In [134]:
# ============================================================================
# GLOBAL CONFIGURATION
# ============================================================================

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_ROOT = "/content/drive/MyDrive/protein_triplets_data"

print(f"Device: {DEVICE}")

Device: cuda


In [135]:
# ============================================================================
# CONFIGURATION CLASS
# ============================================================================

class ProjectConfig:
    """Centralized configuration for features, edges, and hyperparameters"""

    def __init__(self):
        # Node Features
        self.use_coords = False
        self.use_b_factor = False
        self.use_amino_acid = True
        self.use_asa = False
        self.use_rsa = False
        self.use_ss = False
        self.use_backbone_dh = False
        self.use_sidechain_dh = False
        self.use_embedding = False
        self.esm_dim = 1280

        # Edge Types
        self.edge_peptide = False
        self.edge_aromatic = False
        self.edge_disulfide = False
        self.edge_hydrogen = False
        self.edge_hydrophobic = False
        self.edge_ionic = False
        self.edge_knn = True
        self.edge_distance = False
        self.knn_k = 10
        self.dist_threshold = 8.0

        # Model Hyperparameters
        self.hidden_dim = 64
        self.output_dim = 256
        self.heads = 4
        self.dropout = 0.1
        self.lr = 0.0005
        self.batch_size = 8
        self.epochs = 50
        self.margin = 0.2

    @property
    def input_dim(self):
        """Calculate input dimension based on active features"""
        dim = 0
        if self.use_coords: dim += 3
        if self.use_b_factor: dim += 1
        if self.use_amino_acid: dim += 20
        if self.use_asa: dim += 1
        if self.use_rsa: dim += 1
        if self.use_ss: dim += 8
        if self.use_backbone_dh: dim += 3
        if self.use_sidechain_dh: dim += 5
        if self.use_embedding: dim += self.esm_dim
        return dim

    def get_active_edge_funcs(self):
        """Return list of active edge construction functions"""
        edge_funcs = []
        if self.edge_peptide: edge_funcs.append(add_peptide_bonds)
        if self.edge_aromatic: edge_funcs.append(add_aromatic_interactions)
        if self.edge_disulfide: edge_funcs.append(add_disulfide_interactions)
        if self.edge_hydrogen: edge_funcs.append(add_hydrogen_bond_interactions)
        if self.edge_hydrophobic: edge_funcs.append(add_hydrophobic_interactions)
        if self.edge_ionic: edge_funcs.append(add_ionic_interactions)
        if self.edge_knn: edge_funcs.append(partial(add_k_nn_edges, k=self.knn_k))
        if self.edge_distance:
            edge_funcs.append(partial(add_distance_threshold, long_interaction_threshold=self.dist_threshold))
        return edge_funcs

    def get_active_node_metadata_funcs(self):
        """Seçilen node özelliklerine göre Graphein fonksiyonlarını döndürür."""
        node_funcs = []

        if self.use_amino_acid:
            node_funcs.append(amino_acid_one_hot)

        if self.use_asa:
            node_funcs.append(asa)

        if self.use_rsa:
            node_funcs.append(rsa)

        if self.use_ss:
            node_funcs.append(secondary_structure)

        return node_funcs

    def get_node_attributes_list(self):
        """Return list of active node attributes"""
        attrs = []
        if self.use_coords: attrs.append("coords")
        if self.use_b_factor: attrs.append("b_factor")
        if self.use_amino_acid: attrs.append("amino_acid_one_hot")
        if self.use_asa: attrs.append("asa")
        if self.use_rsa: attrs.append("rsa")
        if self.use_ss: attrs.append("ss")
        if self.use_backbone_dh: attrs.append("backbone_dihedral_radians")
        if self.use_sidechain_dh: attrs.append("sidechain_dihedral_radians")
        if self.use_embedding: attrs.append("embedding")
        return attrs

In [None]:
# ============================================================================
# DATA MAPPER
# ============================================================================

class TripletDataPathMapper:
    """Maps protein file structure to anchor-positive-negative triplets with train/val/test splits"""

    def __init__(self, root_dir, train_ratio=0.9, val_ratio=0.05, test_ratio=0.05, seed=42):
        self.root_dir = root_dir
        self.triplets = []
        self.train_ratio = train_ratio
        self.val_ratio = val_ratio
        self.test_ratio = test_ratio
        self.seed = seed
        
        # Split indices
        self.train_triplets = []
        self.val_triplets = []
        self.test_triplets = []
        
        self._map_data()
        self._split_data()

    def _map_data(self):
        originals = glob.glob(os.path.join(self.root_dir, 'originals', "*.pdb"))

        for anchor in originals:
            prot_id = os.path.splitext(os.path.basename(anchor))[0]
            pos_dir = os.path.join(self.root_dir, 'positives', prot_id)
            neg_dir = os.path.join(self.root_dir, 'negatives', prot_id)

            p_files = glob.glob(os.path.join(pos_dir, "*.pdb"))
            n_files = glob.glob(os.path.join(neg_dir, "*.pdb"))

            if p_files and n_files:
                self.triplets.append({
                    'anchor': anchor,
                    'positives': p_files,
                    'negatives': n_files,
                    'protein_id': prot_id
                })

        print(f"Found {len(self.triplets)} protein families")

    def _split_data(self):
        """Split protein families into train/val/test sets"""
        random.seed(self.seed)
        
        # Shuffle protein families
        indices = list(range(len(self.triplets)))
        random.shuffle(indices)
        
        n_total = len(indices)
        n_train = int(n_total * self.train_ratio)
        n_val = int(n_total * self.val_ratio)
        
        train_indices = indices[:n_train]
        val_indices = indices[n_train:n_train + n_val]
        test_indices = indices[n_train + n_val:]
        
        self.train_triplets = [self.triplets[i] for i in train_indices]
        self.val_triplets = [self.triplets[i] for i in val_indices]
        self.test_triplets = [self.triplets[i] for i in test_indices]
        
        print(f"Split: Train={len(self.train_triplets)}, Val={len(self.val_triplets)}, Test={len(self.test_triplets)}")
        print(f"Train proteins: {[t['protein_id'] for t in self.train_triplets]}")
        print(f"Val proteins: {[t['protein_id'] for t in self.val_triplets]}")
        print(f"Test proteins: {[t['protein_id'] for t in self.test_triplets]}")
    
    def get_split(self, split='train'):
        """Get triplets for a specific split"""
        if split == 'train':
            return self.train_triplets
        elif split == 'val':
            return self.val_triplets
        elif split == 'test':
            return self.test_triplets
        else:
            raise ValueError(f"Unknown split: {split}. Use 'train', 'val', or 'test'")

In [None]:
# ============================================================================
# DATASET
# ============================================================================

class TripletProteinGraphDataset(Dataset):
    """PyTorch Geometric Dataset for protein triplets"""

    def __init__(self, mapper, root, config: ProjectConfig, split='train', esm2_embedding_path: str = None, force: bool = False):
        self.mapper = mapper
        self.config = config
        self.split = split
        self.triplets = mapper.get_split(split) # triplets for the specified split
        self.esm2_embedding_path = esm2_embedding_path

        processed_dir = os.path.join(root, "processed")
        if not os.path.exists(processed_dir):
            os.makedirs(processed_dir)
        if force:
            self._clear_processed_files()

        # Load embeddings if needed
        self.esm2_embeddings = {}
        if self.config.use_embedding and self.esm2_embedding_path and os.path.exists(self.esm2_embedding_path):
            self._load_embeddings(self.esm2_embedding_path)

        # Get active features
        self.edge_types = self.config.get_active_edge_funcs()
        self.node_metadata_funcs = self.config.get_active_node_metadata_funcs()
        self.node_attributes = self.config.get_node_attributes_list()
        self.edge_attributes = ['kind', 'edge_attr', 'euclidean_distance']
        
        # Store all triplets for processing (need to process all proteins regardless of split ot create .pt files)
        self.all_triplets = mapper.triplets # triplets for all splits together

        super().__init__(root)

    def _clear_processed_files(self):
        processed_dir = os.path.join(DATA_ROOT, "processed")
        for path in glob.glob(os.path.join(processed_dir, "*.pt")):
            os.remove(path)
            
    @property
    def processed_file_names(self):
        # Process all proteins regardless of split
        unique_paths = set()
        for t in self.all_triplets:
            unique_paths.add(t['anchor'])
            unique_paths.update(t['positives'])
            unique_paths.update(t['negatives'])
        return [os.path.basename(p).replace(".pdb", ".pt") for p in unique_paths]

    @property
    def raw_file_names(self):
        return []

    def len(self) -> int:
        return len(self.triplets) * 10

    def _load_embeddings(self, path):
        with h5py.File(path, "r") as h5_file:
            for grp in h5_file.keys():
                self.esm2_embeddings[grp] = {}
                for seq in h5_file[grp].keys():
                    self.esm2_embeddings[grp][seq] = np.array(h5_file[grp][seq])

    def process(self):
        """Process all unique proteins and save to disk"""
        unique_paths = set()
        for t in self.all_triplets:
            unique_paths.add(t['anchor'])
            unique_paths.update(t['positives'])
            unique_paths.update(t['negatives'])

        print(f"Processing {len(unique_paths)} unique proteins...")

        for path in tqdm(list(unique_paths), desc="Processing"):
            pdb_code = os.path.splitext(os.path.basename(path))[0]
            pt_path = os.path.join(self.processed_dir, f"{pdb_code}.pt")
            pickle_path = os.path.join(self.processed_dir, f"{pdb_code}.pickle")

            if os.path.exists(pt_path) and os.path.exists(pickle_path):
                continue


            g = self._build_graph(path)
            if g is None: continue

            data = self._create_pyg_data(g)

            with open(pickle_path, "wb") as f:
                pickle.dump(g, f)
            torch.save(data, pt_path)

    def _build_graph(self, path: str):
        """Build protein graph from PDB file"""
        config = ProteinGraphConfig(
            edge_construction_functions=self.edge_types,
            node_metadata_functions=self.node_metadata_funcs,
            verbose=False
        )

        g = construct_graph(config=config, path=path, verbose=False)
        first_node = list(g.nodes())[0]
        chain_id = first_node.split(":")[0]
        pdb_code = os.path.basename(path).replace(".pdb", "")
        g = self._process_graph(g, chain_id, path, pdb_code)
        return g

    def _process_graph(self, g, chain_id, pdb_path, pdb_code):
        """Process graph features"""
        unique_edge_types = ["peptide_bond","aromatic","disulfide","hydrogen_bond","hydrophobic","ionic","k_nn","distance_threshold"]
        sequence = g.graph.get(f"sequence_{chain_id}")

        # Process nodes
        for index, (n, d) in enumerate(g.nodes(data=True)):
            aa = n.split(":")[1]
            d['chain_id'] = chain_id
            d['residue_name'] = aa
            d['residue_number'] = int(n.split(":")[2])

            """
            # Clean attributes
            for key in ["asa", "rsa", "ss"]:
                if isinstance(d.get(key), pd.core.series.Series):
                    val = d.get(key).dropna()
                    val = list(val[val != 0].to_dict().values()) if key != "ss" else list(val.unique())
                    d[key] = val[0] if val else (0.0 if key != "ss" else "-")

            if not d.get("asa"): d["asa"] = 0
            if not d.get("rsa"): d["rsa"] = 0.0
            if not d.get("ss"): d["ss"] = "-"

            d["ss"] = self._one_hot_encode([d["ss"]], unique_ss)[0].tolist()
            d["backbone_dihedral_radians"] = self._calc_backbone_dihedrals(pdb_path, d)
            d["sidechain_dihedral_radians"] = self._calc_sidechain_dihedrals(pdb_path, d)
            """
            if self.esm2_embeddings and sequence:
                key = f"{pdb_code}_{chain_id}"
                if key not in self.esm2_embeddings: key = pdb_code
                if key in self.esm2_embeddings:
                    d["embedding"] = self.esm2_embeddings[key][sequence][index]

        # Process edges
        for s, t, d in g.edges(data=True):
            edge_type = list(d["kind"])
            if "knn" in edge_type and len(edge_type) > 1:
                edge_type.remove("knn")

            d["edge_attr"] = [self._one_hot_encode([_type], unique_edge_types)[0].tolist() for _type in edge_type]
            d["kind"] = edge_type

            source_coords = g.nodes[s]["coords"]
            target_coords = g.nodes[t]["coords"]
            d["euclidean_distance"] = round(np.sqrt(np.sum(np.square(source_coords - target_coords))).item(), 5)

        #g = self._scale_graph(g)
        return g

    def _create_pyg_data(self, g, to_undirected_graph=True):
        """Convert NetworkX graph to PyTorch Geometric Data object"""
        node_indexes_mapping = {}
        node_features = collections.defaultdict(list)

        for index, (n, d) in enumerate(g.nodes(data=True)):
            _list = []
            for k in self.node_attributes:
                v = d.get(k)
                if v is None: continue
                if isinstance(v, (list, np.ndarray)):
                    _list.extend(list(v))
                else:
                    _list.append(v)

            node_features["x"].append(_list)
            node_features["pos"].append(d["coords"].tolist())
            node_indexes_mapping[n] = index

        edge_features = collections.defaultdict(list)
        for s, t, d in g.edges(data=True):
            for index, _ in enumerate(d["kind"]):
                edge_attr = []
                edge_features["edge_index"].append([node_indexes_mapping[s], node_indexes_mapping[t]])
                edge_attr.extend(d["edge_attr"][index])
                edge_attr.append(d["euclidean_distance"])
                edge_features["edge_attr"].append(edge_attr)

        data = Data()
        data.x = torch.tensor(node_features["x"], dtype=torch.float)
        data.pos = torch.tensor(node_features["pos"], dtype=torch.float)

        if edge_features["edge_index"]:
            data.edge_index = torch.tensor(edge_features["edge_index"], dtype=torch.long).t().contiguous()
            data.edge_attr = torch.tensor(edge_features["edge_attr"], dtype=torch.float)
        else:
            data.edge_index = torch.empty((2, 0), dtype=torch.long)
            data.edge_attr = torch.empty((0, 0), dtype=torch.float)

        if to_undirected_graph and data.edge_index.numel() > 0:
            data.edge_index, data.edge_attr = to_undirected(data.edge_index, data.edge_attr)

        return data

    def _one_hot_encode(self, classes, class_labels):
        encoding = np.zeros((len(classes), len(class_labels)))
        for i, class_ in enumerate(classes):
            if class_ in class_labels:
                encoding[i, class_labels.index(class_)] = 1
        return encoding
    """
    def _scale_graph(self, g, scale_attributes=None):
        if scale_attributes is None:
            scale_attributes = ["b_factor", "asa", "rsa"]

        index_to_aa = {index: n for index, n in enumerate(g.nodes(data=False))}
        for attr in scale_attributes:
            vals = [d.get(attr, 0) for _, d in g.nodes(data=True)]
            min_val, max_val = min(vals), max(vals)
            scaled_values = [
                round((val - min_val) / (max_val - min_val), 5) if max_val - min_val != 0 else 0
                for val in vals
            ]
            scaled_dict = {index_to_aa[index]: value for index, value in enumerate(scaled_values)}
            for n, d in g.nodes(data=True):
                d[attr] = scaled_dict[n]
        return g
    """
    """
    def _calc_backbone_dihedrals(self, pdb_path: str, aa_props: dict, normalize: bool = True) -> list:
        u = mda.Universe(pdb_path)
        for res in u.residues:
            if (res.resid == aa_props["residue_number"] and
                res.resname == aa_props["residue_name"] and
                res.segid == aa_props["chain_id"]):

                backbone_dihedrals_dict = {
                    "phi": res.phi_selection(),
                    "psi": res.psi_selection(),
                    "omega": res.omega_selection()
                }

                backbone_dihedral_radians = []
                for dihedral_selection in backbone_dihedrals_dict.values():
                    if dihedral_selection:
                        coords = [a.position for a in dihedral_selection.atoms]
                        radian = calc_dihedrals(coords[0], coords[1], coords[2], coords[3])
                        value = round(radian.item(), 5)
                        if normalize:
                            value = (value - (-np.pi)) / (np.pi - (-np.pi))
                        backbone_dihedral_radians.append(round(value, 5))
                    else:
                        backbone_dihedral_radians.append(0.0)

                return backbone_dihedral_radians

        return [0.0, 0.0, 0.0]
"""
    """
    def _calc_sidechain_dihedrals(self, pdb_path: str, aa_props: dict, normalize: bool = True) -> list:
        chi_atoms_dict = dict(
            chi1=dict(ARG=['N', 'CA', 'CB', 'CG'], ASN=['N', 'CA', 'CB', 'CG'], ASP=['N', 'CA', 'CB', 'CG'],
                     CYS=['N', 'CA', 'CB', 'SG'], GLN=['N', 'CA', 'CB', 'CG'], GLU=['N', 'CA', 'CB', 'CG'],
                     HIS=['N', 'CA', 'CB', 'CG'], ILE=['N', 'CA', 'CB', 'CG1'], LEU=['N', 'CA', 'CB', 'CG'],
                     LYS=['N', 'CA', 'CB', 'CG'], MET=['N', 'CA', 'CB', 'CG'], PHE=['N', 'CA', 'CB', 'CG'],
                     PRO=['N', 'CA', 'CB', 'CG'], SER=['N', 'CA', 'CB', 'OG'], THR=['N', 'CA', 'CB', 'OG1'],
                     TRP=['N', 'CA', 'CB', 'CG'], TYR=['N', 'CA', 'CB', 'CG'], VAL=['N', 'CA', 'CB', 'CG1']),
            chi2=dict(ARG=['CA', 'CB', 'CG', 'CD'], ASN=['CA', 'CB', 'CG', 'OD1'], ASP=['CA', 'CB', 'CG', 'OD1'],
                     GLN=['CA', 'CB', 'CG', 'CD'], GLU=['CA', 'CB', 'CG', 'CD'], HIS=['CA', 'CB', 'CG', 'ND1'],
                     ILE=['CA', 'CB', 'CG1', 'CD1'], LEU=['CA', 'CB', 'CG', 'CD1'], LYS=['CA', 'CB', 'CG', 'CD'],
                     MET=['CA', 'CB', 'CG', 'SD'], PHE=['CA', 'CB', 'CG', 'CD1'], PRO=['CA', 'CB', 'CG', 'CD'],
                     TRP=['CA', 'CB', 'CG', 'CD1'], TYR=['CA', 'CB', 'CG', 'CD1']),
            chi3=dict(ARG=['CB', 'CG', 'CD', 'NE'], GLN=['CB', 'CG', 'CD', 'OE1'], GLU=['CB', 'CG', 'CD', 'OE1'],
                     LYS=['CB', 'CG', 'CD', 'CE'], MET=['CB', 'CG', 'SD', 'CE']),
            chi4=dict(ARG=['CG', 'CD', 'NE', 'CZ'], LYS=['CG', 'CD', 'CE', 'NZ']),
            chi5=dict(ARG=['CD', 'NE', 'CZ', 'NH1'])
            )

        u = mda.Universe(pdb_path)
        for res in u.residues:
            if (res.resid == aa_props["residue_number"] and
                res.resname == aa_props["residue_name"] and
                res.segid == aa_props["chain_id"]):

                chi_radians = []
                for chi_res in chi_atoms_dict.values():
                    if chi_res.get(res.resname) and set(chi_res[res.resname]).issubset(set(a.name for a in res.atoms)):
                        chi_selected_atoms = dict.fromkeys(chi_res[res.resname], 1)
                        for a in res.atoms:
                            if chi_selected_atoms.get(a.name) is not None and not isinstance(chi_selected_atoms.get(a.name), np.ndarray):
                                chi_selected_atoms[a.name] = a.position

                        coords = list(chi_selected_atoms.values())
                        radian = calc_dihedrals(coords[0], coords[1], coords[2], coords[3])
                        value = round(radian.item(), 5)
                        if normalize:
                            value = (value - (-np.pi)) / (np.pi - (-np.pi))
                        chi_radians.append(round(value, 5))
                    else:
                        chi_radians.append(0.0)

                while len(chi_radians) < 5:
                    chi_radians.append(0.0)

                return chi_radians

        return [0.0] * 5
        """
    def _load_processed_graph(self, path):
        pdb_code = os.path.splitext(os.path.basename(path))[0]
        pt_path = os.path.join(self.processed_dir, f"{pdb_code}.pt")

        if os.path.exists(pt_path):
            return torch.load(pt_path, weights_only=False)
        else:
            print(f"Warning: {pdb_code} not found, processing on the fly")
            g = self._build_graph(path)
            return self._create_pyg_data(g) if g else None

    def get(self, idx):
        real_idx = idx % len(self.triplets)
        t = self.triplets[real_idx]

        data_a = self._load_processed_graph(t["anchor"])
        data_p = self._load_processed_graph(random.choice(t["positives"]))
        data_n = self._load_processed_graph(random.choice(t["negatives"]))

        return data_a, data_p, data_n

    def download(self):
        pass

In [138]:
# ============================================================================
# MODEL
# ============================================================================

class DeepProteinGAT(nn.Module):
    """3-layer GATv2 model for protein embedding"""

    def __init__(self, input_dim, hidden_dim, output_dim, heads=4, edge_dim = 9):
        super().__init__()

        self.conv1 = GATv2Conv(input_dim, hidden_dim, heads=heads, concat=True, dropout=0.0, edge_dim = edge_dim)
        self.conv2 = GATv2Conv(hidden_dim * heads, hidden_dim, heads=heads, concat=True, dropout=0.0, edge_dim = edge_dim)
        self.conv3 = GATv2Conv(hidden_dim * heads, output_dim, heads=1, concat=False, dropout=0.0, edge_dim = edge_dim)
        self.projection = nn.Linear(output_dim, output_dim)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr , data.batch

        x = x.float()
        x = F.elu(self.conv1(x, edge_index, edge_attr))
        x = F.elu(self.conv2(x, edge_index, edge_attr))
        x = self.conv3(x, edge_index, edge_attr)

        x = global_mean_pool(x, batch)
        x = self.projection(x)
        x = F.normalize(x, p=2, dim=1)
        return x

In [139]:
# ============================================================================
# UTILITIES
# ============================================================================

def triplet_collate(data_list):
    """Collate function for triplet batches"""
    data_list = [x for x in data_list if x is not None]
    if not data_list:
        return None

    batch_a = Batch.from_data_list([x[0] for x in data_list])
    batch_p = Batch.from_data_list([x[1] for x in data_list])
    batch_n = Batch.from_data_list([x[2] for x in data_list])

    return batch_a, batch_p, batch_n

In [None]:
# ============================================================================
# TRAINING PIPELINE
# ============================================================================

"""
def evaluate(model, loader, criterion, device):
    #Evaluate model on a given dataloader
    model.eval()
    total_loss = 0
    valid_batches = 0
    all_pos_dist = []
    all_neg_dist = []
    
    with torch.no_grad():
        for batch in loader:
            if batch is None:
                continue
            
            ba, bp, bn = batch
            ba, bp, bn = ba.to(device), bp.to(device), bn.to(device)
            
            ea = model(ba)
            ep = model(bp)
            en = model(bn)
            
            loss = criterion(ea, ep, en)
            total_loss += loss.item()
            valid_batches += 1
            
            dist_pos = F.pairwise_distance(ea, ep)
            dist_neg = F.pairwise_distance(ea, en)
            all_pos_dist.extend(dist_pos.cpu().tolist())
            all_neg_dist.extend(dist_neg.cpu().tolist())
    
    avg_loss = total_loss / valid_batches if valid_batches > 0 else 0
    avg_pos_dist = np.mean(all_pos_dist) if all_pos_dist else 0
    avg_neg_dist = np.mean(all_neg_dist) if all_neg_dist else 0
    
    return avg_loss, avg_pos_dist, avg_neg_dist
"""


def train_pipeline(config=None, force=False):
    """Main training pipeline with train/val/test splits"""

    # Initialize config
    cfg = ProjectConfig()

    # Initialize WandB
    run = wandb.init(
        project="ContVAR-Project",
        config=vars(cfg),
        reinit=True
    )

    # Update config from wandb if sweep is running
    if config:
        for key, value in config.items():
            if hasattr(cfg, key):
                setattr(cfg, key, value)

    print(f"Training with LR: {cfg.lr}, Hidden: {cfg.hidden_dim}, Heads: {cfg.heads}")

    # Load data with splits
    mapper = TripletDataPathMapper(DATA_ROOT, train_ratio=1.0, val_ratio=0.0, test_ratio=0.0)
    if not mapper.triplets:
        print("No data found!")
        wandb.finish()
        return

    # Create datasets for each split
    train_dataset = TripletProteinGraphDataset(mapper, root=DATA_ROOT, config=cfg, split='train', force=force)
    # val_dataset = TripletProteinGraphDataset(mapper, root=DATA_ROOT, config=cfg, split='val', force=False)
    # test_dataset = TripletProteinGraphDataset(mapper, root=DATA_ROOT, config=cfg, split='test', force=False)
    
    # Create dataloaders for each split
    train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True,
                              collate_fn=triplet_collate, num_workers=0)
    # val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False,
    #                         collate_fn=triplet_collate, num_workers=0)
    # test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False,
    #                          collate_fn=triplet_collate, num_workers=0)

    print(f"Dataset size - Train: {len(train_dataset)}")

    # Initialize model
    model = DeepProteinGAT(
        input_dim=cfg.input_dim,
        hidden_dim=cfg.hidden_dim,
        output_dim=cfg.output_dim,
        heads=cfg.heads
    ).to(DEVICE)

    wandb.watch(model, log="gradients", log_freq=50)

    # Optimizer and loss
    optimizer = optim.Adam(model.parameters(), lr=cfg.lr)
    criterion = nn.TripletMarginLoss(margin=cfg.margin, p=2, swap=True)

    # Training loop
    print("Starting training...")
    best_train_loss = float('inf')

    for epoch in range(cfg.epochs):
        model.train() # moved inside of the loop to ensure correct mode (when we have validation later with eval mode)
        total_loss = 0
        valid_batches = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.epochs}", leave=False)

        for batch in pbar:
            if batch is None:
                continue

            ba, bp, bn = batch
            ba, bp, bn = ba.to(DEVICE), bp.to(DEVICE), bn.to(DEVICE)

            optimizer.zero_grad()
            ea = model(ba)
            ep = model(bp)
            en = model(bn)

            loss = criterion(ea, ep, en)
            loss.backward()
            optimizer.step()

            dist_pos = F.pairwise_distance(ea, ep)
            dist_neg = F.pairwise_distance(ea, en)

            wandb.log({
                "batch_loss": loss.item(),
                "avg_pos_dist": dist_pos.mean().item(),
                "avg_neg_dist": dist_neg.mean().item()
            })

            total_loss += loss.item()
            valid_batches += 1
            pbar.set_postfix({'loss': loss.item()})

        avg_train_loss = total_loss / valid_batches if valid_batches > 0 else 0

        
        # Validation phase (commented out for initial overfitting check)
        # val_loss, val_pos_dist, val_neg_dist = evaluate(model, val_loader, criterion, DEVICE)

        # Logging
        log_dict = {
            "epoch": epoch + 1,
            "train_loss": avg_train_loss,
            # "val_loss": val_loss,
            # "val_pos_dist": val_pos_dist,
            # "val_neg_dist": val_neg_dist,
            "lr": cfg.lr
        }

        # Save best model based on training loss (use val_loss when validation is enabled)
        if avg_train_loss < best_train_loss:
            best_train_loss = avg_train_loss
            model_name = f"model_best_loss.pt"
            torch.save(model.state_dict(), model_name)

            artifact = wandb.Artifact(
                name=f"ContVAR-Best-Model-{wandb.run.id}",
                type="model",
                description=f"Best model at epoch {epoch+1} with train_loss {avg_train_loss:.4f}"
            )
            artifact.add_file(model_name)
            wandb.log_artifact(artifact)
            log_dict["best_model_saved"] = True

        wandb.log(log_dict)
        print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} {'(Saved)' if log_dict.get('best_model_saved') else ''}")

    # Final test evaluation (commented out for initial overfitting check)
    # print("\nEvaluating on test set...")
    # model.load_state_dict(torch.load("model_best_loss.pt", weights_only=False))
    # test_loss, test_pos_dist, test_neg_dist = evaluate(model, test_loader, criterion, DEVICE)
    # 
    # print(f"Test Results - Loss: {test_loss:.4f}, Pos Dist: {test_pos_dist:.4f}, Neg Dist: {test_neg_dist:.4f}")
    # 
    # wandb.log({
    #     "test_loss": test_loss,
    #     "test_pos_dist": test_pos_dist,
    #     "test_neg_dist": test_neg_dist
    # })

    wandb.finish()
    print("Training completed!")

In [141]:
# ============================================================================
# VISUALIZATION
# ============================================================================

def visualize_graph(protein_id=None):
    """Visualize a processed protein graph"""
    import matplotlib.pyplot as plt
    from graphein.protein.visualisation import plotly_protein_structure_graph

    print("Starting visualization...")

    cfg = ProjectConfig()
    mapper = TripletDataPathMapper(DATA_ROOT)

    if not mapper.triplets:
        print("No data found!")
        return

    # Initialize WandB for visualization
    wandb.init(
        project="ContVAR-Project",
        name="Graph-Visualization",
        job_type="visualization",
        config=vars(cfg)
    )

    # Select protein
    if protein_id:
        choice = next((t for t in mapper.triplets if protein_id in t['anchor']), None)
        if not choice:
            print(f"Protein {protein_id} not found!")
            wandb.finish()
            return
    else:
        choice = random.choice(mapper.triplets)

    pdb_path = choice['anchor']
    pdb_code = os.path.splitext(os.path.basename(pdb_path))[0]

    processed_dir = os.path.join(DATA_ROOT, "processed")
    pickle_path = os.path.join(processed_dir, f"{pdb_code}.pickle")

    print(f"Visualizing: {pdb_code}")

    if not os.path.exists(pickle_path):
        print(f"Graph not processed yet. Run training first!")
        wandb.finish()
        return

    try:
        # Load graph
        with open(pickle_path, "rb") as f:
            g = pickle.load(f)

        # Graph statistics
        num_nodes = g.number_of_nodes()
        num_edges = g.number_of_edges()
        density = nx.density(g)

        print(f"Nodes: {num_nodes}, Edges: {num_edges}, Density: {density:.4f}")

        # Create interactive plot
        fig = plotly_protein_structure_graph(
            g,
            colour_edges_by="kind",
            label_node_ids=False,
            node_size_multiplier=1
        )

        fig.update_layout(title=f"Graph Topology: {pdb_code}")

        # Log to WandB
        wandb.log({
            "Interactive_Graph": fig,
            "num_nodes": num_nodes,
            "num_edges": num_edges,
            "graph_density": density
        })

        print(f"Visualization uploaded successfully!")

    except Exception as e:
        print(f"Error: {e}")

    wandb.finish()

In [142]:
# ============================================================================
# MAIN EXECUTION
# ============================================================================

# Login to WandB
wandb.login()
#API KEY: 2becafa4dcb70173759a7b50ee5de92401c637c4

#MODE = "train"  # Options: "train", "visualize"

#if MODE == "train":
    #train_pipeline()
#elif MODE == "visualize":
    #visualize_graph()  # or visualize_graph("specific_protein_id")
#else:
    #print("Invalid mode! Choose 'train' or 'visualize'")

True

In [143]:
#if you want to regenarate the .pt files please set force to true
train_pipeline(force = False)

Training with LR: 0.0005, Hidden: 64, Heads: 4
Found 1 protein families
Starting training...




Epoch 1 | Loss: 0.1935 (Saved)




Epoch 2 | Loss: 0.1774 (Saved)




Epoch 3 | Loss: 0.1362 (Saved)




Epoch 4 | Loss: 0.0318 (Saved)




Epoch 5 | Loss: 0.0552 




Epoch 6 | Loss: 0.0218 (Saved)




Epoch 7 | Loss: 0.0096 (Saved)




Epoch 8 | Loss: 0.0000 (Saved)




Epoch 9 | Loss: 0.0016 




Epoch 10 | Loss: 0.0112 




Epoch 11 | Loss: 0.0000 




Epoch 12 | Loss: 0.0000 




Epoch 13 | Loss: 0.0000 




Epoch 14 | Loss: 0.0000 




Epoch 15 | Loss: 0.0000 




Epoch 16 | Loss: 0.0000 




Epoch 17 | Loss: 0.0033 




Epoch 18 | Loss: 0.0389 




Epoch 19 | Loss: 0.0388 




Epoch 20 | Loss: 0.0113 




Epoch 21 | Loss: 0.0000 




Epoch 22 | Loss: 0.0000 




Epoch 23 | Loss: 0.0183 




Epoch 24 | Loss: 0.0000 




Epoch 25 | Loss: 0.0204 




Epoch 26 | Loss: 0.0000 




Epoch 27 | Loss: 0.0000 




Epoch 28 | Loss: 0.0244 




Epoch 29 | Loss: 0.0000 




Epoch 30 | Loss: 0.0081 




Epoch 31 | Loss: 0.0000 




Epoch 32 | Loss: 0.0349 




Epoch 33 | Loss: 0.0183 




Epoch 34 | Loss: 0.0514 




Epoch 35 | Loss: 0.1170 




Epoch 36 | Loss: 0.0441 




Epoch 37 | Loss: 0.0390 




Epoch 38 | Loss: 0.1179 




Epoch 39 | Loss: 0.1139 




Epoch 40 | Loss: 0.0390 




Epoch 41 | Loss: 0.0348 




Epoch 42 | Loss: 0.1083 




Epoch 43 | Loss: 0.0685 




Epoch 44 | Loss: 0.0000 




Epoch 45 | Loss: 0.0890 




Epoch 46 | Loss: 0.0860 




Epoch 47 | Loss: 0.0000 




Epoch 48 | Loss: 0.0000 




Epoch 49 | Loss: 0.0362 


                                                                  

Epoch 50 | Loss: 0.0000 




0,1
avg_neg_dist,▁▁▁▂▄▇▇██▇▆▅▅▇▄▆▆▇▅▄▄▇▇▄▇▃▄▂▂▃▃▂▃▃▅▅▃▃▆▄
avg_pos_dist,▂▅█▅▅▆▅▄▅▅▅▅▄▄▄▄▄▄▃▃▃▄▃▂▃▃▂▂▂▂▁▂▁▁▁▁▁▁▁▁
batch_loss,█▇▃▄▃▁▂▁▁▁▁▁▁▄▁▁▁▂▁▁▃▁▁▁▁▃▂▁▅▄▅▁▅▄▅▁▁▂▂▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▇▆▂▃▁▁▁▁▁▁▁▁▁▂▁▁▁▂▁▁▁▂▁▁▂▃▅▃▂▅▂▂▅▃▄▄▁▁▁

0,1
avg_neg_dist,0.25442
avg_pos_dist,0.00103
batch_loss,0
best_model_saved,True
epoch,50
lr,0.0005
train_loss,0


Training completed!


In [144]:
visualize_graph()

Starting visualization...
Found 1 protein families


Visualizing: p53
Nodes: 776, Edges: 4510, Density: 0.0150
Visualization uploaded successfully!


0,1
graph_density,▁
num_edges,▁
num_nodes,▁

0,1
graph_density,0.015
num_edges,4510.0
num_nodes,776.0
