In [2]:
import pandas as pd
import torch
from collections import defaultdict

In [3]:
# 1. Load dataset
df = pd.read_csv(r"C:/Users/Manasa/drkg/drkg.tsv", sep="\t", header=None)
df.columns = ["head", "relation", "tail"]


In [4]:
print("Sample triples:")
print(df.head())


Sample triples:
         head                        relation         tail
0  Gene::2157  bioarx::HumGenHumGen:Gene:Gene   Gene::2157
1  Gene::2157  bioarx::HumGenHumGen:Gene:Gene   Gene::5264
2  Gene::2157  bioarx::HumGenHumGen:Gene:Gene   Gene::2158
3  Gene::2157  bioarx::HumGenHumGen:Gene:Gene   Gene::3309
4  Gene::2157  bioarx::HumGenHumGen:Gene:Gene  Gene::28912


In [5]:
entities = set(df["head"]).union(set(df["tail"]))
relations = set(df["relation"])
entity2id = {e: idx for idx, e in enumerate(sorted(entities))}
relation2id = {r: idx for idx, r in enumerate(sorted(relations))}
df["head_id"] = df["head"].map(entity2id)
df["tail_id"] = df["tail"].map(entity2id)
df["rel_id"]  = df["relation"].map(relation2id)


In [7]:
pd.DataFrame(list(entity2id.items()), columns=["entity", "id"]).to_csv("C:/Users/Manasa/OneDrive/Desktop/Drug_Repurposing_Gnn/data/processed/entities.csv", index=False)
pd.DataFrame(list(relation2id.items()), columns=["relation", "id"]).to_csv("C:/Users/Manasa/OneDrive/Desktop/Drug_Repurposing_Gnn/data/processed/relations.csv", index=False)
df[["head_id", "rel_id", "tail_id"]].to_csv("C:/Users/Manasa/OneDrive/Desktop/Drug_Repurposing_Gnn/data/processed/triples.csv", index=False)
edge_index = torch.tensor(
    [df["head_id"].tolist(), df["tail_id"].tolist()],
    dtype=torch.long
)

edge_type = torch.tensor(df["rel_id"].tolist(), dtype=torch.long)



In [8]:
print("edge_index shape:", edge_index.shape)
print("edge_type shape:", edge_type.shape)


edge_index shape: torch.Size([2, 5874261])
edge_type shape: torch.Size([5874261])


In [9]:
drug_mask = df["head"].str.contains("Compound::") | df["tail"].str.contains("Compound::")
disease_mask = df["head"].str.contains("Disease::") | df["tail"].str.contains("Disease::")
drug_disease_subgraph = df[drug_mask | disease_mask]
drug_disease_subgraph.to_csv("C:/Users/Manasa/OneDrive/Desktop/Drug_Repurposing_Gnn/data/processed/drug_disease_subgraph.csv", index=False)
print("Drug-Disease subgraph size:", drug_disease_subgraph.shape) 

Drug-Disease subgraph size: (1967508, 6)
