In [None]:
!pip install torch==2.1.0
!pip install torch_geometric dgl dill

In [None]:
!pip install rdkit dnc

In [1]:

import dill
import os
import pandas as pd
import networkx as nx
import math
import glob
import numpy as np

Collecting torch_geometric
  Downloading torch_geometric-2.5.3-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting dgl
  Downloading dgl-2.1.0-cp310-cp310-manylinux1_x86_64.whl (8.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.5/8.5 MB[0m [31m36.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting dill
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
Collecting torchdata>=0.5.0 (from dgl)
  Downloading torchdata-0.7.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.7/4.7 MB[0m [31m50.7 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=2->torchdata>=0.5.0->dgl)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-n

ModuleNotFoundError: No module named 'dill'

In [None]:
################                     Heterogeneous Graph Construction from EHR Data                      ###########################


data = pd.read_pickle('/content/drive/MyDrive/Carmen-main/data/data_final.pkl')


node_types = ['patient', 'diagnosis', 'drug']

edge_types = [
    ('patient', 'diagnosis', {'name': 'has'}),
    ('diagnosis', 'drug', {'name': 'treated_by'},
     'patient','drug' , {'name':'prescriped'})
]

unique_subject_ids = data['SUBJECT_ID'].unique()
unique_icd9_codes = data['ICD9_CODE'].explode().unique()
unique_ndc = data['NDC'].explode().unique()

#selected_subject_ids = unique_subject_ids[:500]
# Filter the data to only include the selected subject IDs
#filtered_data = data[data['SUBJECT_ID'].isin(selected_subject_ids)]

print(f"Number of unique SUBJECT_ID: {len(unique_subject_ids)}")
#print(f"Number of unique SUBJECT_ID: {len(selected_subject_ids)}")

print(f"Number of unique ICD9_CODE: {len(unique_icd9_codes)}")
print(f"Number of unique NDC: {len(unique_ndc)}")

# Load the DDI matrix from the file  'ddi_A_final.pkl'
ddi_A_final ='/content/drive/MyDrive/Carmen-main/data/ddi_A_final.pkl'
ddi_matrix = dill.load(open(ddi_A_final, 'rb'))

# Print the DDI matrix
print("DDI Matrix:")
print(ddi_matrix)

print(f"Number of drugs in the DDI matrix: {len(ddi_matrix)}")

drugs_with_interactions = set()
for i in range(len(ddi_matrix)):
    for j in range(len(ddi_matrix[i])):
        if ddi_matrix[i][j] == 1:
            drugs_with_interactions.add(i)
            drugs_with_interactions.add(j)

# Print the number of drugs with interactions
print(f"Number of drugs with interactions: {len(drugs_with_interactions)}")

def Load_Into_Graph(data):
    # Create an empty graph
    G = nx.Graph()
    patient_diagnosis_edges = 0
    diagnosis_drug_edges = 0
    drug_drug_edges = 0
    patient_drug_edges = 0

    # Add nodes for each unique drug NDC code
    for ndc_code in unique_ndc:
        G.add_node(ndc_code, node_type='drug')

    # Add edges for DDIs with negative weight
    for i in range(len(ddi_matrix)):
        for j in range(i+1, len(ddi_matrix[i])):  # Start from i+1 to avoid double counting
            if ddi_matrix[i][j] == 1:
                G.add_edge(unique_ndc[i], unique_ndc[j], weight=-1, edge_type='ddi')
                drug_drug_edges += 1

    # Add patient nodes
    for subject_id in unique_subject_ids:
        G.add_node(subject_id, node_type='patient')

    # Add diagnosis nodes
    for icd9_code in unique_icd9_codes:
        G.add_node(icd9_code, node_type='diagnosis')

    # Add edges from data
    for row in data.itertuples(index=False):
        patient = row.SUBJECT_ID
        diagnosis = row.ICD9_CODE
        drug = row.NDC

        # Add patient-diagnosis edges
        for icd9_code in diagnosis:
            if not G.has_edge(patient, icd9_code):
                G.add_edge(patient, icd9_code, edge_type='has')
                patient_diagnosis_edges += 1

        # Add diagnosis-drug edges
        for ndc_code in drug:
            if not G.has_edge(icd9_code, ndc_code):
                G.add_edge(icd9_code, ndc_code, edge_type='treated_by')
                diagnosis_drug_edges += 1

        # Add patient-drug edges
        for ndc_code in drug:
            if not G.has_edge(patient, ndc_code):
                G.add_edge(patient, ndc_code, edge_type='prescribed')
                patient_drug_edges += 1

    # Print the number of nodes and edges
    print(f"Number of nodes in the graph: {G.number_of_nodes()}")
    print(f"Total number of edges in the graph: {G.number_of_edges()}")
    print(f"Patient-Diagnosis edges: {patient_diagnosis_edges}")
    print(f"Diagnosis-Drug edges: {diagnosis_drug_edges}")
    print(f"Drug-Drug edges: {drug_drug_edges}")
    print(f"Patient-Drug edges: {patient_drug_edges}")

    return G

# Load the graph
G = Load_Into_Graph(data)

######################################                META-PATH  Construction          ###################################

def Heterogeneous_Graph(data):
    # Populate the heterogeneous graph

    G= Load_Into_Graph(data)
    # Define the meta-paths
    meta_paths = [
        ['patient', 'diagnosis', 'drug'],
        ['patient', 'diagnosis', 'patient'],
        ['diagnosis', 'drug', 'diagnosis'],
        ['drug', 'diagnosis', 'patient', 'diagnosis', 'drug'],
        ['patient', 'diagnosis', 'drug', 'diagnosis', 'patient']
    ]

    return G
# Meta-path: ['patient', 'diagnosis', 'drug']
def patient_diagnosis_drug(G):
    count = 0
    for patient in G.nodes():
        if 'node_type' in G.nodes[patient] and G.nodes[patient]['node_type'] == 'patient':
            for diagnosis in G.neighbors(patient):
                if 'node_type' in G.nodes[diagnosis] and G.nodes[diagnosis]['node_type'] == 'diagnosis':
                    for drug in G.neighbors(diagnosis):
                        if G.nodes[drug]['node_type'] == 'drug':
                            yield [patient, diagnosis, drug]

print("**Meta-path: Patient -> Diagnosis -> Drug**")
count = 0
max_paths_to_print = 5

for path in patient_diagnosis_drug(G):
    print(" -> ".join(str(node) for node in path))
    count += 1
    if count >= max_paths_to_print:
        break


# Meta-path: ['patient', 'diagnosis', 'patient']
def patient_diagnosis_patient(G):
    count = 0
    for patient1 in G.nodes():
        if G.nodes[patient1].get('node_type') == 'patient':
            for diagnosis in G.neighbors(patient1):
                if G.nodes[diagnosis].get('node_type') == 'diagnosis':
                    for patient2 in G.neighbors(diagnosis):
                        if G.nodes[patient2].get('node_type') == 'patient' and patient2 != patient1:
                            yield [patient1, diagnosis, patient2]


print("[**Meta-path: Patient -> Diagnosis -> Patient**]")
count = 0
max_paths_to_print = 5

for path in patient_diagnosis_patient(G):
    print(" -> ".join(str(node) for node in path))
    count += 1
    if count >= max_paths_to_print:
        break

# Meta-path: ['diagnosis', 'drug', 'diagnosis']
def diagnosis_drug_diagnosis(G):
    count = 0
    for diagnosis1 in G.nodes():
        if G.nodes[diagnosis1].get('node_type') == 'diagnosis':
            for drug in G.neighbors(diagnosis1):
                if G.nodes[drug].get('node_type') == 'drug':
                    for diagnosis2 in G.neighbors(drug):
                        if G.nodes[diagnosis2].get('node_type') == 'diagnosis' and diagnosis2 != diagnosis1:
                            yield [diagnosis1, drug, diagnosis2]

print("[**Meta-path: Diagnosis -> Drug -> Diagnosis**]")
count = 0
max_paths_to_print = 5

for path in diagnosis_drug_diagnosis(G):
    print(" -> ".join(str(node) for node in path))
    count += 1
    if count >= max_paths_to_print:
        break


def find_meta_path_drug_diagnosis_patient_diagnosis_drug(G):
    for drug1 in G.nodes():
        if G.nodes[drug1].get('node_type') == 'drug':
            for diagnosis1 in G.neighbors(drug1):
                if G.nodes[diagnosis1].get('node_type') == 'diagnosis':
                    for patient in G.neighbors(diagnosis1):
                        if G.nodes[patient].get('node_type') == 'patient':
                            for diagnosis2 in G.neighbors(patient):
                                if G.nodes[diagnosis2].get('node_type') == 'diagnosis' and diagnosis2 != diagnosis1:
                                    for drug2 in G.neighbors(diagnosis2):
                                        if G.nodes[drug2].get('node_type') == 'drug' and drug2 != drug1:
                                            yield [drug1, diagnosis1, patient, diagnosis2, drug2]

# print paths
print("Meta-path: drug_diagnosis_patient_diagnosis_drug")
count = 0
max_paths_to_print = 5

for path in find_meta_path_drug_diagnosis_patient_diagnosis_drug(G):
    print(" -> ".join(str(node) for node in path))
    count += 1
    if count >= max_paths_to_print:
        break

def find_meta_path_patient_diagnosis_drug_diagnosis_patient(G):
    for patient1 in G.nodes():
        if G.nodes[patient1].get('node_type') == 'patient':
            for diagnosis1 in G.neighbors(patient1):
                if G.nodes[diagnosis1].get('node_type') == 'diagnosis':
                    for drug in G.neighbors(diagnosis1):
                        if G.nodes[drug].get('node_type') == 'drug':
                            for diagnosis2 in G.neighbors(drug):
                                if G.nodes[diagnosis2].get('node_type') == 'diagnosis' and diagnosis2 == diagnosis1:
                                    for patient2 in G.neighbors(diagnosis2):
                                        if G.nodes[patient2].get('node_type') == 'patient' and patient2 != patient1:
                                            yield [patient1, diagnosis1, drug, diagnosis2, patient2]

#  print paths
print("Meta-path: patient_diagnosis_drug_diagnosis_patient")
count = 0
max_paths_to_print = 5

for path in find_meta_path_patient_diagnosis_drug_diagnosis_patient(G):
    print(" -> ".join(str(node) for node in path))
    count += 1
    if count >= max_paths_to_print:
        break

Heterogeneous_Graph(data)


In [None]:
##############               GraphSAGE MODEL                #####################


# Load your data
data = data_with_labels = pd.read_pickle('/content/drive/MyDrive/Carmen-main/data/data_with_labels.pkl')
print(data.columns)

# One-hot encode patient features
patient_features = pd.get_dummies(data['SUBJECT_ID'].astype(str))
unique_patient_data = data.drop_duplicates(subset='SUBJECT_ID')
unique_patient_data.set_index('SUBJECT_ID', inplace=True)
patient_features = unique_patient_data.join(patient_features)
patient_features = patient_features.apply(pd.to_numeric, errors='coerce')

# One-hot encode drug features
exploded_ndc = data[['SUBJECT_ID', 'NDC']].explode('NDC')
drug_features = pd.get_dummies(exploded_ndc['NDC'].astype(str), prefix='NDC')
drug_features = drug_features.groupby(exploded_ndc['SUBJECT_ID']).sum()

# One-hot encode diagnosis features
exploded_icd9 = data[['SUBJECT_ID', 'ICD9_CODE']].explode('ICD9_CODE')
diagnosis_features = pd.get_dummies(exploded_icd9['ICD9_CODE'].astype(str), prefix='ICD9')
diagnosis_features = diagnosis_features.groupby(exploded_icd9['SUBJECT_ID']).sum()

# Normalize features
scaler = StandardScaler()
patient_features = scaler.fit_transform(patient_features.fillna(0))
drug_features = scaler.fit_transform(drug_features.fillna(0))
diagnosis_features = scaler.fit_transform(diagnosis_features.fillna(0))

# Create your HeteroData object
hetero_data = HeteroData()
hetero_data['patient'].x = torch.tensor(patient_features, dtype=torch.float)
hetero_data['diagnosis'].x = torch.tensor(diagnosis_features, dtype=torch.float)
hetero_data['drug'].x = torch.tensor(drug_features, dtype=torch.float)

# Assuming 'G' is  NetworkX graph
# Create mappings from node identifiers to integers
node_mapping = {node: i for i, node in enumerate(G.nodes())}

# Define edge types
edge_types = [
    ('patient', 'diagnosis', 'has'),
    ('diagnosis', 'drug', 'treated_by'),
    ('drug', 'drug', 'interaction')
]

# Assuming 'num_classes' is the number of unique drug labels
num_classes = len(set.union(*data['NDC_Labels'].apply(set)))

# Convert the lists to tensors and pad them to the same length
label_tensors = [torch.tensor(labels, dtype=torch.long) for labels in data['NDC_Labels']]
target_embeddings = pad_sequence(label_tensors, batch_first=True, padding_value=-1)

# Create a binary matrix for multi-label classification
target_embeddings_binary = torch.zeros(target_embeddings.size(0), num_classes)
for i, label_tensor in enumerate(label_tensors):

    valid_indices = label_tensor[label_tensor != -1]
    target_embeddings_binary[i, valid_indices] = 1

# Create a dictionary mapping 'node_type' to the corresponding binary tensor
target_embeddings_binary_dict = {
    'drug': target_embeddings_binary
}


# Convert edges to integer identifiers and add to HeteroData
for edge_type in edge_types:
    edges = [(node_mapping[u], node_mapping[v]) for u, v, edge_data in G.edges(data=True)
             if edge_data.get('edge_type') == edge_type[2]]
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    hetero_data[edge_type[0], edge_type[1]].edge_index = edge_index

# Define the GraphSAGE model
class GraphSAGENet(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super(GraphSAGENet, self).__init__()
        self.conv1 = SAGEConv(patient_features.shape[1], hidden_channels, normalize=True)
        self.conv2 = SAGEConv(hidden_channels, out_channels, normalize=True)

    def forward(self, x_dict, edge_index_dict):
        # Iterate over all node types in the heterogeneous graph data
        for node_type in x_dict.keys():
            if (node_type, node_type) in edge_index_dict:
                x = x_dict[node_type]
                edge_index = edge_index_dict[(node_type, node_type)]

                # Apply the first convolution layer
                x = self.conv1(x, edge_index)
                x = torch.relu(x)

                # Apply the second convolution layer
                x = self.conv2(x, edge_index)

                # Store the updated node features in the dictionary
                x_dict[node_type] = x

        return x_dict



# Instantiate the model

model = GraphSAGENet(hidden_channels=64, out_channels=num_classes)
# Prepare the data for the model
x_dict = {ntype: hetero_data[ntype].x for ntype in hetero_data.node_types}
edge_index_dict = {(stype, etype, dtype): hetero_data[stype, etype, dtype].edge_index
                   for stype, etype, dtype in hetero_data.edge_types}


#  for multi-label classification
criterion = torch.nn.BCEWithLogitsLoss()  # Adjusted for binary classification
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)


# Training loop
for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    predictions = model(x_dict, edge_index_dict)


# Calculate total loss for the entire graph
total_loss = 0
for node_type in hetero_data.node_types:
    if node_type in predictions:
        if node_type in target_embeddings_binary_dict:
            target = target_embeddings_binary_dict[node_type][:predictions[node_type].size(0), :]
            target.requires_grad = True  # Set requires_grad to True
            loss = criterion(predictions[node_type], target)
            total_loss += loss

# Backpropagate and optimize based on the total loss
total_loss.backward()
optimizer.step()

# After training loop
embeddings = {node_type: model(x_dict, edge_index_dict)[node_type].detach().numpy() for node_type in hetero_data.node_types}

# Print the embeddings for each node type
for node_type, emb in embeddings.items():
    print(f"Embeddings for {node_type} nodes:")
    print(emb)



# Assuming 'embeddings' is a dictionary containing  node embeddings
for node_type, emb in embeddings.items():
    # Save the embeddings to a file
    np.save(f'{node_type}_embeddings.npy', emb)


# To load the embeddings for a specific node type
patient_embeddings = np.load('patient_embeddings.npy')




In [None]:
############################################################################  reshape embeddings and save

from sklearn.decomposition import PCA
import numpy as np

# Load your embeddings
patient_embeddings = np.load('patient_embeddings.npy')
diagnosis_embeddings = np.load('diagnosis_embeddings.npy')
drug_embeddings = np.load('drug_embeddings.npy')

 # Apply PCA to reduce feature dimension to 64
pca_drug = PCA(n_components=64)
drug_embeddings_reduced = pca_drug.fit_transform(drug_embeddings)

# Apply PCA to reduce feature dimension to 64
pca_patient = PCA(n_components=64)
patient_embeddings_reduced = pca_patient.fit_transform(patient_embeddings)

drug_embeddings_final = drug_embeddings_reduced[:32, :]
patient_embeddings_final = patient_embeddings_reduced[:32, :]

# Print the shapes of the embeddings
print("Shape of patient_embeddings_final:", patient_embeddings_final.shape)
print("Shape of reshaped drug_embeddings_final:", drug_embeddings_final.shape)


# Save the reshaped embeddings
np.save('patient_embeddings_final.npy', patient_embeddings_final)
np.save('diagnosis_embeddings.npy', diagnosis_embeddings)
np.save('drug_embeddings_final.npy', drug_embeddings_final)

patient_embeddings = np.load('patient_embeddings_final.npy')
drug_embeddings_final = np.load('drug_embeddings_final.npy')
diagnosis_embeddings = np.load('diagnosis_embeddings.npy')

In [None]:
############## (preprocess_meta_paths ######################################


# Load the embeddings
diagnosis_embeddings = np.load('diagnosis_embeddings.npy')
drug_embeddings = np.load('drug_embeddings.npy')

# Create mappings from diagnosis and drug identifiers to indices
diagnosis_to_idx = {diagnosis_id: idx for idx, diagnosis_id in enumerate(unique_icd9_codes)}
drug_to_idx = {drug_id: idx for idx, drug_id in enumerate(unique_ndc)}

def get_meta_path_embedding(diagnosis_embeddings, drug_embeddings, graph, diagnosis_to_idx, drug_to_idx):
    meta_path_embeddings = []

    # Iterate over all possible meta-path instances
    for diagnosis1 in G.nodes(data='node_type'):
        if diagnosis1[1] == 'diagnosis':
            for drug in G.neighbors(diagnosis1[0]):
                for diagnosis2 in G.neighbors(drug):
                    if diagnosis2 != diagnosis1[0]:
                        # Get the integer indices for the embeddings
                        diagnosis1_idx = diagnosis_to_idx[diagnosis1[0]]
                        drug_idx = drug_to_idx[drug]
                        diagnosis2_idx = diagnosis_to_idx[diagnosis2]

                        # Concatenate embeddings using the indices
                        path_embedding = np.concatenate(
                            (diagnosis_embeddings[diagnosis1_idx],
                             drug_embeddings[drug_idx],
                             diagnosis_embeddings[diagnosis2_idx])
                        )
                        meta_path_embeddings.append(path_embedding)

    # Aggregate embeddings by averaging
    aggregated_embedding = np.mean(meta_path_embeddings, axis=0)
    return aggregated_embedding

# Example usage:
meta_path_embedding = get_meta_path_embedding(diagnosis_embeddings, drug_embeddings, G, diagnosis_to_idx, drug_to_idx)


In [None]:
#####   Our model with cross attention       ################################
import torch
from torch_geometric.nn import SAGEConv
from torch.utils.data import DataLoader, TensorDataset
from torch.nn import MultiheadAttention

class SAGENet(nn.Module):
    def __init__(self, vocab_size, ehr_adj, ddi_adj, emb_dim=64, device=torch.device('cpu:0'), ddi_in_memory=True):
        super(SAGENet, self).__init__()
        K = len(vocab_size)
        self.K = K
        self.vocab_size = vocab_size
        self.device = device
        self.tensor_ddi_adj = torch.FloatTensor(ddi_adj).to(device)
        self.ddi_in_memory = ddi_in_memory
        self.embeddings = nn.ModuleList(
            [nn.Embedding(vocab_size[i], emb_dim) for i in range(K-1)])
        self.dropout = nn.Dropout(p=0.5)



        self.query = nn.Sequential(
            nn.ReLU(),
            nn.Linear(emb_dim * 4, emb_dim),
        )


        source_nodes, target_nodes = np.where(ehr_adj == 1)
        edge_index = torch.tensor([source_nodes, target_nodes], dtype=torch.long)

        # Load precomputed embeddings from GraphSAGE MODEL
        self.patient_embeddings = torch.tensor(np.load('/content/drive/MyDrive/Carmen-main/data/patient_embeddings_final.npy'), dtype=torch.float).to(device)
        self.drug_embeddings = torch.tensor(np.load('/content/drive/MyDrive/Carmen-main/data/drug_embeddings_final.npy'), dtype=torch.float).to(device)
        self.diagnosis_embeddings = torch.tensor(np.load('/content/drive/MyDrive/Carmen-main/data/diagnosis_embeddings.npy'), dtype=torch.float).to(device)
        self.ehr_gcn = GCN(voc_size=vocab_size[2], emb_dim=emb_dim, adj=ehr_adj, device=device)

        self.ddi_gcn = GCN(voc_size=vocab_size[2], emb_dim=emb_dim, adj=ddi_adj, device=device)
        self.inter = nn.Parameter(torch.FloatTensor(1))

        self.output = nn.Sequential(
        nn.ReLU(),
        nn.Linear(emb_dim  4 , emb_dim  2),  #############3
        nn.ReLU(),
        nn.Linear(emb_dim * 2, vocab_size[2])
        )

        #####Add a cross attention layer
        self.num_heads = 4
        self.cross_attn = MultiheadAttention(emb_dim, self.num_heads)

        self.init_weights()

    def forward(self, input):
        # input (adm, 3, codes)


        patient_embeddings = self.patient_embeddings
        drug_embeddings = self.drug_embeddings

        #print("drug_embeddings shape:", drug_embeddings.shape)
        #print("patient_embeddings shape:", patient_embeddings.shape)
        #print("ddi_embeddings shape:", ddi_embeddings.shape)

        patient_embeddings_avg = patient_embeddings.mean(dim=0, keepdim=True)
        drug_embeddings_avg = drug_embeddings.mean(dim=0, keepdim=True)

        #print("patient_embeddings-avg shape:", patient_embeddings_avg.shape)

        ddi_embedding =  self.ehr_gcn() - self.ddi_gcn() * self.inter  # (size, dim)
        ddi_embedding_avg = ddi_embedding.mean(dim=0, keepdim=True)


        patient_embeddings_avg = patient_embeddings_avg.unsqueeze(1)
        drug_embeddings_avg  = drug_embeddings_avg .unsqueeze(1)
        ddi_embedding_avg = ddi_embedding_avg.unsqueeze(1)

        #print("Shape of patient_embeddings_avg:", patient_embeddings_avg.shape)
        #print("Shape of drug_embeddings_avg:", drug_embeddings_avg.shape)
        #print("Shape of ddi_embedding:", ddi_embedding_avg .shape)

        #'''Cross Attention'''#
        cross_attn_output, cross_attn_weights = self.cross_attn(patient_embeddings_avg, drug_embeddings_avg, ddi_embedding_avg) # (1, 1, dim), (1, 1, 1)
        cross_attn_output = cross_attn_output.transpose(0, 1) # (1, dim)
        #print('cross_attn_output=' , cross_attn_output.shape)


        output = self.output(torch.cat([ patient_embeddings_avg, drug_embeddings_avg, ddi_embedding_avg, cross_attn_output], dim=-1)) # (1, dim)###############333
        #print("output shape:", output.shape)

        output = output.squeeze(0)
        #print('new_output=', output.shape)

        if self.training:
            neg_pred_prob = F.sigmoid(output)
            neg_pred_prob = neg_pred_prob.t() * neg_pred_prob  # (voc_size, voc_size)
            batch_neg = neg_pred_prob.mul(self.tensor_ddi_adj).mean()

            return output, batch_neg
        else:
            return output

    def init_weights(self):
        """Initialize weights."""
        initrange = 0.1
        for item in self.embeddings:
            item.weight.data.uniform_(-initrange, initrange)

        self.inter.data.uniform_(-initrange, initrange)
