In [13]:
import torch
import torch.nn as nn
import glob
import csv
import numpy as np

class CNN2D(nn.Module):
    def __init__(self, input_channels):
        super(CNN2D, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        self.pool1 = nn.MaxPool2d(2, 2)
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.pool2 = nn.MaxPool2d(2, 2)
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.pool3 = nn.MaxPool2d(2, 2)

        
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(128 * 4 * 18, 128)  # Adjust based on input size
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, 1)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        
        x = self.conv2(x)
        x = self.pool2(x)
        
        x = self.conv3(x)
        x = self.pool3(x)
        
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)

        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

models = []

k = 50

for i in range (1, (k + 1)):
    model = CNN2D(input_channels=1).to(device)
    model = nn.DataParallel(model)

    model_path = f"../../../Models/Cholesterol/GNN/WeightedGNNModels-5A_exp5/model_bin_{i}.pth" 
    model.load_state_dict(torch.load(model_path, map_location=device))

    model.eval()

    models.append(model)


  model.load_state_dict(torch.load(model_path, map_location=device))


In [14]:
def evaluate_file(model, file_path, threshold=0.5):
    grid = np.load(file_path)

    if grid.ndim == 2:
        non_padded_rows = np.sum(np.any(grid != 0, axis=(1)))
    else:
        raise ValueError(f"Unexpected grid shape: {grid.shape}")
    
    grid_tensor = torch.tensor(grid, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # Add batch and channel dims
    grid_tensor = grid_tensor.to(device)
    
    model.eval()
    with torch.no_grad():
        output = model(grid_tensor).squeeze(1)  

    prob = torch.sigmoid(output).item()

    predicted_class = int(prob >= threshold)

    return predicted_class, prob, non_padded_rows


In [15]:
def evaluate_directory(dir, csv_output):
    predictions = []
    capture_rate = 0
    overlapping_capture_rate = 0
    overlapping_index = 0

    files = glob.glob(f"{dir}/*.npy")

    model_positive_counts = [0] * len(models)
    model_negative_counts = [0] * len(models)

    print(f"{'Filename':<120} {'IndividualCaptureRate':<25} NumberOfAtoms")

    with open(csv_output, "w", newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["filename", "average_score", "number_atoms"]) 

        for file in sorted(files):    
            prediction = 0
            for model_index, model in enumerate(models):
                predicted_class, prob, non_padded_rows = evaluate_file(model, file)
                prediction += prob

                if predicted_class == 1:
                    model_positive_counts[model_index] += 1
                else:
                    model_negative_counts[model_index] += 1

            prediction /= len(models)
            predictions.append((model_index, predicted_class))
            capture_rate += prediction
            #if non_padded_rows <= 65 and non_padded_rows >= 55:
            print(f"{file:<120} {prediction:<25} {non_padded_rows}")
            overlapping_capture_rate += prediction
            overlapping_index += 1
            writer.writerow([file, prediction, non_padded_rows])

    capture_rate /= len(files)
    overlapping_capture_rate /= overlapping_index

    print("\nModel Predictions Summary:")
    for i, (pos, neg) in enumerate(zip(model_positive_counts, model_negative_counts), start=1):
        print(f"Model {i}: Positives = {pos}, Negatives = {neg}")

    print("Overlapping Capture Rate is", overlapping_capture_rate)

    return capture_rate

In [16]:
spies = f"../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-graph-5A/positive"  
csv_output = "ivan_capture_rates.csv"

spy_capture_rate = evaluate_directory(spies, csv_output)

print("Spy Capture Rate is", spy_capture_rate)

Filename                                                                                                                 IndividualCaptureRate     NumberOfAtoms
../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-graph-5A/positive/4HQJ-filtered_combined_matrix.npy                 0.7670954841375351        29
../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-graph-5A/positive/4RET-filtered_combined_matrix.npy                 0.9548534774780273        55
../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-graph-5A/positive/5OQT-filtered_combined_matrix.npy                 0.7469310593605042        51
../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-graph-5A/positive/5SY1-filtered_combined_matrix.npy                 0.9628125691413879        34
../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-graph-5A/positive/5WB2-filtered_combined_matrix.npy                 0.9721274650096894        38
../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-graph-5A/positive/6AWN-filtered_comb

In [133]:
test_positives = f"../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/Positive"
csv_output = "test_positive_capture_rates.csv"

test_positives_capture_rate = evaluate_directory(test_positives, csv_output)

print("Test Positives Capture Rate is", test_positives_capture_rate)

Filename                                                                                                                 IndividualCaptureRate     NumberOfAtoms
../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/Positive/6CO7-filtered_combined_matrix.npy            0.9590863716602326        55
../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/Positive/7DDH-filtered_combined_matrix.npy            0.9758950531482696        56
../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/Positive/7F61-filtered_combined_matrix.npy            0.9235081505775452        59
../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/Positive/7FJD-filtered_combined_matrix.npy            0.9255039882659912        57
../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/Positive/7FJE-filtered_combined_matrix.npy            0.8120794075727463        58
../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/Positive/7P02-filtered

In [134]:
test_unlabeled = f"../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/Unlabeled"
csv_output = "test_unlabeled_capture_rates.csv"

test_unlabeled_capture_rate = evaluate_directory(test_unlabeled, csv_output)

print("Test Unlabeled Capture Rate is", test_unlabeled_capture_rate)

Filename                                                                                                                 IndividualCaptureRate     NumberOfAtoms
../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/Unlabeled/3GKI-f3_combined_matrix.npy                 0.7529516047239304        61
../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/Unlabeled/5XRA-f5_combined_matrix.npy                 0.113811110034585         60
../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/Unlabeled/5ZM7-f3_combined_matrix.npy                 0.007383164752973244      59
../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/Unlabeled/6DRZ-f4_combined_matrix.npy                 0.09051638813689351       64
../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/Unlabeled/6GYH-f3_combined_matrix.npy                 0.015272298865020274      62
../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/Unlabeled/6IDF-f3_comb

In [135]:
test_likely_positives = f"../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/LikelyPositives"
csv_output = "test_positive_unlabeled_capture_rates.csv"

test_likely_positives_capture_rate = evaluate_directory(test_likely_positives, csv_output)

print("Test Likely Positives Capture Rate is", test_likely_positives_capture_rate)

Filename                                                                                                                 IndividualCaptureRate     NumberOfAtoms
../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/LikelyPositives/3A3Y-f1-positive_combined_matrix.npy  0.7908233547210693        55
../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/LikelyPositives/3WGV-f4-positive_combined_matrix.npy  0.003840681263245642      65
../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/LikelyPositives/5AVR-f5-positive_combined_matrix.npy  0.5157500685751438        56
../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/LikelyPositives/5AVT-f2-positive_combined_matrix.npy  0.8376268994808197        59
../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/LikelyPositives/5AVU-f1-positive_combined_matrix.npy  0.9729719662666321        56
../../../Data/SplitData/Cholesterol/cholesterol-graph-5A_exp5/Test/LikelyPositives/5AVV-f