In [41]:
import pandas as pd
import numpy as np
pd.set_option('display.max_columns', None)
import pandas as pd
import torch
import random

### Predicting Nutrient–Gene Regulatory Interactions Using Graph Neural Networks

In [2]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using M3 GPU (MPS)")
else:
    device = torch.device("cpu")
    print("Falling back to CPU")


Using M3 GPU (MPS)


In [3]:
column_names = [
    'ChemicalName',
    'ChemicalID',
    'CasRN',  # To be removed later
    'GeneSymbol',
    'GeneID',
    'GeneForms',
    'Organism',
    'OrganismID',
    'Interaction',
    'InteractionActions',
    'PubMedIDs'
]

In [4]:
chem_gene = pd.read_csv('CTD_chem_gene_ixns.tsv',sep='\t',comment='#', skiprows=7,header=None, names=column_names)
hmdb = pd.read_xml('hmdb_metabolites.xml',parser='lxml')
E_BioMart=pd.read_csv('mart_export.txt', sep=',')


In [9]:
# Filter for human genes only
chem_gene = chem_gene[chem_gene["Organism"] == "Homo sapiens"]

# # Keep only expression-related interactions (regulatory)
chem_gene = chem_gene[chem_gene["InteractionActions"].str.contains("expression", na=False)]
chem_gene.dropna()
chem_gene = chem_gene[["ChemicalName", "GeneSymbol"]].drop_duplicates()
# Standardize formatting
chem_gene["ChemicalName"] = chem_gene["ChemicalName"].str.strip().str.upper()
chem_gene["GeneSymbol"] = chem_gene["GeneSymbol"].str.strip().str.upper()

In [None]:
#Select only relevant columns
nutrients_df = hmdb[["accession", "name", "chemical_formula", "average_molecular_weight", "taxonomy"]]
nutrients_df["name"] = nutrients_df["name"].str.strip().str.upper()

nutrients_df.dropna(subset=["name"], inplace=True)
nutrients_df.drop_duplicates(subset=["name"], inplace=True)

In [14]:
E_BioMart = E_BioMart[E_BioMart["HGNC symbol"].notna()]


In [15]:
# Clean gene symbols
E_BioMart["HGNC symbol"] = E_BioMart["HGNC symbol"].str.strip().str.upper()
E_BioMart = E_BioMart.drop_duplicates(subset=["HGNC symbol"])


In [16]:
cleaned_genes = E_BioMart[["HGNC symbol", "Gene description"]]
cleaned_genes.columns = ["GeneSymbol", "GeneDescription"]
# Define standard chromosomes
standard_chroms = [str(i) for i in range(1, 23)] + ['X', 'Y', 'MT']

# Filter BioMart for only standard chromosomes
E_BioMart = E_BioMart[E_BioMart["Chromosome/scaffold name"].isin(standard_chroms)]
E_BioMart = E_BioMart.dropna(subset=["Chromosome/scaffold name"])



In [17]:
# e.g., convert chromosome to categorical features later
cleaned_genes["Chromosome"] = E_BioMart["Chromosome/scaffold name"]

In [18]:
cleaned_genes['Chromosome'].unique()

array(['MT', nan, 'Y', '21', '13', '18', '22', '20', 'X', '15', '8', '10',
       '14', '9', '4', '16', '5', '19', '7', '6', '12', '11', '17', '3',
       '1', '2'], dtype=object)

In [19]:
# Step 1: Join CTD gene info with gene metadata (BioMart)
ctd_gene_enriched = pd.merge(
    chem_gene,                      # from CTD
    cleaned_genes,           # from BioMart
    how='inner',
    on='GeneSymbol'          # already uppercased
)


In [None]:
# Standardize chemical names to match CTD format
nutrients_df["name"] = nutrients_df["name"].str.strip().str.upper()

nutrients_df = nutrients_df.drop_duplicates(subset=["name"])
chemicals_with_nutrients = pd.merge(
    ctd_gene_enriched,
    nutrients_df,
    how='inner',
    left_on="ChemicalName",
    right_on="name"
)


In [23]:
from sklearn.preprocessing import LabelEncoder

chemical_le = LabelEncoder()
gene_le = LabelEncoder()

chem_ids = chemical_le.fit_transform(chem_gene["ChemicalName"])
gene_ids = gene_le.fit_transform(chem_gene["GeneSymbol"])


In [None]:
num_nutrients = len(chemical_le.classes_)
gene_ids_offset = gene_ids + num_nutrients


edge_index = torch.tensor([chem_ids, gene_ids_offset], dtype=torch.long)


In [25]:
total_nodes = num_nutrients + len(gene_le.classes_)
x = torch.eye(total_nodes)  # identity matrix as node features


In [26]:
from torch_geometric.data import Data

data = Data(x=x, edge_index=edge_index)


In [27]:
positive_edges = edge_index.T  # shape: [num_edges, 2]


In [28]:
num_neg = positive_edges.shape[0]
negative_edges = set()

positive_edge_set = set(map(tuple, positive_edges.tolist()))

while len(negative_edges) < num_neg:
    c = random.randint(0, num_nutrients - 1)
    g = random.randint(num_nutrients, total_nodes - 1)
    if (c, g) not in positive_edge_set:
        negative_edges.add((c, g))

negative_edges = torch.tensor(list(negative_edges), dtype=torch.long)


In [29]:
# Combine edges
all_edges = torch.cat([positive_edges, negative_edges], dim=0)

# Labels: 1 for positive, 0 for negative
labels = torch.cat([
    torch.ones(positive_edges.shape[0]),
    torch.zeros(negative_edges.shape[0])
]).long()


In [30]:
from sklearn.model_selection import train_test_split

train_edges, test_edges, train_labels, test_labels = train_test_split(
    all_edges, labels, test_size=0.2, random_state=42
)


In [31]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class GCNLinkPredictor(torch.nn.Module):
    def __init__(self, num_nodes, hidden_dim=64):
        super().__init__()
        self.embedding = torch.nn.Embedding(num_nodes, hidden_dim)
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.link_predictor = torch.nn.Sequential(
            torch.nn.Linear(hidden_dim * 2, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, 1)
        )

    def forward(self, data, edge_index):
        x = self.embedding.weight
        x = self.conv1(x, data.edge_index).relu()
        x = self.conv2(x, data.edge_index).relu()
        return x

    def predict(self, x, edge_pairs):
        h1 = x[edge_pairs[:, 0]]
        h2 = x[edge_pairs[:, 1]]
        return torch.sigmoid(self.link_predictor(torch.cat([h1, h2], dim=1))).squeeze()


In [32]:

model = GCNLinkPredictor(num_nodes=total_nodes).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.BCELoss()

train_edges = train_edges.to(device)
train_labels = train_labels.to(device)
test_edges = test_edges.to(device)
test_labels = test_labels.to(device)

for epoch in range(1, 101):
    model.train()
    optimizer.zero_grad()
    x = model(data, data.edge_index)
    pred = model.predict(x, train_edges)
    loss = loss_fn(pred, train_labels.float())
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        with torch.no_grad():
            model.eval()
            test_pred = model.predict(x, test_edges)
            test_loss = loss_fn(test_pred, test_labels.float()).item()
            test_acc = ((test_pred > 0.5) == test_labels).float().mean().item()
            print(f"Epoch {epoch}: Train Loss = {loss.item():.4f}, Test Loss = {test_loss:.4f}, Test Acc = {test_acc:.4f}")


Epoch 10: Train Loss = 0.6049, Test Loss = 0.5936, Test Acc = 0.7381
Epoch 20: Train Loss = 0.3912, Test Loss = 0.3890, Test Acc = 0.8440
Epoch 30: Train Loss = 0.2108, Test Loss = 0.2103, Test Acc = 0.9218
Epoch 40: Train Loss = 0.1554, Test Loss = 0.1581, Test Acc = 0.9444
Epoch 50: Train Loss = 0.1404, Test Loss = 0.1445, Test Acc = 0.9493
Epoch 60: Train Loss = 0.1360, Test Loss = 0.1404, Test Acc = 0.9504
Epoch 70: Train Loss = 0.1343, Test Loss = 0.1391, Test Acc = 0.9499
Epoch 80: Train Loss = 0.1331, Test Loss = 0.1414, Test Acc = 0.9510
Epoch 90: Train Loss = 0.1331, Test Loss = 0.1383, Test Acc = 0.9512
Epoch 100: Train Loss = 0.1287, Test Loss = 0.1363, Test Acc = 0.9509


In [33]:
model.eval()
with torch.no_grad():
    x = model(data, data.edge_index)
    test_pred = model.predict(x, test_edges)
    test_loss = loss_fn(test_pred, test_labels.float()).item()
    test_acc = ((test_pred > 0.5) == test_labels).float().mean().item()

print(f"\nFinal Test Loss: {test_loss:.4f}")
print(f"Final Test Accuracy: {test_acc:.4f}")



Final Test Loss: 0.1375
Final Test Accuracy: 0.9501


In [34]:
from sklearn.metrics import roc_auc_score

test_pred_np = test_pred.cpu().numpy()
test_labels_np = test_labels.cpu().numpy()
auc = roc_auc_score(test_labels_np, test_pred_np)
print(f"Test ROC-AUC: {auc:.4f}")


Test ROC-AUC: 0.9867


In [None]:
# # Define all possible nutrient-gene combinations
# all_possible = set((i, j) for i in range(num_nutrients) for j in range(num_nutrients, total_nodes))
#
# # Remove known (training + testing) edges
# known = set(map(tuple, all_edges.tolist()))
# unknown = list(all_possible - known)


### To generate a random sample of negative edges  nutrient–gene pairs that are not known to interact in the training/testing data . so we can  test the model’s ability to predict unknown or “novel” interactions.

In [35]:

sample_size = 10000
unknown = set()

positive_set = set(map(tuple, all_edges.tolist()))

while len(unknown) < sample_size:
    c = random.randint(0, num_nutrients - 1)
    g = random.randint(num_nutrients, total_nodes - 1)
    if (c, g) not in positive_set:
        unknown.add((c, g))

unknown = list(unknown)


### to predict which new (unknown) nutrient–gene pairs might be true interactions — even though they've never been seen during training.

In [37]:

if len(unknown) > sample_size:
    unknown = random.sample(unknown, sample_size)


unknown_tensor = torch.tensor(unknown, dtype=torch.long)

# Get node embeddings once
with torch.no_grad():
    model.eval()
    node_reps = model(data, data.edge_index).cpu()

# Predict in batches
batch_size = 10000
scores = []

for i in range(0, len(unknown), batch_size):
    batch_edges = unknown_tensor[i:i+batch_size].to(device)
    batch_scores = model.predict(node_reps.to(device), batch_edges)
    scores.append(batch_scores.cpu())

scores = torch.cat(scores, dim=0)

# Top-K
topk = 100
top_scores, top_indices = torch.topk(scores, topk)
top_predictions = [unknown[i] for i in top_indices.numpy()]


In [39]:
top_predictions

[(5770, 15772),
 (2384, 12758),
 (2344, 9043),
 (3282, 12642),
 (3784, 9049),
 (5878, 11053),
 (1661, 14556),
 (5346, 21430),
 (4019, 21007),
 (970, 21430),
 (1444, 14556),
 (1477, 8294),
 (5549, 15321),
 (2794, 8152),
 (6115, 15987),
 (1368, 26934),
 (1412, 7977),
 (5375, 16994),
 (2513, 10725),
 (1060, 10972),
 (2250, 18621),
 (4085, 10625),
 (1418, 23410),
 (1493, 19238),
 (5555, 11731),
 (3686, 16742),
 (5555, 16825),
 (5346, 16604),
 (1468, 31122),
 (5490, 29408),
 (4593, 9000),
 (1368, 23975),
 (1477, 30393),
 (4226, 11038),
 (1661, 15714),
 (5581, 7435),
 (6131, 31430),
 (5375, 12761),
 (1853, 11649),
 (5549, 7702),
 (4748, 27123),
 (2513, 12244),
 (4805, 20985),
 (4038, 31821),
 (174, 24090),
 (5922, 24358),
 (1838, 7268),
 (1042, 11205),
 (4593, 31787),
 (3941, 28562),
 (3495, 25573),
 (5980, 8156),
 (615, 29087),
 (2868, 9545),
 (2262, 13154),
 (743, 12339),
 (1853, 25115),
 (5922, 30876),
 (4810, 29408),
 (1444, 32031),
 (743, 14671),
 (1853, 31803),
 (1853, 27483),
 (5807, 

In [40]:
for c, g in top_predictions:
    nutrient_name = chemical_le.inverse_transform([c])[0]
    gene_symbol = gene_le.inverse_transform([g - num_nutrients])[0]
    print(f"{nutrient_name} → {gene_symbol}")


TEMOZOLOMIDE → IFNG
CYCLOSPORINE → F3
CUPRIC CHLORIDE → CASP3
GOLD COMPOUNDS → ESR1
LEFLUNOMIDE → CASP8
THIRAM → CYP2A6
BETA-LAPACHONE → GREB1
ROTENONE → MTOR
MERCURIC BROMIDE → MMP9
7,8-DIHYDRO-7,8-DIHYDROXYBENZO(A)PYRENE 9,10-OXIDE → MTOR
ASBESTOS, CROCIDOLITE → GREB1
ATRAZINE → BMP2
SMOKE → HMMR
ENDOSULFAN → BCL2L1
URETHANE → IL1R1
ANTIRHEUMATIC AGENTS → RPS6KB1
ARISTOLOCHIC ACID I → AURKB
S-(1,2-DICHLOROVINYL)CYSTEINE → KRT8
DI-N-BUTYLPHOSPHORIC ACID → CRYAB
ACETAMINOPHEN → CXCR2
COBALTOUS CHLORIDE → LPL
METHYLEUGENOL → CPT1A
ARSENIC → PAX6
AVOBENZONE → MCL1
SODIUM ARSENITE → DNAJC7
K 7174 → KIT
SODIUM ARSENITE → KLK10
ROTENONE → KDR
ATAZANAVIR SULFATE → TP53
SEOCALCITOL → SQSTM1
OKADAIC ACID → CARD6
ANTIRHEUMATIC AGENTS → PIK3R3
ATRAZINE → TGFBR2
MRK 003 → CYP1B1
BETA-LAPACHONE → IDO1
SODIUM SELENITE → AREG
VALPROIC ACID → TRIO
S-(1,2-DICHLOROVINYL)CYSTEINE → F8
CADMIUM CHLORIDE → DNAAF1
SMOKE → ASCC3
PARTICULATE MATTER → RYR1
DI-N-BUTYLPHOSPHORIC ACID → EIF2AK2
PERFLUORO-N-NONANO