In [51]:
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split

In [52]:
canonical_file="../dataset/canonical_trainset.csv"
canonical_smiles_df=pd.read_csv(canonical_file)

In [53]:
import os
import sys
import nbimporter

project_root = os.path.join(os.getcwd(), '..')
sys.path.append(project_root)

from datapreparation import Data_Processing_3D as nb3


graph_data = nb3.preprocess_smiles_with_labels_3d(canonical_smiles_df["SMILES"][:50].values, canonical_smiles_df["Label"][:50].values)

# Process voxels and label

In [54]:
voxels = np.array([item['voxels'] for item in graph_data])  # [num_samples, depth, height, width]
labels = np.array([item['label'] for item in graph_data])  # [num_samples,]
voxels = torch.tensor(voxels, dtype=torch.float).unsqueeze(1)  # [num_samples, 1, depth, height, width]
labels = torch.tensor(labels, dtype=torch.long)

(50, 20, 20, 20)
(50,)
torch.Size([50, 1, 20, 20, 20])
torch.Size([50])


# load dataset

In [55]:
from torch.utils.data import TensorDataset, DataLoader

voxels_train, voxels_test, labels_train, labels_test = train_test_split(voxels, labels, test_size=0.2, random_state=42)
train_dataset = TensorDataset(voxels_train, labels_train)
test_dataset = TensorDataset(voxels_test, labels_test)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

# Build model

In [56]:
class Simple3DCNN(nn.Module):
    def __init__(self):
        super(Simple3DCNN, self).__init__()
        self.conv1 = nn.Conv3d(1, 16, kernel_size=3, padding=1)
        self.pool = nn.MaxPool3d(2)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 5 * 5 * 5, 128)
        self.fc2 = nn.Linear(128, 2)

    def forward(self, x, feature_extract=False):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        if feature_extract:
            x = x.view(x.size(0), -1)  
            x = F.relu(self.fc1(x))
            return x
        x = x.view(-1, 32 * 5 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# train and evaluate model

In [57]:
from sklearn.metrics import accuracy_score

model = Simple3DCNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

def train(model, train_loader, optimizer, criterion):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    for voxels_batch, labels_batch in train_loader:
        optimizer.zero_grad()
        outputs = model(voxels_batch)
        loss = criterion(outputs, labels_batch)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels_batch.size(0)
        correct += (predicted == labels_batch).sum().item()
    avg_loss = total_loss / len(train_loader)
    accuracy = correct / total
    return avg_loss, accuracy

def validate(model, test_loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for voxels_batch, labels_batch in test_loader:
            outputs = model(voxels_batch)
            loss = criterion(outputs, labels_batch)
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels_batch.size(0)
            correct += (predicted == labels_batch).sum().item()
    avg_loss = total_loss / len(test_loader)
    accuracy = correct / total
    return avg_loss, accuracy

num_epochs = 10
best_val_accuracy = 0

for epoch in range(num_epochs):
    train_loss, train_accuracy = train(model, train_loader, optimizer, criterion)
    val_loss, val_accuracy = validate(model, test_loader, criterion)
    
    print(f'Epoch {epoch+1}, '
          f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, '
          f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')
    
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), 'best3_model.pth')
        print("Model saved at Epoch", epoch+1)



Epoch 1, Train Loss: 0.4690, Train Accuracy: 0.8500, Validation Loss: 0.0001, Validation Accuracy: 1.0000
Model saved at Epoch 1
Epoch 2, Train Loss: 0.3851, Train Accuracy: 0.9500, Validation Loss: 0.0939, Validation Accuracy: 1.0000
Epoch 3, Train Loss: 0.2184, Train Accuracy: 0.9500, Validation Loss: 0.0720, Validation Accuracy: 1.0000
Epoch 4, Train Loss: 0.2126, Train Accuracy: 0.9500, Validation Loss: 0.0575, Validation Accuracy: 1.0000
Epoch 5, Train Loss: 0.2073, Train Accuracy: 0.9500, Validation Loss: 0.0614, Validation Accuracy: 1.0000
Epoch 6, Train Loss: 0.2248, Train Accuracy: 0.9500, Validation Loss: 0.0392, Validation Accuracy: 1.0000
Epoch 7, Train Loss: 0.2192, Train Accuracy: 0.9500, Validation Loss: 0.0820, Validation Accuracy: 1.0000
Epoch 8, Train Loss: 0.1925, Train Accuracy: 0.9500, Validation Loss: 0.0406, Validation Accuracy: 1.0000
Epoch 9, Train Loss: 0.2096, Train Accuracy: 0.9500, Validation Loss: 0.0456, Validation Accuracy: 1.0000
Epoch 10, Train Loss: 0

# extract features

In [58]:
def extract_features_3d(model, loader):
    model.eval()  
    features = []
    labels = []
    with torch.no_grad():  
        for voxels, label in loader:
            feature = model(voxels, feature_extract=True)
            features.append(feature.cpu().numpy())
            labels.append(label.cpu().numpy())
    features = np.concatenate(features, axis=0)
    labels = np.concatenate(labels, axis=0)
    return features, labels


In [59]:
model.load_state_dict(torch.load('best3_model.pth'))
all_loader = DataLoader(TensorDataset(voxels, labels), batch_size=4, shuffle=False)
all_features, all_labels = extract_features_3d(model, all_loader)

In [60]:
print(all_features.shape)
print(all_labels.shape)

(50, 128)
(50,)
