In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
import glob
import matplotlib.pyplot as plt
import os
import csv
from torch_geometric.nn import GATConv, global_mean_pool

# Define GAT model for batched data
class GAT(torch.nn.Module):
    def __init__(self, in_channels, out_channels, dropout_p=0.1):
        super().__init__()
        self.gat = GATConv(in_channels, out_channels, heads=1, concat=True, edge_dim=1)
        self.pool = global_mean_pool  # Can also use global_max_pool or global_add_pool
        self.dropout = nn.Dropout(p=dropout_p)
        self.linear = torch.nn.Linear(out_channels, 1)
        self.activation = nn.Sigmoid()

    def forward(self, x, edge_index, edge_attr, batch):
        out, attn_weights = self.gat(x, edge_index, edge_attr, return_attention_weights=True)
        out = self.dropout(out)
        out = self.pool(out, batch)  # Pool over nodes in each graph
        out = self.dropout(out) 
        out = self.linear(out)
        out = self.activation(out)
        return out, attn_weights

def organize_graph_and_add_weight(file_path, label):
    data = np.load(file_path, allow_pickle=True).item()
    inverse_distance = data['inverse_distance']
    encoded_matrix = data['encoded_matrix']

    x = torch.tensor(encoded_matrix, dtype=torch.float32)
    adj = torch.tensor(inverse_distance, dtype=torch.float32)

    # Normalize adjacency (row-normalize)
    adj = adj / (adj.sum(dim=1, keepdim=True) + 1e-8)

    # Create edge_index and edge weights
    edge_index = (adj > 0).nonzero(as_tuple=False).t()
    edge_weight = adj[adj > 0]

    y = torch.tensor([label], dtype=torch.float32)
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_weight, y=y)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

models = []

k = 1
#version_number = 6

for i in range (1, (k + 1)):
    # Initialize model
    model = GAT(in_channels=37, out_channels=16).to(device)

    # Load saved model weights
    model_path = f"GATModels-5A-v2/Models/model_bin_{i}.pth"  # Update with correct path if needed
    model.load_state_dict(torch.load(model_path, map_location=device))

    # Set the model to evaluation mode
    model.eval()

    models.append(model)


  model.load_state_dict(torch.load(model_path, map_location=device))


In [2]:
def get_capture_rate(dir, csv_output):
    predictions = []
    capture_rate = 0
    overlapping_capture_rate = 0
    overlapping_index = 0

    files = glob.glob(f"{dir}/*.npy")

    model_positive_counts = [0] * len(models)
    model_negative_counts = [0] * len(models)

    with open(csv_output, "w", newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["filename", "average_score", "number_atoms"]) 

        for file in sorted(files):    
            prediction = 0
            for model in models:
                graph = organize_graph_and_add_weight(file, label=0).to(device)
                non_padded_rows = graph.x.size(0)
                with torch.no_grad():
                    out, _ = model(graph.x, graph.edge_index, graph.edge_attr, batch=torch.zeros(graph.x.size(0), dtype=torch.long).to(device))
                    prob = out.item()
                    print(prob)
                    prediction += prob
            prediction /= len(models)
            predictions.append((file, prediction))
            capture_rate += prediction
            print(f"{file:<120} {prediction:<25} {non_padded_rows}")
            overlapping_capture_rate += prediction
            overlapping_index += 1
            writer.writerow([file, prediction, non_padded_rows])

    capture_rate /= len(files)
    overlapping_capture_rate /= overlapping_index

    print("\nModel Predictions Summary:")
    for i, (pos, neg) in enumerate(zip(model_positive_counts, model_negative_counts), start=1):
        print(f"Model {i}: Positives = {pos}, Negatives = {neg}")

    #print("Overlapping Capture Rate is", overlapping_capture_rate)

    return capture_rate

In [3]:
spies = "../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Spies"
csv_output = "spy_capture_rates.csv"

capture_rate = get_capture_rate(spies, csv_output)

print("Spy Capture Rate is", capture_rate)

0.8078364133834839
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Spies/1LRI-filtered_graphs.npy                        0.8078364133834839        71
0.8584931492805481
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Spies/2ZXE-filtered_graphs.npy                        0.8584931492805481        40
0.3873276114463806
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Spies/3GKI-filtered_graphs.npy                        0.3873276114463806        99
0.9557435512542725
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Spies/3NY8-filtered_graphs.npy                        0.9557435512542725        53
0.9768586754798889
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Spies/3NYA-filtered_graphs.npy                        0.9768586754798889        49
0.9225848913192749
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Spies/4DKL-filtered_graphs.npy                        0.9225848913192

In [4]:
test_positives = "../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/Positive"
csv_output = "test_positive_capture_rates.csv"

capture_rate = get_capture_rate(test_positives, csv_output)

print("Test Positive Capture Rate is", capture_rate)

0.555027425289154
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/Positive/4BOE-filtered_graphs.npy                0.555027425289154         124
0.26835373044013977
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/Positive/4BQU-filtered_graphs.npy                0.26835373044013977       129
0.7980592846870422
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/Positive/4M48-filtered_graphs.npy                0.7980592846870422        40
0.6854674220085144
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/Positive/4XP1-filtered_graphs.npy                0.6854674220085144        39
0.37578579783439636
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/Positive/4XPB-filtered_graphs.npy                0.37578579783439636       42
0.5400480628013611
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/Positive/4XPG-filtered_graphs.npy                0.5400480628

In [5]:
test_unlabeled = "../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/Unlabeled"
csv_output = "test_unlabeled_capture_rates.csv"

capture_rate = get_capture_rate(test_unlabeled, csv_output)

print("Test Unlabeled Capture Rate is", capture_rate)

0.12428666651248932
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/Unlabeled/1LRI-f1_graphs.npy                     0.12428666651248932       70
0.1668241173028946
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/Unlabeled/3AM6-f3_graphs.npy                     0.1668241173028946        97
0.44968706369400024
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/Unlabeled/3AM6-f4_graphs.npy                     0.44968706369400024       84
0.14247646927833557
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/Unlabeled/3D4S-f2_graphs.npy                     0.14247646927833557       66
0.34178587794303894
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/Unlabeled/3GKI-f3_graphs.npy                     0.34178587794303894       61
0.15786592662334442
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/Unlabeled/3N9Y-f4_graphs.npy                     0.15786592

In [6]:
test_likely_positives = "../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/LikelyPositives"
csv_output = "test_positive_unlabeled_capture_rates.csv"

capture_rate = get_capture_rate(test_likely_positives, csv_output)

print("Test Unlabeled Capture Rate is", capture_rate)

0.3306519687175751
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/LikelyPositives/2RH1-f2-positive_graphs.npy      0.3306519687175751        34
0.9912856221199036
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/LikelyPositives/2RH1-f3-positive_graphs.npy      0.9912856221199036        21
0.1465192437171936
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/LikelyPositives/2ZXE-f1-positive_graphs.npy      0.1465192437171936        72
0.4762086570262909
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/LikelyPositives/3A3Y-f1-positive_graphs.npy      0.4762086570262909        55
0.33195754885673523
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/LikelyPositives/3WGU-f1-positive_graphs.npy      0.33195754885673523       46
0.7944732904434204
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A/Test/LikelyPositives/3WGU-f2-positive_graphs.npy      0.794473290443