In [1]:
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
import os

def load_and_process_data(prefix, folder):
    node_file = f'{folder}/{prefix}_nodes.csv'
    link_file = f'{folder}/{prefix}_links.csv'

    node_data = pd.read_csv(node_file)
    link_data = pd.read_csv(link_file)

    # Assuming the ground truth labels are in the same format
    labels = node_data['ground_truth'].values
    features = node_data.drop(columns=['ground_truth'])

    node_features = features[['atom_type', 'residue_type', 'radius', 'voromqa_sas_potential', 'residue_mean_sas_potential', 'residue_sum_sas_potential', 'residue_size', 'sas_area', 'voromqa_sas_energy', 'voromqa_depth', 'voromqa_score_a', 'voromqa_score_r', 'volume', 'volume_vdw', 'ufsr_a1', 'ufsr_a2', 'ufsr_c2', 'ufsr_c3', 'ev28', 'ev56']]
    link_features = link_data[['atom_index1', 'atom_index2','area', 'boundary', 'distance', 'voromqa_energy', 'seq_sep_class', 'covalent_bond', 'hbond']]

    edge_index = torch.tensor(np.array([link_features['atom_index1'].values, link_features['atom_index2'].values]), dtype=torch.long)

    self_links = torch.arange(0, len(node_features))
    edge_index = torch.cat([edge_index, torch.stack([self_links, self_links])], dim=1)
    edge_index = torch.cat([edge_index, edge_index[[1, 0], :]], dim=1)  # Add reverse direction

    node_features_tensor = torch.tensor(node_features.values, dtype=torch.float)
    labels_tensor = torch.tensor(labels, dtype=torch.float)

    data = Data(x=node_features_tensor, edge_index=edge_index, y=labels_tensor)

    return data

candidate_pairs_file = 'holo/candidate_pairs.txt'
candidate_pairs = pd.read_csv(candidate_pairs_file, delim_whitespace=True)

graphs = {}
for index, row in candidate_pairs.iterrows():
    holo_prefix = f"{row['holo_pdb_id']}_{row['holo_chain_id']}"
    graphs[holo_prefix] = load_and_process_data(holo_prefix, 'holo')

save_dir = 'sh'
os.makedirs(save_dir, exist_ok=True)

for prefix, graph in graphs.items():
    save_path = os.path.join(save_dir, f'{prefix}_graph.pt')
    torch.save(graph, save_path)

print(f"All HOLO graphs have been saved in the directory: {save_dir}")

All HOLO graphs have been saved in the directory: sh


In [4]:
import os
import pandas as pd
import torch
from torch_geometric.data import Data, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as geom_nn
import gvp
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Batch

class SiameseGVPModel(nn.Module):
    def __init__(self, node_in_dims):
        super(SiameseGVPModel, self).__init__()
        self.gvp1 = gvp.GVP(node_in_dims, (64, 0), vector_gate=True, activations=(F.relu, None))
        self.gvp2 = gvp.GVP((64, 0), (1, 0))
        self.fc1 = nn.Linear(2, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, apo_data, holo_data):
        apo_x = self.process_structure(apo_data)
        holo_x = self.process_structure(holo_data)
        combined = torch.cat([apo_x, holo_x], dim=1)
        combined = F.relu(self.fc1(combined))
        combined = self.fc2(combined)
        return combined

    def process_structure(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.gvp1(x)
        x = self.gvp2(x)
        print(f"x[0].shape: {x[0].shape}, data.batch.shape: {data.batch.shape}")
        x = geom_nn.global_mean_pool(x[0], data.batch)
        return x

class ProteinPairsDataset(Dataset):
    def __init__(self, data_pairs):
        self.data_pairs = data_pairs

    def __len__(self):
        return len(self.data_pairs)

    def __getitem__(self, idx):
        return self.data_pairs[idx]

def load_data_pairs(apo_folder='sg', holo_folder='sh', pairs_file='apo/candidate_pairs.txt'):
    pairs = pd.read_csv(pairs_file, delim_whitespace=True)
    data_pairs = []

    for index, row in pairs.iterrows():
        apo_prefix = f"{row['apo_pdb_id']}_{row['apo_chain_id']}"
        holo_prefix = f"{row['holo_pdb_id']}_{row['holo_chain_id']}"

        apo_graph_path = os.path.join(apo_folder, f'{apo_prefix}_graph.pt')
        holo_graph_path = os.path.join(holo_folder, f'{holo_prefix}_graph.pt')

        if os.path.exists(apo_graph_path) and os.path.exists(holo_graph_path):
            apo_graph = torch.load(apo_graph_path)
            holo_graph = torch.load(holo_graph_path)
            data_pairs.append((apo_graph, holo_graph))

    return data_pairs

def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for apo_batch, holo_batch in loader:
        apo_batch, holo_batch = apo_batch.to(device), holo_batch.to(device)
        optimizer.zero_grad()
        output = model(apo_batch, holo_batch)
        loss = criterion(output.squeeze(), apo_batch.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)



def collate_data_pairs(batch):
    apo_data_list, holo_data_list = zip(*batch)
    apo_batch = Batch.from_data_list(apo_data_list)
    holo_batch = Batch.from_data_list(holo_data_list)
    return apo_batch, holo_batch

# Load data pairs
data_pairs = load_data_pairs()

# Create a dataset and data loader
dataset = ProteinPairsDataset(data_pairs)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_data_pairs)

# Model, optimizer, and loss function
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SiameseGVPModel(node_in_dims=(20, 0)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

# Training loop
num_epochs = 10
loss_history = []

for epoch in range(num_epochs):
    loss = train(model, train_loader, optimizer, criterion, device)
    loss_history.append(loss)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss:.4f}')

# Plotting the training loss
plt.plot(loss_history, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.legend()
plt.show()

RuntimeError: Expected index [80449] to be smaller than self [32] apart from dimension 0 and to be smaller size than src [1]