In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data
import glob
import matplotlib.pyplot as plt
import os
import csv
from torch_geometric.nn import GATConv, global_mean_pool

# Define GAT model for batched data
class GAT(torch.nn.Module):
    def __init__(self, in_channels, out_channels, dropout_p=0.1):
        super().__init__()
        self.gat = GATConv(in_channels, out_channels, heads=1, concat=True, edge_dim=1)
        self.pool = global_mean_pool  # Can also use global_max_pool or global_add_pool
        self.dropout = nn.Dropout(p=dropout_p)
        self.norm = nn.BatchNorm1d(out_channels)
        self.linear = torch.nn.Linear(out_channels, 1)

    def forward(self, x, edge_index, edge_attr, batch):
        out, attn_weights = self.gat(x, edge_index, edge_attr, return_attention_weights=True)
        out = self.dropout(out)
        out = self.pool(out, batch)  # Pool over nodes in each graph
        out = self.norm(out)
        out = self.dropout(out) 
        out = self.linear(out)
        return out, attn_weights

def organize_graph_and_add_weight(file_path, label):
    data = np.load(file_path, allow_pickle=True).item()
    inverse_distance = data['inverse_distance']
    encoded_matrix = data['encoded_matrix']

    x = torch.tensor(encoded_matrix, dtype=torch.float32)
    adj = torch.tensor(inverse_distance, dtype=torch.float32)

    # Normalize adjacency (row-normalize)
    adj = adj / (adj.sum(dim=1, keepdim=True) + 1e-8)

    # Create edge_index and edge weights
    edge_index = (adj > 0).nonzero(as_tuple=False).t()
    edge_weight = adj[adj > 0]

    y = torch.tensor([label], dtype=torch.float32)
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_weight, y=y)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

models_38 = []
models_12 = []
k = 50
#version_number = 6

skip_zero = {1, 5, 8, 11, 17, 19, 23, 36, 37, 43, 47, 48}
skip_one = {i + 1 for i in skip_zero} # Convert to 1-based to match filenames model_bin_{i}.pth

for i in range (1, (k + 1)):
    # Initialize model
    model = GAT(in_channels=37, out_channels=32).to(device)

    # Load saved model weights
    model_path = f"GATModels-5A_exp5v2/Models/model_bin_{i}.pth"  # Update with correct path if needed
    model.load_state_dict(torch.load(model_path, map_location=device))

    # Set the model to evaluation mode
    model.eval()

    if i in skip_one:
        models_12.append(model)
    else:
        models_38.append(model)


  model.load_state_dict(torch.load(model_path, map_location=device))


In [None]:
def get_capture_rate(dir, csv_output, models):
    predictions = []
    capture_rate = 0
    overlapping_capture_rate = 0
    overlapping_index = 0

    files = glob.glob(f"{dir}/*.npy")

    model_positive_counts = [0] * len(models)
    model_negative_counts = [0] * len(models)

    #with open(csv_output, "w", newline='') as csvfile:
        # writer = csv.writer(csvfile)
        # writer.writerow(["filename", "average_score", "number_atoms"]) 

    for file in sorted(files):    
        prediction = 0
        for model_index, model in enumerate(models):
            graph = organize_graph_and_add_weight(file, label=0).to(device)
            non_padded_rows = graph.x.size(0)
            with torch.no_grad():
                out, _ = model(graph.x, graph.edge_index, graph.edge_attr, batch=torch.zeros(graph.x.size(0), dtype=torch.long).to(device))
                prob = torch.sigmoid(out).item()
                prediction += prob

                if prob >= 0.5:
                    model_positive_counts[model_index] += 1
                else:
                    model_negative_counts[model_index] += 1
        prediction /= len(models)
        predictions.append((file, prediction))
        capture_rate += prediction
        print(f"{file:<120} {prediction:<25} {non_padded_rows}")
        overlapping_capture_rate += prediction
        overlapping_index += 1
        #writer.writerow([file, prediction, non_padded_rows])

    capture_rate /= len(files)
    overlapping_capture_rate /= overlapping_index

    print("\nModel Predictions Summary:")
    for i, (pos, neg) in enumerate(zip(model_positive_counts, model_negative_counts), start=1):
        print(f"Model {i}: Positives = {pos}, Negatives = {neg}")

    #print("Overlapping Capture Rate is", overlapping_capture_rate)

    return capture_rate

In [None]:
ivans = "../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-separate-graphs-5A/positive"
csv_output = "ivan_capture_rates.csv"

capture_rate = get_capture_rate(ivans, csv_output, models)

print("Ivan Capture Rate is", capture_rate)

../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-separate-graphs-5A/positive/4HQJ-filtered_graphs.npy                0.8775387024879455        29
../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-separate-graphs-5A/positive/4RET-filtered_graphs.npy                0.42912747502326964       55
../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-separate-graphs-5A/positive/5OQT-filtered_graphs.npy                0.7849809539318084        51
../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-separate-graphs-5A/positive/5SY1-filtered_graphs.npy                0.9631420254707337        34
../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-separate-graphs-5A/positive/5WB2-filtered_graphs.npy                0.9446999228000641        38
../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-separate-graphs-5A/positive/6AWN-filtered_graphs.npy                0.905002304315567         22
../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-separate-graphs-5A/positive/6AWO-filtered_graph

In [28]:
test_positives = "../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/Positive"
csv_output = "test_positive_capture_rates.csv"

capture_rate = get_capture_rate(test_positives, csv_output)

print("Test Positive Capture Rate is", capture_rate)

../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/Positive/1LRI-filtered_graphs.npy           0.8888251280784607        71
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/Positive/2RH1-filtered_graphs.npy           0.831122887134552         53
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/Positive/3NY9-filtered_graphs.npy           0.9577220892906189        52
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/Positive/3NYA-filtered_graphs.npy           0.9596696841716766        49
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/Positive/4BOE-filtered_graphs.npy           0.4106355902552605        124
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/Positive/4OR2-filtered_graphs.npy           0.8734824097156525        40
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/Positive/4XNU-filtered

In [29]:
test_unlabeled = "../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/Unlabeled"
csv_output = "test_unlabeled_capture_rates.csv"

capture_rate = get_capture_rate(test_unlabeled, csv_output)

print("Test Unlabeled Capture Rate is", capture_rate)

../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/Unlabeled/1ZHY-f3_graphs.npy                0.003320596047851723      111
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/Unlabeled/3GKI-f5_graphs.npy                0.16226619333028794       67
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/Unlabeled/3NY8-f1_graphs.npy                0.0602626090683043        70
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/Unlabeled/3NY9-f2_graphs.npy                0.03468574482947588       66
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/Unlabeled/3NYA-f5_graphs.npy                0.04198685772716999       95
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/Unlabeled/3WGV-f3_graphs.npy                0.28095412731170655       72
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/Unlabeled/4BQU-f2_grap

In [30]:
test_likely_positives = "../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/LikelyPositives"
csv_output = "test_positive_unlabeled_capture_rates.csv"

capture_rate = get_capture_rate(test_likely_positives, csv_output)

print("Test Unlabeled Capture Rate is", capture_rate)

../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/LikelyPositives/2RH1-f2-positive_graphs.npy 0.810255675315857         34
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/LikelyPositives/2RH1-f3-positive_graphs.npy 0.933477019071579         21
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/LikelyPositives/2ZXE-f1-positive_graphs.npy 0.19053895115852357       72
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/LikelyPositives/3A3Y-f1-positive_graphs.npy 0.3826242706179619        55
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/LikelyPositives/3WGU-f1-positive_graphs.npy 0.3070597265660763        46
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/LikelyPositives/3WGU-f2-positive_graphs.npy 0.7188967639207839        44
../../../Data/SplitData/Cholesterol/cholesterol-separate-graphs-5A_exp1/Test/LikelyPositives/3WGU-f3