In [None]:
!pip install torch torchvision torchaudio
!pip install torch-geometric
!pip install rdkit-pypi
# Imports
import torch
import torch.nn.functional as F
from torch_geometric.data import DataLoader
# from torch_geometric.datasets import Tox21
from torch_geometric.datasets import ZINC
from torch_geometric.nn import GINConv, GCNConv, global_add_pool
from torch_geometric.utils import train_test_split_edges
from torch_geometric.nn.models import VGAE
from rdkit import Chem
import numpy as np
import networkx as nx

In [None]:
torch_geometric.__version__

In [None]:
# Load the Tox21 dataset
path = 'data/ZINC'
dataset = ZINC(path)

# Shuffle and split
dataset = dataset.shuffle()
train_dataset = dataset[:int(len(dataset)*0.8)]
val_dataset   = dataset[int(len(dataset)*0.8):int(len(dataset)*0.9)]
test_dataset  = dataset[int(len(dataset)*0.9):]

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=32)
test_loader  = DataLoader(test_dataset, batch_size=32)

In [None]:
data.x, data.edge_index

In [None]:
data.pos_edge_index

In [None]:
generator.kl_loss()

In [None]:
# Load ZINC for graph generation
gen_dataset = ZINC('data/ZINC', subset=True)
gen_loader = DataLoader(gen_dataset, batch_size=64, shuffle=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Encoder definition
class GCNEncoder(torch.nn.Module):
    def __init__(self, in_feats, hidden_dim):
        super().__init__()
        self.conv1 = GCNConv(in_feats, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.conv_mu = GCNConv(hidden_dim, hidden_dim)
        self.conv_logvar = GCNConv(hidden_dim, hidden_dim)
    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        return self.conv_mu(x, edge_index), self.conv_logvar(x, edge_index)

# Instantiate VGAE
generator = VGAE(GCNEncoder(gen_dataset.num_features, 64)).to(device)
opt_gen = torch.optim.Adam(generator.parameters(), lr=1e-3)
warmup_epochs = 10  # epochs over which to anneal KL weight from 0 to 1

# Training loop for VGAE (adjacency reconstruction)
for epoch in range(1, 41):
    generator.train(); total_loss = 0
    for data in gen_loader:
        data = data.to(device)
        data.x = data.x.float()
        opt_gen.zero_grad()
        z = generator.encode(data.x, data.edge_index)

        recon_loss = generator.recon_loss(z, data.edge_index)
        kl_loss = generator.kl_loss()
        beta = min(1.0, epoch / warmup_epochs)
        loss = recon_loss + beta * kl_loss
        loss.backward(); opt_gen.step()
        total_loss += loss.item()
    print(f"VGAE Epoch {epoch:02d} | Loss: {total_loss/len(gen_loader):.4f}")

# Sampling new graphs
generator.eval()
with torch.no_grad():
    # Sample latent vectors from the prior
    num_nodes = 100
    hidden_dim = generator.encoder.conv1.out_channels
    z_sample = torch.randn((num_nodes, hidden_dim), device=device)

    # Construct a fully-connected edge index for decoding
    row = torch.arange(num_nodes, device=device).unsqueeze(1).repeat(1, num_nodes).view(-1)
    col = torch.arange(num_nodes, device=device).unsqueeze(0).repeat(num_nodes, 1).view(-1)
    full_edge_index = torch.stack([row, col], dim=0)

    # Decode edge probabilities for all possible node pairs
    edge_probs = generator.decoder(z_sample, full_edge_index, sigmoid=True)

    # Threshold to select likely edges
    threshold = 0.5
    mask = edge_probs > threshold
    sampled_edge_index = full_edge_index[:, mask]

    print(f"Sampled graph with {sampled_edge_index.size(1)} edges (threshold={threshold})")
    # (Further processing to convert sampled_edge_index into a valid molecule would follow)

# G = nx.from_numpy_matrix(adj.cpu().numpy())
# # Attempt to convert networkx to SMILES (placeholder)
# smiles = Chem.MolToSmiles(Chem.RWMol())
# print("Sampled SMILES:", smiles)

In [None]:
??generator.encode

In [None]:
??generator.recon_loss

In [None]:
??generator.decoder

In [None]:
class GIN(torch.nn.Module):
    def __init__(self, num_layers, hidden_dim, num_tasks):
        super(GIN, self).__init__()
        self.convs = torch.nn.ModuleList()
        self.bns   = torch.nn.ModuleList()

        for i in range(num_layers):
            nn_lin = torch.nn.Sequential(
                torch.nn.Linear(hidden_dim if i>0 else dataset.num_features, hidden_dim),
                torch.nn.ReLU(),
                torch.nn.Linear(hidden_dim, hidden_dim)
            )
            conv = GINConv(nn_lin)
            self.convs.append(conv)
            self.bns.append(torch.nn.BatchNorm1d(hidden_dim))

        self.linear = torch.nn.Linear(hidden_dim, num_tasks)

    def forward(self, x, edge_index, batch):
        # x: node features, edge_index: graph connectivity, batch: batch vector
        for conv, bn in zip(self.convs, self.bns):
            x = conv(x, edge_index)
            x = bn(x)
            x = F.relu(x)

        # Global pooling
        x = global_add_pool(x, batch)
        return self.linear(x)

# Instantiate model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GIN(num_layers=5, hidden_dim=64, num_tasks=dataset.num_tasks).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Binary Cross-Entropy with missing label masking
def train():
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        # Mask invalid labels (-1)
        mask = data.y == data.y
        loss = F.binary_cross_entropy_with_logits(out[mask], data.y[mask].to(torch.float))
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(train_dataset)

# Validation
def evaluate(loader):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data.x, data.edge_index, data.batch)
            y_true.append(data.y.cpu())
            y_pred.append(torch.sigmoid(out).cpu())
    y_true = torch.cat(y_true, dim=0)
    y_pred = torch.cat(y_pred, dim=0)
    # Compute ROC-AUC per task
    from sklearn.metrics import roc_auc_score
    scores = []
    for i in range(dataset.num_tasks):
        mask = ~torch.isnan(y_true[:, i])
        if mask.sum() > 0:
            scores.append(roc_auc_score(y_true[mask, i], y_pred[mask, i]))
    return np.mean(scores)

In [None]:
ZINC.

In [None]:
# Run training
for epoch in range(1, 31):
    loss = train()
    val_auc = evaluate(val_loader)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val AUC: {val_auc:.4f}')

In [None]:
test_auc = evaluate(test_loader)
print(f'Test ROC-AUC: {test_auc:.4f}')