In [5]:
import pandas as pd
import torch
from torch_geometric.data import Data
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score

In [6]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using M3 GPU (MPS)")
else:
    device = torch.device("cpu")
    print("Falling back to CPU")


Using M3 GPU (MPS)


In [7]:
prot_alis=  pd.read_csv('9606.protein.aliases.v12.0.txt',sep='\t')

In [8]:
prot_alis

Unnamed: 0,#string_protein_id,alias,source
0,9606.ENSP00000000233,2B6H,Ensembl_PDB
1,9606.ENSP00000000233,2B6H,UniProt_DR_PDB
2,9606.ENSP00000000233,381,Ensembl_HGNC_entrez_id
3,9606.ENSP00000000233,381,KEGG_GENEID
4,9606.ENSP00000000233,381,KEGG_KEGGID_SHORT
...,...,...,...
3889202,9606.ENSP00000501317,regulatory factor X domain containing 2,Ensembl_HGNC_prev_name
3889203,9606.ENSP00000501317,"regulatory factor X, 7",Ensembl_HGNC_prev_name
3889204,9606.ENSP00000501317,regulatory factor X7,Ensembl_HGNC_name
3889205,9606.ENSP00000501317,uc059jng.1,Ensembl_HGNC_ucsc_id


In [9]:
prot= pd.read_csv('9606.protein.links.v12.0.txt',sep=' ')

In [10]:
prot

Unnamed: 0,protein1,protein2,combined_score
0,9606.ENSP00000000233,9606.ENSP00000356607,173
1,9606.ENSP00000000233,9606.ENSP00000427567,154
2,9606.ENSP00000000233,9606.ENSP00000253413,151
3,9606.ENSP00000000233,9606.ENSP00000493357,471
4,9606.ENSP00000000233,9606.ENSP00000324127,201
...,...,...,...
13715399,9606.ENSP00000501317,9606.ENSP00000475489,195
13715400,9606.ENSP00000501317,9606.ENSP00000370447,158
13715401,9606.ENSP00000501317,9606.ENSP00000312272,226
13715402,9606.ENSP00000501317,9606.ENSP00000402092,169


In [11]:
prot =prot[prot['combined_score'] >= 700]
prot_alis = prot_alis[prot_alis['source'] == 'Ensembl_HGNC']

In [16]:
prot_alis.head()

Unnamed: 0,#string_protein_id,alias,source
23,9606.ENSP00000000233,ARF5,Ensembl_HGNC
219,9606.ENSP00000000412,M6PR,Ensembl_HGNC
397,9606.ENSP00000001008,FKBP4,Ensembl_HGNC
527,9606.ENSP00000001146,CYP26B1,Ensembl_HGNC
759,9606.ENSP00000002125,NDUFAF7,Ensembl_HGNC


In [18]:
gene_map = prot_alis.set_index('#string_protein_id')['alias'].to_dict()

prot = prot[
    prot['protein1'].isin(gene_map) &
    prot['protein2'].isin(gene_map)
].copy()

prot['p1gene'] = prot['protein1'].map(gene_map)
prot['p2gene'] = prot['protein2'].map(gene_map)

prot = prot[['protein1', 'p1gene', 'protein2', 'p2gene', 'combined_score']]

In [35]:
prot = prot.drop(columns=['protein1','protein2'])

In [37]:
prot

Unnamed: 0,p1gene,p2gene,combined_score
85,ARF5,ACAP1,825
130,ARF5,COPA,718
160,ARF5,RAB11FIP3,952
197,ARF5,COPB2,752
268,ARF5,COPE,795
...,...,...,...
13715019,LDB1,ZFPM1,942
13715034,LDB1,LHX4,944
13715072,RFX7,RFX5,780
13715299,RFX7,RFXANK,978


In [20]:
reactome_df = pd.read_csv("UniProt2Reactome_All_Levels.txt", sep="\t", names=["UniProt_ID", "Pathway_ID", "URL","pathway_name","Evidence_Code","Species"])

In [53]:
react

In [55]:
human_gene = pd.read_csv('HUMAN_9606_idmapping.dat',sep='\t', header=None, names=["UniProt_ID", "DB_Name", "Gene_Symbol"])

In [56]:
human_gene.head()

Unnamed: 0,UniProt_ID,DB_Name,Gene_Symbol
0,P31946,UniProtKB-ID,1433B_HUMAN
1,P31946,Gene_Name,YWHAB
2,P31946,GI,4507949
3,P31946,GI,377656702
4,P31946,GI,67464628


In [57]:
gene_map = human_gene[human_gene["DB_Name"] == "Gene_Name"]
gene_map = gene_map[["UniProt_ID", "Gene_Symbol"]].rename(columns={"Gene_Symbol": "Value"})


In [58]:
uni_gene = dict(zip(gene_map["UniProt_ID"], gene_map["Value"]))

In [78]:
reactome_df = reactome_df.dropna(subset=["Gene_Symbol","Pathway_ID"])

In [79]:
reactome_df["Gene_Symbol"] = reactome_df["UniProt_ID"].map(uni_gene)

In [81]:
reactome=reactome_df.drop(columns=['URL','Evidence_Code','Species'])
reactome = reactome.drop_duplicates()

In [90]:
reactome

Unnamed: 0,UniProt_ID,Pathway_ID,pathway_name,Gene_Symbol
905,A0A075B6P5,R-HSA-109582,Hemostasis,IGKV2-28
906,A0A075B6P5,R-HSA-1280218,Adaptive Immune System,IGKV2-28
908,A0A075B6P5,R-HSA-1643685,Disease,IGKV2-28
910,A0A075B6P5,R-HSA-166658,Complement cascade,IGKV2-28
911,A0A075B6P5,R-HSA-166663,Initial triggering of complement,IGKV2-28
...,...,...,...,...
881989,Q9Y6Z7,R-HSA-166662,Lectin pathway of complement activation,COLEC10
881990,Q9Y6Z7,R-HSA-166663,Initial triggering of complement,COLEC10
881991,Q9Y6Z7,R-HSA-166786,Creation of C4 and C2 activators,COLEC10
881992,Q9Y6Z7,R-HSA-168249,Innate Immune System,COLEC10


In [82]:
path_count =reactome.groupby("Gene_Symbol")["Pathway_ID"].nunique()


In [85]:
label_df = path_count.apply(lambda x : 1 if x>=3 else 0).reset_index()
label_df.columns = ["gene", "label"]

In [86]:
label_df['label'].value_counts()

label
1    10847
0      446
Name: count, dtype: int64

In [91]:
label = dict(zip(label_df["gene"], label_df["label"]))
list_node_gene = list(set(prot['p1gene']).union(prot['p2gene']).union(reactome['Gene_Symbol']))
list_node_path = list(set(reactome['Pathway_ID']))
nodes = list_node_path + list_node_gene

In [109]:
node_id = {node_name : idx for idx, node_name in enumerate(nodes)}

In [111]:
gene_pathway = []

for _, row in reactome.iterrows():
    gene = row['Gene_Symbol']
    pathway = row['Pathway_ID']

    src = node_id[gene]
    dst = node_id[pathway]

    gene_pathway.append((src, dst))


In [112]:
rev_edges = [(dst, src) for (src, dst) in gene_pathway]
all_edges = gene_pathway + rev_edges


In [115]:
gene_gene = []

for _, row in prot.iterrows():
    gene = row['p1gene']
    gene1 = row['p2gene']

    src = node_id[gene]
    dst = node_id[gene1]

    gene_gene.append((src, dst))


In [116]:
rev_edges_gene = [(dst, src) for (src, dst) in gene_gene]
all_edges_gene = gene_gene + rev_edges_gene


In [117]:
alledges = all_edges_gene+ all_edges

In [118]:
edge_index = torch.tensor(alledges, dtype=torch.long).t()  # shape [2, num_edges]

In [131]:
num_nodes = len(node_id)
x = torch.zeros((num_nodes, 2))
for idx in range(num_nodes):
    x[idx, 0] = sum(1 for (gene, path) in gene_pathway if gene == idx)
for src, dst in alledges:
    x[src, 1] += 1
    x[dst, 1] += 1

x = (x - x.mean(dim=0)) / x.std(dim=0)

In [132]:
labels = torch.full((num_nodes,), -1)  # default is -1 (unlabeled)
for _, row in label_df.iterrows():
    gene = row['gene']
    if gene in node_id:
        labels[node_id[gene]] = row['label']


In [133]:
data = Data(
    x=x,
    edge_index=edge_index,
    y=labels,
)


In [142]:

# Get indices of gene nodes only
gene_nodes = [i for i, name in enumerate(nodes) if not name.startswith("R-HSA")]
train_idx, test_idx = train_test_split(gene_nodes, test_size=0.2, random_state=42)
train_idx, val_idx  = train_test_split(train_idx, test_size=0.25, random_state=42)

train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
val_mask   = torch.zeros(data.num_nodes, dtype=torch.bool)
test_mask  = torch.zeros(data.num_nodes, dtype=torch.bool)

train_mask[train_idx] = True
val_mask[val_idx]     = True
test_mask[test_idx]   = True

data.train_mask = train_mask
data.val_mask   = val_mask
data.test_mask  = test_mask
data = data.to(device)
valid_labels = data.y[data.y != -1]

num_pos = (valid_labels == 1).sum().item()
num_neg = (valid_labels == 0).sum().item()
class_weights = torch.tensor(
    [num_pos / num_neg, 1.0],  # weight class 0 more
    dtype=torch.float
).to(device)

In [143]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv

class GAT(torch.nn.Module):
    def __init__(self, dim_in, dim_h, dim_out, heads=4):
        super().__init__()
        self.gat1 = GATConv(dim_in, dim_h, heads=heads, concat=True, dropout=0.5)
        self.gat2 = GATConv(dim_h * heads, dim_out, heads=1, concat=False, dropout=0.5)

    def forward(self, x, edge_index):
        h = self.gat1(x, edge_index)
        h = F.elu(h)
        h = self.gat2(h, edge_index)
        return F.log_softmax(h, dim=1)


In [146]:
model = GAT(dim_in=data.x.shape[1], dim_h=128, dim_out=2, heads=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)  # weighted loss if class imbalance

for epoch in range(1, 201):
    model.train()
    optimizer.zero_grad()

    out = model(data.x, data.edge_index)

    mask = data.train_mask & (data.y != -1)
    loss = criterion(out[mask], data.y[mask])
    loss.backward()
    optimizer.step()

    # Validation accuracy
    model.eval()
    with torch.no_grad():
        val_mask = data.val_mask & (data.y != -1)
        pred = out[val_mask].argmax(dim=1)
        val_acc = (pred == data.y[val_mask]).sum() / val_mask.sum()

    if epoch % 10 == 0:
        print(f"Epoch {epoch:03d} | Loss: {loss:.3f} | Val Acc: {val_acc:.3f}")


Epoch 010 | Loss: 0.618 | Val Acc: 0.694
Epoch 020 | Loss: 0.589 | Val Acc: 0.580
Epoch 030 | Loss: 0.572 | Val Acc: 0.631
Epoch 040 | Loss: 0.548 | Val Acc: 0.682
Epoch 050 | Loss: 0.510 | Val Acc: 0.675
Epoch 060 | Loss: 0.491 | Val Acc: 0.695
Epoch 070 | Loss: 0.479 | Val Acc: 0.648
Epoch 080 | Loss: 0.450 | Val Acc: 0.766
Epoch 090 | Loss: 0.415 | Val Acc: 0.687
Epoch 100 | Loss: 0.392 | Val Acc: 0.779
Epoch 110 | Loss: 0.368 | Val Acc: 0.759
Epoch 120 | Loss: 0.383 | Val Acc: 0.765
Epoch 130 | Loss: 0.345 | Val Acc: 0.789
Epoch 140 | Loss: 0.374 | Val Acc: 0.778
Epoch 150 | Loss: 0.343 | Val Acc: 0.767
Epoch 160 | Loss: 0.384 | Val Acc: 0.811
Epoch 170 | Loss: 0.345 | Val Acc: 0.765
Epoch 180 | Loss: 0.360 | Val Acc: 0.820
Epoch 190 | Loss: 0.348 | Val Acc: 0.792
Epoch 200 | Loss: 0.321 | Val Acc: 0.833


In [147]:
val_f1 = f1_score(
    data.y[val_mask].cpu(),
    pred.cpu(),
    average="binary"
)
print(f"Val Acc: {val_acc:.3f} | Val F1: {val_f1:.3f}")


Val Acc: 0.833 | Val F1: 0.905


In [148]:
test_mask = data.test_mask & (data.y != -1)
pred = out[test_mask].argmax(dim=1)

from sklearn.metrics import f1_score, precision_score, recall_score

test_f1 = f1_score(data.y[test_mask].cpu(), pred.cpu(), average="binary")
test_precision = precision_score(data.y[test_mask].cpu(), pred.cpu(), average="binary")
test_recall = recall_score(data.y[test_mask].cpu(), pred.cpu(), average="binary")

print(f"Test F1: {test_f1:.3f}, Precision: {test_precision:.3f}, Recall: {test_recall:.3f}")


Test F1: 0.906, Precision: 0.996, Recall: 0.831


In [149]:
probs = torch.softmax(out, dim=1)[:, 1]
gene_scores = [
    (gene_name, probs[idx].item())
    for gene_name, idx in node_id.items()
    if data.y[idx] != -1
]
top_genes = sorted(gene_scores, key=lambda x: x[1], reverse=True)[:20]
print(top_genes)


[('KLC2', 1.0), ('BRMS1', 1.0), ('SPTBN1', 1.0), ('HAUS6', 1.0), ('RAD50', 1.0), ('IGHV3-33', 1.0), ('RPL13A', 1.0), ('RMI2', 1.0), ('AHCTF1', 1.0), ('CCNK', 1.0), ('DUSP16', 1.0), ('AR', 1.0), ('LAT2', 1.0), ('RANGAP1', 1.0), ('SPTB', 1.0), ('FGD1', 1.0), ('SPTBN2', 1.0), ('IRAK2', 1.0), ('UBE2I', 1.0), ('BRCC3', 1.0)]
