In [17]:
import torch
import torch.nn as nn
import glob
import csv
import numpy as np

# 3D CNN Model
class CNN3D(nn.Module):
    def __init__(self):
        super(CNN3D, self).__init__()
        
        self.conv0 = nn.Conv3d(in_channels=23, out_channels=64, kernel_size=1, stride=1, padding=0) # play around with output channels
        self.conv1 = nn.Conv3d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv3d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.pool2 = nn.MaxPool3d(kernel_size=2, stride=2, padding=0)

        #self.dropout_conv = nn.Dropout3d(p=0.05)
        
        # After two pooling layers, spatial dimensions reduce from 40x40x40 -> 5x5x5
        self.fc1 = nn.Linear(128 * 3 * 3 * 3, 256)  # Try increasing over 256
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 2)  # Assuming 1 output for docking status/position

        #self.dropout_fc = nn.Dropout(p=0.15)
        
    def forward(self, x):
        # Forward pass through Conv layers
        x = self.pool(torch.relu(self.conv0(x)))  # Conv0 -> ReLU -> Pooling
        #x = self.dropout_conv(x)
        x = self.pool(torch.relu(self.conv1(x)))  # Conv1 -> ReLU -> Pooling
        x = self.pool2(torch.relu(self.conv2(x)))  # Conv2 -> ReLU -> Pooling

        # Flatten the input for fully connected layers
        x = x.view(-1, 128 * 3 * 3 * 3)
        
        # Forward pass through fully connected layers
        x = torch.relu(self.fc1(x)) #use tanh activation
        #x = self.dropout_fc(x)
        x = torch.relu(self.fc2(x))
        x = torch.nn.functional.softmax(self.fc3(x), dim=1)  # Final layer (output layer)
        #x = torch.clamp(x, min=1e-7, max=1 - 1e-7)  # Clamp outputs to avoid extreme values
        
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

models = []

k = 50

for i in range (1, (k + 1)):
    model = CNN3D().to(device)

    model_path = f"3DCholesterolModels-5A_exp5/Models/model_bin_{i}.pth" 
    model.load_state_dict(torch.load(model_path, map_location=device))

    model.eval()

    models.append(model)


  model.load_state_dict(torch.load(model_path, map_location=device))


In [18]:
def evaluate_file(model, file_path, threshold=0.5):
    grid = np.load(file_path)

    non_padded_rows = np.sum(np.any(grid != 0, axis=(3)))
    
    grid_tensor = torch.tensor(grid, dtype=torch.float32).permute(3, 0, 1, 2).unsqueeze(0)
    grid_tensor = grid_tensor.to(device)
    
    model.eval()
    with torch.no_grad():
        output = model(grid_tensor)

    prob = output[0, 1].item()

    predicted_class = int(prob >= threshold)

    return predicted_class, prob, non_padded_rows


In [19]:
def evaluate_directory(dir, csv_output):
    predictions = []
    capture_rate = 0
    overlapping_capture_rate = 0
    overlapping_index = 0

    files = glob.glob(f"{dir}/*.npy")

    model_positive_counts = [0] * len(models)
    model_negative_counts = [0] * len(models)

    print(f"{'Filename':<120} {'IndividualCaptureRate':<25} NumberOfAtoms")

    with open(csv_output, "w", newline='') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["filename", "average_score", "number_atoms"]) 

        for file in sorted(files):    
            prediction = 0
            for model_index, model in enumerate(models):
                predicted_class, prob, non_padded_rows = evaluate_file(model, file)
                prediction += prob

                if predicted_class == 1:
                    model_positive_counts[model_index] += 1
                else:
                    model_negative_counts[model_index] += 1

            prediction /= len(models)
            predictions.append((model_index, predicted_class))
            capture_rate += prediction
            #if non_padded_rows <= 65 and non_padded_rows >= 55:
            print(f"{file:<120} {prediction:<25} {non_padded_rows}")
            overlapping_capture_rate += prediction
            overlapping_index += 1
            writer.writerow([file, prediction, non_padded_rows])

    capture_rate /= len(files)
    overlapping_capture_rate /= overlapping_index

    print("\nModel Predictions Summary:")
    for i, (pos, neg) in enumerate(zip(model_positive_counts, model_negative_counts), start=1):
        print(f"Model {i}: Positives = {pos}, Negatives = {neg}")

    print("Overlapping Capture Rate is", overlapping_capture_rate)

    return capture_rate

In [20]:
ivans = f"../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-grid-5A/positive"  
csv_output = "ivan_capture_rates.csv"

spy_capture_rate = evaluate_directory(ivans, csv_output)

print("Ivan Capture Rate is", spy_capture_rate)

Filename                                                                                                                 IndividualCaptureRate     NumberOfAtoms
../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-grid-5A/positive/4HQJ-filtered_grid_0.npy                           0.9989777672290802        29
../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-grid-5A/positive/4RET-filtered_grid_0.npy                           0.9340482950210571        55
../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-grid-5A/positive/5OQT-filtered_grid_0.npy                           0.3551073710620403        51
../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-grid-5A/positive/5SY1-filtered_grid_0.npy                           0.9991418302059174        34
../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-grid-5A/positive/5WB2-filtered_grid_0.npy                           0.999446120262146         38
../../../Data/SplitData/Cholesterol/IvanTestSet/ivan-grid-5A/positive/6AWN-filtered_grid_

In [37]:
test_positives = f"../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Positive"
csv_output = "test_positive_capture_rates.csv"

test_positives_capture_rate = evaluate_directory(test_positives, csv_output)

print("Test Positives Capture Rate is", test_positives_capture_rate)

Filename                                                                                                                 IndividualCaptureRate     NumberOfAtoms
../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Positive/3N9Y-filtered_grid_0.npy                      0.9713197994232178        71
../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Positive/3N9Y-filtered_grid_1.npy                      0.9928898227214813        71
../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Positive/3N9Y-filtered_grid_2.npy                      0.9916872453689575        71
../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Positive/3N9Y-filtered_grid_3.npy                      0.9955240058898925        71
../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Positive/3N9Y-filtered_grid_4.npy                      0.9980100858211517        71
../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Positive/3WGV-filtered_

In [38]:
test_unlabeled = f"../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Unlabeled"
csv_output = "test_unlabeled_capture_rates.csv"

test_unlabeled_capture_rate = evaluate_directory(test_unlabeled, csv_output)

print("Test Unlabeled Capture Rate is", test_unlabeled_capture_rate)

Filename                                                                                                                 IndividualCaptureRate     NumberOfAtoms
../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Unlabeled/1ZHY-f2_grid_0.npy                           0.0030731118854600936     78
../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Unlabeled/1ZHY-f2_grid_1.npy                           0.013825518573867157      78
../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Unlabeled/1ZHY-f2_grid_2.npy                           0.013162282207049429      78
../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Unlabeled/1ZHY-f2_grid_3.npy                           0.0760960202361457        78
../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Unlabeled/1ZHY-f2_grid_4.npy                           0.00619010612484999       78
../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Unlabeled/1ZHY-f5_grid_

In [39]:
test_likely_positives = f"../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/LikelyPositives"
csv_output = "test_positive_unlabeled_capture_rates.csv"

test_likely_positives_capture_rate = evaluate_directory(test_likely_positives, csv_output)

print("Test Likely Positives Capture Rate is", test_likely_positives_capture_rate)

Filename                                                                                                                 IndividualCaptureRate     NumberOfAtoms
../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/LikelyPositives/2RH1-f2-positive_grid_0.npy            0.9986886751651763        34
../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/LikelyPositives/2RH1-f2-positive_grid_1.npy            0.9990948045253754        34
../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/LikelyPositives/2RH1-f2-positive_grid_2.npy            0.9986801433563233        34
../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/LikelyPositives/2RH1-f2-positive_grid_3.npy            0.9991277611255646        34
../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/LikelyPositives/2RH1-f2-positive_grid_4.npy            0.9990006232261658        34
../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/LikelyPositives/2RH1-f3

In [40]:
import pandas as pd
import re

df = pd.read_csv("test_positive_capture_rates.csv")
df.columns = df.columns.str.strip()  # remove any hidden spaces in column names
df['number_atoms'] = pd.to_numeric(df['number_atoms'], errors='coerce')  # ensure it's numeric

filtered_df = df[df['average_score'] < 0.5]
pd.set_option('display.max_rows', None)  # None means unlimited
pd.set_option('display.width', 500)
pd.set_option('display.max_colwidth', 200)
print(filtered_df.shape)
print(filtered_df)
mean_score = filtered_df['average_score'].mean()
print(f"Mean average_score: {mean_score}")

# Extract protein names using regex and create a set
protein_names = sorted(
    set(filtered_df['filename'].apply(lambda x: re.search(r'/([A-Z0-9]{4})-filtered_grid', x).group(1)))
)
print(len(protein_names), "is length of protein names")
for name in protein_names:
    print(name)



(40, 3)
                                                                                                filename  average_score  number_atoms
10   ../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Positive/4BOE-filtered_grid_0.npy       0.000007           124
11   ../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Positive/4BOE-filtered_grid_1.npy       0.000013           124
12   ../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Positive/4BOE-filtered_grid_2.npy       0.000038           124
13   ../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Positive/4BOE-filtered_grid_3.npy       0.000032           124
14   ../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Positive/4BOE-filtered_grid_4.npy       0.000023           124
95   ../../../Data/SplitData/Cholesterol/cholesterol-grid-5A_exp5/Test/Positive/5L7D-filtered_grid_0.npy       0.257617            84
96   ../../../Data/SplitData/Cholesterol/cholesterol-g