In [1]:
from torch_geometric.datasets import QM9
from torch_geometric.transforms import AddSelfLoops, ToUndirected, Compose
from torch_geometric.loader import DataLoader
import torch.nn.functional as F
import torch.nn as nn #تغییر ویژگی اصلی و بیس کد داده شده
from torch.nn import Linear
from torch_geometric.nn import GCNConv, global_mean_pool
import torch
from rdkit import Chem
from rdkit.Chem import Crippen, Descriptors
from rdkit import RDLogger

In [2]:
path = './qm9'
transform = Compose([ToUndirected(), AddSelfLoops()])
dataset = QM9(path, transform=transform)

In [3]:
RDLogger.DisableLog('rdApp.*')

def estimate_logS(mol):
    logP = Crippen.MolLogP(mol)
    mw = Descriptors.MolWt(mol)
    return -0.74 * logP - 0.006 * mw

def is_molecule_toxic(mol):
    smarts_nitro = Chem.MolFromSmarts('[N+](=O)[O-]')
    if mol.HasSubstructMatch(smarts_nitro):
        return True
    if any(atom.GetAtomicNum() == 16 for atom in mol.GetAtoms()):
        return True
    if Descriptors.MolWt(mol) > 150:
        return True
    if Crippen.MolLogP(mol) > 3:
        return True
    return False


hybridization_map = {
    Chem.rdchem.HybridizationType.SP: 1,
    Chem.rdchem.HybridizationType.SP2: 2,
    Chem.rdchem.HybridizationType.SP3: 3,
    Chem.rdchem.HybridizationType.SP3D: 4,
    Chem.rdchem.HybridizationType.SP3D2: 5
}


def preprocess_qm9(dataset):
    new_dataset = []
    skipped = 0

    for i, data in enumerate(dataset):
        smiles = data.smiles
        mol = Chem.MolFromSmiles(smiles)

        if mol is None:
            skipped += 1
            continue

        try:
            Chem.SanitizeMol(mol)
            mol = Chem.RemoveHs(mol)

            hybridization_labels = []
            ring_labels = []
            aromaticity_labels = []

            for j in range(data.num_nodes):
                atom_num = data.z[j].item()

                if atom_num == 1:  # Hydrogen
                    hybridization_labels.append(0)
                    ring_labels.append(0)
                    aromaticity_labels.append(0)
                    continue

                atom = mol.GetAtomWithIdx(j)
                hyb = hybridization_map.get(atom.GetHybridization(), 0)
                hybridization_labels.append(hyb)
                ring_labels.append(1 if atom.IsInRing() else 0)
                aromaticity_labels.append(1 if atom.GetIsAromatic() else 0)

            data.hybridization_label = torch.tensor(hybridization_labels, dtype=torch.long)
            data.ring_label = torch.tensor(ring_labels, dtype=torch.long)
            data.aromaticity_label = torch.tensor(aromaticity_labels, dtype=torch.long)

            data.solubility = torch.tensor([[estimate_logS(mol)]], dtype=torch.float)
            data.is_toxic = torch.tensor([[1.0 if is_molecule_toxic(mol) else 0.0]], dtype=torch.float)

            new_dataset.append(data)

        except Exception as e:
            skipped += 1
            

    print(f"✅ Done: {len(new_dataset)} processed | {skipped} skipped")
    return new_dataset

In [4]:
processed_dataset = preprocess_qm9(dataset)
loader = DataLoader(processed_dataset, batch_size=32, shuffle=True)
sample = processed_dataset[0]
sample.x

✅ Done: 128596 processed | 2235 skipped


tensor([[0., 1., 0., 0., 0., 6., 0., 0., 0., 0., 4.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])

In [5]:

train_dataset = processed_dataset[:8000]
test_dataset = processed_dataset[8000:12000]

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)


class GNNMultiTask(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        #implement your model here
        
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

        
        self.hybrid_head = nn.Linear(hidden_channels, 6)   
        self.ring_head = nn.Linear(hidden_channels, 2)      
        self.aromatic_head = nn.Linear(hidden_channels, 2)  

        
        self.graph_proj = nn.Linear(hidden_channels, hidden_channels)

        self.toxic_head = nn.Linear(hidden_channels, 1)    
        self.solubility_head = nn.Linear(hidden_channels, 1)  


    def forward(self, x, edge_index, batch):
        #change this part but keep the output format
        
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))

         
        hybrid = self.hybrid_head(x)
        ring = self.ring_head(x)
        aromatic = self.aromatic_head(x)

        
        pooled = global_mean_pool(x, batch)  
        g = F.relu(self.graph_proj(pooled))

        toxic = torch.sigmoid(self.toxic_head(g))        
        solubility = self.solubility_head(g)             

        return {
            "hybrid": hybrid,         
            "ring": ring,             
            "aromatic": aromatic,     
            "toxic": toxic,           
            "solubility": solubility  
        }


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GNNMultiTask(in_channels=dataset.num_node_features, hidden_channels=64).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [6]:
import torch
import torch.nn.functional as F

def train(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    #Complete this
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()

        
        out = model(data.x, data.edge_index, data.batch)

        
        loss_hybrid = F.cross_entropy(out["hybrid"], data.hybridization_label)

        loss_ring = F.cross_entropy(out["ring"], data.ring_label)
        loss_aromatic = F.cross_entropy(out["aromatic"], data.aromaticity_label)

        
        target_toxic = data.is_toxic.view(-1)  
        pred_toxic = out["toxic"].view(-1)     

        loss_toxic = F.binary_cross_entropy(pred_toxic, target_toxic)

        loss_solubility = F.mse_loss(out["solubility"], data.solubility)

        
        total = (
            loss_hybrid
            + loss_ring
            + loss_aromatic
            + loss_toxic
            + loss_solubility
        )

        
        total.backward()
        optimizer.step()

        total_loss += total.item()


    return total_loss / len(loader)

In [7]:
def evaluate(model, loader, device): #تغییر ویژگی اصلی و بیس کد داده شده
    model.eval()
    correct_hybrid = correct_ring = correct_aromatic = correct_toxic = 0
    total_hybrid = total_ring = total_aromatic = total_toxic = 0
    total_sol_mse = 0

    with torch.no_grad():
        for data in loader:
            data = data.to(device)

            out = model(data.x, data.edge_index, data.batch)

            
            pred_hybrid = out["hybrid"].argmax(dim=1)
            correct_hybrid += (pred_hybrid == data.hybridization_label).sum().item()
            total_hybrid += data.num_nodes

            
            pred_ring = out["ring"].argmax(dim=1)
            correct_ring += (pred_ring == data.ring_label).sum().item()
            total_ring += data.num_nodes

            
            pred_aromatic = out["aromatic"].argmax(dim=1)
            correct_aromatic += (pred_aromatic == data.aromaticity_label).sum().item()
            total_aromatic += data.num_nodes

            
            pred_toxic = (out["toxic"].view(-1) > 0.5).float()
            correct_toxic += (pred_toxic == data.is_toxic.view(-1)).sum().item()
            total_toxic += data.num_graphs

            
            sol_loss = F.mse_loss(out["solubility"], data.solubility, reduction="sum")
            total_sol_mse += sol_loss.item()

    return {
        "hybrid_acc": correct_hybrid / total_hybrid,
        "ring_acc": correct_ring / total_ring,
        "aromatic_acc": correct_aromatic / total_aromatic,
        "toxic_acc": correct_toxic / total_toxic,
        "solubility_mse": total_sol_mse / total_toxic  
    }

In [None]:
NUM_EPOCHS = 20

for epoch in range(1, NUM_EPOCHS + 1):
    train_loss = train(model, train_loader, optimizer, device)
    test_metrics = evaluate(model, test_loader, device)

    print(f"Epoch {epoch:02d} | Loss: {train_loss:.4f} | "
          f"Hybrid Acc: {test_metrics['hybrid_acc']:.4f} | "
          f"Ring Acc: {test_metrics['ring_acc']:.4f} | "
          f"Aromatic Acc: {test_metrics['aromatic_acc']:.4f} | "
          f"Toxic Acc: {test_metrics['toxic_acc']:.4f} | "
          f"Solubility MSE: {test_metrics['solubility_mse']:.4f}")

Epoch 01 | Loss: 2.4960 | Hybrid Acc: 0.7393 | Ring Acc: 0.7839 | Aromatic Acc: 0.9995 | Toxic Acc: 0.9988 | Solubility MSE: 0.4068
Epoch 02 | Loss: 1.5933 | Hybrid Acc: 0.8012 | Ring Acc: 0.8160 | Aromatic Acc: 0.9737 | Toxic Acc: 0.9988 | Solubility MSE: 0.3237
Epoch 03 | Loss: 1.3678 | Hybrid Acc: 0.8238 | Ring Acc: 0.8069 | Aromatic Acc: 0.9913 | Toxic Acc: 0.9988 | Solubility MSE: 0.2820
Epoch 04 | Loss: 1.2776 | Hybrid Acc: 0.8223 | Ring Acc: 0.8229 | Aromatic Acc: 0.9920 | Toxic Acc: 0.9988 | Solubility MSE: 0.2400
Epoch 05 | Loss: 1.2175 | Hybrid Acc: 0.8140 | Ring Acc: 0.8245 | Aromatic Acc: 0.9746 | Toxic Acc: 0.9988 | Solubility MSE: 0.2043
Epoch 06 | Loss: 1.1749 | Hybrid Acc: 0.8163 | Ring Acc: 0.8218 | Aromatic Acc: 0.9748 | Toxic Acc: 0.9988 | Solubility MSE: 0.2006
Epoch 07 | Loss: 1.1444 | Hybrid Acc: 0.8116 | Ring Acc: 0.8265 | Aromatic Acc: 0.9655 | Toxic Acc: 0.9988 | Solubility MSE: 0.2560
Epoch 08 | Loss: 1.1165 | Hybrid Acc: 0.8081 | Ring Acc: 0.8182 | Aromatic A