In [41]:
# the final model is the last one (Model_5_dot)

In [1]:
import pandas as pd
import glob
import os
import torch
from torch import Tensor
import numpy as np
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, to_hetero
import torch.nn.functional as F
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score, f1_score, classification_report
import matplotlib.pyplot as plt
from torch.optim import Adam
import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# this function will be used to build drug_feat
def to_float(value):
    try:
        return float(value)
    except ValueError:
        return float('nan')

In [3]:
# this function will be used to create edge_index_supervision and edge_label_supervision
def split_edges(edge_index, edge_label, split_ratio=0.7):
    num_edges = edge_index.size(1)
    random_permutation = torch.randperm(num_edges)
    split_edge = int(split_ratio * num_edges)
    edge_index_meg_pass = edge_index[:, random_permutation[:split_edge]]
    edge_index_supervision = edge_index[:, random_permutation[split_edge: ]]
    edge_label_supervision = edge_label[random_permutation[split_edge: ]]
    return edge_index_meg_pass, edge_index_supervision, edge_label_supervision

In [4]:
def create_trainData():
    # creating drug_feat
    file_path = './NE Dataset/drug_attr.csv'
    dfDrug = pd.read_csv(file_path)
    dfDrugFeat = dfDrug.drop(dfDrug.columns[:2], axis=1)
    dfDrugFeat.replace('#NAME?', np.nan, inplace=True) # replacing the missing values with 0
    dfDrugFeat.fillna(0, inplace=True)
    dfDrugFeat = dfDrugFeat.applymap(to_float)
    dfDrugFeat.replace([np.inf], 0, inplace=True) # replacing the inf value with zero
    scaler = MinMaxScaler(feature_range=(-1, 1)) # NORMALIZE HERE
    dfDrugFeat = scaler.fit_transform(dfDrugFeat)
    drug_feat = torch.from_numpy(dfDrugFeat).to(torch.float)
    # creating target_feat
    file_path = './NE Dataset/protein_attr.csv'
    dfTarget = pd.read_csv(file_path)
    dfTargetFeat = dfTarget.drop(dfTarget.columns[:2], axis=1)
    target_feat = torch.from_numpy(dfTargetFeat.values).to(torch.float)
    # unique drug and unique target mapping
    uniqueDrug = dfDrug['1'].unique()
    uniqueDrug = pd.DataFrame(data={
        'DrugId': uniqueDrug,
        'mappedId': pd.RangeIndex(len(uniqueDrug)),
    })
    uniqueTraget = dfTarget['#'].unique()
    uniqueTraget = pd.DataFrame(data={
        'TargetId': uniqueTraget,
        'mappedId': pd.RangeIndex(len(uniqueTraget)),
    })
    # creating edge_index_drug_to_drug
    file_path = './NE Dataset/train_ddi.csv'
    dfDDI = pd.read_csv(file_path)
    source_drug_id = pd.merge(dfDDI['Drug'], uniqueDrug,
                             left_on='Drug', right_on='DrugId', how='left')
    source_drug_id = torch.from_numpy(source_drug_id['mappedId'].values)
    destination_drug_id = pd.merge(dfDDI['Target'], uniqueDrug,
                                  left_on='Target', right_on='DrugId', how='left')
    destination_drug_id = torch.from_numpy(destination_drug_id['mappedId'].values)
    edge_index_drug_to_drug = torch.stack([source_drug_id, destination_drug_id], dim=0)
    # creating edge_index_target_to_target
    file_path = './NE Dataset/train_ppi.csv'
    dfPPI = pd.read_csv(file_path)
    source_target_id = pd.merge(dfPPI['Drug'], uniqueTraget,
                               left_on='Drug', right_on='TargetId', how='left')
    source_target_id = torch.from_numpy(source_target_id['mappedId'].values)
    destination_target_id = pd.merge(dfPPI['Target'], uniqueTraget,
                                    left_on='Target', right_on='TargetId', how='left')
    destination_target_id = torch.from_numpy(destination_target_id['mappedId'].values)
    edge_index_target_to_target = torch.stack([source_target_id, destination_target_id], dim=0)
    # creating edge_index_pos_drug_to_target, edge_index_neg_drug_to_target, pos_edge_label, neg_edge_label
    file_path = './NE Dataset/train_dti.csv'
    dfDTI = pd.read_csv(file_path)
    dfDTIpos = dfDTI[dfDTI['Label'] == 1]
    dfDTIneg = dfDTI[dfDTI['Label'] == 0]
    pos_interact_drug_id = pd.merge(dfDTIpos['Drug'], uniqueDrug,
                                   left_on='Drug', right_on='DrugId', how='left')
    pos_interact_drug_id = torch.from_numpy(pos_interact_drug_id['mappedId'].values)
    pos_interact_target_id = pd.merge(dfDTIpos['Target'], uniqueTraget,
                                      left_on='Target', right_on='TargetId', how='left')
    pos_interact_target_id = torch.from_numpy(pos_interact_target_id['mappedId'].values)
    edge_index_pos_drug_to_target = torch.stack([pos_interact_drug_id, pos_interact_target_id], dim=0) # edge_index_pos_drug_to_target
    neg_interact_drug_id = pd.merge(dfDTIneg['Drug'], uniqueDrug,
                                   left_on='Drug', right_on='DrugId', how='left')
    neg_interact_drug_id = torch.from_numpy(neg_interact_drug_id['mappedId'].values)
    neg_interact_target_id = pd.merge(dfDTIneg['Target'], uniqueTraget,
                                     left_on='Target', right_on='TargetId', how='left')
    neg_interact_target_id = torch.from_numpy(neg_interact_target_id['mappedId'].values)
    edge_index_neg_drug_to_target = torch.stack([neg_interact_drug_id, neg_interact_target_id], dim=0) # edge_index_neg_drug_to_target
    pos_edge_label = dfDTIpos['Label']
    pos_edge_label = torch.tensor(pos_edge_label.values) # pos_edge_label
    neg_edge_label = dfDTIneg['Label']
    neg_edge_label = torch.tensor(neg_edge_label.values) # neg_edge_label
    # creating edge_index_supervision and edge_label_supervision
    edge_index_meg_pass, pos_edge_index_supervision, pos_edge_label_supervision = split_edges(edge_index_pos_drug_to_target, pos_edge_label)
    dropped_edge_index, neg_edge_index_supervision, neg_edge_label_supervision = split_edges(edge_index_neg_drug_to_target, neg_edge_label)
    edge_index_supervision = torch.cat((pos_edge_index_supervision, neg_edge_index_supervision), dim=1)
    edge_label_supervision = torch.cat((pos_edge_label_supervision, neg_edge_label_supervision), dim=0)
    random_permutation = torch.randperm(1424)
    edge_index_supervision = edge_index_supervision[:, random_permutation]
    edge_label_supervision = edge_label_supervision[random_permutation]
    # creating trainData
    # adding node and features of drug
    trainData = HeteroData()
    trainData['drug'].node_id = torch.arange(drug_feat.size(0))
    trainData['drug'].x = drug_feat
    # adding node and features of target
    trainData['target'].node_id = torch.arange(target_feat.size(0))
    trainData['target'].x = target_feat
    # adding all the edges which will be used for message passing to the heterogeneous data
    trainData["drug", "interact", "target"].edge_index = edge_index_meg_pass
    trainData["drug", "similar", "drug"].edge_index = edge_index_drug_to_drug
    trainData["target", "similar", "target"].edge_index = edge_index_target_to_target
    trainData = T.ToUndirected()(trainData) # making all the message passing edges undirected
    # adding edge index and edge label for supervision to the data
    trainData["drug", "interact", "target"].edge_label = edge_label_supervision
    trainData["drug", "interact", "target"].edge_label_index = edge_index_supervision
    return trainData

In [5]:
train_data = create_trainData()

  dfDrug = pd.read_csv(file_path)


In [6]:
train_data

HeteroData(
  [1mdrug[0m={
    node_id=[1482],
    x=[1482, 1444]
  },
  [1mtarget[0m={
    node_id=[1408],
    x=[1408, 4159]
  },
  [1m(drug, interact, target)[0m={
    edge_index=[2, 1659],
    edge_label=[1424],
    edge_label_index=[2, 1424]
  },
  [1m(drug, similar, drug)[0m={ edge_index=[2, 1246640] },
  [1m(target, similar, target)[0m={ edge_index=[2, 1043321] },
  [1m(target, rev_interact, drug)[0m={ edge_index=[2, 1659] }
)

In [7]:
# for i in range(100, 1424):
#     print(train_data['drug', 'interact', 'target'].edge_label[i].item())

In [7]:
def create_valData():
    # creating drug_feat
    file_path = './NE Dataset/drug_attr.csv'
    dfDrug = pd.read_csv(file_path)
    dfDrugFeat = dfDrug.drop(dfDrug.columns[:2], axis=1)
    dfDrugFeat.replace('#NAME?', np.nan, inplace=True) # replacing the missing values with 0
    dfDrugFeat.fillna(0, inplace=True)
    dfDrugFeat = dfDrugFeat.applymap(to_float)
    dfDrugFeat.replace([np.inf], 0, inplace=True) # replacing the inf value with zero
    scaler = MinMaxScaler(feature_range=(-1, 1)) # NORMALIZE HERE
    dfDrugFeat = scaler.fit_transform(dfDrugFeat)
    drug_feat = torch.from_numpy(dfDrugFeat).to(torch.float)
    # creating target_feat
    file_path = './NE Dataset/protein_attr.csv'
    dfTarget = pd.read_csv(file_path)
    dfTargetFeat = dfTarget.drop(dfTarget.columns[:2], axis=1)
    target_feat = torch.from_numpy(dfTargetFeat.values).to(torch.float)
    # unique drug and unique target mapping
    uniqueDrug = dfDrug['1'].unique()
    uniqueDrug = pd.DataFrame(data={
        'DrugId': uniqueDrug,
        'mappedId': pd.RangeIndex(len(uniqueDrug)),
    })
    uniqueTraget = dfTarget['#'].unique()
    uniqueTraget = pd.DataFrame(data={
        'TargetId': uniqueTraget,
        'mappedId': pd.RangeIndex(len(uniqueTraget)),
    })
    # creating edge_index_drug_to_drug
    file_path = './NE Dataset/train_ddi.csv'
    dfDDI = pd.read_csv(file_path)
    source_drug_id = pd.merge(dfDDI['Drug'], uniqueDrug,
                             left_on='Drug', right_on='DrugId', how='left')
    source_drug_id = torch.from_numpy(source_drug_id['mappedId'].values)
    destination_drug_id = pd.merge(dfDDI['Target'], uniqueDrug,
                                  left_on='Target', right_on='DrugId', how='left')
    destination_drug_id = torch.from_numpy(destination_drug_id['mappedId'].values)
    edge_index_drug_to_drug = torch.stack([source_drug_id, destination_drug_id], dim=0)
    # creating edge_index_target_to_target
    file_path = './NE Dataset/train_ppi.csv'
    dfPPI = pd.read_csv(file_path)
    source_target_id = pd.merge(dfPPI['Drug'], uniqueTraget,
                               left_on='Drug', right_on='TargetId', how='left')
    source_target_id = torch.from_numpy(source_target_id['mappedId'].values)
    destination_target_id = pd.merge(dfPPI['Target'], uniqueTraget,
                                    left_on='Target', right_on='TargetId', how='left')
    destination_target_id = torch.from_numpy(destination_target_id['mappedId'].values)
    edge_index_target_to_target = torch.stack([source_target_id, destination_target_id], dim=0)
    # creating edge_index_pos_drug_to_target [these edges will be used for message passing]
    file_path = './NE Dataset/train_dti.csv'
    dfDTI = pd.read_csv(file_path)
    dfDTIpos = dfDTI[dfDTI['Label'] == 1]
    pos_interact_drug_id = pd.merge(dfDTIpos['Drug'], uniqueDrug,
                                   left_on='Drug', right_on='DrugId', how='left')
    pos_interact_drug_id = torch.from_numpy(pos_interact_drug_id['mappedId'].values)
    pos_interact_target_id = pd.merge(dfDTIpos['Target'], uniqueTraget,
                                      left_on='Target', right_on='TargetId', how='left')
    pos_interact_target_id = torch.from_numpy(pos_interact_target_id['mappedId'].values)
    edge_index_pos_drug_to_target = torch.stack([pos_interact_drug_id, pos_interact_target_id], dim=0)
    # creating supervision edge and label
    file_path = './NE Dataset/val_dti.csv'
    dfvalDTI = pd.read_csv(file_path)
    edge_index_supervision_drug = pd.merge(dfvalDTI['Drug'], uniqueDrug,
                                          left_on='Drug', right_on='DrugId', how='left')
    edge_index_supervision_drug = torch.from_numpy(edge_index_supervision_drug['mappedId'].values)
    edge_index_supervision_target = pd.merge(dfvalDTI['Target'], uniqueTraget,
                                            left_on='Target', right_on='TargetId', how='left')
    edge_index_supervision_target = torch.from_numpy(edge_index_supervision_target['mappedId'].values)
    edge_index_supervision = torch.stack([edge_index_supervision_drug, edge_index_supervision_target], dim=0)
    edge_label_supervision = dfvalDTI['Label']
    edge_label_supervision = torch.tensor(edge_label_supervision.values)
    random_permutationVal = torch.randperm(948)
    edge_index_supervision = edge_index_supervision[:, random_permutationVal]
    edge_label_supervision = edge_label_supervision[random_permutationVal]
    # creating the heterogeneous graph valData
    valData = HeteroData()
    # adding node and features of drug
    valData['drug'].node_id = torch.arange(drug_feat.size(0))
    valData['drug'].x = drug_feat
    # adding node and features of target
    valData['target'].node_id = torch.arange(target_feat.size(0))
    valData['target'].x = target_feat
    # adding all the edges which will be used for message passing to the heterogeneous data
    valData["drug", "interact", "target"].edge_index = edge_index_pos_drug_to_target
    valData["drug", "similar", "drug"].edge_index = edge_index_drug_to_drug
    valData["target", "similar", "target"].edge_index = edge_index_target_to_target
    # making all the edges undirected
    valData = T.ToUndirected()(valData)
    # adding edge index and edge label for supervision to the data
    valData["drug", "interact", "target"].edge_label = edge_label_supervision
    valData["drug", "interact", "target"].edge_label_index = edge_index_supervision
    return valData

In [8]:
val_data = create_valData()

  dfDrug = pd.read_csv(file_path)


In [9]:
val_data

HeteroData(
  [1mdrug[0m={
    node_id=[1482],
    x=[1482, 1444]
  },
  [1mtarget[0m={
    node_id=[1408],
    x=[1408, 4159]
  },
  [1m(drug, interact, target)[0m={
    edge_index=[2, 2371],
    edge_label=[948],
    edge_label_index=[2, 948]
  },
  [1m(drug, similar, drug)[0m={ edge_index=[2, 1246640] },
  [1m(target, similar, target)[0m={ edge_index=[2, 1043321] },
  [1m(target, rev_interact, drug)[0m={ edge_index=[2, 2371] }
)

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: '{device}'")

Device: 'cuda'


In [10]:
train_data = train_data.to(device)
val_data = val_data.to(device)

In [13]:
# from torch_geometric.loader import LinkNeighborLoader

# # Define seed edges:
# edge_label_index = train_data["drug", "interact", "target"].edge_label_index
# edge_label = train_data["drug", "interact", "target"].edge_label

# train_loader = LinkNeighborLoader(
#     data=train_data,
#     num_neighbors=[20, 10],
#     neg_sampling_ratio=0.0,
#     edge_label_index=(("drug", "interact", "target"), edge_label_index),
#     edge_label=edge_label,
#     batch_size=8,
#     shuffle=True,
# )

In [14]:
# # Inspect a sample:
# sampled_data = next(iter(train_loader))

In [15]:
class GNN_3(torch.nn.Module):
    def __init__(self, final_dimension):
        super().__init__()
        self.conv1 = SAGEConv(final_dimension, final_dimension)
        self.conv2 = SAGEConv(final_dimension, final_dimension)
    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x
    
class Classifier(torch.nn.Module):
    def forward(self, x_drug: Tensor, x_target: Tensor, edge_label_index: Tensor) -> Tensor:
        # Convert node embeddings to edge-level representations:
        edge_feat_drug = x_drug[edge_label_index[0]]
        edge_feat_target = x_target[edge_label_index[1]]

        # Apply dot-product to get a prediction per supervision edge:
        return (edge_feat_drug * edge_feat_target).sum(dim=-1)
    
class Model_3_dot(torch.nn.Module):
    def __init__(self, final_dimension):
        super().__init__()
        # embeddings ???
        # self.user_lin = torch.nn.Linear(final_dimension, final_dimension)
        # self.user_emb = torch.nn.Embedding(data["user"].num_nodes, final_dimension)
        # self.movie_emb = torch.nn.Embedding(data["movie"].num_nodes, final_dimension)
        
        self.drug_lin = torch.nn.Linear(1444, final_dimension)
        self.target_lin = torch.nn.Linear(4159, final_dimension)
        # Instantiate homogeneous GNN
        self.gnn = GNN_3(final_dimension)
        # Convert GNN model into a heterogeneous variant
        self.gnn = to_hetero(self.gnn, metadata=train_data.metadata())
        self.classifier = Classifier()
        
        self.dropout = torch.nn.Dropout(p=0.5)
        
    def forward(self, data: HeteroData) -> Tensor:
#         x_dict = {
#             "user": self.shared_lin(self.user_emb(data["user"].node_id)),
#             "movie": self.shared_lin(self.movie_lin(data["movie"].x) + self.movie_emb(data["movie"].node_id)),
        
#         }

        x_dict = {
            "drug": self.drug_lin(data["drug"].x),
            "target": self.target_lin(data["target"].x),
        
        }
        x_dict = self.gnn(x_dict, train_data.edge_index_dict)
        return x_dict

    def predict_link(self, drug_emb, target_emb, edge_label_index):
        return self.classifier(drug_emb, target_emb, edge_label_index)

In [16]:
model3_dot = Model_3_dot(final_dimension=5000)
model3_dot = model3_dot.to(device)

In [17]:
# Optimizer and loss function
optimizer_dot = Adam(model3_dot.parameters(), lr=0.001)
loss_fn_dot = torch.nn.BCEWithLogitsLoss() # ????BCEwithLogit

In [18]:
# train_data

In [19]:
def train_dot():
    model3_dot.train()
    for epoch in range(1, 200):
        optimizer_dot.zero_grad()
        x_dict = model3_dot(train_data)
        edge_label_index = train_data["drug", "interact", "target"].edge_label_index
        ground_truth = train_data["drug", "interact", "target"].edge_label
        ground_truth = ground_truth.float()
        pred = model3_dot.predict_link( x_dict['drug'], x_dict['target'], edge_label_index).squeeze()
        # print(type(pred[0].item()))
        # print(type(ground_truth[0].item()))
        loss = loss_fn_dot(pred, ground_truth)
        loss.backward()
        optimizer_dot.step()
        torch.cuda.empty_cache()
        print(f"Epoch: {epoch:03d}, Loss: {loss}")

In [20]:
train_dot()

Epoch: 001, Loss: 492.58856201171875
Epoch: 002, Loss: 30413854.0
Epoch: 003, Loss: 119852952.0
Epoch: 004, Loss: 2533610.75
Epoch: 005, Loss: 3540533.5
Epoch: 006, Loss: 3676059.25
Epoch: 007, Loss: 127357.3359375
Epoch: 008, Loss: 1592586.75
Epoch: 009, Loss: 2450857.75
Epoch: 010, Loss: 883291.5
Epoch: 011, Loss: 648540.5625
Epoch: 012, Loss: 1782074.25
Epoch: 013, Loss: 134228.28125
Epoch: 014, Loss: 139038.1875
Epoch: 015, Loss: 6377420.5
Epoch: 016, Loss: 1505187.125
Epoch: 017, Loss: 281080.3125
Epoch: 018, Loss: 118502.0703125
Epoch: 019, Loss: 218558.234375
Epoch: 020, Loss: 132943.390625
Epoch: 021, Loss: 146154.421875
Epoch: 022, Loss: 103381.6875
Epoch: 023, Loss: 44247.6484375
Epoch: 024, Loss: 50762.03515625
Epoch: 025, Loss: 25281.853515625
Epoch: 026, Loss: 53893.5234375
Epoch: 027, Loss: 23381.67578125
Epoch: 028, Loss: 12042.5478515625
Epoch: 029, Loss: 13747.646484375
Epoch: 030, Loss: 5961.2783203125
Epoch: 031, Loss: 2180.607177734375
Epoch: 032, Loss: 3198.6096191

In [21]:
x_dict = model3_dot(train_data.to(device))
edge_label_index = train_data["drug", "interact", "target"].edge_label_index
pred = model3_dot.predict_link( x_dict['drug'], x_dict['target'], edge_label_index).squeeze()

In [22]:
pred_prob = torch.sigmoid(pred)

In [23]:
# pred_prob.max()

In [24]:
pred_prob_binary = pred_prob >= 0.5
pred_prob_binary = pred_prob_binary.float()
pred_prob_binary_np = pred_prob_binary.cpu().numpy()

In [25]:
ground_truth = train_data["drug", "interact", "target"].edge_label
ground_truth = ground_truth.float()
ground_truth_np = ground_truth.cpu().numpy()

In [26]:
report = classification_report(ground_truth_np, pred_prob_binary_np)
print(report)

              precision    recall  f1-score   support

         0.0       0.55      0.86      0.67       712
         1.0       0.68      0.29      0.41       712

    accuracy                           0.58      1424
   macro avg       0.61      0.58      0.54      1424
weighted avg       0.61      0.58      0.54      1424



In [27]:
val_data = val_data.to(device)

In [28]:
x_dict_val = model3_dot(val_data)
edge_label_index_val = val_data["drug", "interact", "target"].edge_label_index
pred_val = model3_dot.predict_link(x_dict_val['drug'], x_dict_val['target'], edge_label_index_val).squeeze()

In [29]:
pred_val_prob = torch.sigmoid(pred_val)
pred_prob_val_binary = pred_val_prob >= 0.5
pred_prob_val_binary = pred_prob_val_binary.float()
pred_prob_val_binary_np = pred_prob_val_binary.cpu().numpy()
ground_truth_val = val_data["drug", "interact", "target"].edge_label
ground_truth_val_np = ground_truth_val.cpu().numpy()

In [30]:
report = classification_report(ground_truth_val_np, pred_prob_val_binary_np)
print(report)

              precision    recall  f1-score   support

           0       0.49      0.66      0.56       474
           1       0.48      0.31      0.37       474

    accuracy                           0.49       948
   macro avg       0.48      0.49      0.47       948
weighted avg       0.48      0.49      0.47       948



In [31]:
pred_prob_val_binary

tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
        1., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 1., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 1.,
        0., 1., 1., 0., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 1., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 1., 0., 1., 1., 0., 0.,
        1., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 1., 0., 0., 0., 0., 1., 1.,
        1., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0.,
        1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        1., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 1., 1., 0., 0.,
        1., 0., 0., 1., 1., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 0., 1., 0., 1., 1., 

In [32]:
# I will create another model with shared lin

In [12]:
class GNN_4(torch.nn.Module):
    def __init__(self, final_dimension):
        super().__init__()
        self.conv1 = SAGEConv(final_dimension, final_dimension)
        self.conv2 = SAGEConv(final_dimension, final_dimension)
    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x
    
class Classifier(torch.nn.Module):
    def forward(self, x_drug: Tensor, x_target: Tensor, edge_label_index: Tensor) -> Tensor:
        # Convert node embeddings to edge-level representations:
        edge_feat_drug = x_drug[edge_label_index[0]]
        edge_feat_target = x_target[edge_label_index[1]]

        # Apply dot-product to get a prediction per supervision edge:
        return (edge_feat_drug * edge_feat_target).sum(dim=-1)
    
class Model_4_dot(torch.nn.Module):
    def __init__(self, final_dimension):
        super().__init__()
        # embeddings ???
        # self.user_lin = torch.nn.Linear(final_dimension, final_dimension)
        # self.user_emb = torch.nn.Embedding(data["user"].num_nodes, final_dimension)
        # self.movie_emb = torch.nn.Embedding(data["movie"].num_nodes, final_dimension)
        
        self.drug_lin = torch.nn.Linear(1444, final_dimension)
        self.target_lin = torch.nn.Linear(4159, final_dimension)
        self.shared_lin = torch.nn.Linear(final_dimension, final_dimension)
        # Instantiate homogeneous GNN
        self.gnn = GNN_4(final_dimension)
        # Convert GNN model into a heterogeneous variant
        self.gnn = to_hetero(self.gnn, metadata=train_data.metadata())
        self.classifier = Classifier()
        
        self.dropout = torch.nn.Dropout(p=0.5)
        
    def forward(self, data: HeteroData) -> Tensor:
        x_dict = {
            "drug": self.shared_lin(self.drug_lin(data["drug"].x)),
            "target": self.shared_lin(self.target_lin(data["target"].x)),
        
        }

        x_dict = self.gnn(x_dict, train_data.edge_index_dict)
        return x_dict

    def predict_link(self, drug_emb, target_emb, edge_label_index):
        return self.classifier(drug_emb, target_emb, edge_label_index)

In [13]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: '{device}'")

Device: 'cuda'


In [14]:
train_data = train_data.to(device)

In [15]:
val_data = val_data.to(device)

In [16]:
model4_dot = Model_4_dot(final_dimension=5000)
model4_dot = model4_dot.to(device)

In [17]:
# Optimizer and loss function
optimizer_dot = Adam(model4_dot.parameters(), lr=0.001)
loss_fn_dot = torch.nn.BCEWithLogitsLoss() # ????BCEwithLogit

In [18]:
def train_dot():
    model4_dot.train()
    for epoch in range(1, 300):
        optimizer_dot.zero_grad()
        x_dict = model4_dot(train_data)
        edge_label_index = train_data["drug", "interact", "target"].edge_label_index
        ground_truth = train_data["drug", "interact", "target"].edge_label
        ground_truth = ground_truth.float()
        pred = model4_dot.predict_link( x_dict['drug'], x_dict['target'], edge_label_index).squeeze()
        # print(type(pred[0].item()))
        # print(type(ground_truth[0].item()))
        loss = loss_fn_dot(pred, ground_truth)
        loss.backward()
        optimizer_dot.step()
        torch.cuda.empty_cache()
        print(f"Epoch: {epoch:03d}, Loss: {loss}")

In [19]:
train_dot()

Epoch: 001, Loss: 76.97655487060547
Epoch: 002, Loss: 14295515.0
Epoch: 003, Loss: 4810861056.0
Epoch: 004, Loss: 905814208.0
Epoch: 005, Loss: 51440992.0
Epoch: 006, Loss: 1693360768.0
Epoch: 007, Loss: 408601856.0
Epoch: 008, Loss: 299470592.0
Epoch: 009, Loss: 139751584.0
Epoch: 010, Loss: 650130496.0
Epoch: 011, Loss: 1428370048.0
Epoch: 012, Loss: 56385232896.0
Epoch: 013, Loss: 4971673088.0
Epoch: 014, Loss: 30785531904.0
Epoch: 015, Loss: 22881927168.0
Epoch: 016, Loss: 678736384.0
Epoch: 017, Loss: 758128896.0
Epoch: 018, Loss: 1078195840.0
Epoch: 019, Loss: 668963072.0
Epoch: 020, Loss: 406938432.0
Epoch: 021, Loss: 621747392.0
Epoch: 022, Loss: 309612512.0
Epoch: 023, Loss: 124594600.0
Epoch: 024, Loss: 101180472.0
Epoch: 025, Loss: 382756768.0
Epoch: 026, Loss: 372196928.0
Epoch: 027, Loss: 126243808.0
Epoch: 028, Loss: 140264992.0
Epoch: 029, Loss: 114259512.0
Epoch: 030, Loss: 21320776.0
Epoch: 031, Loss: 19001060.0
Epoch: 032, Loss: 10679451.0
Epoch: 033, Loss: 12800734.0

Epoch: 245, Loss: 927.036376953125
Epoch: 246, Loss: 863.291259765625
Epoch: 247, Loss: 830.4476928710938
Epoch: 248, Loss: 953.1025390625
Epoch: 249, Loss: 820.6620483398438
Epoch: 250, Loss: 804.8634033203125
Epoch: 251, Loss: 898.9557495117188
Epoch: 252, Loss: 812.3843383789062
Epoch: 253, Loss: 780.9647216796875
Epoch: 254, Loss: 902.6033935546875
Epoch: 255, Loss: 804.8889770507812
Epoch: 256, Loss: 781.6464233398438
Epoch: 257, Loss: 891.2623901367188
Epoch: 258, Loss: 767.72998046875
Epoch: 259, Loss: 746.0284423828125
Epoch: 260, Loss: 866.3286743164062
Epoch: 261, Loss: 775.3165283203125
Epoch: 262, Loss: 743.881591796875
Epoch: 263, Loss: 868.7384033203125
Epoch: 264, Loss: 744.3464965820312
Epoch: 265, Loss: 718.401611328125
Epoch: 266, Loss: 845.1798095703125
Epoch: 267, Loss: 738.0054321289062
Epoch: 268, Loss: 710.663818359375
Epoch: 269, Loss: 844.6620483398438
Epoch: 270, Loss: 709.8916625976562
Epoch: 271, Loss: 684.0540161132812
Epoch: 272, Loss: 826.8541259765625
Ep

In [20]:
x_dict = model4_dot(train_data.to(device))
edge_label_index = train_data["drug", "interact", "target"].edge_label_index
pred = model4_dot.predict_link( x_dict['drug'], x_dict['target'], edge_label_index).squeeze()

pred_prob = torch.sigmoid(pred)

pred_prob_binary = pred_prob >= 0.5
pred_prob_binary = pred_prob_binary.float()
pred_prob_binary_np = pred_prob_binary.cpu().numpy()

ground_truth = train_data["drug", "interact", "target"].edge_label
ground_truth = ground_truth.float()
ground_truth_np = ground_truth.cpu().numpy()

report = classification_report(ground_truth_np, pred_prob_binary_np)
print(report)

              precision    recall  f1-score   support

         0.0       0.55      0.81      0.65       712
         1.0       0.64      0.34      0.44       712

    accuracy                           0.57      1424
   macro avg       0.59      0.57      0.55      1424
weighted avg       0.59      0.57      0.55      1424



In [21]:
x_dict_val = model4_dot(val_data)
edge_label_index_val = val_data["drug", "interact", "target"].edge_label_index
pred_val = model4_dot.predict_link(x_dict_val['drug'], x_dict_val['target'], edge_label_index_val).squeeze()

pred_val_prob = torch.sigmoid(pred_val)
pred_prob_val_binary = pred_val_prob >= 0.5
pred_prob_val_binary = pred_prob_val_binary.float()
pred_prob_val_binary_np = pred_prob_val_binary.cpu().numpy()
ground_truth_val = val_data["drug", "interact", "target"].edge_label
ground_truth_val_np = ground_truth_val.cpu().numpy()

report = classification_report(ground_truth_val_np, pred_prob_val_binary_np)
print(report)

              precision    recall  f1-score   support

           0       0.49      0.63      0.55       474
           1       0.48      0.34      0.40       474

    accuracy                           0.48       948
   macro avg       0.48      0.48      0.47       948
weighted avg       0.48      0.48      0.47       948



In [22]:
# I will create another model which has torch.nn.Embedding

In [23]:
train_data

HeteroData(
  [1mdrug[0m={
    node_id=[1482],
    x=[1482, 1444]
  },
  [1mtarget[0m={
    node_id=[1408],
    x=[1408, 4159]
  },
  [1m(drug, interact, target)[0m={
    edge_index=[2, 1659],
    edge_label=[1424],
    edge_label_index=[2, 1424]
  },
  [1m(drug, similar, drug)[0m={ edge_index=[2, 1246640] },
  [1m(target, similar, target)[0m={ edge_index=[2, 1043321] },
  [1m(target, rev_interact, drug)[0m={ edge_index=[2, 1659] }
)

In [11]:
class GNN_5(torch.nn.Module):
    def __init__(self, final_dimension):
        super().__init__()
        self.conv1 = SAGEConv(final_dimension, final_dimension)
        self.conv2 = SAGEConv(final_dimension, final_dimension)
    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x
    
class Classifier(torch.nn.Module):
    def forward(self, x_drug: Tensor, x_target: Tensor, edge_label_index: Tensor) -> Tensor:
        # Convert node embeddings to edge-level representations:
        edge_feat_drug = x_drug[edge_label_index[0]]
        edge_feat_target = x_target[edge_label_index[1]]

        # Apply dot-product to get a prediction per supervision edge:
        return (edge_feat_drug * edge_feat_target).sum(dim=-1)
    
class Model_5_dot(torch.nn.Module):
    def __init__(self, final_dimension):
        super().__init__()
        # embeddings ???
        # self.user_lin = torch.nn.Linear(final_dimension, final_dimension)
        self.drug_emb = torch.nn.Embedding(train_data["drug"].num_nodes, final_dimension)
        self.target_emb = torch.nn.Embedding(train_data["target"].num_nodes, final_dimension)
        
        self.drug_lin = torch.nn.Linear(1444, final_dimension)
        self.target_lin = torch.nn.Linear(4159, final_dimension)
        self.shared_lin = torch.nn.Linear(final_dimension, final_dimension)
        # Instantiate homogeneous GNN
        self.gnn = GNN_5(final_dimension)
        # Convert GNN model into a heterogeneous variant
        self.gnn = to_hetero(self.gnn, metadata=train_data.metadata())
        self.classifier = Classifier()
        
        self.dropout = torch.nn.Dropout(p=0.5)
        
    def forward(self, data: HeteroData) -> Tensor:
        x_dict = {
            "drug": self.shared_lin(self.drug_lin(data["drug"].x) + (self.drug_emb(data['drug'].node_id))),
            "target": self.shared_lin(self.target_lin(data["target"].x) + (self.target_emb(data['target'].node_id))),
        
        }

        x_dict = self.gnn(x_dict, data.edge_index_dict)
        return x_dict

    def predict_link(self, drug_emb, target_emb, edge_label_index):
        return self.classifier(drug_emb, target_emb, edge_label_index)

In [22]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: '{device}'")

Device: 'cuda'


In [23]:
train_data = train_data.to(device)
val_data = val_data.to(device)

In [24]:
model5_dot = Model_5_dot(final_dimension=5000)
model5_dot = model5_dot.to(device)

In [25]:
# Optimizer and loss function
optimizer_dot = Adam(model5_dot.parameters(), lr=0.001)
loss_fn_dot = torch.nn.BCEWithLogitsLoss() # ????BCEwithLogit

In [26]:
def train_dot():
    model5_dot.train()
    for epoch in range(1, 300):
        optimizer_dot.zero_grad()
        x_dict = model5_dot(train_data)
        edge_label_index = train_data["drug", "interact", "target"].edge_label_index
        ground_truth = train_data["drug", "interact", "target"].edge_label
        ground_truth = ground_truth.float()
        pred = model5_dot.predict_link( x_dict['drug'], x_dict['target'], edge_label_index).squeeze()
        # print(type(pred[0].item()))
        # print(type(ground_truth[0].item()))
        loss = loss_fn_dot(pred, ground_truth)
        loss.backward()
        optimizer_dot.step()
        torch.cuda.empty_cache()
        print(f"Epoch: {epoch:03d}, Loss: {loss}")

In [27]:
train_dot()

Epoch: 001, Loss: 60.737754821777344
Epoch: 002, Loss: 15420651.0
Epoch: 003, Loss: 7756938240.0
Epoch: 004, Loss: 482417504.0
Epoch: 005, Loss: 71965840.0
Epoch: 006, Loss: 545572544.0
Epoch: 007, Loss: 194074368.0
Epoch: 008, Loss: 2349987584.0
Epoch: 009, Loss: 274188832.0
Epoch: 010, Loss: 2452683008.0
Epoch: 011, Loss: 209993859072.0
Epoch: 012, Loss: 1093994112.0
Epoch: 013, Loss: 15195159552.0
Epoch: 014, Loss: 2947631872.0
Epoch: 015, Loss: 1063228288.0
Epoch: 016, Loss: 26001235968.0
Epoch: 017, Loss: 31820266.0
Epoch: 018, Loss: 2393083136.0
Epoch: 019, Loss: 9467177984.0
Epoch: 020, Loss: 5389392896.0
Epoch: 021, Loss: 669510016.0
Epoch: 022, Loss: 625241.875
Epoch: 023, Loss: 1584435.25
Epoch: 024, Loss: 13199970.0
Epoch: 025, Loss: 101206112.0
Epoch: 026, Loss: 480967040.0
Epoch: 027, Loss: 395929120.0
Epoch: 028, Loss: 585915840.0
Epoch: 029, Loss: 79868528.0
Epoch: 030, Loss: 43474020.0
Epoch: 031, Loss: 318076000.0
Epoch: 032, Loss: 9544590.0
Epoch: 033, Loss: 28067064.

Epoch: 266, Loss: 60932.203125
Epoch: 267, Loss: 60106.6953125
Epoch: 268, Loss: 47818.95703125
Epoch: 269, Loss: 89859.6328125
Epoch: 270, Loss: 36502.8203125
Epoch: 271, Loss: 54431.515625
Epoch: 272, Loss: 83736.2890625
Epoch: 273, Loss: 30925.689453125
Epoch: 274, Loss: 55234.62890625
Epoch: 275, Loss: 49019.44921875
Epoch: 276, Loss: 33332.8828125
Epoch: 277, Loss: 33784.5859375
Epoch: 278, Loss: 53843.3046875
Epoch: 279, Loss: 58206.05859375
Epoch: 280, Loss: 36239.71484375
Epoch: 281, Loss: 40771.51171875
Epoch: 282, Loss: 42320.17578125
Epoch: 283, Loss: 29839.607421875
Epoch: 284, Loss: 40549.640625
Epoch: 285, Loss: 28941.146484375
Epoch: 286, Loss: 33118.421875
Epoch: 287, Loss: 35462.828125
Epoch: 288, Loss: 28943.80859375
Epoch: 289, Loss: 33578.08984375
Epoch: 290, Loss: 44349.48828125
Epoch: 291, Loss: 52001.4375
Epoch: 292, Loss: 34344.34375
Epoch: 293, Loss: 62717.05859375
Epoch: 294, Loss: 50779.8203125
Epoch: 295, Loss: 42287.56640625
Epoch: 296, Loss: 94725.9765625


In [28]:
x_dict = model5_dot(train_data.to(device))
edge_label_index = train_data["drug", "interact", "target"].edge_label_index
pred = model5_dot.predict_link( x_dict['drug'], x_dict['target'], edge_label_index).squeeze()

pred_prob = torch.sigmoid(pred)

pred_prob_binary = pred_prob >= 0.5
pred_prob_binary = pred_prob_binary.float()
pred_prob_binary_np = pred_prob_binary.cpu().numpy()

ground_truth = train_data["drug", "interact", "target"].edge_label
ground_truth = ground_truth.float()
ground_truth_np = ground_truth.cpu().numpy()

report = classification_report(ground_truth_np, pred_prob_binary_np)
print(report)

              precision    recall  f1-score   support

         0.0       0.55      0.09      0.15       712
         1.0       0.50      0.93      0.65       712

    accuracy                           0.51      1424
   macro avg       0.53      0.51      0.40      1424
weighted avg       0.53      0.51      0.40      1424



In [29]:
x_dict_val = model5_dot(val_data)
edge_label_index_val = val_data["drug", "interact", "target"].edge_label_index
pred_val = model5_dot.predict_link(x_dict_val['drug'], x_dict_val['target'], edge_label_index_val).squeeze()

pred_val_prob = torch.sigmoid(pred_val)
pred_prob_val_binary = pred_val_prob >= 0.5
pred_prob_val_binary = pred_prob_val_binary.float()
pred_prob_val_binary_np = pred_prob_val_binary.cpu().numpy()
ground_truth_val = val_data["drug", "interact", "target"].edge_label
ground_truth_val_np = ground_truth_val.cpu().numpy()

report = classification_report(ground_truth_val_np, pred_prob_val_binary_np)
print(report)

              precision    recall  f1-score   support

           0       0.50      0.49      0.49       474
           1       0.50      0.51      0.50       474

    accuracy                           0.50       948
   macro avg       0.50      0.50      0.50       948
weighted avg       0.50      0.50      0.50       948

