In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from ase import Atoms
from ase.io import read
from torch_geometric.data import Data
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import to_networkx
from torch_geometric.data import DataLoader
from sklearn.model_selection import train_test_split

import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem

In [None]:
## TODO: Unzip all of Tox21 from gz.

In [None]:
class GraphConvolution(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GraphConvolution, self).__init__(aggr='add')
        self.lin = nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        return x_j

    def update(self, aggr_out):
        return self.lin(aggr_out)

class GraphEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(GraphEncoder, self).__init__()
        self.conv1 = GraphConvolution(input_dim, hidden_dim)
        self.conv2 = GraphConvolution(hidden_dim, hidden_dim)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = nn.ReLU()(x)
        x = self.conv2(x, edge_index)
        return x


In [None]:
def load_data(csv_paths, sdf_paths):
    graphs = []
    for csv_path, sdf_path in zip(csv_paths, sdf_paths):
        # Load features from CSV
        features_df = pd.read_csv(csv_path)
        features = features_df.values

        # Load atoms from SDF
        mol = Chem.SDMolSupplier(sdf_path)[0]

        AllChem.Compute2DCoords(mol)

        # Convert molecule to graph representation
        graph = mol_to_graph(mol, features)
        graphs.append(graph)
    return graphs

def mol_to_graph(mol, features):
    atomic_numbers = mol.get_atomic_numbers()
    edge_index = mol.get_all_edges()
    x = np.concatenate([features, np.eye(len(atomic_numbers))], axis=1)
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    x = torch.tensor(x, dtype=torch.float)
    return Data(x=x, edge_index=edge_index)

In [None]:
class ToxicClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(ToxicClassifier, self).__init__()
        self.encoder = GraphEncoder(input_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, 1)

    def forward(self, data):
        x = self.encoder(data)
        x = self.fc(x)
        return torch.sigmoid(x)

# Load data
mol_files = ['mol1.mol', 'mol2.mol', ...]  # List of paths to your Mol files
graphs = load_data(mol_files)

# Split data
train_graphs, test_graphs = train_test_split(graphs, test_size=0.2, random_state=42)

# Initialize model
input_dim = len(graphs[0].x[0])  # Assuming node feature size is same for all graphs
hidden_dim = 64
model = ToxicClassifier(input_dim, hidden_dim)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train model
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for graph in train_graphs:
        optimizer.zero_grad()
        output = model(graph)
        target = torch.tensor([graph.tox], dtype=torch.float)  #### CHECK with the tox21 data to validate what the column name is!!
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}')

# Evaluate model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for graph in test_graphs:
        output = model(graph)
        predicted = torch.round(output)
        target = torch.tensor([graph.y], dtype=torch.float)
        total += 1
        if predicted == target:
            correct += 1

    accuracy = correct / total
    print(f'Accuracy: {accuracy}')