In [1]:
# Temel kütüphane ve GPU ayarları
import os
import random
import glob
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# PyTorch Geometric importları
from torch.utils.data import DataLoader
from torch_geometric.data import Dataset, Batch
from torch_geometric.nn import GATv2Conv, global_add_pool, global_max_pool

# Graphein importları
from graphein.protein.config import ProteinGraphConfig
from graphein.protein.graphs import construct_graph
from graphein.protein.edges.distance import add_k_nn_edges, add_peptide_bonds
from graphein.protein.features.nodes.amino_acid import amino_acid_one_hot
from graphein.protein.tensor.io import protein_to_pyg

# Model konfigürasyonları (One-Hot 20 aminoasit ile)
INPUT_FEATURE_DIM = 20 
HIDDEN_DIM = 64
EMBEDDING_DIM = 256
MARGIN = 0.2
BATCH_SIZE = 8
EPOCHS = 50 
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

DATA_ROOT = "protein_triplets_data"

print(f"Çalışma Ortamı: {DEVICE}")

Çalışma Ortamı: cuda


In [2]:
# Graphein: Protein graf yapım pipeline
# - Node features: Only One-Hot (mutasyon duyarlılığı için)
# - Edge: Peptide bond + kNN graph (yapıyı korur)

graphein_config = ProteinGraphConfig(
    node_metadata_functions=[amino_acid_one_hot],  # Aminoasit One-Hot
    edge_construction_functions=[
        add_peptide_bonds,  # Backbone
        add_k_nn_edges      # 3D komşuluk
    ],
    graph_metadata={
        "k": 10  # kNN komşu sayısı
    }
)

print("Graphein Config Hazır: One-Hot + k-NN(10) + Peptide Bonds")

Graphein Config Hazır: One-Hot + k-NN(10) + Peptide Bonds


In [3]:
class TripletDataPathMapper:
    """
    Protein klasör yapısını okuyup anchor–positive–negative üçlülerini çıkarır.
    originals/{id}.pdb
    positives/{id}/*.pdb
    negatives/{id}/*.pdb
    """
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.triplets = []
        self._map_data()

    def _map_data(self):
        originals = glob.glob(os.path.join(self.root_dir, 'originals', "*.pdb"))
        
        for anchor in originals:
            prot_id = os.path.splitext(os.path.basename(anchor))[0]

            pos_dir = os.path.join(self.root_dir, 'positives', prot_id)
            neg_dir = os.path.join(self.root_dir, 'negatives', prot_id)

            p_files = glob.glob(os.path.join(pos_dir, "*.pdb"))
            n_files = glob.glob(os.path.join(neg_dir, "*.pdb"))

            # Hem positive hem negative varsa triplet olarak ekleniyor
            if p_files and n_files:
                self.triplets.append({
                    'anchor': anchor,
                    'positives': p_files,
                    'negatives': n_files
                })

        print(f"Bulunan Protein Ailesi: {len(self.triplets)}")

In [4]:
class TripletDatasetGraphein(Dataset):
    """
    Her çağrıda (get) Graphein ile graf oluşturup 
    PyTorch Geometric objesine çevirir.
    """
    def __init__(self, mapper, config):
        super().__init__()
        self.triplets = mapper.triplets
        self.config = config

    def len(self):
        # Her proteinin epoch başına 10 kez görünmesi: data augmentation
        return len(self.triplets) * 10

    def get(self, idx):
        real_idx = idx % len(self.triplets)
        t = self.triplets[real_idx]

        try:
            # Graphein → NetworkX Graph
            g_a = construct_graph(config=self.config, pdb_path=t['anchor'], verbose=False)
            g_p = construct_graph(config=self.config, pdb_path=random.choice(t['positives']), verbose=False)
            g_n = construct_graph(config=self.config, pdb_path=random.choice(t['negatives']), verbose=False)

            # NetworkX → PyTorch Geometric
            data_a = protein_to_pyg(g_a)
            data_p = protein_to_pyg(g_p)
            data_n = protein_to_pyg(g_n)

            return data_a, data_p, data_n

        except Exception:
            # Bozuk PDB vs varsa atlanıyor
            return None

In [5]:
def triplet_collate(data_list):
    """
    Dataset'ten dönen tripletleri batch'e dönüştürür.
    None dönenler (bozuk PDB) ayıklanır.
    """
    data_list = [x for x in data_list if x is not None]
    if not data_list:
        return None

    batch_a = Batch.from_data_list([x[0] for x in data_list])
    batch_p = Batch.from_data_list([x[1] for x in data_list])
    batch_n = Batch.from_data_list([x[2] for x in data_list])

    return batch_a, batch_p, batch_n

In [6]:
class DeepProteinGAT(nn.Module):
    """
    3 katmanlı GATv2 tabanlı protein embedding modeli.
    Pooling: global_add
    Output normalize edilir (metric learning için).
    """
    def __init__(self, input_dim, hidden_dim, output_dim, heads=4):
        super().__init__()

        # GATv2 katmanları — Dropout kapalı (sinyal kaybetmemek için)
        self.conv1 = GATv2Conv(input_dim, hidden_dim, heads=heads, concat=True, dropout=0.0)
        self.conv2 = GATv2Conv(hidden_dim * heads, hidden_dim, heads=heads, concat=True, dropout=0.0)
        self.conv3 = GATv2Conv(hidden_dim * heads, output_dim, heads=1, concat=False, dropout=0.0)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch

        x = x.float()  # One-hot float olmalı
        x = F.elu(self.conv1(x, edge_index))
        x = F.elu(self.conv2(x, edge_index))
        x = self.conv3(x, edge_index)

        # Hybrid pooling: sum + max
        x = global_add_pool(x, batch)
        x = self.projection(x)       # optional, linear layer
        # L2 normalize (Triplet Loss için kritik)
        x = F.normalize(x, p=2, dim=1)
        return x

In [7]:
def train_pipeline():
    # Dataset yüklenir
    mapper = TripletDataPathMapper(DATA_ROOT)
    if not mapper.triplets:
        print("Veri bulunamadı!")
        return

    dataset = TripletDatasetGraphein(mapper, graphein_config)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=triplet_collate)

    # Model & Loss
    model = DeepProteinGAT(INPUT_FEATURE_DIM, HIDDEN_DIM, EMBEDDING_DIM).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=0.0005)  
    criterion = nn.TripletMarginLoss(margin=MARGIN, p=2)

    print("Eğitim Başlıyor...")
    model.train()

    for epoch in range(EPOCHS):
        total_loss = 0
        valid_batches = 0

        for batch in loader:
            if batch is None:
                continue

            ba, bp, bn = batch
            ba, bp, bn = ba.to(DEVICE), bp.to(DEVICE), bn.to(DEVICE)

            optimizer.zero_grad()

            # Forward pass
            ea = model(ba)
            ep = model(bp)
            en = model(bn)

            # Triplet loss
            loss = criterion(ea, ep, en)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            valid_batches += 1

        avg_loss = total_loss / valid_batches if valid_batches > 0 else 0
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {avg_loss:.4f}")

In [8]:
if __name__ == "__main__":
    train_pipeline()

Bulunan Protein Ailesi: 1
Eğitim Başlıyor...
Epoch 10/50 | Loss: 0.0000
Epoch 20/50 | Loss: 0.0000
Epoch 30/50 | Loss: 0.0000
Epoch 40/50 | Loss: 0.0000
Epoch 50/50 | Loss: 0.0000
