In [34]:
import sqlite3
import pandas as pd
import numpy as np
import random


In [70]:
def create_mcq(df, source_column, target_column, node_type):
    disease_pairs = df[source_column].unique()
    disease_pairs = [(disease1, disease2) for disease1 in disease_pairs for disease2 in disease_pairs if disease1 != disease2]

    new_data = []

    #For each source pair, find a common target and 4 negative samples
    for disease1, disease2 in disease_pairs:
        common_gene = set(df[df[source_column] == disease1][target_column]).intersection(set(df[df[source_column] == disease2][target_column]))
        common_gene = list(common_gene)[0] if common_gene else None
        # Get 4 random negative samples
        negative_samples = df[(df[source_column] != disease1) & (df[source_column] != disease2)][target_column].sample(4).tolist()
        new_data.append(((disease1, disease2), common_gene, negative_samples))

    new_df = pd.DataFrame(new_data, columns=["disease_pair", "correct_node", "negative_samples"])
    new_df.dropna(subset = ["correct_node"], inplace=True)
    new_df.loc[:, "disease_1"] = new_df["disease_pair"].apply(lambda x: x[0])
    new_df.loc[:, "disease_2"] = new_df["disease_pair"].apply(lambda x: x[1])
    new_df.negative_samples = new_df.negative_samples.apply(lambda x:", ".join(x[0:4]))
    new_df.loc[:, "text"] = "Out of the given list, which " + node_type + " is associated with both " + new_df.disease_1 + " and " + new_df.disease_2 + ". Given list is: " + new_df.correct_node + ", " + new_df.negative_samples
    return new_df


In [2]:
DB_PATH = "../../../data/benchmark_datasets/disgenet/disgenet_2020.db"


In [3]:
conn = sqlite3.connect(DB_PATH)
c = conn.cursor()


In [4]:

table_name = "diseaseAttributes"

c.execute("SELECT * FROM {}".format(table_name))

rows = c.fetchall()
disease_df = pd.DataFrame(rows, columns=["diseaseNID", "diseaseId", "diseaseName", "type"])


In [5]:
table_name = "geneAttributes"
c.execute("SELECT * FROM {}".format(table_name))

rows = c.fetchall()
gene_df = pd.DataFrame(rows, columns=["geneNID", "geneId", "geneName", "geneDescription", "pLI", "DSI", "DPI"])


In [6]:
table_name = "variantAttributes"
c.execute("SELECT * FROM {}".format(table_name))

rows = c.fetchall()
variant_df = pd.DataFrame(rows, columns=["variantNID", "variantId", "s", "chromosome", "coord", "most_severe_consequence", "DSI", "DPI"])


In [23]:
table_name = "geneDiseaseNetwork"
c.execute("SELECT * FROM {}".format(table_name))

rows = c.fetchall()
disease_gene_df = pd.DataFrame(rows, columns=["NID", "diseaseNID", "geneNID", "source", "association", "associationType", "sentence", "pmid", "score", "EL", "EI", "year"])

# Selecting association with maximum score
disease_gene_df_selected  = disease_gene_df[disease_gene_df.score == 1]

disease_gene_df_selected_1 = pd.merge(disease_gene_df_selected, disease_df, on="diseaseNID")
disease_gene_df_selected_2 = pd.merge(disease_gene_df_selected_1, gene_df, on="geneNID")

disease_gene_df_selected_final = disease_gene_df_selected_2[["diseaseName", "geneName"]].drop_duplicates()

disease_gene_df_selected_final.loc[:, 'geneCount'] = disease_gene_df_selected_final.groupby('geneName')['geneName'].transform('count')
disease_gene_df_selected_final_more_gene_count = disease_gene_df_selected_final[disease_gene_df_selected_final.geneCount > 1]



In [82]:
disease_gene_mcq = create_mcq(disease_gene_df_selected_final_more_gene_count, "diseaseName", "geneName", "Gene")


In [84]:
disease_gene_mcq.to_csv("../../../data/benchmark_datasets/test_questions_two_hop_mcq_from_disgenet.csv", index=False, header=True)



In [63]:
# sem_df = pd.read_csv("../../../data/benchmark_datasets/semmeddb/compound_treats_disease_from_semmeddb.csv")


In [71]:
# sem_df_mcq = create_mcq(sem_df, "object", "subject", "Compound")


In [81]:
# sem_df_mcq.correct_node.unique()