In [28]:
import pandas as pd
import numpy as np
import torch
import dgl
from torch.optim import SparseAdam
from torch.utils.data import DataLoader
from dgl.nn.pytorch import MetaPath2Vec

In [29]:
file_path = '../data/dgidb/preprocessed_34_10.tsv'
interaction_matrix = pd.read_csv(file_path, sep='\t', index_col=0)

In [30]:
# Generate node lists
drugs = list(interaction_matrix.columns)
genes = list(interaction_matrix.index)

# Map drugs and genes to unique IDs
drug_to_id = {drug: idx for idx, drug in enumerate(drugs)}
gene_to_id = {gene: idx for idx, gene in enumerate(genes)}

# Extract edges using the mappings
edges = []
for gene, interactions in interaction_matrix.iterrows():
    gene_id = gene_to_id[gene]
    for drug, interaction in interactions.items():
        if interaction != 0:
            drug_id = drug_to_id[drug]
            edges.append((drug_id, gene_id))  # drug to gene edge


In [31]:
# Construct the heterograph
g = dgl.heterograph({
    ('drug', 'dg', 'gene'): edges
})

In [32]:
model = MetaPath2Vec(g, ['dg'], window_size=5, walk_length=10, num_walks=80, workers=5)

  0%|          | 0/1236 [00:00<?, ?it/s]

100%|██████████| 1236/1236 [00:00<00:00, 7135.18it/s]


In [33]:
dataloader = DataLoader(torch.arange(g.num_nodes('drug')), batch_size=128,
                        shuffle=True, collate_fn=model.sample)
optimizer = SparseAdam(model.parameters(), lr=0.025)

In [34]:
for (pos_u, pos_v, neg_v) in dataloader:
    loss = model(pos_u, pos_v, neg_v)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [35]:
# Get the embeddings of all drug nodes
drug_nids = torch.LongTensor(model.local_to_global_nid['drug'])
drug_emb = model.node_embed(drug_nids)

In [36]:
# Get the embeddings of all gene nodes
gene_nids = torch.LongTensor(model.local_to_global_nid['gene'])
gene_emb = model.node_embed(gene_nids)

In [37]:
# Create inverse mappings
id_to_drug = {idx: drug for drug, idx in drug_to_id.items()}
id_to_gene = {idx: gene for gene, idx in gene_to_id.items()}

# Convert embeddings to DataFrames
drug_embeddings_df = pd.DataFrame(drug_emb.detach().numpy())
drug_embeddings_df.index = [id_to_drug[i] for i in range(drug_embeddings_df.shape[0])]

gene_embeddings_df = pd.DataFrame(gene_emb.detach().numpy())
gene_embeddings_df.index = [id_to_gene[i] for i in range(gene_embeddings_df.shape[0])]

In [38]:
save_path = '../data/dgidb/embeddings'

drug_embeddings_df.to_csv(save_path+'/metapath2vec_drug_embeddings.csv',header=None)
gene_embeddings_df.to_csv(save_path+'/metapath2vec_gene_embeddings.csv',header=None)